Run function check and out check in TestTensorDeviceOps (#43830)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43830
Reviewed By: ailzhang
Differential Revision: D23438101
Pulled By: mruberry
fbshipit-source-id: b581ce779ea2f50ea8dfec51d5469031ec7a0a67
diff --git a/test/test_torch.py b/test/test_torch.py
index f88b92b..aa2699d 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -19461,6 +19461,8 @@
# - cpu_dtype_list (=[]), a list of torch dtypes to test the op(s) on cpu
# - make_inplace_variant (=True), if true the inplace version of the op (op_) is also tested
# - decorators (=[]), a list of decorators to apply to the test
+# - self_position (=-1), the position of self in the arg list, -1 means skip function check
+# - test_out (=False), whether to test the out= version of the operator
tensor_op_tests = [
('add', '', _small_3d, lambda t, d: [_number(3.14, 3, t)], 1e-2),
('add', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-2),
@@ -19528,7 +19530,7 @@
1e-1, 1e-5, _types2, _cpu_types, True,
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)],
- 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)]),
+ 1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True, [tf32_on_and_off(0.005)], 0, True),
('addmm', 'scalar', _medium_2d,
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)],
1e-1, 1e-1, 1e-4, _float_types2, _cpu_types, True,
@@ -19539,7 +19541,7 @@
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types,
- True, [tf32_on_and_off(0.005)]),
+ True, [tf32_on_and_off(0.005)], 0, True),
('addmv', 'scalar', _medium_1d,
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)],
1e-2, 1e-1, 1e-4, _float_types2 + _complex_types_skip_rocm, _cpu_types, True,
@@ -19828,7 +19830,9 @@
float_precision,
dtype_list,
dtype_cpu_list,
- decorators) -> None:
+ decorators,
+ self_position,
+ test_out) -> None:
def fn(self, device, dtype) -> None:
# Generates the CPU inputs
# Note: CPU tensors are never torch.half
@@ -19860,6 +19864,25 @@
self.assertEqual(cpu_args, device_args, atol=precision, rtol=0, exact_dtype=False)
self.assertEqual(cpu_result, device_result, atol=precision, rtol=0, exact_dtype=False)
+ # check method matches with function
+ if self_position >= 0:
+ cpu_args.insert(self_position, cpu_tensor)
+ device_args.insert(self_position, device_tensor)
+ cpu_function_result = getattr(torch, op_str)(*cpu_args)
+ device_function_result = getattr(torch, op_str)(*device_args)
+ self.assertEqual(cpu_result, cpu_function_result, atol=precision, rtol=0)
+ self.assertEqual(device_result, device_function_result, atol=precision, rtol=0)
+
+ # check method matches with function(out)
+ if test_out:
+ bad_value = math.nan if dtype.is_floating_point or dtype.is_complex else 666
+ cpu_out = torch.full_like(cpu_result, bad_value)
+ device_out = torch.full_like(device_result, bad_value)
+ getattr(torch, op_str)(*cpu_args, out=cpu_out)
+ getattr(torch, op_str)(*device_args, out=device_out)
+ self.assertEqual(cpu_result, cpu_out, atol=precision, rtol=0)
+ self.assertEqual(device_result, device_out, atol=precision, rtol=0)
+
test_name = "test_" + op_str + subtest_str
assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name)
@@ -19889,18 +19912,22 @@
dtype_list=_types,
dtype_cpu_list=_cpu_types,
make_inplace_variant=True,
- decorators=None):
+ decorators=None,
+ self_position=-1,
+ test_out=False):
if subtest_str:
subtest_str = '_' + subtest_str
generate_test_function(cls, op_str, subtest_str, tensor_ctor, arg_ctor, half_precision,
- bfloat16_precision, float_precision, dtype_list, dtype_cpu_list, decorators)
+ bfloat16_precision, float_precision, dtype_list, dtype_cpu_list,
+ decorators, self_position, test_out)
if make_inplace_variant:
op_str = op_str + '_'
subtest_str = 'inplace' + subtest_str
generate_test_function(cls, op_str, subtest_str, tensor_ctor, arg_ctor, half_precision,
- bfloat16_precision, float_precision, dtype_list, dtype_cpu_list, decorators)
+ bfloat16_precision, float_precision, dtype_list, dtype_cpu_list,
+ decorators, -1, False)
for test in tensor_op_tests:
caller(cls, *test)