Revert "[Inductor] Add FlexAttention backward kernel dynamic shape tests (#127728)"
This reverts commit 10e3406ea5d115a54a7d753d33110762eb6c07ff.
Reverted https://github.com/pytorch/pytorch/pull/127728 on behalf of https://github.com/yanboliang due to Ineternal breakage of https://github.com/pytorch/pytorch/pull/127208 hence reverting ([comment](https://github.com/pytorch/pytorch/pull/127728#issuecomment-2145822667))
diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py
index 16206b5..d4feead 100644
--- a/test/inductor/test_flex_attention.py
+++ b/test/inductor/test_flex_attention.py
@@ -151,47 +151,6 @@
msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
- def _check_out_and_grad(
- self,
- golden_out: torch.Tensor,
- ref_out: torch.Tensor,
- compiled_out: torch.Tensor,
- q_gold: torch.Tensor,
- q_ref: torch.Tensor,
- q: torch.Tensor,
- k_gold: torch.Tensor,
- k_ref: torch.Tensor,
- k: torch.Tensor,
- v_gold: torch.Tensor,
- v_ref: torch.Tensor,
- v: torch.Tensor,
- ):
- dtype = ref_out.dtype
- with torch.no_grad():
- # Note, it seems like we really are less accurate than the float32
- # computation, likely due to the online softmax
- if dtype == torch.float32:
- fudge_factor = 10.0
- else:
- fudge_factor = 1.1
-
- # Checkout output
- self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
-
- # Check gradients
- q_fudge_factor = 2.5 * fudge_factor
- self._check_equal(
- q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
- )
- k_fudge_factor = 4 * fudge_factor
- self._check_equal(
- k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
- )
- v_fudge_factor = 4 * fudge_factor
- self._check_equal(
- v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
- )
-
def run_test(
self,
score_mod: Callable,
@@ -218,20 +177,30 @@
ref_out.backward(backward_grad)
compiled_out.backward(backward_grad)
- self._check_out_and_grad(
- golden_out,
- ref_out,
- compiled_out,
- q_gold,
- q_ref,
- q,
- k_gold,
- k_ref,
- k,
- v_gold,
- v_ref,
- v,
- )
+ with torch.no_grad():
+ # Note, it seems like we really are less accurate than the float32
+ # computation, likely due to the online softmax
+ if dtype == torch.float32:
+ fudge_factor = 10.0
+ else:
+ fudge_factor = 1.1
+
+ # Checkout output
+ self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out")
+
+ # Check gradients
+ q_fudge_factor = 2.5 * fudge_factor
+ self._check_equal(
+ q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query"
+ )
+ k_fudge_factor = 4 * fudge_factor
+ self._check_equal(
+ k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key"
+ )
+ v_fudge_factor = 4 * fudge_factor
+ self._check_equal(
+ v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value"
+ )
def run_dynamic_test(
self,
@@ -244,34 +213,24 @@
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
- q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- q1_ref, k1_ref, v1_ref = query_key_value_clones(q1, k1, v1)
- q1_gold, k1_gold, v1_gold = query_key_value_clones(q1, k1, v1, torch.float64)
- ref_out1 = sdpa_partial(q1_ref, k1_ref, v1_ref)
- golden_out1 = sdpa_partial(q1_gold, k1_gold, v1_gold)
-
- backward_grad1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
-
- golden_out1.backward(backward_grad1.to(torch.float64))
- ref_out1.backward(backward_grad1)
+ q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ golden_out1 = sdpa_partial(
+ q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
+ )
+ ref_out1 = sdpa_partial(q1, k1, v1)
# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
- q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True)
- q2_ref, k2_ref, v2_ref = query_key_value_clones(q2, k2, v2)
- q2_gold, k2_gold, v2_gold = query_key_value_clones(q2, k2, v2, torch.float64)
- ref_out2 = sdpa_partial(q2_ref, k2_ref, v2_ref)
- golden_out2 = sdpa_partial(q2_gold, k2_gold, v2_gold)
-
- backward_grad2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
-
- golden_out2.backward(backward_grad2.to(torch.float64))
- ref_out2.backward(backward_grad2)
+ q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
+ golden_out2 = sdpa_partial(
+ q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
+ )
+ ref_out2 = sdpa_partial(q2, k2, v2)
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
@@ -279,41 +238,20 @@
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
- compiled_out1.backward(backward_grad1)
- self._check_out_and_grad(
- golden_out1,
- ref_out1,
- compiled_out1,
- q1_gold,
- q1_ref,
- q1,
- k1_gold,
- k1_ref,
- k1,
- v1_gold,
- v1_ref,
- v1,
- )
+ # Note, it seems like we really are less accurate than the float32
+ # computation, likely due to the online softmax
+ if dtype == torch.float32:
+ fudge_factor = 10.0
+ else:
+ fudge_factor = 1.1
+
+ self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
- compiled_out2.backward(backward_grad2)
- self._check_out_and_grad(
- golden_out2,
- ref_out2,
- compiled_out2,
- q2_gold,
- q2_ref,
- q2,
- k2_gold,
- k2_ref,
- k2,
- v2_gold,
- v2_ref,
- v2,
- )
+ self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
def run_automatic_dynamic_test(