Completed Guide to Export PyTorch Models to ONNX – PyTorch Tutorial

By | October 8, 2023

In this tutorial, we will introduce a completed guide to export pytorch models to onnx. You can learn how to do by following steps.


We usually use torch.onnx.export() to export pytorch models to onnx, it is defined as:

torch.onnx.export(model, args, f, export_params=True, verbose=False, training=<TrainingMode.EVAL: 0>, input_names=None, output_names=None, operator_export_type=<OperatorExportTypes.ONNX: 0>, opset_version=None, do_constant_folding=True, dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None, export_modules_as_functions=False)

Here are some important parameters we must concern.

model: It should be a torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction object. We will export this model to onnx.

args: It can be a tuple or tensor. However, we recommend you to use tuple.

This parameter determines parameters in the model forward() function, we should notice the order must be the same in the model forward().

For example:

If we have created a model as follows:

torch model forward parameters in onnx

In this vits_encoder model forward(), the shape of x, x_lengths, length_scale, sentiment_prob can be:

x = [batch_size, time_step]
x_lengths = [batch_size]
length_scale = [1,]
sentiment_prob = [batch_size, 10]

Here we can set args as follows:

x = stn_tst.unsqueeze(0)
x_lengths = torch.LongTensor([stn_tst.size(0)])
length_scale = torch.Tensor([1.0])
sentiment_prob = torch.FloatTensor([[0.0, 1.0, 0.0]])

args = (x, x_lengths, length_scale, sentiment_prob)

Here we suppose batch_size = 1. Meanwhile, we should notice the value of time_step is not the same in different batch.

However, if only one parameter in forward(), we can do as follows:

class SumModule(torch.nn.Module):
    def forward(self, x):
        return torch.sum(x, dim=1)

args =(torch.ones(2, 2),)

f: The exported onnx file name.

export_params = True

verbose = True: Prints a description of the model being exported to stdout

input_names: A list of string, names to assign to the input nodes of the graph, in order. You can set it based on args.

For example:

input_names = ["x", "x_lengths", "length_scale", "sentiment_prob"]

output_names: A list of string, names to assign to the output nodes of the graph, in order. We should set it based on model forward() return.

For example:

torch model forward return

We can set

output_names = ["z_p","y_mask"]

opset_version: The version of the default (ai.onnx) opset to target. Must be >= 7 and <= 16. You can set it to 13.

training: It can be TrainingMode.EVAL (default), TrainingMode.TRAINING.

TrainingMode.EVAL: export the model in inference mode.

TrainingMode.TRAINING: export the model in training mode. Disables optimizations which might interfere with training.

keep_initializers_as_inputs: It can be True. If True, all the initializers (typically corresponding to parameters) in the exported graph will also be added as inputs to the graph. If opset_version < 9, initializers must be part of graph inputs and this argument will be ignored and the behavior will be equivalent to setting this argument to True.

dynamic_axes: This is very important parameter, it will tell onnx which axis of input parameters is dynamic.

By default the exported model will have the shapes of all input and output tensors set to exactly match those given in args.

As example above

x = [batch_size, time_step]
x_lengths = [batch_size]
length_scale = [1,]
sentiment_prob = [batch_size, 10]

We can fix the batch_size = 1 when infering, which means we infer only one sample every time.

However, the time_step is not the same in different samples, which may determine there are dynamic shape in forward() output.

In order to set dynamic_axes correctly, we can see this example.

torch.onnx.export() dynamic_axes example

dynamic_axes is a python dictionary, it contains all input and output with dynamic axis.

As example above, as to input, sequences in input_names has a dynamic axis, the shape of it is [batch_size, time_step]. We have fixed batch_size = 1, its shape is [1, time_step]. We can find axis = 1 is dynamic, and we set dynamic_axes = {“sequences”:{1: “time_step”}}

As to output, the output is z_p and y_mask that are in output_names.

The shape of z_p is [batch_size, 80, time_step]

The shape of y_mask is [batch_size, 80, time_step]

If batch_size = 1, we can find when axis = 2, the time_step in z_p and y_mask are dynamic.

Finally, we can set

dynamic_axes = {
                "sequences":{1: "time_step"},
                "z_p":{2: "time_step"}
                "y_mask":{2: "time_step"}

Here we also can set time_step to other name as you wish.

Finally, we can export a pytorch model to onnx successfully.

Here are full example:

pytorch model export onnx example