Get LSTM Cell Weights and Regularize LSTM in TensorFlow – TensorFlow Tutorial

By | July 3, 2019

As to LSTM, it has three gates, they are:

lstm three gates

Input transform:

LSTM Input transform

To regularize LSTM, we should get gates weights in each lstm, they are Wxi, Whi, Wxf, Whf,Wxo ,Who , Wxcand Whc

How to Get these weights?

Step 1: Get all variables in LSTM

List All Trainable and Untrainable Variables in TensorFlow

List All Variables including Constant and Placeholder in TensorFlow

First, we use lstm in our model like this:

        with tf.name_scope('doc_word_encode'):
            outputs, state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=tf.nn.rnn_cell.LSTMCell(self.hidden_size, forget_bias=1.0), #self.hidden_size = 100
                cell_bw=tf.nn.rnn_cell.LSTMCell(self.hidden_size, forget_bias=1.0),
                inputs=inputs,
                sequence_length=word_in_sen_len,
                dtype=tf.float32,
                scope='doc_word'
            )
            outputs = tf.concat(outputs, 2) # [-1, 200]

We can find this LSTMCell in doc_word_encode.

We can use code below to check all trainable variables.

#-*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf
np.set_printoptions(threshold=np.inf)

model_dataset = 'imdb/1557460934'
checkpoint_file = "../checkpoints/"+model_dataset+"/model-4100" 

init = tf.global_variables_initializer() 
init_local = tf.local_variables_initializer()
with tf.Session() as sess:
    sess.run([init, init_local])
    #load graph
    saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
    saver.restore(sess, checkpoint_file)
        
    v = [n.name for n in tf.trainable_variables()]
    for vv in v:
        print vv

As to forward lstm in bilstm, we can get:

lstm cell weights

From the result, we can find, there only be a kernel in lstm, it is not eight gate weights, which means only a kernel represents eight gate weights.

Step 2: Print the shape of lstm kernel.

Get Tensor Variable by Tensor Name

we also can use code below:

v = [n for n in tf.trainable_variables()]
    for vv in v:
        if 'doc_word/fw/lstm_cell/kernel:0' in vv.name:
            print vv

We get this kernel as:

<tf.Variable 'doc_word/fw/lstm_cell/kernel:0' shape=(300, 400) dtype=float32_ref>

Step 3: Why the shape of lstm kernel is (300,400)?

As forward lstm in our bilstm, the demenison of xt = 200 and self.hidden_size = 100

It means:

Wxi  is 200*?, Whi is 100*? , Wxf  is 200*?, Whf  is 100*?,Wxo is 200*?, Who is 100*?, Wxc is 200*?, Whc is 100*?

Check tensorflow source code.

We will find:

1. All weights is named as kernel.

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

2.The kernel in BasicLSTMCell(LayerRNNCell) is built by [input_depth + h_depth, 4 * self._num_units]

    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + h_depth, 4 * self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[4 * self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))

 

In our model:

input_depth = 200, self._num_units = 100

so the shape of kernel is (300, 400)

How to get i, j, f, o?

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(
        value=gate_inputs, num_or_size_splits=4, axis=one)

In lstm, it calculates i, j, f, o by contacting xt and ht-1 to generate a (300, ?) demension variable, then set the second demension is self._num_units.

Leave a Reply

Your email address will not be published. Required fields are marked *