[functorch] Batching rules for: threshold_backward, clamp_min, clamp_max
diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
index 2f27725..ee194c6 100644
--- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
+++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
@@ -119,12 +119,23 @@
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::div, const Scalar&>));
BINARY_POINTWISE(div);
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
+ VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
+ binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
// at::pow has three out-of-place overloads
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
+ VMAP_SUPPORT("clamp_min.Tensor",
+ SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_min)>));
+ VMAP_SUPPORT("clamp_min",
+ SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_min), const Scalar&>));
+ VMAP_SUPPORT("clamp_max.Tensor",
+ SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_max)>));
+ VMAP_SUPPORT("clamp_max",
+ SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_max), const Scalar&>));
+
#define COMPARISON_POINTWISE(op) \
VMAP_SUPPORT(#op".Tensor", \
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 1634511..bebbf44 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -1079,6 +1079,9 @@
with self.assertRaises(AssertionError):
uses_fallback(self)
+def _make_case(op, input_getter=TensorFactory.randn):
+ return (op, input_getter)
+
class TestVmapOperators(Namespace.TestVmapBase):
def _vmap_test(self, *args, **kwargs):
@@ -1172,70 +1175,98 @@
with self.assertRaisesRegex(RuntimeError, msg):
vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
- def test_binary_pointwise_ops(self):
+ @parameterized('case', {
+ 'clamp_min': _make_case(torch.clamp_min),
+ 'clamp_max': _make_case(torch.clamp_max),
+ })
+ def test_clamp_variant(self, case):
+ test = self._vmap_test
+
def get_number(getter):
return getter([]).item()
- def make_case(op, input_getter=TensorFactory.randn):
- return (op, input_getter)
+ op, getter = case
+ device = 'cpu'
+ B0, B1 = 7, 11
- cases = [
- # Basic arithmetic
- make_case(torch.add),
- make_case(lambda x, y: x + y),
- make_case(torch.sub),
- make_case(lambda x, y: x - y),
- make_case(torch.mul),
- make_case(lambda x, y: x * y),
- make_case(torch.div, input_getter=TensorFactory.randp1),
- make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
- make_case(torch.pow, input_getter=TensorFactory.randp1),
- make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
- ]
+ # Single vmap: op(Tensor, Tensor)
+ test(op, (getter([B0, 3], device), getter([B0, 3], device)))
+ test(op, (getter([B0], device), getter([B0, 2, 3], device)))
+ test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
+ test(op, (getter([B0], device), getter([2, B0, 3], device)),
+ in_dims=(0, 1), out_dims=1)
+ test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
+ test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
+
+ # Nested vmap: op(Tensor, Tensor)
+ test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
+ test(vmap(op, in_dims=(None, 0)),
+ (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
+
+ # Python number overload: op(Tensor, Number)
+ number = get_number(getter)
+ self._test_unary(lambda t: op(t, number), getter, device)
+
+ @parameterized('case', {
+ 'add': _make_case(torch.add),
+ 'add_dunder': _make_case(lambda x, y: x + y),
+ 'sub': _make_case(torch.sub),
+ 'sub_dunder': _make_case(lambda x, y: x - y),
+ 'mul': _make_case(torch.mul),
+ 'mul_dunder': _make_case(lambda x, y: x * y),
+ 'div': _make_case(torch.div, input_getter=TensorFactory.randp1),
+ 'div_dunder': _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
+ 'pow': _make_case(torch.pow, input_getter=TensorFactory.randp1),
+ 'pow_dunder': _make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
+ })
+ def test_arithmetic(self, case):
test = self._vmap_test
- for op, getter in cases:
- device = 'cpu'
- B0, B1 = 7, 11
+ def get_number(getter):
+ return getter([]).item()
- # Single vmap: op(Tensor, Tensor)
- test(op, (getter([B0, 3], device), getter([B0, 3], device)))
- test(op, (getter([B0], device), getter([B0, 2, 3], device)))
- test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
- test(op, (getter([B0], device), getter([2, B0, 3], device)),
- in_dims=(0, 1), out_dims=1)
- test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
- test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
+ op, getter = case
+ device = 'cpu'
+ B0, B1 = 7, 11
- # Nested vmap: op(Tensor, Tensor)
- test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
- test(vmap(op, in_dims=(None, 0)),
- (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
+ # Single vmap: op(Tensor, Tensor)
+ test(op, (getter([B0, 3], device), getter([B0, 3], device)))
+ test(op, (getter([B0], device), getter([B0, 2, 3], device)))
+ test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
+ test(op, (getter([B0], device), getter([2, B0, 3], device)),
+ in_dims=(0, 1), out_dims=1)
+ test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
+ test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
- # Python number overload: op(Tensor, Number) (and vice-versa)
- number = get_number(getter)
- self._test_unary(lambda t: op(t, number), getter, device)
- number = get_number(getter)
- self._test_unary(lambda t: op(number, t), getter, device)
+ # Nested vmap: op(Tensor, Tensor)
+ test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
+ test(vmap(op, in_dims=(None, 0)),
+ (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
- # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
- test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
- test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
- test(op, (getter([B0], device), getter([B0], device)))
+ # Python number overload: op(Tensor, Number) (and vice-versa)
+ number = get_number(getter)
+ self._test_unary(lambda t: op(t, number), getter, device)
+ number = get_number(getter)
+ self._test_unary(lambda t: op(number, t), getter, device)
- # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
- test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
- test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
+ # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
+ test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
+ test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
+ test(op, (getter([B0], device), getter([B0], device)))
- if not torch.cuda.is_available():
- continue
+ # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
+ test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
+ test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
- # TODO(rzou): fix the following
- # # Test cross-device scalars
- # number = get_number(getter)
- # self._test_unary(lambda t: op(t, number), getter, device='cuda')
- # self._test_unary(lambda t: op(number, t), getter, device='cuda')
- # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
+ if not torch.cuda.is_available():
+ return
+
+ # TODO(rzou): fix the following
+ # # Test cross-device scalars
+ # number = get_number(getter)
+ # self._test_unary(lambda t: op(t, number), getter, device='cuda')
+ # self._test_unary(lambda t: op(number, t), getter, device='cuda')
+ # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
# TODO: as_strided BR
@unittest.expectedFailure