[ONNX]Fix export of flatten (#40418)

Summary:
Shape is passed to _reshape_to_tensor as a Constant and cannot infer shape of the input when model is exported with dynamic axes set. Instead of a Constant pass output of a subgraph Shape-Slice-Concat to compute the shape for the Reshape node in _reshape_to_tensor function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/40418

Reviewed By: hl475

Differential Revision: D22480127

Pulled By: houseroad

fbshipit-source-id: 11853adb6e6914936871db1476916699141de435
diff --git a/test/onnx/expect/TestOperators.test_flatten.expect b/test/onnx/expect/TestOperators.test_flatten.expect
index a995a1e..418bc94 100644
--- a/test/onnx/expect/TestOperators.test_flatten.expect
+++ b/test/onnx/expect/TestOperators.test_flatten.expect
@@ -1,26 +1,65 @@
 ir_version: 6
 producer_name: "pytorch"
-producer_version: "XXX"
+producer_version: "1.7"
 graph {
   node {
+    input: "0"
     output: "1"
-    name: "Constant_0"
+    name: "Shape_0"
+    op_type: "Shape"
+  }
+  node {
+    input: "1"
+    output: "2"
+    name: "Slice_1"
+    op_type: "Slice"
+    attribute {
+      name: "axes"
+      ints: 0
+      type: INTS
+    }
+    attribute {
+      name: "ends"
+      ints: 0
+      type: INTS
+    }
+    attribute {
+      name: "starts"
+      ints: 0
+      type: INTS
+    }
+  }
+  node {
+    output: "3"
+    name: "Constant_2"
     op_type: "Constant"
     attribute {
       name: "value"
       t {
         dims: 1
         data_type: 7
-        raw_data: "\030\000\000\000\000\000\000\000"
+        raw_data: "\377\377\377\377\377\377\377\377"
       }
       type: TENSOR
     }
   }
   node {
+    input: "2"
+    input: "3"
+    output: "4"
+    name: "Concat_3"
+    op_type: "Concat"
+    attribute {
+      name: "axis"
+      i: 0
+      type: INT
+    }
+  }
+  node {
     input: "0"
-    input: "1"
-    output: "2"
-    name: "Reshape_1"
+    input: "4"
+    output: "5"
+    name: "Reshape_4"
     op_type: "Reshape"
   }
   name: "torch-jit-export"
@@ -47,7 +86,7 @@
     }
   }
   output {
-    name: "2"
+    name: "5"
     type {
       tensor_type {
         elem_type: 1
diff --git a/test/onnx/test_pytorch_onnx_caffe2.py b/test/onnx/test_pytorch_onnx_caffe2.py
index 54e0816..0dbb9f0 100644
--- a/test/onnx/test_pytorch_onnx_caffe2.py
+++ b/test/onnx/test_pytorch_onnx_caffe2.py
@@ -1620,7 +1620,7 @@
         self.run_model_test(ScatterModel(), train=False, input=(input, indices, values),
                             batch_size=BATCH_SIZE, use_gpu=False)
 
-
+    @skipIfUnsupportedOpsetVersion([10])
     def test_flatten(self):
         class FlattenModel(torch.nn.Module):
             def forward(self, input):
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 72fa374..fc3b610 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2301,6 +2301,22 @@
         x = torch.randint(10, (1, 2, 3, 4))
         self.run_test(FlattenModel(), x)
 
+    @skipIfUnsupportedMinOpsetVersion(9)
+    def test_flatten_dynamic_axes(self):
+        class MyModule(torch.nn.Module):
+            def forward(self, x):
+                return torch.flatten(x, start_dim=2, end_dim=3)
+
+        batch_size = 3
+        x = torch.randn(batch_size, 5, 4, 5)
+        y = torch.randn(5, 5, 4, 5)
+        model = MyModule()
+        self.run_test(model, x, test_with_inputs=[y],
+                      input_names=['input'],
+                      output_names=['output'],
+                      dynamic_axes={'input' : {0 : 'batch_size'},
+                                    'output' : {0 : 'batch_size'}})
+
     @skipIfUnsupportedMinOpsetVersion(11)
     def test_getitem(self):
         class GetItemModel(torch.jit.ScriptModule):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 4629748..46cbab8 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -441,6 +441,18 @@
                       op_mode + " mode. The model will be exported in " +
                       training_mode + ", as specified by the export mode.")
 
+def _flatten_helper(g, input, start_dim, end_dim, dim):
+    input_size = g.op("Shape", input)
+    slice1 = _slice_helper(g, input_size, axes=[0], starts=[0], ends=[start_dim])
+    slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long))]
+    if end_dim < dim - 1:
+        slice3 = _slice_helper(g, input_size, axes=[0], starts=[end_dim + 1], ends=[dim])
+        slices = [slice1, g.op("Constant", value_t=torch.tensor([-1], dtype=torch.long)), slice3]
+
+    final_shape = g.op("Concat", *slices, axis_i=0)
+    from torch.onnx.symbolic_opset9 import _reshape_from_tensor
+    return _reshape_from_tensor(g, input, final_shape)
+
 # ---------------------------------------------------------------------
 # ONNX operator version
 # ---------------------------------------------------------------------
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index b02f018..3a83880 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -648,29 +648,20 @@
 @parse_args('v', 'i', 'i')
 def flatten(g, input, start_dim, end_dim):
     dim = input.type().dim()
+    if dim is None:
+        return _unimplemented("dim",
+                              "ONNX and PyTorch use different strategies to split the input. "
+                              "Input rank must be known at export time.")
+
     # use ONNX's Flatten operator for cases where the output shape is 2D
     if start_dim == 1:
-        if (end_dim == -1 or (end_dim is not None and end_dim == dim - 1)):
+        if (end_dim == -1 or end_dim == dim - 1):
             return g.op("Flatten", input, axis_i=start_dim)
     elif start_dim == 0:
-        if (end_dim == -2 or (end_dim is not None and end_dim == dim - 2)):
+        if (end_dim == -2 or end_dim == dim - 2):
             return g.op("Flatten", input, axis_i=end_dim + 1)
-    # use Reshape for cases where the output shape is not 2D
-    if not input.isCompleteTensor():
-        return _unimplemented("flatten",
-                              "input size not accessible "
-                              "(consider using reshape op instead of flatten op to export to ONNX)")
     # if end_dim is negative add dim
     if end_dim < 0 :
         end_dim = dim + end_dim
-    input_dims = input.type().sizes()
-    output_dims = []
-    for i in range(0, dim):
-        if start_dim < i and end_dim >= i:
-            output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
-        else:
-            output_dims.append(input_dims[i])
-    shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
-    from torch.onnx.symbolic_opset9 import _reshape_from_tensor
-    p = _reshape_from_tensor(g, input, shape)
-    return p
+
+    return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index bc0bc4a..a96d6c0 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1992,6 +1992,11 @@
 @parse_args('v', 'i', 'i')
 def flatten(g, input, start_dim, end_dim):
     dim = input.type().dim()
+    if dim is None:
+        return _unimplemented("dim",
+                              "ONNX and PyTorch use different strategies to split the input. "
+                              "Input rank must be known at export time.")
+
     # TODO: remove this as onnx opset 11 spec allows negative axes
     if end_dim < 0 :
         end_dim = dim + end_dim
@@ -2000,22 +2005,8 @@
         return g.op("Flatten", input, axis_i=start_dim)
     if start_dim == 0 and end_dim == dim - 2 :
         return g.op("Flatten", input, axis_i=end_dim + 1)
-    # use Reshape for cases where the output shape is not 2D
-    if not input.isCompleteTensor():
-        return _unimplemented("flatten",
-                              "input size not accessible "
-                              "(consider using reshape op instead of flatten op to export to ONNX)")
-    input_dims = input.type().sizes()
-    output_dims = []
-    for i in range(0, dim):
-        if start_dim < i and end_dim >= i:
-            output_dims[start_dim] = output_dims[start_dim] * input_dims[i]
-        else:
-            output_dims.append(input_dims[i])
-    shape = g.op("Constant", value_t=torch.LongTensor(output_dims))
-    p = _reshape_from_tensor(g, input, shape)
-    return p
 
+    return sym_help._flatten_helper(g, input, start_dim, end_dim, dim)
 
 @parse_args('v')
 def nonzero(g, input):