Turn off scalar_check for _th_normal.
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29955
Test Plan: Imported from OSS
Differential Revision: D18548051
Pulled By: gchanan
fbshipit-source-id: c652999ac9e37d2592aa85ef022040fe0700b5cf
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index d6ce14d..1a9e4d5 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -1487,6 +1487,7 @@
backends:
- CPU
return: argument 0
+ scalar_check: false
variants:
- function
options:
diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu
index 4af7ac9..2055911 100644
--- a/aten/src/ATen/native/cuda/Distributions.cu
+++ b/aten/src/ATen/native/cuda/Distributions.cu
@@ -714,7 +714,7 @@
Tensor& normal_out_cuda(Tensor& output, double mean, const Tensor& std, Generator* gen) {
normal_cuda_(output, 0, 1, gen);
- auto mean_tensor = at::full({1}, mean, output.options());
+ auto mean_tensor = at::full({}, mean, output.options());
// NB: addcmul_out copies the tensor to be added into the output.
// Please look at aten/src/THC/generic/THCTensorMathPointwise.cu
// The previous function here was addcmul_out(output, mean_tensor, output, std, 1);
diff --git a/test/test_torch.py b/test/test_torch.py
index b470069..96447d1 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -6218,6 +6218,17 @@
self.assertEqual((), torch.gather(one_d, 0, torch.zeros((), dtype=torch.int64, device=device)).shape)
self.assertEqual((1,), torch.gather(one_d, 0, torch.zeros((1,), dtype=torch.int64, device=device)).shape)
+ # normal
+ # documentation says out shape matches shape of mean
+ self.assertEqual((), torch.normal(zero_d, zero_d).shape)
+ self.assertEqual((1,), torch.normal(one_d, zero_d).shape)
+ self.assertEqual((), torch.normal(1, zero_d).shape)
+ self.assertEqual((), torch.normal(zero_d, 1).shape)
+ self.assertEqual((1,), torch.normal(one_d, 1).shape)
+ # TODO: this behavior differs on CPU and GPU, see https://github.com/pytorch/pytorch/issues/30480.
+ # self.assertEqual((), torch.normal(zero_d, one_d).shape)
+ # self.assertEqual((), torch.normal(1, one_d).shape)
+
@onlyCPU
@dtypes(torch.float)
def test_diag(self, device, dtype):