Fix onnx export (#23180)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23180
This pass needs to be run later because it breaks jit graph invariants and the lower_all_tuples pass still needs a valid jit graph.
Reviewed By: houseroad
Differential Revision: D16427680
fbshipit-source-id: 427c7e74c59a3d7d62f2855ed626cf6258107509
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index dab67be..555ba87 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -109,13 +109,14 @@
torch._C._jit_pass_lint(graph)
if operator_export_type != OperatorExportTypes.RAW:
- # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
- torch._C._jit_pass_prepare_division_for_onnx(graph)
# onnx does not support tuples, so try to remove them
torch._C._jit_pass_lower_all_tuples(graph)
torch._C._jit_pass_peephole(graph, True)
torch._C._jit_pass_lint(graph)
+ # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
+ torch._C._jit_pass_prepare_division_for_onnx(graph)
+
torch._C._jit_pass_onnx_remove_print(graph)
torch._C._jit_pass_onnx_preprocess_caffe2(graph)