# Get Forward and Backward LSTM Output of Bi-LSTM in TensorFlow – TensorFlow Tutorial

By | November 10, 2021

We usually use Bi-LSTM or Bi-GRU to model sequence. How to get its output? In this tutorial, we will discuss this topic.

## Bi-LSTM or Bi-GRU

A Bi-LSTM or Bi-GRU looks like below:

As to Bi-LSTM, it contains a forward LSTM and backword LSTM.

## How to create Bi-LSTM in TensorFlow?

In tensroflow, we can use tf.nn.bidirectional_dynamic_rnn() to create a Bi-LSTM network.

Here is an example:

import tensorflow as tf
import numpy as np
#5*4*10
inputs = tf.Variable(tf.truncated_normal([3, 5, 100], stddev=0.1), name="inputs")

hidden_size = 10
outputs, state = tf.nn.bidirectional_dynamic_rnn(
cell_fw=tf.nn.rnn_cell.LSTMCell(hidden_size, forget_bias=1.0),
cell_bw=tf.nn.rnn_cell.LSTMCell(hidden_size, forget_bias=1.0),
inputs=inputs,
sequence_length= None,
dtype=tf.float32,
scope='bilstm_doc_word'
)

You should notice the inputs tensor of Bi-LSTM is [batch_size, max_seq_length, dim]. In this example, it is [3, 5, 100].

## How to get the output of Bi-LSTM?

There are three types of outputs in Bi-LSTM.

Type 1: concatenating the output of forward and backward LSTM.

outputs_merge = tf.concat(outputs, 2)

Then you will get a output with shape [batch_size, max_seq_length, 2*hidden_dim], here hidden_dim is the hidden state size in lstm.

As to example above, you will find outputs_merge is [3, 5, 2*10]

Type 2: Get the mean output of the forward and backward LSTM.

You can find how to do in this tutorial:

Average the Output of RNN/GRU/LSTM/BiLSTM for Variable Length Sequences – Deep Learning Tutorial

Type 3: Get the last hidden output of the forward and backward LSTM.

Here is an example:

last_forward = outputs[0][:,-1,:] # 3*5*10
last_backward = outputs[1][:,0,:]
init = tf.global_variables_initializer()
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
sess.run([init, init_local])
np.set_printoptions(precision=4, suppress=True)
f =sess.run(last_forward)
b = sess.run(last_backward)
m = sess.run(outputs_merge)
print("last forward shape=", f.shape)
print(f)
print("last backward shape=", b.shape)
print(b)
print("concatenate forward and backward")
print(m)

Run this code, you will see: