Revert "[Inductor] Add FlexAttention backward kernel dynamic shape tests (#127728)"

This reverts commit 10e3406ea5d115a54a7d753d33110762eb6c07ff.

Reverted https://github.com/pytorch/pytorch/pull/127728 on behalf of https://github.com/yanboliang due to Ineternal breakage of https://github.com/pytorch/pytorch/pull/127208 hence reverting ([comment](https://github.com/pytorch/pytorch/pull/127728#issuecomment-2145822667))
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 16206b5..d4feead 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -151,47 +151,6 @@
             msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
             self.assertTrue(False, msg)
 
-    def _check_out_and_grad(
-        self,
-        golden_out: torch.Tensor,
-        ref_out: torch.Tensor,
-        compiled_out: torch.Tensor,
-        q_gold: torch.Tensor,
-        q_ref: torch.Tensor,
-        q: torch.Tensor,
-        k_gold: torch.Tensor,
-        k_ref: torch.Tensor,
-        k: torch.Tensor,
-        v_gold: torch.Tensor,
-        v_ref: torch.Tensor,
-        v: torch.Tensor,
-    ):
-        dtype = ref_out.dtype
-        with torch.no_grad():
-            # Note, it seems like we really are less accurate than the float32
-            # computation, likely due to the online softmax
-            if dtype == torch.float32:
-                fudge_factor = 10.0
-            else:
-                fudge_factor = 1.1
-
-            # Checkout output
-            self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
-
-            # Check gradients
-            q_fudge_factor = 2.5 * fudge_factor
-            self._check_equal(
-                q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
-            )
-            k_fudge_factor = 4 * fudge_factor
-            self._check_equal(
-                k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
-            )
-            v_fudge_factor = 4 * fudge_factor
-            self._check_equal(
-                v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
-            )
-
     def run_test(
         self,
         score_mod: Callable,
@@ -218,20 +177,30 @@
         ref_out.backward(backward_grad)
         compiled_out.backward(backward_grad)
 
-        self._check_out_and_grad(
-            golden_out,
-            ref_out,
-            compiled_out,
-            q_gold,
-            q_ref,
-            q,
-            k_gold,
-            k_ref,
-            k,
-            v_gold,
-            v_ref,
-            v,
-        )
+        with torch.no_grad():
+            # Note, it seems like we really are less accurate than the float32
+            # computation, likely due to the online softmax
+            if dtype == torch.float32:
+                fudge_factor = 10.0
+            else:
+                fudge_factor = 1.1
+
+            # Checkout output
+            self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
+
+            # Check gradients
+            q_fudge_factor = 2.5 * fudge_factor
+            self._check_equal(
+                q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
+            )
+            k_fudge_factor = 4 * fudge_factor
+            self._check_equal(
+                k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
+            )
+            v_fudge_factor = 4 * fudge_factor
+            self._check_equal(
+                v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
+            )
 
     def run_dynamic_test(
         self,
@@ -244,34 +213,24 @@
     ):
         sdpa_partial = create_attention(score_mod)
         # The first eager batch, shape (B, H, S, D)
-        q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
-        q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
-        ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
-        golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
-
-        backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
-
-        golden_out1.backward(backward_grad1.to(torch.float64))
-        ref_out1.backward(backward_grad1)
+        q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        golden_out1 = sdpa_partial(
+            q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
+        )
+        ref_out1 = sdpa_partial(q1, k1, v1)
 
         # The second eager batch, shape (B * 2, H, S / 2, D)
         B = int(B * 2)
         S = int(S / 2)
-        q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
-        q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
-        q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
-        ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
-        golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
-
-        backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
-
-        golden_out2.backward(backward_grad2.to(torch.float64))
-        ref_out2.backward(backward_grad2)
+        q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+        golden_out2 = sdpa_partial(
+            q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
+        )
+        ref_out2 = sdpa_partial(q2, k2, v2)
 
         # Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
         # We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
@@ -279,41 +238,20 @@
         # Compiling with dynamic shape in the first batch.
         compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
         compiled_out1 = compiled_sdpa(q1, k1, v1)
-        compiled_out1.backward(backward_grad1)
 
-        self._check_out_and_grad(
-            golden_out1,
-            ref_out1,
-            compiled_out1,
-            q1_gold,
-            q1_ref,
-            q1,
-            k1_gold,
-            k1_ref,
-            k1,
-            v1_gold,
-            v1_ref,
-            v1,
-        )
+        # Note, it seems like we really are less accurate than the float32
+        # computation, likely due to the online softmax
+        if dtype == torch.float32:
+            fudge_factor = 10.0
+        else:
+            fudge_factor = 1.1
+
+        self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
         self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
 
         # No re-compilation, use the compiled dynamic shape version.
         compiled_out2 = compiled_sdpa(q2, k2, v2)
-        compiled_out2.backward(backward_grad2)
-        self._check_out_and_grad(
-            golden_out2,
-            ref_out2,
-            compiled_out2,
-            q2_gold,
-            q2_ref,
-            q2,
-            k2_gold,
-            k2_ref,
-            k2,
-            v2_gold,
-            v2_ref,
-            v2,
-        )
+        self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
         self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
 
     def run_automatic_dynamic_test(