In order to process variable length sequence in tensorflow, we can use dynamic RNNs. There are two basic dynamic RNNs in tensorflow. They are:

- tf.nn.bidirectional_dynamic_rnn
- tf.nn.dynamic_rnn

We have introduced how to use tf.nn.dynamic_rnn() in tensorflow, here is the tutorial.

Understand tf.nn.dynamic_rnn() for TensorFlow Beginners – TensorFLow Tutorial

When we use tf.nn.bidirectional_dynamic_rnn() to build a BiLSTM, it will call tf.nn.dynamic_rnn(). In this tutorial, we will introduce how the tf.nn.bidirectional_dynamic_rnn() process variable length sequence, which is very useful to help you understand this function and build your custom model.

**1. tf.nn.bidirectional_dynamic_rnn()**

tf.nn.bidirectional_dynamic_rnn() is defined as:

@tf_export("nn.bidirectional_dynamic_rnn") def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, initial_state_fw=None, initial_state_bw=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None):

The source code is here:

https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/python/ops/rnn.py

In order to make tf.nn.bidirectional_dynamic_rnn() to process variable length sequence, we should pass a sequence_length, which is a list like [1, 2, 3, 4]. You should notice:

len(sequence_length) = batch_size

**2. Run cell_fw**

tf.nn.bidirectional_dynamic_rnn() will use tf.nn.dynamic_rnn() to run cell_fw and get output_fw and output_state_fw, here is the source code.

with vs.variable_scope("fw") as fw_scope: output_fw, output_state_fw = dynamic_rnn( cell=cell_fw, inputs=inputs, sequence_length=sequence_length, initial_state=initial_state_fw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=fw_scope)

cell_fw is simple, however, we should notice the cell_bw.

**3. Run cell_bw**

Before we run cell_bw, we should reverse the sequence by sequence_length. tf.nn.bidirectional_dynamic_rnn() will use tf.reverse_sequence() to reverse the sequence.

To understand how to use tf.reverse_sequence(), you can read this tutorial.

Understand TensorFlow tf.reverse_sequence(): Reverse a Tensor by Length

Here is the source code:

def _reverse(input_, seq_lengths, seq_dim, batch_dim): if seq_lengths is not None: return array_ops.reverse_sequence( input=input_, seq_lengths=seq_lengths, seq_dim=seq_dim, batch_dim=batch_dim) else: return array_ops.reverse(input_, axis=[seq_dim]) with vs.variable_scope("bw") as bw_scope: inputs_reverse = _reverse( inputs, seq_lengths=sequence_length, seq_dim=time_dim, batch_dim=batch_dim)

After having reversed the sequence, cell_bw is run.

tmp, output_state_bw = dynamic_rnn( cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, initial_state=initial_state_bw, dtype=dtype, parallel_iterations=parallel_iterations, swap_memory=swap_memory, time_major=time_major, scope=bw_scope)

**4. Understand dynamic_rnn() function**

dynamic_rnn() function is defined as:

def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None):

It will receive a sequence_length to process variable length sequence.

dynamic_rnn() function will call _dynamic_rnn_loop().

Here is the source code:

(outputs, final_state) = _dynamic_rnn_loop( cell, inputs, state, parallel_iterations=parallel_iterations, swap_memory=swap_memory, sequence_length=sequence_length, dtype=dtype)

**5. Understand _dynamic_rnn_loop() function**

_dynamic_rnn_loop() is defined as:

def _dynamic_rnn_loop(cell, inputs, initial_state, parallel_iterations, swap_memory, sequence_length=None, dtype=None):

It also receive a rnn cell, inputs and sequence_length to process variable length sequence.

In this function, tf.while_loop() will be run. It will process each time step.

_, output_final_ta, final_state = control_flow_ops.while_loop( cond=lambda time, *_: time < loop_bound, body=_time_step, loop_vars=(time, output_ta, state), parallel_iterations=parallel_iterations, maximum_iterations=time_steps, swap_memory=swap_memory)

We should notice _time_step() function will be called in this loop.

**6. Understand _time_step()**

_time_step() function is defined as:

def _time_step(time, output_ta_t, state):

In this function, sequence_length is not none, _rnn_step() will be called.

Here is the source code:

if sequence_length is not None: (output, new_state) = _rnn_step( time=time, sequence_length=sequence_length, min_sequence_length=min_sequence_length, max_sequence_length=max_sequence_length, zero_output=zero_output, state=state, call_cell=call_cell, state_size=state_size, skip_conditionals=True) else: (output, new_state) = call_cell()

We should notice: skip_conditionals=True

**7. Understand _rnn_step()**

_rnn_step() is defined as:

def _rnn_step( time, sequence_length, min_sequence_length, max_sequence_length, zero_output, state, call_cell, state_size, skip_conditionals=False):

This function will process each time step by call_cell. We should notice skip_conditionals=True in tf.nn.bidirectional_dynamic_rnn().

In this function, because skip_conditionals=True, _copy_some_through() will be called.

Here is the source code:

if skip_conditionals: # Instead of using conditionals, perform the selective copy at all time # steps. This is faster when max_seq_len is equal to the number of unrolls # (which is typical for dynamic_rnn). new_output, new_state = call_cell() nest.assert_same_structure(state, new_state) new_state = nest.flatten(new_state) new_output = nest.flatten(new_output) final_output_and_state = _copy_some_through(new_output, new_state) else: empty_update = lambda: flat_zero_output + flat_state final_output_and_state = control_flow_ops.cond( # if t >= max_seq_len: copy all state through, output zeros time >= max_sequence_length, empty_update, # otherwise calculation is required: copy some or all of it through _maybe_copy_some_through)

From this code

new_output, new_state = call_cell()

we can find: each time step will be run, no matter this time step is a valid (time < max_sequence_length) or invalid (time >= max_sequence_length).

**8. Understand _copy_some_through()**

Look at key code below:

copy_cond = time >= sequence_length def _copy_one_through(output, new_output): # TensorArray and scalar get passed through. if isinstance(output, tensor_array_ops.TensorArray): return new_output if output.shape.ndims == 0: return new_output # Otherwise propagate the old or the new value. with ops.colocate_with(new_output): return array_ops.where(copy_cond, output, new_output) def _copy_some_through(flat_new_output, flat_new_state): # Use broadcasting select to determine which values should get # the previous state & zero output, and which values should get # a calculated state & output. flat_new_output = [ _copy_one_through(zero_output, new_output) for zero_output, new_output in zip(flat_zero_output, flat_new_output)] flat_new_state = [ _copy_one_through(state, new_state) for state, new_state in zip(flat_state, flat_new_state)] return flat_new_output + flat_new_state

_copy_some_through() will call _copy_one_through().

From these two function we can find:

**(1) If a time step is a valid step, which means time < max_sequence_length.**

new_output will be the output calculated by call_cell()

new_state is also the state calculated by call_cell()

**(2) If a time step is invalid, which means time >= max_sequence_length.**

new_output will be zero ouput, which means it will be zero tensor with [batch_size, num_units]

new_state will be the state of the last valid time step.

Here the is an example.

In this example, there are 5 time step, however, max lenght is 2, which means the last three time steps are invalid.

x_{2}, x_{3} and x_{4} are passed into rnn. h_{2}, h_{3} and h_{4} will be calculated.

However, x_{2}, x_{3} and x_{4} are invalid time steps, the h_{2}, h_{3} and h_{4} will be set to zero tensor. Meanwhile, the c_{2}, c_{3} and c_{4} will be set to c_{1}.

**As to backword cell cell_bw, we must reverse the output by sequence length.**

output_bw = _reverse( tmp, seq_lengths=sequence_length, seq_dim=time_dim, batch_dim=batch_dim)