[functorch] Batching rules for: threshold_backward, clamp_min, clamp_max
diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
index 2f27725..ee194c6 100644
--- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
+++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
@@ -119,12 +119,23 @@
   VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::div, const Scalar&>));
   BINARY_POINTWISE(div);
   VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
+  VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
+        binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
 
   // at::pow has three out-of-place overloads
   VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, &at::pow>));
   VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>));
   VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
 
+  VMAP_SUPPORT("clamp_min.Tensor",
+      SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_min)>));
+  VMAP_SUPPORT("clamp_min",
+      SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_min), const Scalar&>));
+  VMAP_SUPPORT("clamp_max.Tensor",
+      SINGLE_ARG(binary_pointwise_batch_rule<TensorTensorType, static_cast<TensorTensorType>(&at::clamp_max)>));
+  VMAP_SUPPORT("clamp_max",
+      SINGLE_ARG(basic_unary_batch_rule<TensorScalarType, static_cast<TensorScalarType>(&at::clamp_max), const Scalar&>));
+
 
 #define COMPARISON_POINTWISE(op) \
   VMAP_SUPPORT(#op".Tensor", \
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 1634511..bebbf44 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -1079,6 +1079,9 @@
             with self.assertRaises(AssertionError):
                 uses_fallback(self)
 
+def _make_case(op, input_getter=TensorFactory.randn):
+    return (op, input_getter)
+
 
 class TestVmapOperators(Namespace.TestVmapBase):
     def _vmap_test(self, *args, **kwargs):
@@ -1172,70 +1175,98 @@
         with self.assertRaisesRegex(RuntimeError, msg):
             vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
 
-    def test_binary_pointwise_ops(self):
+    @parameterized('case', {
+        'clamp_min': _make_case(torch.clamp_min),
+        'clamp_max': _make_case(torch.clamp_max),
+    })
+    def test_clamp_variant(self, case):
+        test = self._vmap_test
+
         def get_number(getter):
             return getter([]).item()
 
-        def make_case(op, input_getter=TensorFactory.randn):
-            return (op, input_getter)
+        op, getter = case
+        device = 'cpu'
+        B0, B1 = 7, 11
 
-        cases = [
-            # Basic arithmetic
-            make_case(torch.add),
-            make_case(lambda x, y: x + y),
-            make_case(torch.sub),
-            make_case(lambda x, y: x - y),
-            make_case(torch.mul),
-            make_case(lambda x, y: x * y),
-            make_case(torch.div, input_getter=TensorFactory.randp1),
-            make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
-            make_case(torch.pow, input_getter=TensorFactory.randp1),
-            make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
-        ]
+        # Single vmap: op(Tensor, Tensor)
+        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
+        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
+        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
+        test(op, (getter([B0], device), getter([2, B0, 3], device)),
+             in_dims=(0, 1), out_dims=1)
+        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
+        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
+
+        # Nested vmap: op(Tensor, Tensor)
+        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
+        test(vmap(op, in_dims=(None, 0)),
+             (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
+
+        # Python number overload: op(Tensor, Number)
+        number = get_number(getter)
+        self._test_unary(lambda t: op(t, number), getter, device)
+
+    @parameterized('case', {
+        'add': _make_case(torch.add),
+        'add_dunder': _make_case(lambda x, y: x + y),
+        'sub': _make_case(torch.sub),
+        'sub_dunder': _make_case(lambda x, y: x - y),
+        'mul': _make_case(torch.mul),
+        'mul_dunder': _make_case(lambda x, y: x * y),
+        'div': _make_case(torch.div, input_getter=TensorFactory.randp1),
+        'div_dunder': _make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1),
+        'pow': _make_case(torch.pow, input_getter=TensorFactory.randp1),
+        'pow_dunder': _make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1),
+    })
+    def test_arithmetic(self, case):
         test = self._vmap_test
 
-        for op, getter in cases:
-            device = 'cpu'
-            B0, B1 = 7, 11
+        def get_number(getter):
+            return getter([]).item()
 
-            # Single vmap: op(Tensor, Tensor)
-            test(op, (getter([B0, 3], device), getter([B0, 3], device)))
-            test(op, (getter([B0], device), getter([B0, 2, 3], device)))
-            test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
-            test(op, (getter([B0], device), getter([2, B0, 3], device)),
-                 in_dims=(0, 1), out_dims=1)
-            test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
-            test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
+        op, getter = case
+        device = 'cpu'
+        B0, B1 = 7, 11
 
-            # Nested vmap: op(Tensor, Tensor)
-            test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
-            test(vmap(op, in_dims=(None, 0)),
-                 (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
+        # Single vmap: op(Tensor, Tensor)
+        test(op, (getter([B0, 3], device), getter([B0, 3], device)))
+        test(op, (getter([B0], device), getter([B0, 2, 3], device)))
+        test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
+        test(op, (getter([B0], device), getter([2, B0, 3], device)),
+             in_dims=(0, 1), out_dims=1)
+        test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
+        test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
 
-            # Python number overload: op(Tensor, Number) (and vice-versa)
-            number = get_number(getter)
-            self._test_unary(lambda t: op(t, number), getter, device)
-            number = get_number(getter)
-            self._test_unary(lambda t: op(number, t), getter, device)
+        # Nested vmap: op(Tensor, Tensor)
+        test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
+        test(vmap(op, in_dims=(None, 0)),
+             (getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
 
-            # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
-            test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
-            test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
-            test(op, (getter([B0], device), getter([B0], device)))
+        # Python number overload: op(Tensor, Number) (and vice-versa)
+        number = get_number(getter)
+        self._test_unary(lambda t: op(t, number), getter, device)
+        number = get_number(getter)
+        self._test_unary(lambda t: op(number, t), getter, device)
 
-            # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
-            test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
-            test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
+        # Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
+        test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
+        test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
+        test(op, (getter([B0], device), getter([B0], device)))
 
-            if not torch.cuda.is_available():
-                continue
+        # Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
+        test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
+        test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
 
-            # TODO(rzou): fix the following
-            # # Test cross-device scalars
-            # number = get_number(getter)
-            # self._test_unary(lambda t: op(t, number), getter, device='cuda')
-            # self._test_unary(lambda t: op(number, t), getter, device='cuda')
-            # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
+        if not torch.cuda.is_available():
+            return
+
+        # TODO(rzou): fix the following
+        # # Test cross-device scalars
+        # number = get_number(getter)
+        # self._test_unary(lambda t: op(t, number), getter, device='cuda')
+        # self._test_unary(lambda t: op(number, t), getter, device='cuda')
+        # self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
 
     # TODO: as_strided BR
     @unittest.expectedFailure