Register aten.repeat in partitioner, add aten.sqrt (#3510)
Summary:
Pull Request resolved: https://github.com/pytorch/executorch/pull/3510
As title, also fixed some nit in the unittest.
ghstack-source-id: 225119225
Reviewed By: copyrightly, jorgep31415
Differential Revision: D56961340
fbshipit-source-id: c9b505a9003c7f58541c7799ea290d1867579eba
diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py
index 0e171dd..ac7df8c 100644
--- a/backends/vulkan/partitioner/vulkan_partitioner.py
+++ b/backends/vulkan/partitioner/vulkan_partitioner.py
@@ -42,6 +42,7 @@
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.tanh.default,
+ exir_ops.edge.aten.sqrt.default,
# Matrix multiplication operators
exir_ops.edge.aten.mm.default,
# Pooling operators
@@ -62,6 +63,7 @@
exir_ops.edge.aten.split_with_sizes_copy.default,
exir_ops.edge.aten.split.Tensor,
exir_ops.edge.aten.slice_copy.Tensor,
+ exir_ops.edge.aten.repeat.default,
# Other
operator.getitem,
exir_ops.edge.aten.full.default,
diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
index c32593d..2d8ec36 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
@@ -16,3 +16,5 @@
OPERATOR: 1 / (1 + exp(-1 * X))
- NAME: tanh
OPERATOR: tanh(clamp(X, -15.0, 15.0))
+ - NAME: sqrt
+ OPERATOR: sqrt(X)
diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
index 3888118..b2fb113 100644
--- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
@@ -103,6 +103,7 @@
DEFINE_ACTIVATION_FN(abs);
DEFINE_ACTIVATION_FN(sigmoid);
DEFINE_ACTIVATION_FN(tanh);
+DEFINE_ACTIVATION_FN(sqrt);
DEFINE_CLAMP_FN(clamp);
DEFINE_CLAMP_FN(hardtanh);
DEFINE_RELU_FN(relu);
@@ -114,6 +115,7 @@
VK_REGISTER_OP(aten.relu.default, relu);
VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
VK_REGISTER_OP(aten.tanh.default, tanh);
+ VK_REGISTER_OP(aten.sqrt.default, sqrt);
}
} // namespace vkcompute
diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py
index 329ae02..13b5542 100644
--- a/backends/vulkan/test/op_tests/cases.py
+++ b/backends/vulkan/test/op_tests/cases.py
@@ -555,6 +555,18 @@
return test_suite
+def get_unary_ops_inputs():
+ test_suite = VkTestSuite(
+ [
+ (M1,),
+ (M1, M2),
+ (S1, M1, M2),
+ (S1, S2, S2, M2),
+ ]
+ )
+ return test_suite
+
+
test_suites = {
"aten.add.Tensor": get_binary_elementwise_inputs(),
"aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -577,4 +589,5 @@
"aten.cat.default": get_cat_inputs(),
"aten.split_with_sizes_copy.default": get_split_with_sizes_inputs(),
"aten.split.Tensor": get_split_tensor_inputs(),
+ "aten.sqrt.default": get_unary_ops_inputs(),
}
diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py
index 07c7de0..a896ed9 100644
--- a/backends/vulkan/test/test_vulkan_delegate.py
+++ b/backends/vulkan/test/test_vulkan_delegate.py
@@ -29,7 +29,13 @@
class TestBackends(unittest.TestCase):
def assert_outputs_equal(
- self, model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False
+ self,
+ model_output,
+ ref_output,
+ atol=1e-03,
+ rtol=1e-03,
+ first_output_only=False,
+ equal_nan=True,
):
"""
Helper testing function that asserts that the model output and the reference output
@@ -44,19 +50,35 @@
self.assertTrue(len(ref_output) == len(model_output))
if first_output_only:
self.assertTrue(
- torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol)
+ torch.allclose(
+ model_output[0],
+ ref_output[0],
+ atol=atol,
+ rtol=rtol,
+ equal_nan=equal_nan,
+ )
)
else:
for i in range(len(ref_output)):
self.assertTrue(
torch.allclose(
- model_output[i], ref_output[i], atol=atol, rtol=rtol
+ model_output[i],
+ ref_output[i],
+ atol=atol,
+ rtol=rtol,
+ equal_nan=equal_nan,
)
)
else:
# If one output, eager returns tensor while executor tuple of size 1
self.assertTrue(
- torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
+ torch.allclose(
+ model_output[0],
+ ref_output,
+ atol=atol,
+ rtol=rtol,
+ equal_nan=equal_nan,
+ )
)
def lower_module_and_test_output(
@@ -304,7 +326,7 @@
self.lower_module_and_test_output(pow_module, sample_inputs)
- def lower_clamp_module_and_test_output(self, module):
+ def lower_unary_module_and_test_output(self, module):
batch = Dim("batch", max=8)
sample_inputs = (torch.randn(8, 16, 96, 92),)
@@ -314,6 +336,7 @@
(torch.randn(6, 5, 35, 89),),
(torch.randn(7, 9, 32, 38),),
]
+
self.lower_module_and_test_output(
module,
sample_inputs,
@@ -329,7 +352,7 @@
def forward(self, x):
return torch.clamp(x, min=-3.14)
- self.lower_clamp_module_and_test_output(ClampModule())
+ self.lower_unary_module_and_test_output(ClampModule())
def test_vulkan_backend_hardtanh(self):
class HardTanHModule(torch.nn.Module):
@@ -340,7 +363,7 @@
def forward(self, x):
return self.tanh(x)
- self.lower_clamp_module_and_test_output(HardTanHModule())
+ self.lower_unary_module_and_test_output(HardTanHModule())
def test_vulkan_backend_relu(self):
class ReLUModule(torch.nn.Module):
@@ -350,7 +373,17 @@
def forward(self, x):
return torch.relu(x)
- self.lower_clamp_module_and_test_output(ReLUModule())
+ self.lower_unary_module_and_test_output(ReLUModule())
+
+ def test_vulkan_backend_sqrt(self):
+ class SqrtModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return torch.sqrt(x)
+
+ self.lower_unary_module_and_test_output(SqrtModule())
def test_vulkan_backend_max_pool2d(self):
class MaxPool2dModule(torch.nn.Module):
@@ -395,7 +428,7 @@
def forward(self, x):
return torch.abs(x)
- self.lower_clamp_module_and_test_output(AbsModule())
+ self.lower_unary_module_and_test_output(AbsModule())
def test_vulkan_backend_sigmoid(self):
class SigmoidModule(torch.nn.Module):
@@ -405,7 +438,7 @@
def forward(self, x):
return torch.sigmoid(x)
- self.lower_clamp_module_and_test_output(SigmoidModule())
+ self.lower_unary_module_and_test_output(SigmoidModule())
def test_vulkan_backend_tanh(self):
class TanhModule(torch.nn.Module):
@@ -415,7 +448,7 @@
def forward(self, x):
return torch.tanh(x)
- self.lower_clamp_module_and_test_output(TanhModule())
+ self.lower_unary_module_and_test_output(TanhModule())
def test_vulkan_backend_partial(self):
class SimpleModel(torch.nn.Module):
@@ -905,6 +938,22 @@
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
)
+ def test_vulkan_backend_repeat(self):
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x.repeat([2, 3, 1, 2])
+
+ sample_inputs = (torch.randn(size=(3, 7, 5, 9), dtype=torch.float32),)
+
+ self.lower_module_and_test_output(
+ TestModule(),
+ sample_inputs,
+ memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
+ )
+
def DISABLED_test_vulkan_backend_t_default(self):
# aten.permute_copy.default is not enabled yet in partitioner
class TestModule(torch.nn.Module):