add randomness kwarg to jacfwd (#84220)

From https://github.com/pytorch/functorch/issues/1010, if a user runs jacfwd with a function that uses randomness, it will fail since the default behavior for vmap is error. This lets the user specify the randomness behavior to jacfwd too since it is doing vmap(jvp(forward)). This is less likely to show up in jacrev since that only vmaps over the backwards pass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84220
Approved by: https://github.com/zou3519
diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py
index bc6d2e2..8750172 100644
--- a/functorch/functorch/_src/eager_transforms.py
+++ b/functorch/functorch/_src/eager_transforms.py
@@ -838,7 +838,7 @@
     return tensor.unflatten(dim, shape)
 
 
-def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False):
+def jacfwd(func: Callable, argnums: argnums_t = 0, has_aux: bool = False, *, randomness: str = "error"):
     """
     Computes the Jacobian of :attr:`func` with respect to the arg(s) at index
     :attr:`argnum` using forward-mode autodiff
@@ -854,6 +854,9 @@
             the function to be differentiated and the second element is
             auxiliary objects that will not be differentiated.
             Default: False.
+        randomness(str): Flag indicating what type of randomness to use.
+            See :func:`vmap` for more detail. Allowed: "different", "same", "error".
+            Default: "error"
 
     Returns:
         Returns a function that takes in the same inputs as :attr:`func` and
@@ -957,7 +960,7 @@
             _, jvp_out = output
             return jvp_out
 
-        results = vmap(push_jvp)(basis)
+        results = vmap(push_jvp, randomness=randomness)(basis)
         if has_aux:
             results, aux = results
             # aux is in the standard basis format, e.g. NxN matrix
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index ceb3c0c..6b85f37 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -45,7 +45,7 @@
 from collections import namedtuple
 
 import functorch
-from functorch import vmap, grad, grad_and_value, jvp, vjp
+from functorch import vmap, grad, grad_and_value, jvp, vjp, jacfwd
 from functorch.experimental import chunk_vmap
 from functorch._C import reshape_dim_into, reshape_dim_outof
 from functorch._src.make_functional import functional_init_with_buffers
@@ -4479,6 +4479,19 @@
             self._assert_all_slices_unique(output)
 
 
+    def test_jacfwd_with_random(self):
+        # checks on behavior are above, this just checks that jacfwd respects
+        # the randomness param
+
+        x = torch.rand(3, 4)
+        with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
+            jacfwd(torch.bernoulli)(x)
+
+        # x isn't batched so use bernoulli since it doesn't do inplace randomness
+        jacfwd(torch.bernoulli, randomness="same")(x)
+        jacfwd(torch.bernoulli, randomness="different")(x)
+
+
 class TestTransformFailure(TestCase):
     @parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
     def test_fails_with_autograd_function(self, device, transform):
@@ -4512,7 +4525,6 @@
         with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
             transform(input)
 
-
 only_for = ("cpu", "cuda")
 instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)