Fix allowed dtypes for mem_eff attention (#116026)

# Summary

Fix issue bug in detecting mem eff capability for cuda devices less than sm80:
https://github.com/pytorch-labs/gpt-fast/issues/49

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116026
Approved by: https://github.com/janeyx99
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 1fcbf97..49883cc 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -339,9 +339,9 @@
   return false;
 #endif
   // Constraints specific to mem efficient attention
-  constexpr auto default_mem_efficient_dtypes =
+  constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
       array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
-  constexpr auto sm50_mem_efficient_dtypes =
+  constexpr auto less_than_sm80_mem_efficient_dtypes =
       array_of<at::ScalarType>(at::kHalf, at::kFloat);
 
   //  Define gate functions that determine if a mem efficient kernel can be ran
@@ -381,10 +381,10 @@
   }
 
   auto dprop = at::cuda::getCurrentDeviceProperties();
-  if (dprop->major == 5) {
-    return check_tensor_dtype(params, sm50_mem_efficient_dtypes, debug);
+  if (dprop->major >= 8) {
+    return check_tensor_dtype(params, greater_than_or_equal_sm80_mem_efficient_dtypes, debug);
   }
-  return check_tensor_dtype(params, default_mem_efficient_dtypes, debug);
+  return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
 }
 
 SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 2a4d91c..d7503c3 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -68,6 +68,7 @@
 isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
 isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
 isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
+isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
 
 def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
     deviation = true_value - computed_value
@@ -1503,8 +1504,9 @@
 
 
     @onlyCUDA
-    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
-    def test_mem_efficient_fail_bfloat16_sm50(self, device):
+    @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
+                     "Current platform does not support fused SDPA or is an SM80+ device.")
+    def test_mem_efficient_fail_bfloat16_less_than_sm80(self, device):
         dtype = torch.bfloat16
         size = SdpaShape(16, 16, 32, 32)
         make_tensor = partial(torch.rand, size, device=device, dtype=dtype)