Port `mm` cuda from TH to ATen (#34891)
Summary:
Issue https://github.com/pytorch/pytorch/issues/24596
This PR moves `mm` cuda to ATen. The internal `addmmImpl` that was used as the base of the old TH version of `mm` cuda is also ported.
This PR also sets up `addmm` cuda to be fairly easily ported to ATen in a future PR, since TH `mm` and `addmm` used the same `addmmImpl` function at their core.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34891
Differential Revision: D20650713
Pulled By: ngimel
fbshipit-source-id: 692aba1bbae65a18d23855b5e101446082d64c66
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 1748895..af0319a 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -827,25 +827,6 @@
- CONSTANT AS_REAL(1)
]]
[[
- name: _th_mm
- cuda_bfloat16: True
- variants: function
- backends:
- - CUDA
- return: argument 0
- options:
- - cname: addmm
- arguments:
- - arg: THTensor* result
- output: True
- resize: [ [self, 0], [mat2,1] ]
- - argument 0
- - THTensor* self
- - THTensor* mat2
- - CONSTANT AS_REAL(0)
- - CONSTANT AS_REAL(1)
-]]
-[[
name: _th_bmm
cuda_bfloat16: True
cname: baddbmm
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index a94afba..a377968 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -1,5 +1,6 @@
#include <ATen/ATen.h>
#include <ATen/LegacyTHFunctionsCUDA.h>
+#include <ATen/cuda/CUDABlas.h>
namespace at { namespace native {
@@ -23,4 +24,128 @@
return legacy::cuda::_th_bmm_out(result, batch1, batch2);
}
+Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) {
+ Tensor tensor_;
+ IntArrayRef tensor_strides = tensor.strides();
+
+ if ((tensor_strides[0] == 1) && (tensor_strides[1] != 0)) {
+ tensor_ = tensor;
+ transpose_tensor = false;
+ } else if ((tensor_strides[1] == 1) && (tensor_strides[0] != 0)) {
+ tensor_ = tensor;
+ transpose_tensor = true;
+ } else {
+ transpose_tensor = true;
+ tensor_ = tensor.clone(at::MemoryFormat::Contiguous);
+ }
+
+ return tensor_;
+}
+
+// Check https://github.com/pytorch/pytorch/issues/22078
+// for information about the bug. We don't know the exact conditions that trigger it,
+// but using Sgemm or Hgemm on Maxwell or Pascal seems to be a
+// necessary condition.
+static void checkCuda90Bug(int i_m, int i_n, int i_k)
+{
+#if CUDA_VERSION < 9200 && CUDA_VERSION >= 9000
+ static std::once_flag alreadyWarned;
+ const int LIMIT = 1 << 21;
+ if (i_m > LIMIT || i_n > LIMIT || i_k > LIMIT) {
+ cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+ if (prop->major == 5 || prop->major == 6) {
+ std::call_once(alreadyWarned, []() {
+ TORCH_WARN("Matrix multiplication for dimensions larger than 2^21 has known bugs on your combination of CUDA version and device type. Please consider upgrading to CUDA 9.2 or later.");
+ });
+ }
+ }
+#endif
+}
+
+Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
+ TORCH_CHECK(
+ (mat1.dim() == 2) && (mat2.dim() == 2) &&
+ (self.dim() == 2) && (result.dim() == 2),
+ "tensors must be 2-D"
+ );
+ IntArrayRef mat1_sizes = mat1.sizes();
+ IntArrayRef mat2_sizes = mat2.sizes();
+ IntArrayRef self_sizes = self.sizes();
+ TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0");
+ TORCH_CHECK(self_sizes[0] == mat1_sizes[0], "self dim 0 must match mat1 dim 0");
+ TORCH_CHECK(self_sizes[1] == mat2_sizes[1], "self dim 1 must match mat2 dim 1");
+
+ // If self and result either point to the same data or if beta is zero,
+ // we can avoid copying self into result. Otherwise, we need to copy.
+ if (beta.to<double>() != 0.0) {
+ if ((result.data_ptr() != self.data_ptr()) || (result.strides() != self.strides())) {
+ result.copy_(self);
+ }
+ }
+
+ IntArrayRef result_sizes = result.sizes();
+ if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
+ return result;
+ }
+
+ bool transpose_result;
+ Tensor result_ = prepare_matrix_for_cublas(result, transpose_result);
+ bool transpose_mat1;
+ bool transpose_mat2;
+ Tensor mat1_ = transpose_result ? mat2 : mat1;
+ Tensor mat2_ = transpose_result ? mat1 : mat2;
+ mat1_ = prepare_matrix_for_cublas(mat1_, transpose_mat1);
+ mat2_ = prepare_matrix_for_cublas(mat2_, transpose_mat2);
+
+ if (transpose_result) {
+ transpose_mat1 = !transpose_mat1;
+ transpose_mat2 = !transpose_mat2;
+ mat1_sizes = mat1_.sizes();
+ mat2_sizes = mat2_.sizes();
+ }
+
+ int64_t m = mat1_sizes[transpose_result ? 1 : 0];
+ int64_t k = mat1_sizes[transpose_result ? 0 : 1];
+ int64_t n = mat2_sizes[transpose_result ? 0 : 1];
+ int64_t mat1_ld = mat1_.stride((transpose_mat1 == transpose_result) ? 1 : 0);
+ int64_t mat2_ld = mat2_.stride((transpose_mat2 == transpose_result) ? 1 : 0);
+ int64_t result_ld = result_.stride(transpose_result ? 0 : 1);
+ at::ScalarType scalar_type = self.scalar_type();
+
+ AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
+ if (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::Float) {
+ checkCuda90Bug(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k));
+ }
+ scalar_t alpha_val = alpha.to<scalar_t>();
+ scalar_t beta_val = beta.to<scalar_t>();
+ scalar_t* mat1_ptr = mat1_.data_ptr<scalar_t>();
+ scalar_t* mat2_ptr = mat2_.data_ptr<scalar_t>();
+ scalar_t* result_ptr = result_.data_ptr<scalar_t>();
+ at::cuda::blas::gemm<scalar_t>(
+ transpose_mat1 ? 't' : 'n',
+ transpose_mat2 ? 't' : 'n',
+ m, n, k,
+ alpha_val,
+ mat1_ptr, mat1_ld,
+ mat2_ptr, mat2_ld,
+ beta_val,
+ result_ptr, result_ld
+ );
+ });
+ if (result.data_ptr() != result_.data_ptr()) {
+ result.copy_(result_);
+ }
+ return result;
+}
+
+Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) {
+ result.resize_({ self.size(0), mat2.size(1) });
+ return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
+}
+
+Tensor mm_cuda(const Tensor& self, const Tensor& mat2) {
+ Tensor result = at::empty({ self.size(0), mat2.size(1) }, self.options());
+ return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
+}
+
} }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index c33232e..ff99e28 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1961,7 +1961,7 @@
variants: function, method
dispatch:
CPU: mm_cpu
- CUDA: legacy::cuda::_th_mm
+ CUDA: mm_cuda
SparseCPU: _sparse_mm
SparseCUDA: _sparse_mm
supports_named_tensor: True
@@ -1969,7 +1969,7 @@
- func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: mm_cpu_out
- CUDA: legacy::cuda::_th_mm_out
+ CUDA: mm_out_cuda
SparseCPU: _sparse_mm_out
SparseCUDA: _sparse_mm_out
supports_named_tensor: True
diff --git a/test/test_torch.py b/test/test_torch.py
index d2f57e1..0340589 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -14552,8 +14552,9 @@
self.assertEqual(r.dtype, a.dtype)
@slowTest
- @onlyCPU
- def test_mm(self, device):
+ @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64)
+ @dtypesIfCUDA(torch.float32, torch.float64)
+ def test_mm(self, device, dtype):
def _test_mm(n, m, p, dtype, genf):
# helper function
def matrixmultiply(mat1, mat2):
@@ -14625,12 +14626,24 @@
res2 = matrixmultiply(mat1, mat2)
self.assertEqual(res, res2)
+ def genf_int(x, y):
+ return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
+
+ def genf_bfloat(x, y):
+ return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype)
+
+ def genf_float(x, y):
+ return torch.randn(x, y, dtype=dtype, device=device)
+
for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]:
- _test_mm(n, m, p, torch.float32, lambda x, y: torch.randn(x, y, dtype=torch.float32, device=device))
- _test_mm(n, m, p, torch.float64, lambda x, y: torch.randn(x, y, dtype=torch.float64, device=device))
- _test_mm(n, m, p, torch.int32, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32, device=device))
- _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64, device=device))
- _test_mm(n, m, p, torch.bfloat16, lambda x, y: torch.randn(x, y, dtype=torch.float32, device=device).bfloat16())
+ if (dtype == torch.int32) or (dtype == torch.int64):
+ genf = genf_int
+ elif (dtype == torch.bfloat16):
+ genf = genf_bfloat
+ else:
+ genf = genf_float
+
+ _test_mm(n, m, p, dtype, genf)
@onlyCPU
@dtypes(torch.float)