Guide to Export Onnx Support aten::scaled_dot_product_attention Operator – Onnx Tutorial

By | September 3, 2024

When we export a torch model, we may get this error: torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::scaled_dot_product_attention’ to ONNX opset version 13 is not supported.

This error is below:

aten scaled_dot_product_attenton to onnx is not supported

How to fix this error?

You can fix this error by our steps. Otherwise, you may find the onnx can be exported successfully, however, this onnx can not be used to infer.

Step 1: prepare environment

You should be sure the torch version is 2.3.1 or higher.

I have testd torch=2.0.0, however, this error can not be fixed. Then, i updated it to 2.3.1

The key points are:

torch = 2.3
onnx = 1.10.0

Why do we use torch = 2.3.1

We guess onnx symbolic opset scaled_dot_product_attention, which is added in symbolic_opset14.py, is incompatible with torch 2.0.0.

When testing, we have found that we can export a torch model to onnx in torch 2.0.0, however, it can not be used to infer. It reports error: onnx right operand cannot broadcaset on dim 3 LeftShape 

onnx right operand cannot broadcaset on dim 3 LeftShape

https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset14.py#L137

onnx symbolic opset scaled_dot_product_attention

Step 2: export model to onnx

When we export a torch model to onnx, we can not set opset_version = 14.

onnx opset_version not set to 14

From symbolic_opset14.py we can find: onnx opset_version= 14 supports scaled_dot_product_attention operation.

We can export the model to onnx successfully on opset_version= 14.

However, if we load this onnx to infer, we may get this error: Opset 14 is under development and support for this is limited

Step 3: how to export onnx on opset_version=13

opset_version = 13 can not support scaled_dot_product_attention operation, we have to add it based on symbolic_opset14.py

Import symbolic_opset14.py or add scaled_dot_product_attention operation to your export script, you can get a valid onnx.

For example:

how to export onnx on opset_version=13

Finally, we will get a valid onnx and we can use it to infer successfully.

Leave a Reply