[mkldnn_matmul] enable mkldnn matmul for aarch64 bf16 devices (#83671) (#85546)

this PR enables mkldnn matmul for aarch64 bf16 devices for both bf16 as well as fp32 input.

This PR is dependent on
cpuinfo commit update PR: https://github.com/pytorch/pytorch/pull/83620
Issue: https://github.com/pytorch/pytorch/issues/83594
This is a reland of  https://github.com/pytorch/pytorch/pull/83671

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85546
Approved by: https://github.com/kit1980
diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp
index 9b07dbf..d41ebac 100644
--- a/aten/src/ATen/native/mkldnn/Matmul.cpp
+++ b/aten/src/ATen/native/mkldnn/Matmul.cpp
@@ -127,11 +127,24 @@
               (mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
               (mat1.dim() == 1 && mat2.dim() == 1),  // aten::dot
               "mkldnn_matmul:  unsupported dims for mat and mat2");
-  TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
-   mat2.scalar_type() == at::kBFloat16 &&
-   result.scalar_type() == at::kBFloat16, "mkldnn_matmul:  only enabled for bf16 path");
+
   TORCH_CHECK(mkldnn_bf16_device_check(),
-    "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
+    "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3");
+
+#if defined(__aarch64__)
+  if (mkldnn_bf16_device_check_arm()) {
+     //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
+     //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
+     TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
+                 ((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
+                 "mkldnn_matmul:  only enabled for fp32 and bf16 path");
+  } else
+#endif
+  {
+     TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
+                 mat2.scalar_type() == at::kBFloat16 &&
+                 result.scalar_type() == at::kBFloat16, "mkldnn_matmul:  only enabled for bf16 path");
+  }
 
   auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
   auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
@@ -209,14 +222,29 @@
     const Tensor& mat1,
     const Tensor& mat2,
     const Tensor& result) {
-  return (
-      use_mkldnn_bf16_matmul() &&
-      mat1.scalar_type() == kBFloat16 &&
-      mat2.scalar_type() == kBFloat16 &&
-      (!result.defined() || result.scalar_type() == kBFloat16) &&
-      mat1.numel() != 0 &&
-      mat2.numel() != 0 &&
-      checksize(mat1, mat2));
+#if defined(__aarch64__)
+  if (mkldnn_bf16_device_check_arm()) {
+     //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
+     //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
+     return (
+        use_mkldnn_bf16_matmul() &&
+        (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
+        ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
+        mat1.numel() != 0 &&
+        mat2.numel() != 0 &&
+        checksize(mat1, mat2));
+  } else
+#endif
+  {
+     return (
+        use_mkldnn_bf16_matmul() &&
+        mat1.scalar_type() == kBFloat16 &&
+        mat2.scalar_type() == kBFloat16 &&
+        (!result.defined() || result.scalar_type() == kBFloat16) &&
+        mat1.numel() != 0 &&
+        mat2.numel() != 0 &&
+        checksize(mat1, mat2));
+  }
 }
 
 } // namespace native
diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h
index a27b842..37a489b 100644
--- a/aten/src/ATen/native/mkldnn/Utils.h
+++ b/aten/src/ATen/native/mkldnn/Utils.h
@@ -25,8 +25,18 @@
 };
 
 inline bool mkldnn_bf16_device_check() {
-  return cpuinfo_initialize() && cpuinfo_has_x86_avx512bw()
-      && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq();
+  return cpuinfo_initialize() && ((cpuinfo_has_x86_avx512bw()
+     && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq()) || (cpuinfo_has_arm_bf16()));
 }
 
+#if defined(__aarch64__)
+inline bool mkldnn_bf16_device_check_arm() {
+  return (cpuinfo_initialize() && cpuinfo_has_arm_bf16());
+}
+#else
+constexpr bool mkldnn_bf16_device_check_arm() {
+  return false;
+}
+#endif
+
 }