# Use Custom GRU Network to Classify MNIST Handwritten Digits – Deep Learning Tutorial

By | July 20, 2020

In this tutorial, we will use our custom GRU network to classify MNIST handwritten digits, which aims to evaluate the effectiveness of our custom GRU.

To understand how to build a custom GRU, you can read this tutorial.

Build a Custom GRU Network Using TensorFlow: A Step Guide – TensorFlow Tutorial

## Import libraries

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
import random
import gru

You should notice gru model contains our custom GRU network.

mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot=True)

## Set some parameters

# Variable
learning_rate = 1e-3
batch_size = 64
hidden_dim = 50
input_size = 28
time_step = 28 # 28 gru units
total_steps = 4000
category_num = 10
steps_per_validate = 15

In this experiment, we will set hidden dimension of GRU is 50, batch size is 64, learning rate is 0.001

## Set model inputs

# Initial Inputs
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
# batch_size * 28 * 28
x_shape = tf.reshape(x, [-1, time_step, input_size])

x_shape will be inputed into model, it is batch_size * time_step * input_size, for example: 64 * 28 * 28

## Get the ouput of GRU network

custom_gru = gru.GRU(inputs = x_shape, emb_dim =input_size, hidden_dim = hidden_dim)
output = custom_gru.output() # batch_size x 28 * 50
#average output
output_y = tf.reduce_mean(output, 1)

## Get the output of our model

# Output Layer
w = tf.Variable(tf.truncated_normal([hidden_dim, category_num], -0.01, 0.01), dtype=tf.float32)
b = tf.Variable(tf.random_uniform([category_num], -0.01, 0.01), dtype=tf.float32)
y = tf.matmul(output_y, w) + b

## Calculate the loss and prediction

# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

## Start to train

#init
init = tf.global_variables_initializer()

try:
with tf.Session() as sess:
sess.run(init)

test_acc = 0.
dev_acc = 0.
better_acc = 0.0

#set train times
for step in range(total_steps + 1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
_, acc = sess.run([train, accuracy] , feed_dict={x: batch_x, y_label: batch_y})
print("train step="+str(step) +" accuracy = " + str(acc))

# Test Accuracy
if step % steps_per_validate == 0:

dev_x, dev_y = mnist.validation.images, mnist.validation.labels
dev_acc = sess.run(accuracy,feed_dict = {x: dev_x, y_label: dev_y})

print("dev step="+str(step) +" accuracy = " + str(dev_acc))

if better_acc < dev_acc:
test_x, test_y = mnist.test.images, mnist.test.labels
test_acc = sess.run(accuracy,feed_dict = {x: test_x, y_label: test_y})
print("test step="+str(step) +" accuracy = " + str(test_acc))
better_acc = dev_acc

except Exception as e:
print(e)

In this experiment, we will run 4,000 batch_size and get the result.

Run the code above, you will get the resutl like this: