Register aten.repeat in partitioner, add aten.sqrt (#3510)

Summary:
Pull Request resolved: https://github.com/pytorch/executorch/pull/3510

As title, also fixed some nit in the unittest.
ghstack-source-id: 225119225

Reviewed By: copyrightly, jorgep31415

Differential Revision: D56961340

fbshipit-source-id: c9b505a9003c7f58541c7799ea290d1867579eba
diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py
index 0e171dd..ac7df8c 100644
--- a/backends/vulkan/partitioner/vulkan_partitioner.py
+++ b/backends/vulkan/partitioner/vulkan_partitioner.py
@@ -42,6 +42,7 @@
             exir_ops.edge.aten.relu.default,
             exir_ops.edge.aten.sigmoid.default,
             exir_ops.edge.aten.tanh.default,
+            exir_ops.edge.aten.sqrt.default,
             # Matrix multiplication operators
             exir_ops.edge.aten.mm.default,
             # Pooling operators
@@ -62,6 +63,7 @@
             exir_ops.edge.aten.split_with_sizes_copy.default,
             exir_ops.edge.aten.split.Tensor,
             exir_ops.edge.aten.slice_copy.Tensor,
+            exir_ops.edge.aten.repeat.default,
             # Other
             operator.getitem,
             exir_ops.edge.aten.full.default,
diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
index c32593d..2d8ec36 100644
--- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
+++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
@@ -16,3 +16,5 @@
       OPERATOR: 1 / (1 + exp(-1 * X))
     - NAME: tanh
       OPERATOR: tanh(clamp(X, -15.0, 15.0))
+    - NAME: sqrt
+      OPERATOR: sqrt(X)
diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
index 3888118..b2fb113 100644
--- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
+++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
@@ -103,6 +103,7 @@
 DEFINE_ACTIVATION_FN(abs);
 DEFINE_ACTIVATION_FN(sigmoid);
 DEFINE_ACTIVATION_FN(tanh);
+DEFINE_ACTIVATION_FN(sqrt);
 DEFINE_CLAMP_FN(clamp);
 DEFINE_CLAMP_FN(hardtanh);
 DEFINE_RELU_FN(relu);
@@ -114,6 +115,7 @@
   VK_REGISTER_OP(aten.relu.default, relu);
   VK_REGISTER_OP(aten.sigmoid.default, sigmoid);
   VK_REGISTER_OP(aten.tanh.default, tanh);
+  VK_REGISTER_OP(aten.sqrt.default, sqrt);
 }
 
 } // namespace vkcompute
diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py
index 329ae02..13b5542 100644
--- a/backends/vulkan/test/op_tests/cases.py
+++ b/backends/vulkan/test/op_tests/cases.py
@@ -555,6 +555,18 @@
     return test_suite
 
 
+def get_unary_ops_inputs():
+    test_suite = VkTestSuite(
+        [
+            (M1,),
+            (M1, M2),
+            (S1, M1, M2),
+            (S1, S2, S2, M2),
+        ]
+    )
+    return test_suite
+
+
 test_suites = {
     "aten.add.Tensor": get_binary_elementwise_inputs(),
     "aten.sub.Tensor": get_binary_elementwise_inputs(),
@@ -577,4 +589,5 @@
     "aten.cat.default": get_cat_inputs(),
     "aten.split_with_sizes_copy.default": get_split_with_sizes_inputs(),
     "aten.split.Tensor": get_split_tensor_inputs(),
+    "aten.sqrt.default": get_unary_ops_inputs(),
 }
diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py
index 07c7de0..a896ed9 100644
--- a/backends/vulkan/test/test_vulkan_delegate.py
+++ b/backends/vulkan/test/test_vulkan_delegate.py
@@ -29,7 +29,13 @@
 
 class TestBackends(unittest.TestCase):
     def assert_outputs_equal(
-        self, model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False
+        self,
+        model_output,
+        ref_output,
+        atol=1e-03,
+        rtol=1e-03,
+        first_output_only=False,
+        equal_nan=True,
     ):
         """
         Helper testing function that asserts that the model output and the reference output
@@ -44,19 +50,35 @@
             self.assertTrue(len(ref_output) == len(model_output))
             if first_output_only:
                 self.assertTrue(
-                    torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol)
+                    torch.allclose(
+                        model_output[0],
+                        ref_output[0],
+                        atol=atol,
+                        rtol=rtol,
+                        equal_nan=equal_nan,
+                    )
                 )
             else:
                 for i in range(len(ref_output)):
                     self.assertTrue(
                         torch.allclose(
-                            model_output[i], ref_output[i], atol=atol, rtol=rtol
+                            model_output[i],
+                            ref_output[i],
+                            atol=atol,
+                            rtol=rtol,
+                            equal_nan=equal_nan,
                         )
                     )
         else:
             # If one output, eager returns tensor while executor tuple of size 1
             self.assertTrue(
-                torch.allclose(model_output[0], ref_output, atol=atol, rtol=rtol)
+                torch.allclose(
+                    model_output[0],
+                    ref_output,
+                    atol=atol,
+                    rtol=rtol,
+                    equal_nan=equal_nan,
+                )
             )
 
     def lower_module_and_test_output(
@@ -304,7 +326,7 @@
 
         self.lower_module_and_test_output(pow_module, sample_inputs)
 
-    def lower_clamp_module_and_test_output(self, module):
+    def lower_unary_module_and_test_output(self, module):
         batch = Dim("batch", max=8)
         sample_inputs = (torch.randn(8, 16, 96, 92),)
 
@@ -314,6 +336,7 @@
             (torch.randn(6, 5, 35, 89),),
             (torch.randn(7, 9, 32, 38),),
         ]
+
         self.lower_module_and_test_output(
             module,
             sample_inputs,
@@ -329,7 +352,7 @@
             def forward(self, x):
                 return torch.clamp(x, min=-3.14)
 
-        self.lower_clamp_module_and_test_output(ClampModule())
+        self.lower_unary_module_and_test_output(ClampModule())
 
     def test_vulkan_backend_hardtanh(self):
         class HardTanHModule(torch.nn.Module):
@@ -340,7 +363,7 @@
             def forward(self, x):
                 return self.tanh(x)
 
-        self.lower_clamp_module_and_test_output(HardTanHModule())
+        self.lower_unary_module_and_test_output(HardTanHModule())
 
     def test_vulkan_backend_relu(self):
         class ReLUModule(torch.nn.Module):
@@ -350,7 +373,17 @@
             def forward(self, x):
                 return torch.relu(x)
 
-        self.lower_clamp_module_and_test_output(ReLUModule())
+        self.lower_unary_module_and_test_output(ReLUModule())
+
+    def test_vulkan_backend_sqrt(self):
+        class SqrtModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return torch.sqrt(x)
+
+        self.lower_unary_module_and_test_output(SqrtModule())
 
     def test_vulkan_backend_max_pool2d(self):
         class MaxPool2dModule(torch.nn.Module):
@@ -395,7 +428,7 @@
             def forward(self, x):
                 return torch.abs(x)
 
-        self.lower_clamp_module_and_test_output(AbsModule())
+        self.lower_unary_module_and_test_output(AbsModule())
 
     def test_vulkan_backend_sigmoid(self):
         class SigmoidModule(torch.nn.Module):
@@ -405,7 +438,7 @@
             def forward(self, x):
                 return torch.sigmoid(x)
 
-        self.lower_clamp_module_and_test_output(SigmoidModule())
+        self.lower_unary_module_and_test_output(SigmoidModule())
 
     def test_vulkan_backend_tanh(self):
         class TanhModule(torch.nn.Module):
@@ -415,7 +448,7 @@
             def forward(self, x):
                 return torch.tanh(x)
 
-        self.lower_clamp_module_and_test_output(TanhModule())
+        self.lower_unary_module_and_test_output(TanhModule())
 
     def test_vulkan_backend_partial(self):
         class SimpleModel(torch.nn.Module):
@@ -905,6 +938,22 @@
             memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
         )
 
+    def test_vulkan_backend_repeat(self):
+        class TestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                return x.repeat([2, 3, 1, 2])
+
+        sample_inputs = (torch.randn(size=(3, 7, 5, 9), dtype=torch.float32),)
+
+        self.lower_module_and_test_output(
+            TestModule(),
+            sample_inputs,
+            memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
+        )
+
     def DISABLED_test_vulkan_backend_t_default(self):
         # aten.permute_copy.default is not enabled yet in partitioner
         class TestModule(torch.nn.Module):