Fix cuda graphs & sdpa for dropout==0 (#101280)

Fixes cuda graph failures from https://github.com/pytorch/pytorch/pull/100931

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101280
Approved by: https://github.com/ngimel
diff --git a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
index 092762a..d4ce611 100644
--- a/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
+++ b/aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
@@ -341,8 +341,13 @@
         }
         launch_params.params.philox_args = philox_state;
     } else {
-        seed_t = at::empty({}, at::dtype(at::kLong));
-        offset_t = at::empty({}, at::dtype(at::kLong));
+        if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
+            seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
+            offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
+        } else {
+            seed_t = at::empty({}, at::dtype(at::kLong));
+            offset_t = at::empty({}, at::dtype(at::kLong));
+        }
     }
 
     run_fmha_fwd(launch_params);
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 053678f..1f4ed7e 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -1990,6 +1990,7 @@
             tmp = torch.rand_like(query, device=query.device)  # test non-zero intragraph offset
             output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
                 query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=True)
+            assert all(not isinstance(o, torch.Tensor) or o.is_cuda for o in output_tuple)
         g.replay()
         out_first = output_tuple[0].clone()
         dbug_mask_first = output_tuple[-1].clone()
diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py
index ef009df..a3dd6c2 100644
--- a/torch/_inductor/cudagraph_trees.py
+++ b/torch/_inductor/cudagraph_trees.py
@@ -464,10 +464,6 @@
             return f"StorageWeakRefWrapper to {self.data_ptr()}; alive"
 
 
-def is_cuda_tensor(x):
-    return isinstance(x, torch.Tensor) and x.device.type == "cuda"
-
-
 @contextlib.contextmanager
 def _use_cuda_memory_pool_manager(device, mem_pool, stream):
     """
@@ -579,9 +575,11 @@
 
         assert len(new_inputs) == 0
 
+        # sdpa returns cpu tensors when not recording cuda graph
         def add_ref(o):
             return (
                 o is not None
+                and o.is_cuda
                 and o.untyped_storage().data_ptr() not in non_cudagraph_inps
             )
 
@@ -1057,6 +1055,11 @@
                 self.output_storage_alias.append(UnaliasedStorage)
                 continue
 
+            check(
+                o.is_cuda,
+                lambda: f"Expected all cuda outputs in cuda graph recording. Non cuda output from {self.stack_traces[i]}",
+            ),
+
             ref = static_input_persistent_storage_ptrs.get(
                 o.untyped_storage().data_ptr(), None
             )