[functorch] Added maximum/minimum/clamp batching rules
diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
index ac90003..fff8fee 100644
--- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
+++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
@@ -123,28 +123,11 @@
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>));
BINARY_POINTWISE_WITH_SCALAR(add);
- BINARY_POINTWISE_WITH_SCALAR(sub);
- BINARY_POINTWISE_WITH_SCALAR(rsub);
- BINARY_POINTWISE(mul);
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(add, Scalar)), &at::add, const Scalar&, const Scalar&>));
- VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
- VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
- VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
- VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
- BINARY_POINTWISE(div);
- VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
- VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
- binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
- VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
- binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
-
VMAP_SUPPORT("atan2", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(atan2)), &at::atan2>));
- // at::pow has three out-of-place overloads
- VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
- VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
- VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
-
+ VMAP_SUPPORT("clamp",
+ SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp)), &at::clamp, const optional<Scalar>&, const optional<Scalar>&>));
VMAP_SUPPORT("clamp_min.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_min, Tensor)), &at::clamp_min>));
VMAP_SUPPORT("clamp_min",
@@ -155,6 +138,32 @@
SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp_max)), &at::clamp_max, const Scalar&>));
+ BINARY_POINTWISE(div);
+ VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
+
+ VMAP_SUPPORT("maximum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(maximum)), &at::maximum>));
+ VMAP_SUPPORT("minimum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(minimum)), &at::minimum>));
+
+ BINARY_POINTWISE(mul);
+ VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
+
+ // at::pow has three out-of-place overloads
+ VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
+ VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
+ VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
+
+ BINARY_POINTWISE_WITH_SCALAR(sub);
+ VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
+
+ BINARY_POINTWISE_WITH_SCALAR(rsub);
+ VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
+
+ VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
+ binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
+ VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
+ VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
+ binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
+
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index a36568d..97dbe38 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -540,7 +540,7 @@
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
- op = torch.maximum
+ op = torch.copysign
x = torch.randn(11)
y = torch.randn(11)
with warnings.catch_warnings(record=True) as wa:
@@ -575,7 +575,7 @@
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
- op = torch.maximum
+ op = torch.copysign
x = torch.randn(11)
y = torch.randn(11)
self._assert_uses_vmap_fallback((op,), (x, y))
@@ -606,7 +606,7 @@
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
- op = torch.maximum
+ op = torch.copysign
x = torch.randn(5, 7, 11)
y = torch.randn(5, 7, 11)