[ONNX]Fix export of flatten (#40418)
Summary:
Shape is passed to _reshape_to_tensor as a Constant and cannot infer shape of the input when model is exported with dynamic axes set. Instead of a Constant pass output of a subgraph Shape-Slice-Concat to compute the shape for the Reshape node in _reshape_to_tensor function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40418
Reviewed By: hl475
Differential Revision: D22480127
Pulled By: houseroad
fbshipit-source-id: 11853adb6e6914936871db1476916699141de435
diff --git a/test/onnx/expect/TestOperators.test_flatten.expect b/test/onnx/expect/TestOperators.test_flatten.expect
index a995a1e..418bc94 100644
--- a/test/onnx/expect/TestOperators.test_flatten.expect
+++ b/test/onnx/expect/TestOperators.test_flatten.expect
@@ -1,26 +1,65 @@
ir_version: 6
producer_name: "pytorch"
-producer_version: "XXX"
+producer_version: "1.7"
graph {
node {
+ input: "0"
output: "1"
- name: "Constant_0"
+ name: "Shape_0"
+ op_type: "Shape"
+ }
+ node {
+ input: "1"
+ output: "2"
+ name: "Slice_1"
+ op_type: "Slice"
+ attribute {
+ name: "axes"
+ ints: 0
+ type: INTS
+ }
+ attribute {
+ name: "ends"
+ ints: 0
+ type: INTS
+ }
+ attribute {
+ name: "starts"
+ ints: 0
+ type: INTS
+ }
+ }
+ node {
+ output: "3"
+ name: "Constant_2"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
data_type: 7
- raw_data: "\030\000\000\000\000\000\000\000"
+ raw_data: "\377\377\377\377\377\377\377\377"
}
type: TENSOR
}
}
node {
+ input: "2"
+ input: "3"
+ output: "4"
+ name: "Concat_3"
+ op_type: "Concat"
+ attribute {
+ name: "axis"
+ i: 0
+ type: INT
+ }
+ }
+ node {
input: "0"
- input: "1"
- output: "2"
- name: "Reshape_1"
+ input: "4"
+ output: "5"
+ name: "Reshape_4"
op_type: "Reshape"
}
name: "torch-jit-export"
@@ -47,7 +86,7 @@
}
}
output {
- name: "2"
+ name: "5"
type {
tensor_type {
elem_type: 1
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 54e0816..0dbb9f0 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1620,7 +1620,7 @@
self.run_model_test(ScatterModel(), train=False, input=(input, indices, values),
batch_size=BATCH_SIZE, use_gpu=False)
-
+ @skipIfUnsupportedOpsetVersion([10])
def test_flatten(self):
class FlattenModel(torch.nn.Module):
def forward(self, input):
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 72fa374..fc3b610 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2301,6 +2301,22 @@
x = torch.randint(10, (1, 2, 3, 4))
self.run_test(FlattenModel(), x)
+ @skipIfUnsupportedMinOpsetVersion(9)
+ def test_flatten_dynamic_axes(self):
+ class MyModule(torch.nn.Module):
+ def forward(self, x):
+ return torch.flatten(x, start_dim=2, end_dim=3)
+
+ batch_size = 3
+ x = torch.randn(batch_size, 5, 4, 5)
+ y = torch.randn(5, 5, 4, 5)
+ model = MyModule()
+ self.run_test(model, x, test_with_inputs=[y],
+ input_names=['input'],
+ output_names=['output'],
+ dynamic_axes={'input' : {0 : 'batch_size'},
+ 'output' : {0 : 'batch_size'}})
+
@skipIfUnsupportedMinOpsetVersion(11)
def test_getitem(self):
class GetItemModel(torch.jit.ScriptModule):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 4629748..46cbab8 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -441,6 +441,18 @@
op_mode + " mode. The model will be exported in " +
training_mode + ", as specified by the export mode.")
+def _flatten_helper(g, input, start_dim, end_dim, dim):
+ input_size = g.op("Shape", input)
+ slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
+ slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
+ if end_dim < dim - 1:
+ slice3 = _slice_helper(g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim])
+ slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slice3]
+
+ final_shape = g.op("Concat", *slices, axis_i=0)
+ from torch.onnx.symbolic_opset9 import _reshape_from_tensor
+ return _reshape_from_tensor(g, input, final_shape)
+
# ---------------------------------------------------------------------
# ONNX operator version
# ---------------------------------------------------------------------
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index b02f018..3a83880 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -648,29 +648,20 @@
@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
+ if dim is None:
+ return _unimplemented("dim",
+ "ONNX and PyTorch use different strategies to split the input. "
+ "Input rank must be known at export time.")
+
# use ONNX's Flatten operator for cases where the output shape is 2D
if start_dim == 1:
- if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)):
+ if (end_dim == -1 or end_dim == dim - 1):
return g.op("Flatten", input, axis_i=start_dim)
elif start_dim == 0:
- if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)):
+ if (end_dim == -2 or end_dim == dim - 2):
return g.op("Flatten", input, axis_i=end_dim + 1)
- # use Reshape for cases where the output shape is not 2D
- if not input.isCompleteTensor():
- return _unimplemented("flatten",
- "input size not accessible "
- "(consider using reshape op instead of flatten op to export to ONNX)")
# if end_dim is negative add dim
if end_dim < 0 :
end_dim = dim + end_dim
- input_dims = input.type().sizes()
- output_dims = []
- for i in range(0, dim):
- if start_dim < i and end_dim >= i:
- output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
- else:
- output_dims.append(input_dims[i])
- shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
- from torch.onnx.symbolic_opset9 import _reshape_from_tensor
- p = _reshape_from_tensor(g, input, shape)
- return p
+
+ return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index bc0bc4a..a96d6c0 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1992,6 +1992,11 @@
@parse_args('v', 'i', 'i')
def flatten(g, input, start_dim, end_dim):
dim = input.type().dim()
+ if dim is None:
+ return _unimplemented("dim",
+ "ONNX and PyTorch use different strategies to split the input. "
+ "Input rank must be known at export time.")
+
# TODO: remove this as onnx opset 11 spec allows negative axes
if end_dim < 0 :
end_dim = dim + end_dim
@@ -2000,22 +2005,8 @@
return g.op("Flatten", input, axis_i=start_dim)
if start_dim == 0 and end_dim == dim - 2 :
return g.op("Flatten", input, axis_i=end_dim + 1)
- # use Reshape for cases where the output shape is not 2D
- if not input.isCompleteTensor():
- return _unimplemented("flatten",
- "input size not accessible "
- "(consider using reshape op instead of flatten op to export to ONNX)")
- input_dims = input.type().sizes()
- output_dims = []
- for i in range(0, dim):
- if start_dim < i and end_dim >= i:
- output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
- else:
- output_dims.append(input_dims[i])
- shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
- p = _reshape_from_tensor(g, input, shape)
- return p
+ return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
@parse_args('v')
def nonzero(g, input):