[PyTorch] Don't enter MHA fast path when bias & query dtypes don't match

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76879

The fast path does not support this: transform_bias_rescale_qkv will try to grab bias.data_ptr() assuming the dtypes are the same. (Also, I have no idea how this happens.)

Differential Revision: [D36156872](https://our.internmc.facebook.com/intern/diff/D36156872/)

Approved by: https://github.com/cpuhrsch
diff --git a/test/test_nn.py b/test/test_nn.py
index f7e96d2..ebbeacc 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -20251,6 +20251,14 @@
                 with cm:
                     _test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
 
+    @dtypes(torch.double)
+    @torch.no_grad()
+    def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(self, device, dtype):
+        mha = torch.nn.MultiheadAttention(3, 3, batch_first=True, dtype=dtype, device=device).eval()
+        mha.in_proj_bias = torch.nn.Parameter(mha.in_proj_bias.to(torch.half).to(device))
+        query = torch.randn(3, 3, 3, dtype=dtype, device=device)
+        mha(query, query, query)
+
     @dtypes(torch.float)
     @dtypesIfCUDA(torch.half, torch.float)
     def test_transformerencoderlayer_gelu(self, device, dtype):
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index 46ab379..151b0c5 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -1062,7 +1062,15 @@
         if not is_batched:
             why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
         elif query is not key or key is not value:
+            # When lifting this restriction, don't forget to either
+            # enforce that the dtypes all match or test cases where
+            # they don't!
             why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
+        elif query.dtype != self.in_proj_bias.dtype:
+            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+        elif query.dtype != self.in_proj_weight.dtype:
+            # this case will fail anyway, but at least they'll get a useful error message.
+            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
         elif self.training:
             why_not_fast_path = "training is enabled"
         elif not self.batch_first: