[PyTorch] Exercise MHA fast path in JIT

Tests previously did not exercise this; now they do.

Differential Revision: [D35945821](https://our.internmc.facebook.com/intern/diff/D35945821/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76416

Approved by: https://github.com/ezyang
diff --git a/test/test_jit.py b/test/test_jit.py
index cd6bc43..76e1822 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15084,6 +15084,22 @@
         # print(jit_out / py_out - 1)
         self.assertEqual(jit_out, py_out, atol=5e-4, rtol=1e-4)
 
+    def test_torchscript_multi_head_attn_fast_path(self):
+        src_l = 3
+        bsz = 5
+        embed_size = 8
+        nhead = 2
+        multi_head_attn = torch.nn.MultiheadAttention(embed_size, nhead, batch_first=True)
+        multi_head_attn = multi_head_attn.eval()
+
+        query = key = value = torch.rand((bsz, src_l, embed_size))
+
+        with torch.no_grad():
+            py_out = multi_head_attn(query, key, value)
+            mha = torch.jit.script(multi_head_attn)
+            jit_out = mha(query, key, value)
+        torch.testing.assert_close(jit_out, py_out)
+
     @unittest.skipIf(not RUN_CUDA, "no CUDA")
     def test_scriptmodule_multi_head_attn_cuda(self):