Refactor values kwarg in foreach tests (#112781)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112781
Approved by: https://github.com/lezcano
ghstack dependencies: #112778
diff --git a/test/test_foreach.py b/test/test_foreach.py
index 1ac70fd..aeed73c 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -30,12 +30,14 @@
def __init__(self, func):
self.func = func
- def __call__(self, inputs, values=None, **kwargs):
- if values is not None:
+ def __call__(self, inputs, scalars=None, **kwargs):
+ if scalars is not None:
assert len(inputs) == 3
- if isinstance(values, Number):
- values = [values for _ in range(len(inputs[0]))]
- return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
+ # We need to distribute each scalar to the regular func and it needs
+ # special consideration as it is a keyword only argument to the
+ # regular func. (Strangely, it is not a keyword only argument to the
+ # foreach func)
+ return [self.func(*i, value=scalars[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
# binary op with tensorlist and scalar.
inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
@@ -149,14 +151,9 @@
func, ref, _, _ = self._get_funcs(op)
for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
ref_kwargs = sample.kwargs
- kwargs = ref_kwargs.copy()
# div promotes ints to floats, so we cannot go on the fastpath there
div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
- if op in foreach_pointwise_op_db:
- values = kwargs.pop("values", None)
- if values is not None:
- sample.args = (*sample.args, values)
ref_input, ctxmgr = sample.input, nullcontext()
if inplace:
with torch.no_grad():
@@ -164,7 +161,7 @@
ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
try:
with ctxmgr:
- actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **kwargs)
+ actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **sample.kwargs)
except Exception as e:
with (
self.assertRaisesRegex(type(e), re.escape(str(e)))
@@ -256,40 +253,44 @@
assert isinstance(sample.args, tuple)
assert len(sample.args) == 2
inputs = [sample.input, *sample.args]
- kwargs = sample.kwargs
+ kwargs = sample.kwargs.copy()
disable_fastpath = sample.disable_fastpath and is_fastpath
wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
- values = kwargs.pop("values", None)
+ scalars = kwargs.pop("scalars", None)
- if is_fastpath and isinstance(values, list):
+ if is_fastpath and scalars:
sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
inputs = [sample.input, *sample.args]
- tensor_values = torch.tensor(values)
+ tensor_values = torch.tensor(scalars)
# 1D Tensor of scalars
for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
- values=tensor_values)
+ scalars=tensor_values, **kwargs)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
- values=tensor_values[0],
+ scalars=tensor_values[0],
custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
+ **kwargs,
)
if self.is_cuda:
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
- values=tensor_values.cuda(),
+ scalars=tensor_values.cuda(),
custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
+ **kwargs,
)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
- values=tensor_values[:2],
- custom_values_err=f"Expected length of scalars to match input of length {len(values)} but got 2 instead.",
+ scalars=tensor_values[:2],
+ custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
+ **kwargs,
)
self._pointwise_test(
op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
- values=torch.tensor([[0, 1], [2, 3]])[:, 1],
+ scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
custom_values_err="Expected scalars to be contiguous.",
+ **kwargs,
)
# Tests of implicit broadcasting
@@ -307,41 +308,42 @@
]
self._pointwise_test(
wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
- values=values)
+ scalars=scalars, **kwargs)
self._pointwise_test(
inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
- is_inplace=True, values=values)
+ is_inplace=True, scalars=scalars, **kwargs)
def _pointwise_test(
self,
op, ref, inputs, is_fastpath, is_inplace,
*,
- values=None, custom_values_err=None,
+ scalars=None, custom_values_err=None, **kwargs
):
- kwargs = {}
ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
try:
with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
- ref(ref_inputs)
+ ref(ref_inputs, **kwargs)
else:
- expected = ref(ref_inputs)
+ expected = ref(ref_inputs, **kwargs)
self.assertEqual(expected, actual)
- if values is not None:
+ if scalars is not None:
+ kwargs = kwargs.copy()
+ kwargs["scalars"] = scalars
try:
- actual = op(inputs + [values], self.is_cuda, is_fastpath, **kwargs)
+ actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
except RuntimeError as e:
# Match with error messages from regular non-foreach reference if no
# custom error message was provided.
if custom_values_err is None:
with self.assertRaisesRegex(type(e), re.escape(str(e))):
- ref(ref_inputs, values=values)
+ ref(ref_inputs, **kwargs)
else:
self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
else:
- expected = ref(ref_inputs, values=values)
+ expected = ref(ref_inputs, **kwargs)
self.assertEqual(expected, actual)
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
@@ -692,8 +694,6 @@
func, *_ = self._get_funcs(op)
sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
self.assertTrue(all(t.requires_grad for t in sample.input))
- if func.func in foreach_pointwise_op_db:
- sample.kwargs.pop("values", None)
(out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs)
out1.backward(torch.ones_like(out1))
self.assertIsNotNone(sample.input[0].grad)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ff92642..e6cd427 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8936,7 +8936,8 @@
sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
for _ in range(2)
]
- kwargs["values"] = None
+ if "scalars" in kwargs:
+ del kwargs["scalars"]
kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype))
yield ForeachSampleInput(input, *args, **kwargs)
@@ -8959,8 +8960,10 @@
kwargs = {}
if rightmost_arg_type == ForeachRightmostArgType.TensorList:
args.append(rightmost_arg)
+ elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]:
+ kwargs["scalars"] = rightmost_arg
else:
- kwargs["values"] = rightmost_arg
+ kwargs["value"] = rightmost_arg
kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype))
assert len(args) == 2, f"{len(args)=}"
sample = ForeachSampleInput(input, *args, **kwargs)