An Introduction to How TensorFlow Bidirectional Dynamic RNN Process Variable Length Sequence – LSTM Tutorial

By | July 22, 2020

In order to process variable length sequence in tensorflow, we can use dynamic RNNs. There are two basic dynamic RNNs in tensorflow. They are:

  • tf.nn.bidirectional_dynamic_rnn
  • tf.nn.dynamic_rnn

We have introduced how to use tf.nn.dynamic_rnn() in tensorflow, here is the tutorial.

Understand tf.nn.dynamic_rnn() for TensorFlow Beginners – TensorFLow Tutorial

When we use tf.nn.bidirectional_dynamic_rnn() to build a BiLSTM, it will call tf.nn.dynamic_rnn(). In this tutorial, we will introduce how the tf.nn.bidirectional_dynamic_rnn() process variable length sequence, which is very useful to help you understand this function and build your custom model.

1. tf.nn.bidirectional_dynamic_rnn()

tf.nn.bidirectional_dynamic_rnn() is defined as:

def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None,
                              initial_state_fw=None, initial_state_bw=None,
                              dtype=None, parallel_iterations=None,
                              swap_memory=False, time_major=False, scope=None):

The source code is here:

In order to make tf.nn.bidirectional_dynamic_rnn() to process variable length sequence, we should pass a sequence_length, which is a list like [1, 2, 3, 4]. You should notice:

len(sequence_length) = batch_size

2. Run cell_fw

tf.nn.bidirectional_dynamic_rnn() will use tf.nn.dynamic_rnn() to run cell_fw and get output_fw and output_state_fw, here is the source code.

    with vs.variable_scope("fw") as fw_scope:
      output_fw, output_state_fw = dynamic_rnn(
          cell=cell_fw, inputs=inputs, sequence_length=sequence_length,
          initial_state=initial_state_fw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=fw_scope)

cell_fw is simple, however, we should notice the cell_bw.

3. Run cell_bw

Before we run cell_bw, we should reverse the sequence by sequence_length. tf.nn.bidirectional_dynamic_rnn() will use tf.reverse_sequence() to reverse the sequence.

To understand how to use tf.reverse_sequence(), you can read this tutorial.

Understand TensorFlow tf.reverse_sequence(): Reverse a Tensor by Length

Here is the source code:

    def _reverse(input_, seq_lengths, seq_dim, batch_dim):
      if seq_lengths is not None:
        return array_ops.reverse_sequence(
            input=input_, seq_lengths=seq_lengths,
            seq_dim=seq_dim, batch_dim=batch_dim)
        return array_ops.reverse(input_, axis=[seq_dim])

    with vs.variable_scope("bw") as bw_scope:
      inputs_reverse = _reverse(
          inputs, seq_lengths=sequence_length,
          seq_dim=time_dim, batch_dim=batch_dim)

After having reversed the sequence, cell_bw is run.

    tmp, output_state_bw = dynamic_rnn(
          cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length,
          initial_state=initial_state_bw, dtype=dtype,
          parallel_iterations=parallel_iterations, swap_memory=swap_memory,
          time_major=time_major, scope=bw_scope)

4. Understand dynamic_rnn() function

dynamic_rnn() function is defined as:

def dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None,
                dtype=None, parallel_iterations=None, swap_memory=False,
                time_major=False, scope=None):

It will receive a sequence_length to process variable length sequence.

dynamic_rnn() function will call _dynamic_rnn_loop().

Here is the source code:

(outputs, final_state) = _dynamic_rnn_loop(

5. Understand _dynamic_rnn_loop() function

_dynamic_rnn_loop() is defined as:

def _dynamic_rnn_loop(cell,

It also receive a rnn cell, inputs and sequence_length to process variable length sequence.

In this function, tf.while_loop() will be run. It will process each time step.

 _, output_final_ta, final_state = control_flow_ops.while_loop(
      cond=lambda time, *_: time < loop_bound,
      loop_vars=(time, output_ta, state),

We should notice _time_step() function will be called in this loop.

6. Understand _time_step()

_time_step() function is defined as:

def _time_step(time, output_ta_t, state):

In this function, sequence_length is not none,  _rnn_step() will be called.

Here is the source code:

if sequence_length is not None:
      (output, new_state) = _rnn_step(
      (output, new_state) = call_cell()

We should notice: skip_conditionals=True

7. Understand _rnn_step()

_rnn_step() is defined as:

def _rnn_step(
    time, sequence_length, min_sequence_length, max_sequence_length,
    zero_output, state, call_cell, state_size, skip_conditionals=False):

This function will process each time step by call_cell. We should notice skip_conditionals=True in tf.nn.bidirectional_dynamic_rnn().

In this function, because skip_conditionals=True, _copy_some_through() will be called.

Here is the source code:

if skip_conditionals:
    # Instead of using conditionals, perform the selective copy at all time
    # steps.  This is faster when max_seq_len is equal to the number of unrolls
    # (which is typical for dynamic_rnn).
    new_output, new_state = call_cell()
    nest.assert_same_structure(state, new_state)
    new_state = nest.flatten(new_state)
    new_output = nest.flatten(new_output)
    final_output_and_state = _copy_some_through(new_output, new_state)
    empty_update = lambda: flat_zero_output + flat_state
    final_output_and_state = control_flow_ops.cond(
        # if t >= max_seq_len: copy all state through, output zeros
        time >= max_sequence_length, empty_update,
        # otherwise calculation is required: copy some or all of it through

From this code

new_output, new_state = call_cell()

we can find: each time step will be run, no matter this time step is a valid (time < max_sequence_length) or invalid (time >= max_sequence_length).

8. Understand _copy_some_through()

Look at key code below:

copy_cond = time >= sequence_length
def _copy_one_through(output, new_output):
    # TensorArray and scalar get passed through.
    if isinstance(output, tensor_array_ops.TensorArray):
      return new_output
    if output.shape.ndims == 0:
      return new_output
    # Otherwise propagate the old or the new value.
    with ops.colocate_with(new_output):
      return array_ops.where(copy_cond, output, new_output)

  def _copy_some_through(flat_new_output, flat_new_state):
    # Use broadcasting select to determine which values should get
    # the previous state & zero output, and which values should get
    # a calculated state & output.
    flat_new_output = [
        _copy_one_through(zero_output, new_output)
        for zero_output, new_output in zip(flat_zero_output, flat_new_output)]
    flat_new_state = [
        _copy_one_through(state, new_state)
        for state, new_state in zip(flat_state, flat_new_state)]
    return flat_new_output + flat_new_state

_copy_some_through() will call _copy_one_through().

From these two function we can find:

(1) If a time step is a valid step, which means timemax_sequence_length.

new_output will be the output calculated by call_cell()

new_state is also the state calculated by call_cell()

(2) If a time step is invalid, which means time >= max_sequence_length.

new_output will be zero ouput, which means it will be zero tensor with [batch_size,  num_units]

new_state will be the state of the last valid time step.

Here the is an example.

the output of tensorflow dynamic rnn process variable length sequence

In this example, there are 5 time step, however, max lenght is 2, which means the last three time steps are invalid.

x2, x3 and x4 are passed into rnn. h2, h3 and h4 will be calculated.

However, x2, x3 and x4 are invalid time steps, the h2, h3 and h4 will be set to zero tensor. Meanwhile, the c2, c3 and c4 will be set to c1.

As to backword cell cell_bw, we must reverse the output by sequence length.

    output_bw = _reverse(
      tmp, seq_lengths=sequence_length,
      seq_dim=time_dim, batch_dim=batch_dim)

Leave a Reply

Your email address will not be published. Required fields are marked *