[onnx] Add support for autograd function inlining in ONNX_ATEN_FALLBACK mode (#85736)
Solution to #85027
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85736
Approved by: https://github.com/BowenBao
diff --git a/test/onnx/test_autograd_funs.py b/test/onnx/test_autograd_funs.py
index a0980d2..97f0652 100644
--- a/test/onnx/test_autograd_funs.py
+++ b/test/onnx/test_autograd_funs.py
@@ -148,7 +148,7 @@
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
)
iter = graph.nodes()
- self.assertEqual(next(iter).kind(), "prim::PythonOp")
+ self.assertEqual(next(iter).kind(), "aten::ATen")
def test_inline_and_symbolic(self):
class Exp(torch.autograd.Function):
diff --git a/torch/csrc/jit/passes/onnx.cpp b/torch/csrc/jit/passes/onnx.cpp
index ea183ae..98f3cb4 100644
--- a/torch/csrc/jit/passes/onnx.cpp
+++ b/torch/csrc/jit/passes/onnx.cpp
@@ -465,7 +465,9 @@
// 1. The torch.autograd.Function class of this node object has `symbolic`
// method defined.
// 2. Custom export symbolic is registered for prim::PythonOp.
- if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX) {
+ if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
+ operator_export_type ==
+ ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
try {
inlineAutograd(op);
} catch (const std::exception& ex) {