Refactor values kwarg in foreach tests (#112781)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112781
Approved by: https://github.com/lezcano
ghstack dependencies: #112778
diff --git a/test/test_foreach.py b/test/test_foreach.py
index 1ac70fd..aeed73c 100644
--- a/test/test_foreach.py
+++ b/test/test_foreach.py
@@ -30,12 +30,14 @@
     def __init__(self, func):
         self.func = func
 
-    def __call__(self, inputs, values=None, **kwargs):
-        if values is not None:
+    def __call__(self, inputs, scalars=None, **kwargs):
+        if scalars is not None:
             assert len(inputs) == 3
-            if isinstance(values, Number):
-                values = [values for _ in range(len(inputs[0]))]
-            return [self.func(*i, value=values[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
+            # We need to distribute each scalar to the regular func and it needs
+            # special consideration as it is a keyword only argument to the
+            # regular func. (Strangely, it is not a keyword only argument to the
+            # foreach func)
+            return [self.func(*i, value=scalars[idx], **kwargs) for idx, i in enumerate(zip(*inputs))]
         if len(inputs) == 2 and isinstance(inputs[1], (Number, torch.Tensor)):
             # binary op with tensorlist and scalar.
             inputs[1] = [inputs[1] for _ in range(len(inputs[0]))]
@@ -149,14 +151,9 @@
             func, ref, _, _ = self._get_funcs(op)
         for sample in op.sample_inputs(device, dtype, noncontiguous=noncontiguous):
             ref_kwargs = sample.kwargs
-            kwargs = ref_kwargs.copy()
             # div promotes ints to floats, so we cannot go on the fastpath there
             div_slowpath = dtype in integral_types_and(torch.bool) and op.name == '_foreach_div'
             expect_fastpath = not (noncontiguous or sample.disable_fastpath or div_slowpath)
-            if op in foreach_pointwise_op_db:
-                values = kwargs.pop("values", None)
-                if values is not None:
-                    sample.args = (*sample.args, values)
             ref_input, ctxmgr = sample.input, nullcontext()
             if inplace:
                 with torch.no_grad():
@@ -164,7 +161,7 @@
                 ctxmgr = InplaceForeachVersionBumpCheck(self, sample.input)
             try:
                 with ctxmgr:
-                    actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **kwargs)
+                    actual = func([sample.input, *sample.args], self.is_cuda, expect_fastpath, **sample.kwargs)
             except Exception as e:
                 with (
                     self.assertRaisesRegex(type(e), re.escape(str(e)))
@@ -256,40 +253,44 @@
             assert isinstance(sample.args, tuple)
             assert len(sample.args) == 2
             inputs = [sample.input, *sample.args]
-            kwargs = sample.kwargs
+            kwargs = sample.kwargs.copy()
             disable_fastpath = sample.disable_fastpath and is_fastpath
             wrapped_op, ref, inplace_op, inplace_ref = self._get_funcs(op)
-            values = kwargs.pop("values", None)
+            scalars = kwargs.pop("scalars", None)
 
-            if is_fastpath and isinstance(values, list):
+            if is_fastpath and scalars:
                 sample = sample.transform(lambda t: t.clone().detach() if torch.is_tensor(t) else t)
                 inputs = [sample.input, *sample.args]
-                tensor_values = torch.tensor(values)
+                tensor_values = torch.tensor(scalars)
                 # 1D Tensor of scalars
                 for is_inplace, op_, ref_ in ((False, wrapped_op, ref), (True, inplace_op, inplace_ref)):
                     self._pointwise_test(
                         op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
-                        values=tensor_values)
+                        scalars=tensor_values, **kwargs)
                     self._pointwise_test(
                         op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
-                        values=tensor_values[0],
+                        scalars=tensor_values[0],
                         custom_values_err="Expected packed scalar Tensor to be of dimension 1. Got 0 instead.",
+                        **kwargs,
                     )
                     if self.is_cuda:
                         self._pointwise_test(
                             op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
-                            values=tensor_values.cuda(),
+                            scalars=tensor_values.cuda(),
                             custom_values_err="Expected scalars to be on CPU, got cuda:0 instead.",
+                            **kwargs,
                         )
                     self._pointwise_test(
                         op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
-                        values=tensor_values[:2],
-                        custom_values_err=f"Expected length of scalars to match input of length {len(values)} but got 2 instead.",
+                        scalars=tensor_values[:2],
+                        custom_values_err=f"Expected length of scalars to match input of length {len(scalars)} but got 2 instead.",
+                        **kwargs,
                     )
                     self._pointwise_test(
                         op_, ref_, inputs, is_fastpath and not disable_fastpath, is_inplace,
-                        values=torch.tensor([[0, 1], [2, 3]])[:, 1],
+                        scalars=torch.tensor([[0, 1], [2, 3]])[:, 1],
                         custom_values_err="Expected scalars to be contiguous.",
+                        **kwargs,
                     )
 
             # Tests of implicit broadcasting
@@ -307,41 +308,42 @@
             ]
             self._pointwise_test(
                 wrapped_op, ref, inputs, is_fastpath and disable_fastpath, is_inplace=False,
-                values=values)
+                scalars=scalars, **kwargs)
             self._pointwise_test(
                 inplace_op, inplace_ref, inputs, is_fastpath and disable_fastpath,
-                is_inplace=True, values=values)
+                is_inplace=True, scalars=scalars, **kwargs)
 
     def _pointwise_test(
         self,
         op, ref, inputs, is_fastpath, is_inplace,
         *,
-        values=None, custom_values_err=None,
+        scalars=None, custom_values_err=None, **kwargs
     ):
-        kwargs = {}
         ref_inputs = [[t.clone().detach() for t in inputs[0]], inputs[1], inputs[2]] if is_inplace else inputs
         try:
             with (InplaceForeachVersionBumpCheck(self, inputs[0]) if is_inplace else nullcontext()):
                 actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
         except RuntimeError as e:
             with self.assertRaisesRegex(type(e), re.escape(str(e))):
-                ref(ref_inputs)
+                ref(ref_inputs, **kwargs)
         else:
-            expected = ref(ref_inputs)
+            expected = ref(ref_inputs, **kwargs)
             self.assertEqual(expected, actual)
-        if values is not None:
+        if scalars is not None:
+            kwargs = kwargs.copy()
+            kwargs["scalars"] = scalars
             try:
-                actual = op(inputs + [values], self.is_cuda, is_fastpath, **kwargs)
+                actual = op(inputs, self.is_cuda, is_fastpath, **kwargs)
             except RuntimeError as e:
                 # Match with error messages from regular non-foreach reference if no
                 # custom error message was provided.
                 if custom_values_err is None:
                     with self.assertRaisesRegex(type(e), re.escape(str(e))):
-                        ref(ref_inputs, values=values)
+                        ref(ref_inputs, **kwargs)
                 else:
                     self.assertEqual(re.escape(str(e)), re.escape(custom_values_err))
             else:
-                expected = ref(ref_inputs, values=values)
+                expected = ref(ref_inputs, **kwargs)
                 self.assertEqual(expected, actual)
 
     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
@@ -692,8 +694,6 @@
         func, *_ = self._get_funcs(op)
         sample = list(op.sample_inputs(dtype=dtype, device=device, requires_grad=True, num_input_tensors=[2], same_size=True))[0]
         self.assertTrue(all(t.requires_grad for t in sample.input))
-        if func.func in foreach_pointwise_op_db:
-            sample.kwargs.pop("values", None)
         (out1, out2) = func([sample.input, *sample.args], is_cuda=False, expect_fastpath=False, **sample.kwargs)
         out1.backward(torch.ones_like(out1))
         self.assertIsNotNone(sample.input[0].grad)
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index ff92642..e6cd427 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -8936,7 +8936,8 @@
             sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, zero_size=True, **_foreach_inputs_kwargs)
             for _ in range(2)
         ]
-        kwargs["values"] = None
+        if "scalars" in kwargs:
+            del kwargs["scalars"]
         kwargs.update(self._sample_kwargs(opinfo, args[-1], ForeachRightmostArgType.TensorList, dtype))
         yield ForeachSampleInput(input, *args, **kwargs)
 
@@ -8959,8 +8960,10 @@
                 kwargs = {}
                 if rightmost_arg_type == ForeachRightmostArgType.TensorList:
                     args.append(rightmost_arg)
+                elif rightmost_arg_type in [ForeachRightmostArgType.Tensor, ForeachRightmostArgType.ScalarList]:
+                    kwargs["scalars"] = rightmost_arg
                 else:
-                    kwargs["values"] = rightmost_arg
+                    kwargs["value"] = rightmost_arg
                 kwargs.update(self._sample_kwargs(opinfo, rightmost_arg, rightmost_arg_type, dtype))
                 assert len(args) == 2, f"{len(args)=}"
                 sample = ForeachSampleInput(input, *args, **kwargs)