Implement Attention Visualization with Python – TensorFlow Tutorial

By | June 27, 2019

Attention mechanism has been widely used in deep learning, such as data mining, sentiment analysis and machine translation.  No matter what strategy of attention, you must implement a attention visualization to compare in different models.

In this tutorial, we will tell you how to implement attention visualization using python.

Step 1: Install seaborn

pip install seaborn

Step 2: Implement attention visualization

If you have two models, each of them gets a attention value on the same sentence.

For example:

As to sentence: shit, this food is very disappointment.

Attention value of Model A is:0.3276, 0.0003, 0.0009, 0.0000, 0.0010, 0.0192, 0.6497, 0.0013

Attention value of Model B is: 0.0184, 0.0000, 0.0005, 0.0000, 0.0000, 0.0000, 0.9810, 0.0000

To display difference between them with a graph, you can use example code below:

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
sns.set()
data_word = ['shit',',','this','food','is','very','disappointment','.']
data_att = [[0.3276,0.0003,0.0009,0.0000,0.0010,0.0192,0.6497,0.0013],
            [0.0184,0.0000,0.0005,0.0000,0.0000,0.0000,0.9810,0.0000]
            ] 

d = pd.DataFrame(data = data_att,index = data_index, columns=data_word)

f, ax = plt.subplots(figsize=(6,2))
sns.heatmap(d, vmin=0, vmax=1.0, ax=ax, cmap="OrRd")

label_y = ax.get_yticklabels()
plt.setp(label_y, rotation=360, horizontalalignment='right')
label_x = ax.get_xticklabels()
plt.setp(label_x, rotation=45, horizontalalignment='right')
plt.show()

The result looks like this:

sentence attention visualization

Leave a Reply

Your email address will not be published. Required fields are marked *