[functorch] Fix normal_ and bernoulli (pytorch/functorch#670)

* normal_fix

* fix binomial test
diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
index 34b1d88..912d25e 100644
--- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
+++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
@@ -292,14 +292,18 @@
   return std::make_tuple(out, out_bdim);
 }
 
+Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
+  return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
+}
+
 TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
   #define BINARY_RANDOM_POINTWISE(op) \
-  m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
-#define BINARY_RANDOM_POINTWISE2(op, overload) \
-  m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
+    m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
+  #define BINARY_RANDOM_POINTWISE2(op, overload) \
+    m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
 
   BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
-  BINARY_RANDOM_POINTWISE(binomial);
+  m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
 }
 
 TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 071d899..5cf1c86 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -3814,7 +3814,7 @@
             lambda t, _: t.random_(**kwargs),
             lambda t, _: t.random_(100, **kwargs),
             lambda t, _: t.random_(-5, 100, **kwargs),
-            # lambda t, _: t.normal_(**kwargs),  TODO(samdow): fix normal_ with -1 bdim
+            lambda t, _: t.normal_(**kwargs),
             lambda t, _: t.bernoulli_(**kwargs),
             lambda t, _: t.cauchy_(**kwargs),
             lambda t, _: t.exponential_(**kwargs),
@@ -3851,7 +3851,7 @@
                 self.assertEqual(vmap_result, expected)
             else:
                 if batched_input != "none":
-                    passed_expected = passed_expected[0]
+                    passed_expected = passed_expected[0].clone()  # bug in pytorch, normal_ on views doesn't work
                 expected = op(passed_expected, always_batched)
                 self._assert_all_slices_equal(vmap_result)
                 for i in range(B0):
@@ -3923,8 +3923,7 @@
         kwargs = {'generator': generator} if use_generator else {}
         ops = [
             lambda t, o, _: torch.normal(t, o, **kwargs),
-            # TODO(samdow): fix binomial
-            # lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
+            lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
         ]
 
         B0 = 4