Understand __call__() and call() Inheriting From RNNCell in TensorFlow – TensorFlow Tutorial

By | March 30, 2022

In tensorflow, we can create a custom RNN by inheriting tf.nn.rnn_cell.RNNCell or tensorflow.contrib.rnn.RNNCell. However, you may find some rnn classes have implemented __call__(), some implemented call(). Why? In this tutorial, we will tell you the reason.

To understand __call__() in python, you can read:

Python __call__(): Call Function with Dynamic Parameters – Python Tutorial

__call__() and call() in TensorFlow RNNCell

To create a custom RNN, we may do as follows:

from tensorflow.contrib.rnn import RNNCell

class DecoderPrenetWrapper(RNNCell):

Here DecoderPrenetWrapper is the child class of RNNCell.

You can create your own __call__() or call().

Because:

DecoderPrenetWrapper<-tensorflow.contrib.rnn.RNNCell<-base_layer.Layer

In base_layer.Layer, a call() function is defined and called by its __call__().

For example:

  @doc_controls.for_subclass_implementers
  def call(self, inputs, **kwargs):  # pylint: disable=unused-argument
    """This is where the layer's logic lives.

    Arguments:
        inputs: Input tensor, or list/tuple of input tensors.
        **kwargs: Additional keyword arguments.

    Returns:
        A tensor or list/tuple of tensors.
    """
    return inputs

Understand __call__() and call() Inheriting From RNNCell in TensorFlow - TensorFlow Tutorial

It means we can overwrite call() function to create a custom RNN.

Leave a Reply

Your email address will not be published. Required fields are marked *