Adding custom testing based on opinfos input for ops with custom rules. (#67500)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/67500
* #66898
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D32497547
Pulled By: Gamrix
fbshipit-source-id: 07761f0e27f4ac289377ff3279ce6470d4b727dd
diff --git a/test/jit/test_dtype_analysis.py b/test/jit/test_dtype_analysis.py
index a6522af..a9fd1ee 100644
--- a/test/jit/test_dtype_analysis.py
+++ b/test/jit/test_dtype_analysis.py
@@ -1,7 +1,13 @@
from itertools import product
+from typing import Tuple
import torch
from torch import complex32, float32, float64, int32, int64
+from torch.testing._internal.common_methods_invocations import (
+ SampleInput,
+ sample_inputs_adaptive_avg_pool2d,
+ sample_inputs_conv2d,
+)
from torch.testing._internal.common_utils import set_default_dtype
from torch.testing._internal.jit_utils import JitTestCase
@@ -31,7 +37,7 @@
@staticmethod
def node_output_dtype(graph):
graph_out = list(graph.outputs())
- assert(len(graph_out) == 1)
+ assert len(graph_out) == 1
return graph_out[0].type().dtype()
def prop_dtype_on_graph(self, graph, example_inputs):
@@ -52,22 +58,28 @@
torch._C._jit_pass_propagate_dtype(graph)
def assert_dtype_equal(self, fn, in_shapes, in_dtypes):
- # Eager execution
inputs = [self.get_rand_tensor(s, d) for s, d in zip(in_shapes, in_dtypes)]
try:
- expected_res = fn(*inputs)
+ self.assert_dtype_equal_custom_args(fn, inputs)
except Exception:
- # Skip anything that doesn't execute in Eager Mode?
+ fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
+ raise AssertionError(fail_text)
+
+ def assert_dtype_equal_custom_args(self, fn, args):
+ try:
+ # Eager execution
+ expected_res = fn(*args)
+ except RuntimeError as e:
return
+
expected_dtype = expected_res.dtype
# Run the Dtype Analysis
graph = torch.jit.script(fn).graph # Note this is a cached graph
- self.prop_dtype_on_graph(graph, inputs)
+ self.prop_dtype_on_graph(graph, args)
actual_dtype = self.node_output_dtype(graph)
- fail_text = f"Failed for shapes {in_shapes}, and dtypes {in_dtypes}"
- self.assertEqual(actual_dtype, expected_dtype, fail_text)
+ self.assertEqual(actual_dtype, expected_dtype, "Failed Verification")
def get_rand_tensor(self, shape, dtype):
if shape is self.SCALAR:
@@ -76,7 +88,9 @@
elif dtype is int64:
return 2
else:
- raise RuntimeError("Testing of scalars only supported for fp32 and int64")
+ raise RuntimeError(
+ "Testing of scalars only supported for fp32 and int64"
+ )
if dtype in (int32, int64):
rand_tensor = torch.randint(0, 10, shape, dtype=dtype)
@@ -88,7 +102,6 @@
self.assertEqual(rand_tensor.dtype, dtype)
return rand_tensor
-
def test_unary(self):
# Testing the Unary Implementation that uses metatensors
@@ -164,10 +177,71 @@
scalar_type = in_dtypes[1]
if scalar_type == float32:
+
def add(x, y: float):
return x + y
+
else:
+
def add(x, y: int):
return x + y
self.assert_dtype_equal(add, in_shapes, in_dtypes)
+
+ def test_custom_rules(self):
+ # Test some of the ops that are not covered by Metatensors
+
+ # Note that unlike the Conv2d module, the function conv2d
+ # does not take dtype/device arguments.
+
+ def conv2d_fn(input, weight, bias):
+ return torch.nn.functional.conv2d(input, weight, bias)
+
+ def adaptive_avg_pool2d_fn(input, output_size: Tuple[int]):
+ return torch._C._nn.adaptive_avg_pool2d(input, output_size)
+
+ for fn, inputs_fn in (
+ (conv2d_fn, sample_inputs_conv2d),
+ (adaptive_avg_pool2d_fn, sample_inputs_adaptive_avg_pool2d),
+ ):
+ for dtype in (torch.int8, torch.float64):
+ # Gets default version for conv2d
+ sample_input: SampleInput = inputs_fn(None, "cpu", dtype, False)[-1]
+ input_args = [sample_input.input, *sample_input.args]
+ self.assert_dtype_equal_custom_args(fn, input_args)
+
+ def test_conv_no_mixed_args(self):
+ def conv2d_fn(input, weight, bias):
+ return torch.nn.functional.conv2d(input, weight, bias)
+
+ # Now make sure that conv2d doesn't support mixed args
+ conv_ins = sample_inputs_conv2d(None, "cpu", torch.float, False)
+ conv_in = conv_ins[-1]
+ weight, bias = conv_in.args
+ weight = weight.type(torch.long)
+
+ with self.assertRaises(RuntimeError):
+ conv2d_fn(conv_in.input, weight, bias)
+
+ # Check that we also don't propagate
+ graph = torch.jit.script(conv2d_fn).graph # Note this is a cached graph
+ self.prop_dtype_on_graph(graph, [conv_in.input, weight, bias])
+ actual_dtype = self.node_output_dtype(graph)
+ self.assertEqual(actual_dtype, None)
+
+
+ def test_combined(self):
+ # Test a case with both custom rules and metatensors
+
+ def func(input, weight, bias, y):
+ conv_out = torch.nn.functional.conv2d(input, weight, bias)
+ conv_2 = conv_out + y
+ flattened = torch.flatten(conv_2, start_dim=2)
+ add_res = flattened + y
+ return add_res
+
+ conv_ins = sample_inputs_conv2d(None, "cpu", torch.int8, False)
+ conv_in = conv_ins[-1]
+ y_val = torch.rand((1,), dtype=torch.float32)
+ input_args = [conv_in.input, *conv_in.args, y_val]
+ self.assert_dtype_equal_custom_args(func, input_args)