Fix aten::to symbolic + add expand_as (#13325)
Summary:
https://github.com/pytorch/pytorch/pull/13146 broke some cases of ONNX export, this fixes them
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13325
Differential Revision: D12844294
Pulled By: jamesr66a
fbshipit-source-id: f98dd0685820b2a1e5fcd49733cfa5c19c48a4e7
diff --git a/test/expect/TestJit.test_export_tensoroption_to.expect b/test/expect/TestJit.test_export_tensoroption_to.expect
new file mode 100644
index 0000000..fc7ab97
--- /dev/null
+++ b/test/expect/TestJit.test_export_tensoroption_to.expect
@@ -0,0 +1,22 @@
+ModelProto {
+ producer_name: "pytorch"
+ domain: ""
+ doc_string: ""
+ graph:
+ GraphProto {
+ name: "torch-jit-export"
+ inputs: [{name: "0", type:Tensor dims: 2}]
+ outputs: [{name: "7", type:Tensor dims: 2}]
+ initializers: []
+ nodes: [
+ Node {type: "Constant", inputs: [], outputs: [1], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Gather", inputs: [0,1], outputs: [2], attributes: [{ name: 'axis', type: int, value: 0}]},
+ Node {type: "Constant", inputs: [], outputs: [3], attributes: [{ name: 'value', type: tensor, value:TensorProto shape: []}]},
+ Node {type: "Shape", inputs: [3], outputs: [4], attributes: []},
+ Node {type: "Expand", inputs: [2,4], outputs: [5], attributes: []},
+ Node {type: "Cast", inputs: [5], outputs: [6], attributes: [{ name: 'to', type: int, value: 1}]},
+ Node {type: "Add", inputs: [6,0], outputs: [7], attributes: []}
+ ]
+ }
+ opset_import: [OperatorSetIdProto { domain: }],
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index c2b7d98..5d6007f 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -1917,6 +1917,17 @@
x = torch.rand(3, 4)
self.assertEqual(random_bar(x), (x + 1)[0:1])
+ def test_export_tensoroption_to(self):
+ def foo(x):
+ return x.new_tensor(x[0]).cpu() + x
+
+ traced = torch.jit.trace(foo, (torch.rand([2])))
+ example_outputs = traced(torch.rand([2]))
+
+ f = io.BytesIO()
+ self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
+ example_outputs=example_outputs))
+
def test_pretty_printer(self):
@torch.jit.script
def if_test(a, b):
diff --git a/torch/onnx/symbolic.py b/torch/onnx/symbolic.py
index 7efdf20..f3963b3 100644
--- a/torch/onnx/symbolic.py
+++ b/torch/onnx/symbolic.py
@@ -350,6 +350,11 @@
return None
+def expand_as(g, self, other):
+ shape = g.op("Shape", other)
+ return g.op("Expand", self, shape)
+
+
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
return g.op("Gather", weight, indices)
@@ -1097,6 +1102,11 @@
# aten::to(Tensor, Device, ScalarType, bool, bool)
dtype = _get_const(args[1], 'i', 'dtype')
return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
+ elif len(args) == 5:
+ # aten::to(Tensor, ScalarType, Layout, Device, bool, bool) -> Tensor
+ dtype = _get_const(args[0], 'i', 'dtype')
+ # Layout and device are ignored
+ return g.op("Cast", self, to_i=scalar_type_to_onnx[dtype])
else:
raise NotImplementedError("Unknown aten::to signature")