[JIT] ShapeProp: add missing ops from mobilenet v3. (#59163)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59163
Test Plan: Imported from OSS
Reviewed By: ejguan
Differential Revision: D28853833
Pulled By: ZolotukhinM
fbshipit-source-id: 451fb9ee848968049d26fb5623a904d8fa7bd6fc
diff --git a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
index a4ca941..fa65872 100644
--- a/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
+++ b/torch/csrc/jit/runtime/symbolic_shape_registry.cpp
@@ -49,6 +49,43 @@
out.append(elem)
return out
+ def unary_one_unused_input(self: List[int], inp0: Any):
+ out: List[int] = []
+ for elem in self:
+ out.append(elem)
+ return out
+
+ def unary(self: List[int]):
+ out: List[int] = []
+ for elem in self:
+ out.append(elem)
+ return out
+
+ def view(self: List[int], sizes: List[int]):
+ # TODO: add assertions to check whether requested dims are valid
+ out: List[int] = []
+ for elem in sizes:
+ if elem == -1:
+ # TODO: support -1 in view dimensions
+ raise AssertionError("Shape function doesn't support -1 view dims yet")
+ out.append(elem)
+ return out
+
+ def mean_dim(self: List[int], dims: List[int], keep_dim: bool, dt : Any):
+ out: List[int] = []
+ idx : int = 0
+ for elem in self:
+ is_mean_dim : bool = False
+ for reduce_dim in dims:
+ if idx == reduce_dim:
+ is_mean_dim = True
+ if is_mean_dim:
+ if keep_dim:
+ out.append(1)
+ else:
+ out.append(elem)
+ return out
+
def broadcast_one_unused_input(self: List[int], other: List[int], unused: Any):
return broadcast(self, other)
@@ -150,6 +187,12 @@
assert broadcast(bias, out) == out
return out
+ def addmm(self: List[int], mat1: List[int], mat2: List[int], beta: Any, alpha: Any):
+ out = matmul(mat1, t(mat2))
+ if self is not None:
+ assert broadcast(self, out) == out
+ return out
+
def check_non_negative(array: List[int]) -> bool:
for val in array:
if val < 0:
@@ -264,9 +307,12 @@
// clang-format off
static const OperatorMap<std::string> schema_to_function_graph{
{"aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
+ {"aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "unary_one_unused_input"},
{"aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
+ {"aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "unary_one_unused_input"},
{"aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "broadcast"},
{"aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "broadcast_one_unused_input"},
+ {"aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "unary_two_unused_inputs"},
{"aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "unary_two_unused_inputs"},
{"aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "adaptive_avg_pool2d"},
{"aten::mm(Tensor self, Tensor mat2) -> Tensor", "mm"},
@@ -279,6 +325,11 @@
{"aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "conv2d"},
{"aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, int[3] stride=1, int[3] padding=0, int[3] dilation=1, int groups=1) -> Tensor", "conv3d"},
{"aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "flatten"},
+ {"aten::relu(Tensor self) -> Tensor", "unary"},
+ {"aten::view(Tensor(a) self, int[] size) -> Tensor(a)", "view"},
+ {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "view"},
+ {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
+ {"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "addmm"},
};
// clang-format on
return schema_to_function_graph;