Fix cuda graphs & sdpa for dropout==0 (#101280)
Fixes cuda graph failures from https://github.com/pytorch/pytorch/pull/100931
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101280
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
index 092762a..d4ce611 100644
--- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
+++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
@@ -341,8 +341,13 @@
}
launch_params.params.philox_args = philox_state;
} else {
- seed_t = at::empty({}, at::dtype(at::kLong));
- offset_t = at::empty({}, at::dtype(at::kLong));
+ if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
+ seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
+ offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
+ } else {
+ seed_t = at::empty({}, at::dtype(at::kLong));
+ offset_t = at::empty({}, at::dtype(at::kLong));
+ }
}
run_fmha_fwd(launch_params);
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 053678f..1f4ed7e 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -1990,6 +1990,7 @@
tmp = torch.rand_like(query, device=query.device) # test non-zero intragraph offset
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=True)
+ assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
g.replay()
out_first = output_tuple[0].clone()
dbug_mask_first = output_tuple[-1].clone()
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index ef009df..a3dd6c2 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -464,10 +464,6 @@
return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
-def is_cuda_tensor(x):
- return isinstance(x, torch.Tensor) and x.device.type == "cuda"
-
-
@contextlib.contextmanager
def _use_cuda_memory_pool_manager(device, mem_pool, stream):
"""
@@ -579,9 +575,11 @@
assert len(new_inputs) == 0
+ # sdpa returns cpu tensors when not recording cuda graph
def add_ref(o):
return (
o is not None
+ and o.is_cuda
and o.untyped_storage().data_ptr() not in non_cudagraph_inps
)
@@ -1057,6 +1055,11 @@
self.output_storage_alias.append(UnaliasedStorage)
continue
+ check(
+ o.is_cuda,
+ lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
+ ),
+
ref = static_input_persistent_storage_ptrs.get(
o.untyped_storage().data_ptr(), None
)