Fix allowed dtypes for mem_eff attention (#116026)
# Summary
Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 1fcbf97..49883cc 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -339,9 +339,9 @@
return false;
#endif
// Constraints specific to mem efficient attention
- constexpr auto default_mem_efficient_dtypes =
+ constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
- constexpr auto sm50_mem_efficient_dtypes =
+ constexpr auto less_than_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat);
// Define gate functions that determine if a mem efficient kernel can be ran
@@ -381,10 +381,10 @@
}
auto dprop = at::cuda::getCurrentDeviceProperties();
- if (dprop->major == 5) {
- return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug);
+ if (dprop->major >= 8) {
+ return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
}
- return check_tensor_dtype(params, default_mem_efficient_dtypes, debug);
+ return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
}
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 2a4d91c..d7503c3 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -68,6 +68,7 @@
isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
+isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
@@ -1503,8 +1504,9 @@
@onlyCUDA
- @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
- def test_mem_efficient_fail_bfloat16_sm50(self, device):
+ @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
+ "Current platform does not support fused SDPA or is an SM80+ device.")
+ def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
dtype = torch.bfloat16
size = SdpaShape(16, 16, 32, 32)
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)