Fix NaN handling in torch.mv. (#31666)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31666
List of changes:
1) Fix a case where torch.mv was not handling NaNs correctly. In particular, with a transposed tensor and expanded vector, NaNs in the output are kept, even if beta = 0.
This is handled in the `out=` case by zero-ing out the passed-in Tensor, but this can happen just the same with the non-out variant if the allocated tensor happens to have a NaN.
Also adds tests for this case.
NOTE: we zero out the output tensor in all cases for mv and mm, even though this is probably overkill. I didn't find another case where this would be a problem, but the old code at least
attempted to do this for all mv and mm calls and I didn't add comprehensive testing to be sure that it's not a problem.
2) on CPU: move mv, mv_out, mm, mm_out to be direct wrappers on _th_addmv, _th_addmm, rather than having their own wrappers in Declarations.cwrap.
Ths is to remove the magic around cpu_zero from the codegen, which simplifies the codegen and makes testing this easier.
Test Plan: Imported from OSS
Differential Revision: D19239953
Pulled By: gchanan
fbshipit-source-id: 27d0748d215ad46d17a8684696d88f4cfd8a917e
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index f8517aa..48e6da2 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -885,16 +885,16 @@
]]
[[
name: _th_mv
- cpu_bfloat16: True
cuda_bfloat16: True
cname: addmv
+ backends:
+ - CUDA
variants: function
return: argument 0
arguments:
- arg: THTensor* result
output: True
resize: [ [self, 0] ]
- cpu_zero: True
- argument 0
- THTensor* self
- THTensor* vec
@@ -903,9 +903,10 @@
]]
[[
name: _th_mm
- cpu_bfloat16: True
cuda_bfloat16: True
variants: function
+ backends:
+ - CUDA
return: argument 0
options:
- cname: addmm
@@ -913,7 +914,6 @@
- arg: THTensor* result
output: True
resize: [ [self, 0], [mat2,1] ]
- cpu_zero: True
- argument 0
- THTensor* self
- THTensor* mat2
@@ -933,7 +933,6 @@
- arg: THTensor* result
output: True
resize: [ [self,0], [self,1], [mat2,2] ]
- cpu_zero: True
- argument 0
- THTensor* self
- THTensor* mat2
diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py
index 26279be..0a66854 100644
--- a/aten/src/ATen/function_wrapper.py
+++ b/aten/src/ATen/function_wrapper.py
@@ -432,7 +432,6 @@
# Broadcast is originally a str but gets unwrapped to a List or Dict in-place
'broadcast': Any,
'resize': str,
- 'cpu_zero': bool,
'zero': bool,
}, total=False)
@@ -1493,7 +1492,7 @@
initializers.append(resize_arg(arg))
# also special handling where we zero some outputs.
- if arg.get('zero', False) or (arg.get('cpu_zero', False) and not is_cuda):
+ if arg.get('zero', False):
initializers.append("{}.zero_();".format(arg['name']))
# only initialize non-null arguments
diff --git a/aten/src/ATen/native/BlasWrappersCPU.cpp b/aten/src/ATen/native/BlasWrappersCPU.cpp
new file mode 100644
index 0000000..811e2f0
--- /dev/null
+++ b/aten/src/ATen/native/BlasWrappersCPU.cpp
@@ -0,0 +1,37 @@
+#include <ATen/ATen.h>
+#include <ATen/NativeFunctions.h>
+#include <ATen/LegacyTHFunctionsCPU.h>
+
+// These are wrappers around the TH Linear Algebra / BLAS wrappers (mv, mm, bmm).
+
+namespace at {
+namespace native {
+
+Tensor & mv_cpu_out(Tensor & result, const Tensor & self, const Tensor & vec) {
+ result.resize_({ self.size(0) });
+ // we likely don't need to do this, see [NOTE: cpu_zero].
+ // We should do a full accounting that all cases are handled correctly, without it, though.
+ result.zero_();
+ return legacy::cpu::_th_addmv_out(result, result, self, vec, 0, 1);
+}
+
+Tensor mv_cpu(const Tensor & self, const Tensor & vec) {
+ Tensor result = at::empty({0}, self.options());
+ return mv_cpu_out(result, self, vec);
+}
+
+Tensor & mm_cpu_out(Tensor & result, const Tensor & self, const Tensor & mat2) {
+ result.resize_({ self.size(0), mat2.size(1) });
+ // we likely don't need to do this, see [NOTE: cpu_zero].
+ // We should do a full accounting that all cases are handled correctly, without it, though.
+ result.zero_();
+ return legacy::cpu::_th_addmm_out(result, result, self, mat2, 0, 1);
+}
+
+Tensor mm_cpu(const Tensor & self, const Tensor & mat2) {
+ Tensor result = at::empty({0}, self.options());
+ return mm_cpu_out(result, self, mat2);
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 5462372..ef87718 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -264,7 +264,7 @@
if (is_bmm_out) {
for (int64_t b = 0; b < bs; b++) {
auto r = self_or_result.select(0, b);
- legacy::cpu::_th_mm_out(r, batch1.select(0, b), batch2.select(0, b));
+ native::mm_cpu_out(r, batch1.select(0, b), batch2.select(0, b));
}
} else {
for (int64_t b = 0; b < bs; b++) {
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b0cc2bc..61027ac 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1934,7 +1934,7 @@
use_c10_dispatcher: full
variants: function, method
dispatch:
- CPU: legacy::cpu::_th_mm
+ CPU: mm_cpu
CUDA: legacy::cuda::_th_mm
SparseCPU: _sparse_mm
SparseCUDA: _sparse_mm
@@ -1942,7 +1942,7 @@
- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
- CPU: legacy::cpu::_th_mm_out
+ CPU: mm_cpu_out
CUDA: legacy::cuda::_th_mm_out
SparseCPU: _sparse_mm_out
SparseCUDA: _sparse_mm_out
@@ -2007,13 +2007,13 @@
use_c10_dispatcher: full
variants: function, method
dispatch:
- CPU: legacy::cpu::_th_mv
+ CPU: mv_cpu
CUDA: legacy::cuda::_th_mv
supports_named_tensor: True
- func: mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
- CPU: legacy::cpu::_th_mv_out
+ CPU: mv_cpu_out
CUDA: legacy::cuda::_th_mv_out
supports_named_tensor: True
diff --git a/aten/src/TH/generic/THBlas.cpp b/aten/src/TH/generic/THBlas.cpp
index 67cfdf7..e7fc72c 100644
--- a/aten/src/TH/generic/THBlas.cpp
+++ b/aten/src/TH/generic/THBlas.cpp
@@ -87,6 +87,17 @@
int i_n = (int)n;
int i_incx = (int)incx;
+ // [NOTE: cpu_zero]
+ // at least on the following version of BLAS this does not folllow the same semantics
+ // when a == 0 and there exists a NaN in the input. Namely, the non-BLAS code below results
+ // in a value of 0, whereas this results in a value of NaN. This is problematic because a
+ // NaN in an output tensor needs to be zero'ed explicitly through a separate mechanism.
+ // At the ATen/TH binding layer, this was via "cpu_zero", which would zero out the output
+ // tensor. This probably isn't necessary if we avoid these calls, but I haven't done a
+ // full analysis of the code.
+ // BLAS version:
+ // [conda] blas 1.0 mkl
+ // [conda] mkl 2019.4 243
#if defined(TH_REAL_IS_DOUBLE)
dscal_(&i_n, &a, x, &i_incx);
#else
diff --git a/test/test_torch.py b/test/test_torch.py
index c891ccf..c7301c8 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -11777,6 +11777,46 @@
self.assertEqual(torch.full((2, 3), beta * value, device=device),
torch.addmm(input=input, mat1=mat, mat2=mat2, alpha=alpha, beta=beta, out=out))
+ @onlyCPU # not supported by CUBLAS
+ def test_blas_nan_out(self, device):
+ # These functions should work correctly with NaN filled outputs,
+ # but need special handling, see [NOTE: cpu_zero]
+ b = 3
+ n = 5
+ m = 7
+ p = 11
+
+ # torch.mv
+ nm = torch.randn((m, n), device=device).t()
+ _m = torch.randn((), device=device).expand(m)
+ _m_out = torch.full((m,), float('nan'), device=device)
+ self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
+ self.assertEqual(0, torch.isnan(torch.mv(nm, _m)).sum())
+
+ # torch.mm
+ mp = torch.randn((p, m), device=device).t()
+ np_out = torch.full((n, p), float('nan'), device=device)
+ self.assertEqual(torch.mm(nm, mp), torch.mm(nm, mp, out=np_out))
+
+ # torch.bmm
+ bnm = torch.randn((b, m, n), device=device).transpose(1, 2)
+ bmp = torch.randn((b, p, m), device=device).transpose(1, 2)
+ bnp_out = torch.full((b, n, p), float('nan'), device=device)
+ self.assertEqual(torch.bmm(bnm, bmp), torch.bmm(bnm, bmp, out=bnp_out))
+
+ @onlyCPU # not supported by CUBLAS
+ def test_blas_mv_large_input(self, device):
+ # This would previously fail if the allocated output had NaNs, see:
+ # https://github.com/pytorch/pytorch/issues/31663 and [NOTE: cpu_zero]
+ n = 3000
+ m = 200
+
+ nm = torch.randn((m, n), device=device).t()
+ _m = torch.randn((), device=device).expand(m)
+ _m_out = torch.full((m,), 0., device=device)
+
+ self.assertEqual(torch.mv(nm, _m), torch.mv(nm, _m, out=_m_out))
+
@skipCUDAIfRocm
def test_unique_dim(self, device):
self.assertFalse(hasattr(torch, 'unique_dim'))