[PyTorch] Don't enter MHA fast path when bias & query dtypes don't match
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76879
The fast path does not support this: transform_bias_rescale_qkv will try to grab bias.data_ptr() assuming the dtypes are the same. (Also, I have no idea how this happens.)
Differential Revision: [D36156872](https://our.internmc.facebook.com/intern/diff/D36156872/)
Approved by: https://github.com/cpuhrsch
diff --git a/test/test_nn.py b/test/test_nn.py
index f7e96d2..ebbeacc 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -20251,6 +20251,14 @@
with cm:
_test(batch_first=batch_first, training=training, atol=atol, rtol=rtol)
+ @dtypes(torch.double)
+ @torch.no_grad()
+ def test_multihead_attn_fast_path_query_and_bias_have_different_dtypes(self, device, dtype):
+ mha = torch.nn.MultiheadAttention(3, 3, batch_first=True, dtype=dtype, device=device).eval()
+ mha.in_proj_bias = torch.nn.Parameter(mha.in_proj_bias.to(torch.half).to(device))
+ query = torch.randn(3, 3, 3, dtype=dtype, device=device)
+ mha(query, query, query)
+
@dtypes(torch.float)
@dtypesIfCUDA(torch.half, torch.float)
def test_transformerencoderlayer_gelu(self, device, dtype):
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index 46ab379..151b0c5 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -1062,7 +1062,15 @@
if not is_batched:
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
+ # When lifting this restriction, don't forget to either
+ # enforce that the dtypes all match or test cases where
+ # they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
+ elif query.dtype != self.in_proj_bias.dtype:
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
+ elif query.dtype != self.in_proj_weight.dtype:
+ # this case will fail anyway, but at least they'll get a useful error message.
+ why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first: