[mkldnn_matmul] enable mkldnn matmul for aarch64 bf16 devices (#83671) (#85546)
this PR enables mkldnn matmul for aarch64 bf16 devices for both bf16 as well as fp32 input.
This PR is dependent on
cpuinfo commit update PR: https://github.com/pytorch/pytorch/pull/83620
Issue: https://github.com/pytorch/pytorch/issues/83594
This is a reland of https://github.com/pytorch/pytorch/pull/83671
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85546
Approved by: https://github.com/kit1980
diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp
index 9b07dbf..d41ebac 100644
--- a/aten/src/ATen/native/mkldnn/Matmul.cpp
+++ b/aten/src/ATen/native/mkldnn/Matmul.cpp
@@ -127,11 +127,24 @@
(mat1.dim() == 2 && mat2.dim() == 1) || // aten::mv
(mat1.dim() == 1 && mat2.dim() == 1), // aten::dot
"mkldnn_matmul: unsupported dims for mat and mat2");
- TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
- mat2.scalar_type() == at::kBFloat16 &&
- result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
+
TORCH_CHECK(mkldnn_bf16_device_check(),
- "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq");
+ "mkldnn_matmul: mkldnn_matmul bf16 path needs the cpu support avx512bw, avx512vl and avx512dq, or AWS Graviton3");
+
+#if defined(__aarch64__)
+ if (mkldnn_bf16_device_check_arm()) {
+ //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
+ //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
+ TORCH_CHECK((mat1.scalar_type() == mat2.scalar_type()) && (mat1.scalar_type() == result.scalar_type()) &&
+ ((mat1.scalar_type() == at::kFloat) || (mat1.scalar_type() == at::kBFloat16)),
+ "mkldnn_matmul: only enabled for fp32 and bf16 path");
+ } else
+#endif
+ {
+ TORCH_CHECK(mat1.scalar_type() == at::kBFloat16 &&
+ mat2.scalar_type() == at::kBFloat16 &&
+ result.scalar_type() == at::kBFloat16, "mkldnn_matmul: only enabled for bf16 path");
+ }
auto mat1_unsqueezed = mat1.dim() == 1 ? mat1.unsqueeze(0) : mat1;
auto mat2_unsqueezed = mat2.dim() == 1 ? mat2.unsqueeze(1) : mat2;
@@ -209,14 +222,29 @@
const Tensor& mat1,
const Tensor& mat2,
const Tensor& result) {
- return (
- use_mkldnn_bf16_matmul() &&
- mat1.scalar_type() == kBFloat16 &&
- mat2.scalar_type() == kBFloat16 &&
- (!result.defined() || result.scalar_type() == kBFloat16) &&
- mat1.numel() != 0 &&
- mat2.numel() != 0 &&
- checksize(mat1, mat2));
+#if defined(__aarch64__)
+ if (mkldnn_bf16_device_check_arm()) {
+ //onednn fastmath mode can leverage bf16 HW even for the fp32 input, e.g. Arm Neoverse V1
+ //so, don't restrict the mkldnn_matmul only for bf16 inputs, allow it for float as well
+ return (
+ use_mkldnn_bf16_matmul() &&
+ (mat1.scalar_type() == mat2.scalar_type()) && (!result.defined() || (mat1.scalar_type() == result.scalar_type())) &&
+ ((mat1.scalar_type() == kFloat) || (mat1.scalar_type() == kBFloat16)) &&
+ mat1.numel() != 0 &&
+ mat2.numel() != 0 &&
+ checksize(mat1, mat2));
+ } else
+#endif
+ {
+ return (
+ use_mkldnn_bf16_matmul() &&
+ mat1.scalar_type() == kBFloat16 &&
+ mat2.scalar_type() == kBFloat16 &&
+ (!result.defined() || result.scalar_type() == kBFloat16) &&
+ mat1.numel() != 0 &&
+ mat2.numel() != 0 &&
+ checksize(mat1, mat2));
+ }
}
} // namespace native
diff --git a/aten/src/ATen/native/mkldnn/Utils.h b/aten/src/ATen/native/mkldnn/Utils.h
index a27b842..37a489b 100644
--- a/aten/src/ATen/native/mkldnn/Utils.h
+++ b/aten/src/ATen/native/mkldnn/Utils.h
@@ -25,8 +25,18 @@
};
inline bool mkldnn_bf16_device_check() {
- return cpuinfo_initialize() && cpuinfo_has_x86_avx512bw()
- && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq();
+ return cpuinfo_initialize() && ((cpuinfo_has_x86_avx512bw()
+ && cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512dq()) || (cpuinfo_has_arm_bf16()));
}
+#if defined(__aarch64__)
+inline bool mkldnn_bf16_device_check_arm() {
+ return (cpuinfo_initialize() && cpuinfo_has_arm_bf16());
+}
+#else
+constexpr bool mkldnn_bf16_device_check_arm() {
+ return false;
+}
+#endif
+
}