Build a Custom GRU Network Using TensorFlow: A Step Guide – TensorFlow Tutorial

By | July 20, 2020

We have created our custom LSTM network using tensorflow. Here is the tutorial:

Build Your Own LSTM Model Using TensorFlow: Steps to Create a Customized LSTM

As to GRU network, there are some differences comparing with LSTM.

Understand GRU (Gated Recurrent Unit): Difference Between GRU and LSTM

In this tutorial, we will introduce how to build our custom GRU network using tensorflow, which is very similar to create a custom lstm network.

Formula of GRU

As to GRU network, we can create it based on these formulas:

The formula of GRU

Build a GRU network using TensorFlow

We can use tf.nn.rnn_cell.GRUCell() to create a GRU network, however, we will create our own GRU cell using tensorflow in this tutorial.

1. Create a python file called

We will create our custom gru in

2. Import libraries

import tensorflow as tf
import numpy as np

3. The full code of custom gru

We will show you the full code of gru cell.

class GRU():
    inputs: batch_size * time_step * dim
    sequence_length: is a list
    def __init__(self,inputs, emb_dim, hidden_dim, sequence_length = None, revers = False):
        self.emb_dim = emb_dim
        self.hidden_dim = hidden_dim
        self.time_step = tf.shape(inputs)[1]        
        self.batch_size = tf.shape(inputs)[0]
        if revers:
            if sequence_length is not None:
                inputs = tf.reverse_sequence(inputs, seq_lengths=sequence_length, seq_axis = 1, batch_axis = 0)
                inputs = tf.reverse(inputs, axis = [1])
        self.inputs = tf.transpose(inputs, perm=[1, 0, 2])
        with tf.variable_scope('gru_init'):
            self.g_recurrent_unit = self.create_recurrent_unit()  # maps h_tm1 to h_t for generator
            self.g_output_unit = self.create_output_unit()  # maps h_t to o_t (output token logits)
        # Initial states
        self.h0 = tf.zeros([self.batch_size, self.hidden_dim])
        gen_o = tf.TensorArray(dtype=tf.float32, size=self.time_step ,
                                             dynamic_size=False, infer_shape=True)
        def _g_recurrence(i, x_t, h_tm1, gen_o):
            h_t = self.g_recurrent_unit(x_t, h_tm1)  
            o_t = self.g_output_unit(h_t)  
            gen_o = gen_o.write(i, o_t)#
            i_next = tf.where(tf.less(i, self.time_step-1), i+1, self.time_step-1)
            x_t_next = tf.nn.embedding_lookup(self.inputs,i_next) #batch x emb_dim
            return i+1, x_t_next, h_t, gen_o
        i_l, _, h_l_t, self.gen_o = tf.while_loop(
            cond=lambda i, _1, _2, _3: i < self.time_step, #
            loop_vars=(tf.constant(0, dtype=tf.int32), #
                       self.h0, gen_o))

        self.gen_o = self.gen_o.stack()  # seq_length x batch_size
        self.outputs = tf.transpose(self.gen_o, perm=[1, 0, 2])  #
    def init_matrix(self, shape):
        return tf.random_normal(shape, stddev=0.1)
        #return tf.random_uniform(shape, -0.01, 0.01)

    def init_vector(self, shape):
        return tf.zeros(shape)
        #return tf.random_uniform(shape, -0.01, 0.01)
    def create_recurrent_unit(self):
        # Weights and Bias for input and hidden tensor
        # reset gate
        self.Wi = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]), name = 'input_gate_wi')
        self.Ui = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]), name = 'input_gate_ui') = tf.Variable(self.init_vector([self.hidden_dim]), name = 'input_gate_bias')
        self.Wf = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]), name = 'forget_gate_wf')
        self.Uf = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]), name = 'forget_gate_wf') = tf.Variable(self.init_vector([self.hidden_dim]), name = 'forget_gate_bias')
        self.Wc = tf.Variable(self.init_matrix([self.emb_dim, self.hidden_dim]), name = 'control_gate_wc')
        self.Uc = tf.Variable(self.init_matrix([self.hidden_dim, self.hidden_dim]), name = 'control_gate_uc')
        self.bc = tf.Variable(self.init_vector([self.hidden_dim]), name = 'control_gate_bias')
        def unit(x, previous_hidden_state):
            # Reset Gate
            r = tf.sigmoid(
                tf.matmul(x, self.Wi) +
                tf.matmul(previous_hidden_state, self.Ui) +

            # Forget Gate
            z = tf.sigmoid(
                tf.matmul(x, self.Wf) +
                tf.matmul(previous_hidden_state, self.Uf) +

            # New Memory Cell
            n = tf.nn.tanh(
                tf.matmul(x, self.Wc) +
                tf.matmul(r * previous_hidden_state, self.Uc) + self.bc

            # Current Hidden state
            current_hidden_state = (1 - z) * n + z * previous_hidden_state
            return current_hidden_state
        return unit
    def create_output_unit(self):#, params):

        def unit(hidden_state):
            return hidden_state

        return unit
    def output(self):
        return self.outputs

3.1 create_recurrent_unit()

We will create weights and biases of GRU in this function and create forget and rest gate.

3.2 output()

This function will return the output of GRU, you should notice the shape of output is batch_size * time_step * dim.

3.3 inputs

GRU will receive an inputs with batch_size * time_step * dim.

3.4 Remove reset gate

If you plan to remove reset gate:

            r = tf.sigmoid(
                tf.matmul(x, self.Wi) +
                tf.matmul(previous_hidden_state, self.Ui) +

you can modify create_recurrent_unit() function.

3.5 Create a BiGRU

If you plan to create a BiGRU like BiLSTM, you should set sequence_length and revers.

Here is an example:

Build a Custom BiLSTM Model Using TensorFlow: A Step Guide

In order to use this custom GRU, you can do like this:

import gru
custom_gru = gru.GRU(inputs = x_shape, emb_dim =input_size, hidden_dim = hidden_dim)
output = custom_gru.output() # batch_size x 28 * 50
#average output
output_y = tf.reduce_mean(output, 1)

Leave a Reply

Your email address will not be published. Required fields are marked *