[Core ML] Support enumerated input shapes (#74441)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74441
For xirp based segmentation models, we want to support enumerated input shapes. This allows us to support both landscape and portrait mode images without sacrificing the performance. P488118264
ghstack-source-id: 151736964
Test Plan: `buck run coreml:xirp -- --model="/home/taox/xirp/xirp_20a.pt" --out="/home/taox/xirp/xirp_20a_coreml_enumerated.ptl"`
Reviewed By: mcr229
Differential Revision: D34803184
fbshipit-source-id: c462c0783846a1489ca7ce4d5a654aa6927c9c44
(cherry picked from commit 67d418c97531daaf3d03d1000ca4a4ff60de2a95)
diff --git a/torch/backends/_coreml/preprocess.py b/torch/backends/_coreml/preprocess.py
index 7f27e60..3884058 100644
--- a/torch/backends/_coreml/preprocess.py
+++ b/torch/backends/_coreml/preprocess.py
@@ -1,7 +1,6 @@
import hashlib
import json
-from dataclasses import dataclass, astuple, field
-from typing import Dict, Tuple, List
+from typing import Dict, Tuple
import coremltools as ct # type: ignore[import]
import torch
@@ -35,86 +34,56 @@
ALL = "all"
-@dataclass
-class _TensorSpec:
- shape: List[int] = field(default_factory=List[int])
- dtype: int = ScalarType.Float
+def TensorSpec(shape, dtype=ScalarType.Float):
+ return (shape, dtype)
-def TensorSpec(*args, **kwargs):
- """
- TensorSpec specifies the tensor information. The default dtype is float32
- Example:
- ts = TensorSpec(
- shape = [1, 3, 224, 224],
- dtype = ScalarType.Float
- )
- """
- return astuple(_TensorSpec(*args, **kwargs))
+def CompileSpec(inputs, outputs, backend=CoreMLComputeUnit.CPU, allow_low_precision=True):
+ return (inputs, outputs, backend, allow_low_precision)
-@dataclass
-class _CompileSpec:
- inputs: Tuple[_TensorSpec] = () # type: ignore[assignment]
- outputs: Tuple[_TensorSpec] = () # type: ignore[assignment]
- backend: str = CoreMLComputeUnit.CPU
- allow_low_precision: bool = True
+def _check_enumerated_shape(shape):
+ for s in shape:
+ if not isinstance(s, (list, tuple)):
+ return False
+ return True
-def CompileSpec(*args, **kwargs):
- """
- CompileSpec specifies the model information.
- Example:
- cs = CompileSpec(
- inputs=(
- TensorSpec(
- shape=[1, 3, 224, 224],
- ),
- ),
- outputs=(
- TensorSpec(
- shape=[1, 1000],
- ),
- ),
- backend=CoreMLComputeUnit.CPU,
- allow_low_precision=True,
- ),
- """
- return astuple(_CompileSpec(*args, **kwargs))
-
-
-def _convert_to_mil_type(spec: _TensorSpec, name: str):
- ml_type = TensorType(shape=spec.shape, dtype=torch_to_mil_types[spec.dtype])
+def _convert_to_mil_type(shape, dtype, name: str):
+ mil_shape = shape
+ if _check_enumerated_shape(shape):
+ mil_shape = ct.EnumeratedShapes(shape)
+ ml_type = TensorType(shape=mil_shape, dtype=torch_to_mil_types[dtype])
ml_type.name = name
return ml_type
def preprocess(script_module: torch._C.ScriptObject, compile_spec: Dict[str, Tuple]):
spec = compile_spec["forward"]
- forward_spec = _CompileSpec(*spec)
+ input_specs, output_specs, backend, allow_low_precision = spec
mil_inputs = []
inputs = []
- for index, input_spec in enumerate(forward_spec.inputs):
- input_spec = _TensorSpec(*input_spec) # type: ignore[misc]
+ for index, input in enumerate(input_specs):
+ shape, dtype = input
name = "input_" + str(index)
- inputs.append([name, str(input_spec.dtype), str(input_spec.shape)])
- ml_type = _convert_to_mil_type(input_spec, name)
+ inputs.append([name, str(dtype), str(shape)])
+ ml_type = _convert_to_mil_type(shape, dtype, name)
mil_inputs.append(ml_type)
model = torch.jit.RecursiveScriptModule._construct(script_module, lambda x: None)
mlmodel = ct.convert(model, inputs=mil_inputs)
spec = mlmodel.get_spec()
- output_specs = forward_spec.outputs
assert len(spec.description.output) == len(output_specs) # type: ignore[attr-defined]
outputs = []
- for index, output_spec in enumerate(output_specs):
- output_spec = _TensorSpec(*output_spec) # type: ignore[misc]
+ for index, output in enumerate(output_specs):
+ shape, dtype = output
name = spec.description.output[index].name # type: ignore[attr-defined]
- outputs.append([name, str(output_spec.dtype), str(output_spec.shape)])
+ outputs.append([name, str(dtype), str(shape)])
mlmodel = ct.models.model.MLModel(spec)
+ print(mlmodel)
config = {
"spec_ver": str(spec.specificationVersion), # type: ignore[attr-defined]
- "backend": forward_spec.backend,
- "allow_low_precision": str(forward_spec.allow_low_precision),
+ "backend": backend,
+ "allow_low_precision": str(allow_low_precision),
}
metadata = {
"coremltool_ver": mlmodel.user_defined_metadata[CT_METADATA_VERSION],