[functorch] Fix normal_ and bernoulli (pytorch/functorch#670)
* normal_fix
* fix binomial test
diff --git a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
index 34b1d88..912d25e 100644
--- a/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
+++ b/functorch/functorch/csrc/BatchRulesBinaryOps.cpp
@@ -292,14 +292,18 @@
return std::make_tuple(out, out_bdim);
}
+Tensor binomial_wrapper(const Tensor& count, const Tensor& prob, c10::optional<Generator> gen) {
+ return at::binomial(count, prob.contiguous(), gen); // Bug in PyTorch, prob shouldn't need to be contiguous
+}
+
TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {
#define BINARY_RANDOM_POINTWISE(op) \
- m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
-#define BINARY_RANDOM_POINTWISE2(op, overload) \
- m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
+ m.impl(#op, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN(op)));
+ #define BINARY_RANDOM_POINTWISE2(op, overload) \
+ m.impl(#op"."#overload, BINARY_RANDOM_POINTWISE_BATCH_RULE(ATEN_FN2(op, overload)));
BINARY_RANDOM_POINTWISE2(normal, Tensor_Tensor);
- BINARY_RANDOM_POINTWISE(binomial);
+ m.impl("binomial", BINARY_RANDOM_POINTWISE_BATCH_RULE(at::functorch::binomial_wrapper));
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py
index 071d899..5cf1c86 100644
--- a/functorch/test/test_vmap.py
+++ b/functorch/test/test_vmap.py
@@ -3814,7 +3814,7 @@
lambda t, _: t.random_(**kwargs),
lambda t, _: t.random_(100, **kwargs),
lambda t, _: t.random_(-5, 100, **kwargs),
- # lambda t, _: t.normal_(**kwargs), TODO(samdow): fix normal_ with -1 bdim
+ lambda t, _: t.normal_(**kwargs),
lambda t, _: t.bernoulli_(**kwargs),
lambda t, _: t.cauchy_(**kwargs),
lambda t, _: t.exponential_(**kwargs),
@@ -3851,7 +3851,7 @@
self.assertEqual(vmap_result, expected)
else:
if batched_input != "none":
- passed_expected = passed_expected[0]
+ passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work
expected = op(passed_expected, always_batched)
self._assert_all_slices_equal(vmap_result)
for i in range(B0):
@@ -3923,8 +3923,7 @@
kwargs = {'generator': generator} if use_generator else {}
ops = [
lambda t, o, _: torch.normal(t, o, **kwargs),
- # TODO(samdow): fix binomial
- # lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
+ lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
]
B0 = 4