Port `mm` cuda from TH to ATen (#34891)

Summary:
Issue https://github.com/pytorch/pytorch/issues/24596

This PR moves `mm` cuda to ATen. The internal `addmmImpl` that was used as the base of the old TH version of `mm` cuda is also ported.

This PR also sets up `addmm` cuda to be fairly easily ported to ATen in a future PR, since TH `mm` and `addmm` used the same `addmmImpl` function at their core.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34891

Differential Revision: D20650713

Pulled By: ngimel

fbshipit-source-id: 692aba1bbae65a18d23855b5e101446082d64c66
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 1748895..af0319a 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -827,25 +827,6 @@
     - CONSTANT AS_REAL(1)
 ]]
 [[
-  name: _th_mm
-  cuda_bfloat16: True
-  variants: function
-  backends:
-    - CUDA
-  return: argument 0
-  options:
-    - cname: addmm
-      arguments:
-        - arg: THTensor* result
-          output: True
-          resize: [ [self, 0], [mat2,1] ]
-        - argument 0
-        - THTensor* self
-        - THTensor* mat2
-        - CONSTANT AS_REAL(0)
-        - CONSTANT AS_REAL(1)
-]]
-[[
   name: _th_bmm
   cuda_bfloat16: True
   cname: baddbmm
diff --git a/aten/src/ATen/native/cuda/LinearAlgebra.cu b/aten/src/ATen/native/cuda/LinearAlgebra.cu
index a94afba..a377968 100644
--- a/aten/src/ATen/native/cuda/LinearAlgebra.cu
+++ b/aten/src/ATen/native/cuda/LinearAlgebra.cu
@@ -1,5 +1,6 @@
 #include <ATen/ATen.h>
 #include <ATen/LegacyTHFunctionsCUDA.h>
+#include <ATen/cuda/CUDABlas.h>
 
 namespace at { namespace native {
 
@@ -23,4 +24,128 @@
   return legacy::cuda::_th_bmm_out(result, batch1, batch2);
 }
 
+Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) {
+  Tensor tensor_;
+  IntArrayRef tensor_strides = tensor.strides();
+
+  if ((tensor_strides[0] == 1) && (tensor_strides[1] != 0)) {
+    tensor_ = tensor;
+    transpose_tensor = false;
+  } else if ((tensor_strides[1] == 1) && (tensor_strides[0] != 0)) {
+    tensor_ = tensor;
+    transpose_tensor = true;
+  } else {
+    transpose_tensor = true;
+    tensor_ = tensor.clone(at::MemoryFormat::Contiguous);
+  }
+
+  return tensor_;
+}
+
+// Check https://github.com/pytorch/pytorch/issues/22078
+// for information about the bug. We don't know the exact conditions that trigger it,
+// but using Sgemm or Hgemm on Maxwell or Pascal seems to be a
+// necessary condition.
+static void checkCuda90Bug(int i_m, int i_n, int i_k)
+{
+#if CUDA_VERSION < 9200 && CUDA_VERSION >= 9000
+  static std::once_flag alreadyWarned;
+  const int LIMIT = 1 << 21;
+  if (i_m > LIMIT || i_n > LIMIT || i_k > LIMIT) {
+    cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
+    if (prop->major == 5 || prop->major == 6) {
+      std::call_once(alreadyWarned, []() {
+        TORCH_WARN("Matrix multiplication for dimensions larger than 2^21 has known bugs on your combination of CUDA version and device type. Please consider upgrading to CUDA 9.2 or later.");
+      });
+    }
+  }
+#endif
+}
+
+Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
+  TORCH_CHECK(
+    (mat1.dim() == 2) && (mat2.dim() == 2) &&
+    (self.dim() == 2) && (result.dim() == 2),
+    "tensors must be 2-D"
+  );
+  IntArrayRef mat1_sizes = mat1.sizes();
+  IntArrayRef mat2_sizes = mat2.sizes();
+  IntArrayRef self_sizes = self.sizes();
+  TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0");
+  TORCH_CHECK(self_sizes[0] == mat1_sizes[0], "self dim 0 must match mat1 dim 0");
+  TORCH_CHECK(self_sizes[1] == mat2_sizes[1], "self dim 1 must match mat2 dim 1");
+
+  // If self and result either point to the same data or if beta is zero,
+  // we can avoid copying self into result. Otherwise, we need to copy.
+  if (beta.to<double>() != 0.0) {
+    if ((result.data_ptr() != self.data_ptr()) || (result.strides() != self.strides())) {
+      result.copy_(self);
+    }
+  }
+
+  IntArrayRef result_sizes = result.sizes();
+  if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
+    return result;
+  }
+
+  bool transpose_result;
+  Tensor result_ = prepare_matrix_for_cublas(result, transpose_result);
+  bool transpose_mat1;
+  bool transpose_mat2;
+  Tensor mat1_ = transpose_result ? mat2 : mat1;
+  Tensor mat2_ = transpose_result ? mat1 : mat2;
+  mat1_ = prepare_matrix_for_cublas(mat1_, transpose_mat1);
+  mat2_ = prepare_matrix_for_cublas(mat2_, transpose_mat2);
+
+  if (transpose_result) {
+    transpose_mat1 = !transpose_mat1;
+    transpose_mat2 = !transpose_mat2;
+    mat1_sizes = mat1_.sizes();
+    mat2_sizes = mat2_.sizes();
+  }
+
+  int64_t m = mat1_sizes[transpose_result ? 1 : 0];
+  int64_t k = mat1_sizes[transpose_result ? 0 : 1];
+  int64_t n = mat2_sizes[transpose_result ? 0 : 1];
+  int64_t mat1_ld = mat1_.stride((transpose_mat1 == transpose_result) ? 1 : 0);
+  int64_t mat2_ld = mat2_.stride((transpose_mat2 == transpose_result) ? 1 : 0);
+  int64_t result_ld = result_.stride(transpose_result ? 0 : 1);
+  at::ScalarType scalar_type = self.scalar_type();
+
+  AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "addmm_cuda", [&] {
+    if (scalar_type == at::ScalarType::Half || scalar_type == at::ScalarType::Float) {
+      checkCuda90Bug(static_cast<int>(m), static_cast<int>(n), static_cast<int>(k));
+    }
+    scalar_t alpha_val = alpha.to<scalar_t>();
+    scalar_t beta_val = beta.to<scalar_t>();
+    scalar_t* mat1_ptr = mat1_.data_ptr<scalar_t>();
+    scalar_t* mat2_ptr = mat2_.data_ptr<scalar_t>();
+    scalar_t* result_ptr = result_.data_ptr<scalar_t>();
+    at::cuda::blas::gemm<scalar_t>(
+      transpose_mat1 ? 't' : 'n',
+      transpose_mat2 ? 't' : 'n',
+      m, n, k,
+      alpha_val,
+      mat1_ptr, mat1_ld,
+      mat2_ptr, mat2_ld,
+      beta_val,
+      result_ptr, result_ld
+    );
+  });
+  if (result.data_ptr() != result_.data_ptr()) {
+    result.copy_(result_);
+  }
+  return result;
+}
+
+Tensor& mm_out_cuda(Tensor& result, const Tensor& self, const Tensor& mat2) {
+  result.resize_({ self.size(0), mat2.size(1) });
+  return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
+}
+
+Tensor mm_cuda(const Tensor& self, const Tensor& mat2) {
+  Tensor result = at::empty({ self.size(0), mat2.size(1) }, self.options());
+  return addmm_out_cuda_impl(result, result, self, mat2, 0, 1);
+}
+
 } }
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index c33232e..ff99e28 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1961,7 +1961,7 @@
   variants: function, method
   dispatch:
     CPU: mm_cpu
-    CUDA: legacy::cuda::_th_mm
+    CUDA: mm_cuda
     SparseCPU: _sparse_mm
     SparseCUDA: _sparse_mm
   supports_named_tensor: True
@@ -1969,7 +1969,7 @@
 - func: mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
   dispatch:
     CPU: mm_cpu_out
-    CUDA: legacy::cuda::_th_mm_out
+    CUDA: mm_out_cuda
     SparseCPU: _sparse_mm_out
     SparseCUDA: _sparse_mm_out
   supports_named_tensor: True
diff --git a/test/test_torch.py b/test/test_torch.py
index d2f57e1..0340589 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -14552,8 +14552,9 @@
         self.assertEqual(r.dtype, a.dtype)
 
     @slowTest
-    @onlyCPU
-    def test_mm(self, device):
+    @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64)
+    @dtypesIfCUDA(torch.float32, torch.float64)
+    def test_mm(self, device, dtype):
         def _test_mm(n, m, p, dtype, genf):
             # helper function
             def matrixmultiply(mat1, mat2):
@@ -14625,12 +14626,24 @@
             res2 = matrixmultiply(mat1, mat2)
             self.assertEqual(res, res2)
 
+        def genf_int(x, y):
+            return torch.randint(0, 100, (x, y), dtype=dtype, device=device)
+
+        def genf_bfloat(x, y):
+            return torch.randn(x, y, dtype=torch.float32, device=device).to(dtype)
+
+        def genf_float(x, y):
+            return torch.randn(x, y, dtype=dtype, device=device)
+
         for (n, m, p) in [(20, 10, 5), (15, 5, 10), (5, 18, 10)]:
-            _test_mm(n, m, p, torch.float32, lambda x, y: torch.randn(x, y, dtype=torch.float32, device=device))
-            _test_mm(n, m, p, torch.float64, lambda x, y: torch.randn(x, y, dtype=torch.float64, device=device))
-            _test_mm(n, m, p, torch.int32, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int32, device=device))
-            _test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64, device=device))
-            _test_mm(n, m, p, torch.bfloat16, lambda x, y: torch.randn(x, y, dtype=torch.float32, device=device).bfloat16())
+            if (dtype == torch.int32) or (dtype == torch.int64):
+                genf = genf_int
+            elif (dtype == torch.bfloat16):
+                genf = genf_bfloat
+            else:
+                genf = genf_float
+
+            _test_mm(n, m, p, dtype, genf)
 
     @onlyCPU
     @dtypes(torch.float)