Use Your Own Customized LSTM to Classify MNIST Handwritten Digits – TensorFlow Tutorial

By | June 26, 2020

We have created a customized lstm model (lstm.py) using tensorflow. Here is the tutorial:

Build Your Own LSTM Model Using TensorFlow: Steps to Create a Customized LSTM

In this tutorial, we will use this customized lstm model to train mnist set and classify handwritten digits. To understand mnist set, you can view:

Understand and Read TensorFlow MNIST Dataset for Beginners

Preliminary

We should import some libraries.

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import numpy as np
import random
import lstm

Here import lstm is used to load our customized lstm model.

Load mnist data

mnist = input_data.read_data_sets(os.getcwd() + "/MNIST-data/", one_hot=True)

We have saved our mnist data in MNIST-data folder.

Set some hyper parameters

We should set some hyper parameters first.

# learning rate
learning_rate = 1e-3

hidden_dim = 50
input_size = 28
time_step = 28 # 28 lstm cells
total_steps = 1000
category_num = 10
steps_per_validate = 15
steps_per_test = 15
batch_size = 64

Define mode input

We should create some tensorflow placehoder variables.

x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
x_shape = tf.reshape(x, [-1, time_step, input_size]) # batch_size * 28 * 28

batch_size_train = tf.placeholder(tf.int32, [])

Use customized lstm model to classify handwritten digits

Get lstm output

custom_lstm = lstm.LSTM(x_shape, emb_dim =input_size, hidden_dim = hidden_dim, sequence_length = time_step)
output = custom_lstm.gen_o
output_y = tf.reduce_mean(output, 1)

In this model, we will average all outputa of all lstm cells. output_y is the final output of lstm.

Then we can use outpu_y to train model and get the prediction.

w = tf.Variable(tf.truncated_normal([hidden_dim, category_num], -0.01, 0.01), dtype=tf.float32)
b = tf.Variable(tf.random_uniform([category_num], -0.01, 0.01), dtype=tf.float32)
y = tf.matmul(output_y, w) + b

# Loss
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y_label, logits=y)
train = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cross_entropy)

# Prediction
correction_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

Start to train model

We will train lstm model and print the train process. We will get the test accuracy based on the best validation accuracy as our final result.

init = tf.global_variables_initializer()
try:
    
    with tf.Session() as sess:
        sess.run(init)
        test_acc = 0.
        dev_acc = 0.
        better_acc = 0.0
        
        #set train times
        for step in range(total_steps + 1):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            _, acc = sess.run([train, accuracy] , feed_dict={x: batch_x, y_label: batch_y, keep_prob_train: keep_prob, batch_size_train: batch_size})
           
            # Train Accuracy
            print("train step="+str(step) +" accuracy = " + str(acc))
            # Validation Accuracy
            if step % steps_per_validate == 0:
                
                dev_x, dev_y = mnist.validation.images, mnist.validation.labels
                dev_acc = sess.run(accuracy,feed_dict = {x: dev_x, y_label: dev_y, keep_prob_train: 1.0, batch_size_train: dev_x.shape[0]})
                print("dev step="+str(step) +" accuracy = " + str(dev_acc))
                if better_acc < dev_acc:
                    test_x, test_y = mnist.test.images, mnist.test.labels
                    test_acc = sess.run(accuracy,feed_dict = {x: test_x, y_label: test_y, keep_prob_train: 1.0, batch_size_train: test_x.shape[0]})
                    print("test step="+str(step) +" accuracy = " + str(test_acc))
                    better_acc = dev_acc

except Exception as e:
    print(e)

Run this code, you will get a train process below:

Use Your Own Customized LSTM to Classify MNIST Handwritten Digits - TensorFlow Tutorial

Leave a Reply

Your email address will not be published. Required fields are marked *