Add Tensor Update Operation to tf.GraphKeys.UPDATE_OPS – TensorFlow Tutorial

By | August 19, 2021

In tensorflow, we can add some update tensor operations to tf.GraphKeys.UPDATE_OPS to manage these update operations. In this tutorial, we will introduce you how to do.

Look at example below:

import tensorflow as tf

x = tf.get_variable('x', [5, 10], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)

y = tf.get_variable('y', [5, 10], dtype=tf.float32, initializer=tf.constant_initializer(0), trainable=False)
diff = x - y
xy = tf.convert_to_tensor(np.array([0,1,2,3,4]))
z = tf.compat.v1.scatter_sub(x, xy, diff)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, z)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

Run this code, we will get this result:

[<tf.Tensor 'ScatterSub:0' shape=(5, 10) dtype=float32_ref>]

From the result, we can find:

We can use tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, z) to add a update operation to tf.GraphKeys.UPDATE_OPS.

We can use tf.get_collection(tf.GraphKeys.UPDATE_OPS) to list all update operations in tf.GraphKeys.UPDATE_OPS.


tf.add_to_collection() is defined as:


As to center loss, in order to update center weights, we may do as follows:

optimizer = tf.train.AdamOptimizer(0.001)

with tf.control_dependencies([self.centers_update_op]):
    train_op = optimizer.minimize(self.loss, global_step=global_step)

Implement Center Loss Function for Text Classification in TensorFlow – TensorFlow Tutorial

However, we also can implement it like this:

centers_update_op = tf.scatter_sub(centers, label, diff)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, centers_update_op)

global_step = tf.Variable(0, name="global_step", trainable=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optimizer = tf.train.AdamOptimizer(
    grads_and_vars = optimizer.compute_gradients(mix_model.loss)
    train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

Leave a Reply

Your email address will not be published. Required fields are marked *