Adding custom testing based on opinfos input for ops with custom rules. (#67500)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67500

* #66898

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D32497547

Pulled By: Gamrix

fbshipit-source-id: 07761f0e27f4ac289377ff3279ce6470d4b727dd
diff --git a/test/jit/test_dtype_analysis.py b/test/jit/test_dtype_analysis.py
index a6522af..a9fd1ee 100644
--- a/test/jit/test_dtype_analysis.py
+++ b/test/jit/test_dtype_analysis.py
@@ -1,7 +1,13 @@
 from itertools import product
+from typing import Tuple
 
 import torch
 from torch import complex32, float32, float64, int32, int64
+from torch.testing._internal.common_methods_invocations import (
+    SampleInput,
+    sample_inputs_adaptive_avg_pool2d,
+    sample_inputs_conv2d,
+)
 from torch.testing._internal.common_utils import set_default_dtype
 from torch.testing._internal.jit_utils import JitTestCase
 
@@ -31,7 +37,7 @@
     @staticmethod
     def node_output_dtype(graph):
         graph_out = list(graph.outputs())
-        assert(len(graph_out) == 1)
+        assert len(graph_out) == 1
         return graph_out[0].type().dtype()
 
     def prop_dtype_on_graph(self, graph, example_inputs):
@@ -52,22 +58,28 @@
         torch._C._jit_pass_propagate_dtype(graph)
 
     def assert_dtype_equal(self, fn, in_shapes, in_dtypes):
-        # Eager execution
         inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
         try:
-            expected_res = fn(*inputs)
+            self.assert_dtype_equal_custom_args(fn, inputs)
         except Exception:
-            # Skip anything that doesn't execute in Eager Mode?
+            fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
+            raise AssertionError(fail_text)
+
+    def assert_dtype_equal_custom_args(self, fn, args):
+        try:
+            # Eager execution
+            expected_res = fn(*args)
+        except RuntimeError as e:
             return
+
         expected_dtype = expected_res.dtype
 
         # Run the Dtype Analysis
         graph = torch.jit.script(fn).graph  # Note this is a cached graph
-        self.prop_dtype_on_graph(graph, inputs)
+        self.prop_dtype_on_graph(graph, args)
         actual_dtype = self.node_output_dtype(graph)
 
-        fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
-        self.assertEqual(actual_dtype, expected_dtype, fail_text)
+        self.assertEqual(actual_dtype, expected_dtype, "Failed Verification")
 
     def get_rand_tensor(self, shape, dtype):
         if shape is self.SCALAR:
@@ -76,7 +88,9 @@
             elif dtype is int64:
                 return 2
             else:
-                raise RuntimeError("Testing of scalars only supported for fp32 and int64")
+                raise RuntimeError(
+                    "Testing of scalars only supported for fp32 and int64"
+                )
 
         if dtype in (int32, int64):
             rand_tensor = torch.randint(0, 10, shape, dtype=dtype)
@@ -88,7 +102,6 @@
         self.assertEqual(rand_tensor.dtype, dtype)
         return rand_tensor
 
-
     def test_unary(self):
         # Testing the Unary Implementation that uses metatensors
 
@@ -164,10 +177,71 @@
                 scalar_type = in_dtypes[1]
 
                 if scalar_type == float32:
+
                     def add(x, y: float):
                         return x + y
+
                 else:
+
                     def add(x, y: int):
                         return x + y
 
                 self.assert_dtype_equal(add, in_shapes, in_dtypes)
+
+    def test_custom_rules(self):
+        # Test some of the ops that are not covered by Metatensors
+
+        # Note that unlike the Conv2d module, the function conv2d
+        # does not take dtype/device arguments.
+
+        def conv2d_fn(input, weight, bias):
+            return torch.nn.functional.conv2d(input, weight, bias)
+
+        def adaptive_avg_pool2d_fn(input, output_size: Tuple[int]):
+            return torch._C._nn.adaptive_avg_pool2d(input, output_size)
+
+        for fn, inputs_fn in (
+            (conv2d_fn, sample_inputs_conv2d),
+            (adaptive_avg_pool2d_fn, sample_inputs_adaptive_avg_pool2d),
+        ):
+            for dtype in (torch.int8, torch.float64):
+                # Gets default version for conv2d
+                sample_input: SampleInput = inputs_fn(None, "cpu", dtype, False)[-1]
+                input_args = [sample_input.input, *sample_input.args]
+                self.assert_dtype_equal_custom_args(fn, input_args)
+
+    def test_conv_no_mixed_args(self):
+        def conv2d_fn(input, weight, bias):
+            return torch.nn.functional.conv2d(input, weight, bias)
+
+        # Now make sure that conv2d doesn't support mixed args
+        conv_ins = sample_inputs_conv2d(None, "cpu", torch.float, False)
+        conv_in = conv_ins[-1]
+        weight, bias = conv_in.args
+        weight = weight.type(torch.long)
+
+        with self.assertRaises(RuntimeError):
+            conv2d_fn(conv_in.input, weight, bias)
+
+        # Check that we also don't propagate
+        graph = torch.jit.script(conv2d_fn).graph  # Note this is a cached graph
+        self.prop_dtype_on_graph(graph, [conv_in.input, weight, bias])
+        actual_dtype = self.node_output_dtype(graph)
+        self.assertEqual(actual_dtype, None)
+
+
+    def test_combined(self):
+        # Test a case with both custom rules and metatensors
+
+        def func(input, weight, bias, y):
+            conv_out = torch.nn.functional.conv2d(input, weight, bias)
+            conv_2 = conv_out + y
+            flattened = torch.flatten(conv_2, start_dim=2)
+            add_res = flattened + y
+            return add_res
+
+        conv_ins = sample_inputs_conv2d(None, "cpu", torch.int8, False)
+        conv_in = conv_ins[-1]
+        y_val = torch.rand((1,), dtype=torch.float32)
+        input_args = [conv_in.input, *conv_in.args, y_val]
+        self.assert_dtype_equal_custom_args(func, input_args)