[PyTorch] Exercise MHA fast path in JIT
Tests previously did not exercise this; now they do.
Differential Revision: [D35945821](https://our.internmc.facebook.com/intern/diff/D35945821/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76416
Approved by: https://github.com/ezyang
diff --git a/test/test_jit.py b/test/test_jit.py
index cd6bc43..76e1822 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15084,6 +15084,22 @@
# print(jit_out / py_out - 1)
self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
+ def test_torchscript_multi_head_attn_fast_path(self):
+ src_l = 3
+ bsz = 5
+ embed_size = 8
+ nhead = 2
+ multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
+ multi_head_attn = multi_head_attn.eval()
+
+ query = key = value = torch.rand((bsz, src_l, embed_size))
+
+ with torch.no_grad():
+ py_out = multi_head_attn(query, key, value)
+ mha = torch.jit.script(multi_head_attn)
+ jit_out = mha(query, key, value)
+ torch.testing.assert_close(jit_out, py_out)
+
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_multi_head_attn_cuda(self):