[bug] Binomial distribution BTRS algorithm has small chance of returning -1 (#38456)
Summary:
I was so excited to take advantage of https://github.com/pytorch/pytorch/issues/36858 getting merged that I installed the nightly build, and I'm glad I did!
It turns out that there's a _very small_ chance that the current algorithm will return a negative value (I imagine only -1 is possible but not sure about that).
Basically the logic [here](https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Distributions.h#L198-L213), which returns a value that passes certain checks before checking if its negative. I can't figure out the particular range that causes this but could reproduce it by taking a billion samples with `count` 1 and `prob` 0.9:
```python
(
torch.distributions.Binomial(
total_count=torch.tensor(1.0).cuda(), probs=torch.tensor(0.9).cuda()
).sample(torch.Size((1000000000,))) >= 0
).all()
```
Reliably evaluates to `tensor(False, device='cuda:0')` on my machine. 100M samples usually does it but not always, so that's around the rate at which this crops up (it took me most of a whole day to run into it!). Seems to be CUDA specific, I imagine due to some esoteric reason I cannot begin to guess.
This PR tries to solve it in the most obvious way: reject negative values _before_ testing the bounding box, not after. But a better solution is probably to figure out why this occurs at all, and stop it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38456
Differential Revision: D21664886
Pulled By: jerryzh168
fbshipit-source-id: 99b0eed980e214bede484c100388a74d8c40ca55
diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h
index 577617a..e39820f 100644
--- a/aten/src/ATen/native/Distributions.h
+++ b/aten/src/ATen/native/Distributions.h
@@ -201,16 +201,16 @@
us = 0.5 - compat_abs(U);
k = static_cast<scalar_t>(compat_floor((2 * a / us + b) * U + c));
+ // Reject non-sensical answers.
+ if (k < 0 || k > count) {
+ continue;
+ }
// Region for which the box is tight, and we can return our calculated value.
// This should happen 0.86 * v_r times. In the limit as n * p is large,
// the acceptance rate converges to ~79% (and in the lower regime it is ~24%).
if (us >= 0.07 && V <= v_r) {
return k;
}
- // Reject non-sensical answers.
- if (k < 0 || k > count) {
- continue;
- }
// This deviates from Hormann's BTRS algorithm, as there is a log missing.
// For all (u, v) pairs outside of the bounding box, this calculates the