[Core ML] Support enumerated input shapes (#74441)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74441

For xirp based segmentation models, we want to support enumerated input shapes. This allows us to support both landscape and portrait mode images without sacrificing the performance. P488118264
ghstack-source-id: 151736964

Test Plan: `buck run coreml:xirp -- --model="/home/taox/xirp/xirp_20a.pt" --out="/home/taox/xirp/xirp_20a_coreml_enumerated.ptl"`

Reviewed By: mcr229

Differential Revision: D34803184

fbshipit-source-id: c462c0783846a1489ca7ce4d5a654aa6927c9c44
(cherry picked from commit 67d418c97531daaf3d03d1000ca4a4ff60de2a95)
diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py
index 7f27e60..3884058 100644
--- a/torch/backends/_coreml/preprocess.py
+++ b/torch/backends/_coreml/preprocess.py
@@ -1,7 +1,6 @@
 import hashlib
 import json
-from dataclasses import dataclass, astuple, field
-from typing import Dict, Tuple, List
+from typing import Dict, Tuple
 
 import coremltools as ct  # type: ignore[import]
 import torch
@@ -35,86 +34,56 @@
     ALL = "all"
 
 
-@dataclass
-class _TensorSpec:
-    shape: List[int] = field(default_factory=List[int])
-    dtype: int = ScalarType.Float
+def TensorSpec(shape, dtype=ScalarType.Float):
+    return (shape, dtype)
 
 
-def TensorSpec(*args, **kwargs):
-    """
-    TensorSpec specifies the tensor information. The default dtype is float32
-    Example:
-    ts = TensorSpec(
-        shape = [1, 3, 224, 224],
-        dtype = ScalarType.Float
-    )
-    """
-    return astuple(_TensorSpec(*args, **kwargs))
+def CompileSpec(inputs, outputs, backend=CoreMLComputeUnit.CPU, allow_low_precision=True):
+    return (inputs, outputs, backend, allow_low_precision)
 
 
-@dataclass
-class _CompileSpec:
-    inputs: Tuple[_TensorSpec] = ()  # type: ignore[assignment]
-    outputs: Tuple[_TensorSpec] = ()  # type: ignore[assignment]
-    backend: str = CoreMLComputeUnit.CPU
-    allow_low_precision: bool = True
+def _check_enumerated_shape(shape):
+    for s in shape:
+        if not isinstance(s, (list, tuple)):
+            return False
+    return True
 
 
-def CompileSpec(*args, **kwargs):
-    """
-    CompileSpec specifies the model information.
-    Example:
-    cs = CompileSpec(
-            inputs=(
-                TensorSpec(
-                    shape=[1, 3, 224, 224],
-                ),
-            ),
-            outputs=(
-                TensorSpec(
-                    shape=[1, 1000],
-                ),
-            ),
-            backend=CoreMLComputeUnit.CPU,
-            allow_low_precision=True,
-    ),
-    """
-    return astuple(_CompileSpec(*args, **kwargs))
-
-
-def _convert_to_mil_type(spec: _TensorSpec, name: str):
-    ml_type = TensorType(shape=spec.shape, dtype=torch_to_mil_types[spec.dtype])
+def _convert_to_mil_type(shape, dtype, name: str):
+    mil_shape = shape
+    if _check_enumerated_shape(shape):
+        mil_shape = ct.EnumeratedShapes(shape)
+    ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
     ml_type.name = name
     return ml_type
 
 
 def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
     spec = compile_spec["forward"]
-    forward_spec = _CompileSpec(*spec)
+    input_specs, output_specs, backend, allow_low_precision = spec
     mil_inputs = []
     inputs = []
-    for index, input_spec in enumerate(forward_spec.inputs):
-        input_spec = _TensorSpec(*input_spec)  # type: ignore[misc]
+    for index, input in enumerate(input_specs):
+        shape, dtype = input
         name = "input_" + str(index)
-        inputs.append([name, str(input_spec.dtype), str(input_spec.shape)])
-        ml_type = _convert_to_mil_type(input_spec, name)
+        inputs.append([name, str(dtype), str(shape)])
+        ml_type = _convert_to_mil_type(shape, dtype, name)
         mil_inputs.append(ml_type)
     model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
     mlmodel = ct.convert(model, inputs=mil_inputs)
     spec = mlmodel.get_spec()
-    output_specs = forward_spec.outputs
     assert len(spec.description.output) == len(output_specs)  # type: ignore[attr-defined]
     outputs = []
-    for index, output_spec in enumerate(output_specs):
-        output_spec = _TensorSpec(*output_spec)  # type: ignore[misc]
+    for index, output in enumerate(output_specs):
+        shape, dtype = output
         name = spec.description.output[index].name  # type: ignore[attr-defined]
-        outputs.append([name, str(output_spec.dtype), str(output_spec.shape)])
+        outputs.append([name, str(dtype), str(shape)])
     mlmodel = ct.models.model.MLModel(spec)
+    print(mlmodel)
     config = {
         "spec_ver": str(spec.specificationVersion),  # type: ignore[attr-defined]
-        "backend": forward_spec.backend,
-        "allow_low_precision": str(forward_spec.allow_low_precision),
+        "backend": backend,
+        "allow_low_precision": str(allow_low_precision),
     }
     metadata = {
         "coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],