commit | 364f526b9cdf9818a7647b5e637efdee825d61a1 | [log] [tgz] |
---|---|---|
author | min-jean-cho <min.jean.cho@intel.com> | Wed Jan 11 03:24:06 2023 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Wed Jan 11 03:24:10 2023 +0000 |
tree | 75256f77f38deec462f57c5e04c90d2d1f172c36 | |
parent | 554a796aefb068cb1902b69bcb556fab5746a4b6 [diff] |
[Inductor] assert generator for random, dropout (#91833) See comment https://github.com/pytorch/pytorch/pull/90869#discussion_r1063731541 , https://github.com/pytorch/pytorch/pull/91673#discussion_r1061099337. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91833 Approved by: https://github.com/jansel
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 631d34e..607b7ac 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py
@@ -1516,6 +1516,7 @@ @register_decomposition(aten._fused_dropout) @pw_cast_for_opmath def _fused_dropout_decomposition(input, p, generator=None): + assert generator is None mask = (torch.rand_like(input) < p).to(dtype=torch.uint8) res = mask.type_as(input) * input * (1.0 / p) return (res, mask)