disable src mask for transformer and multiheadattention fastpath (#81277)

Disable fastpath if src_mask passed to TransformerEncoderLayer and MultiheadAttention.
- Refactored test_transformerencoder from test_nn.py to test_transformers.py. Added a src_mask test there.
- Added a specific src_mask test in test_transformers.py

Fixes https://github.com/pytorch/pytorch/issues/81129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81277
Approved by: https://github.com/zrphercule
diff --git a/test/test_nn.py b/test/test_nn.py
index ad39aca..336dd35 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -5915,32 +5915,6 @@
             # output_2d in shape of [T, 1, D]
             self.assertEqual(output_3d[i].unsqueeze(0).transpose(0, 1), output_2d)
 
-    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
-    def test_self_attn_TxT_attn_mask(self):
-        embed_dim = 16
-        num_heads = 4
-        batch_size = 10
-        tgt_len = 16
-
-        query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda")  # [N, T, D]
-        attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()  # [T, T]
-        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))
-
-        attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
-
-        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
-        mta_model.eval()
-
-        # Generate 3D results
-        with torch.inference_mode():
-            output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
-            output_mask_4d = output_mask_4d.transpose(0, 1)  # [N, T, D]
-
-            output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
-            output_mask_TxT = output_mask_TxT.transpose(0, 1)  # [N, T, D]
-
-            self.assertEqual(output_mask_4d, output_mask_TxT)
-
     def test_multihead_attn_no_bias(self):
         embed_dim = 8
         num_heads = 4
@@ -7985,190 +7959,6 @@
                                                 [2.42240309, 0.0354595, -0.60659063, -0.05378816]]]))
             torch.testing.assert_close(result, ref_output, rtol=1e-5, atol=0)
 
-    def test_transformerencoder(self):
-        def get_a_test_layer(use_cuda, activation, batch_first=False):
-            d_model = 4
-            nhead = 2
-            dim_feedforward = 16
-            dropout = 0.0
-            device = torch.device("cuda" if use_cuda else "cpu")
-
-            layer = nn.TransformerEncoderLayer(
-                d_model,
-                nhead,
-                dim_feedforward=dim_feedforward,
-                dropout=dropout,
-                activation=activation,
-                batch_first=batch_first,
-            ).to(device)
-
-            with torch.no_grad():
-                # set constant weights of the model
-                for idx, p in enumerate(layer.parameters()):
-                    x = p.data
-                    sz = x.view(-1).size(0)
-                    shape = x.shape
-                    x = torch.cos(torch.arange(0, sz).float().view(shape))
-                    p.data.copy_(x)
-
-            return layer
-
-        # this is a deterministic test for TransformerEncoder
-        activation = F.relu
-        use_cuda = torch.cuda.is_available()
-        device = torch.device("cuda" if use_cuda else "cpu")
-
-        def _test(batch_first, training, enable_nested_tensor):
-            def perm_fn(x):
-                return x.transpose(1, 0) if batch_first else x
-
-            encoder_layer = get_a_test_layer(use_cuda=use_cuda, activation=activation,
-                                             batch_first=batch_first)
-
-            model = nn.TransformerEncoder(encoder_layer, 1).to(device)
-            if not training:
-                model = model.eval()
-
-            # deterministic input
-            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
-                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
-                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
-                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
-                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
-                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
-                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
-                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
-                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
-                                                   [0.3718, 0.4945, 0.9511, 0.0864]]]
-                                                 )).to(device)
-            result = model(encoder_input)
-            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
-                                                [2.427987, 0.021213, -0.602496, -0.084103]],
-                                               [[2.424689, 0.019155, -0.604793, -0.085672],
-                                                [2.413863, 0.022211, -0.612486, -0.072490]],
-                                               [[2.433774, 0.021598, -0.598343, -0.087548],
-                                                [2.425104, 0.019748, -0.604515, -0.084839]],
-                                               [[2.436185, 0.022682, -0.596625, -0.087261],
-                                                [2.433556, 0.021891, -0.598509, -0.086832]],
-                                               [[2.416246, 0.017512, -0.610712, -0.082961],
-                                                [2.422901, 0.024187, -0.606178, -0.074929]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-            # all 0
-            mask = torch.zeros([2, 5]).to(device) == 1
-            result = model(encoder_input, src_key_padding_mask=mask)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-            mask[0, 1] = 1
-            mask[1, 3] = 1
-            mask[1, 4] = 1
-            # If mask is not left aligned
-            # We disable nested tensor
-            model.enable_nested_tensor = enable_nested_tensor
-            result = model(encoder_input, src_key_padding_mask=mask)
-            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
-                                                [2.428811, 0.021445, -0.601912, -0.084252]],
-                                               [[2.425009, 0.019155, -0.604566, -0.085899],
-                                                [2.415408, 0.02249, -0.611415, -0.073]],
-                                               [[2.434199, 0.021682, -0.598039, -0.087699],
-                                                [2.42598, 0.019941, -0.603896, -0.085091]],
-                                               [[2.436457, 0.022736, -0.59643, -0.08736],
-                                                [2.434021, 0.022093, -0.598179, -0.08679]],
-                                               [[2.416531, 0.017498, -0.610513, -0.083181],
-                                                [2.4242, 0.024653, -0.605266, -0.074959]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-            # test case 2, multiple layers no norm
-            model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
-            if not training:
-                model = model.eval()
-            result = model(encoder_input, src_key_padding_mask=mask)
-            ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
-                                                [2.419102, 0.017452, -0.608703, -0.085026]],
-                                               [[2.419043, 0.017445, -0.608744, -0.084999],
-                                                [2.419052, 0.017446, -0.608738, -0.085004]],
-                                               [[2.419067, 0.017448, -0.608727, -0.085010],
-                                                [2.419098, 0.017452, -0.608706, -0.085024]],
-                                               [[2.419072, 0.017449, -0.608724, -0.085012],
-                                                [2.419119, 0.017455, -0.608691, -0.085034]],
-                                               [[2.419019, 0.017442, -0.608761, -0.084989],
-                                                [2.419075, 0.017449, -0.608722, -0.085014]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-            model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
-            if not training:
-                model = model.eval()
-            result = model(encoder_input, src_key_padding_mask=mask)
-            ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
-                                                [2.419101, 0.017453, -0.608704, -0.085025]],
-                                               [[2.419101, 0.017453, -0.608703, -0.085025],
-                                                [2.419101, 0.017453, -0.608704, -0.085025]],
-                                               [[2.419101, 0.017453, -0.608703, -0.085025],
-                                                [2.419101, 0.017453, -0.608704, -0.085025]],
-                                               [[2.419101, 0.017453, -0.608703, -0.085025],
-                                                [2.419101, 0.017453, -0.608704, -0.085025]],
-                                               [[2.419101, 0.017453, -0.608703, -0.085025],
-                                                [2.419101, 0.017453, -0.608704, -0.085025]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-            # test case 3, multiple layers with norm
-            # d_model = 4
-            norm = nn.LayerNorm(4)
-            model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
-            if not training:
-                model = model.eval()
-            result = model(encoder_input, src_key_padding_mask=mask)
-            ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
-                                                [1.695955, -0.357639, -0.893050, -0.445266]],
-                                               [[1.695948, -0.357634, -0.893082, -0.445233],
-                                                [1.695950, -0.357635, -0.893077, -0.445238]],
-                                               [[1.695951, -0.357636, -0.893069, -0.445246],
-                                                [1.695955, -0.357639, -0.893052, -0.445264]],
-                                               [[1.695952, -0.357636, -0.893066, -0.445249],
-                                                [1.695957, -0.357641, -0.893041, -0.445276]],
-                                               [[1.695946, -0.357632, -0.893095, -0.445220],
-                                                [1.695952, -0.357637, -0.893065, -0.445251]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-            model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
-            if not training:
-                model = model.eval()
-            result = model(encoder_input, src_key_padding_mask=mask)
-            ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
-                                                [1.695955, -0.357639, -0.893051, -0.445265]],
-                                               [[1.695955, -0.357639, -0.893051, -0.445265],
-                                                [1.695955, -0.357639, -0.893051, -0.445265]],
-                                               [[1.695955, -0.357639, -0.893051, -0.445265],
-                                                [1.695955, -0.357639, -0.893051, -0.445265]],
-                                               [[1.695955, -0.357639, -0.893051, -0.445265],
-                                                [1.695955, -0.357639, -0.893051, -0.445265]],
-                                               [[1.695955, -0.357639, -0.893051, -0.445265],
-                                                [1.695955, -0.357639, -0.893051, -0.445265]]]
-                                              )).to(device)
-            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
-            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
-
-        for batch_first in (True, False):
-            for training in (True, False):
-                for enable_nested_tensor in (True, False):
-                    # Fast path requires inference mode.
-                    if training:
-                        cm = contextlib.nullcontext()
-                    else:
-                        cm = torch.no_grad()
-                    with cm:
-                        _test(batch_first, training, enable_nested_tensor)
-
     def test_transformerdecoder(self):
         def get_a_test_layer(use_cuda, activation, batch_first=False):
             d_model = 4
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 19670f4..bff415e 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -1,19 +1,83 @@
 # Owner(s): ["module: nn"]
 
+import contextlib
 import torch
+import torch.nn as nn
+import torch.nn.functional as F
 import unittest
 
 from torch.testing._internal.common_nn import NNTestCase
-from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests
+from torch.testing._internal.common_utils import TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests
 from torch.testing._internal.common_cuda import TEST_CUDA
 
 if TEST_FAIRSEQ:
     import fairseq.models.transformer as fairseq_transformer
 
+@contextlib.contextmanager
+def set_default_dtype(dtype):
+    saved_dtype = torch.get_default_dtype()
+    torch.set_default_dtype(dtype)
+    try:
+        yield
+    finally:
+        torch.set_default_dtype(saved_dtype)
+
 class TestTransformers(NNTestCase):
     _do_cuda_memory_leak_check = True
     _do_cuda_non_default_stream = True
 
+    device_list = ['cpu']  # TODO: is there a way to do parametrize for this?
+    if TEST_CUDA:
+        device_list.append('cuda')
+
+    @unittest.skip("4D mask not supported yet - activate when 4D mask supported")
+    @unittest.skipIf(not TEST_CUDA, "CUDA unavailable")  # TODO: make this work for both cuda and cpu
+    def test_self_attn_TxT_attn_mask(self):
+        embed_dim = 16
+        num_heads = 4
+        batch_size = 10
+        tgt_len = 16
+
+        query = torch.rand(batch_size, tgt_len, embed_dim, device="cuda")  # [N, T, D]
+        attn_mask = torch.randint(0, 2, (tgt_len, tgt_len)).cuda().float()  # [T, T]
+        attn_mask = attn_mask.masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.0))
+
+        attn_mask_4d = attn_mask.expand(batch_size, num_heads, tgt_len, tgt_len)
+
+        mta_model = torch.nn.MultiheadAttention(embed_dim, num_heads, batch_first=True).cuda()
+        mta_model.eval()
+
+        # Generate 3D results
+        with torch.inference_mode():
+            output_mask_4d = mta_model(query, query, query, attn_mask=attn_mask_4d)[0]
+            output_mask_4d = output_mask_4d.transpose(0, 1)  # [N, T, D]
+
+            output_mask_TxT = mta_model(query, query, query, attn_mask=attn_mask)[0]
+            output_mask_TxT = output_mask_TxT.transpose(0, 1)  # [N, T, D]
+
+            self.assertEqual(output_mask_4d, output_mask_TxT)
+
+    @parametrize("device", device_list)
+    def test_transformerencoderlayer_src_mask(self, device):
+        batch_size = 2
+        seqlen = 4
+        d_model = 8
+        nhead = 8
+        dim_feedforward = 32
+
+        model = torch.nn.TransformerEncoderLayer(
+            d_model=d_model,
+            nhead=nhead,
+            dim_feedforward=dim_feedforward,
+            batch_first=True).to(device)
+        src = torch.rand(batch_size, seqlen, d_model).to(device)  # bs, seqlen, d_model
+        src_mask = torch.zeros(seqlen, seqlen).to(torch.bool).to(device)
+
+        model(src, src_mask=src_mask)
+        model.eval()
+        with torch.no_grad():
+            model(src, src_mask=src_mask)
+
     @parametrize("use_torchscript", [True, False])
     @parametrize("with_no_grad", [True, False])
     @parametrize("training", [True, False])
@@ -39,12 +103,252 @@
         mask = torch.Tensor([[0, 1]]).to(torch.bool)
 
         if with_no_grad:
-            with torch.no_grad():
-                model(x, src_key_padding_mask=mask)
+            cm = torch.no_grad()
         else:
+            cm = contextlib.nullcontext()
+        with cm:
             model(x, src_key_padding_mask=mask)
 
-    @unittest.skipIf(not TEST_FAIRSEQ, "numpy not found")
+    @parametrize("with_no_grad", [True, False])
+    @parametrize("training", [True, False])
+    @parametrize("enable_nested_tensor", [False])
+    @parametrize("device", device_list)
+    def test_transformerencoder_square_input(self, with_no_grad, training, enable_nested_tensor, device):
+        """
+        Test for edge cases when input of shape (batch size, sequence length, embedding dimension) has
+        batch size == sequence length
+        """
+        model = torch.nn.TransformerEncoder(
+            torch.nn.TransformerEncoderLayer(d_model=4, nhead=2, dim_feedforward=16, dropout=0.0, batch_first=True),
+            num_layers=2,
+            enable_nested_tensor=enable_nested_tensor
+        ).to(device)
+
+        with torch.no_grad():
+            # set constant weights of the model
+            for idx, p in enumerate(model.parameters()):
+                x = p.data
+                sz = x.view(-1).size(0)
+                shape = x.shape
+                x = torch.cos(torch.arange(0, sz).float().view(shape))
+                p.data.copy_(x)
+
+        if training:
+            model = model.train()
+        else:
+            model = model.eval()
+        x = torch.arange(0, 16).reshape(2, 2, 4).to(torch.float).to(device)
+        src_mask = torch.Tensor([[0, 1], [0, 0]]).to(torch.bool).to(device)
+
+        if with_no_grad:
+            cm = torch.no_grad()
+        else:
+            cm = contextlib.nullcontext()
+        with cm:
+            result = model(x, mask=src_mask)
+
+        ref_output = torch.Tensor([[[2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351],
+                                    [2.420306205749512, 0.017629241570830, -0.607857942581177, -0.085519507527351]],
+                                   [[2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689],
+                                    [2.419836044311523, 0.017548924311996, -0.608187675476074, -0.085347734391689]]]
+                                  ).to(device)
+        self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+        torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+    @parametrize("batch_first", [True, False])
+    @parametrize("training", [True, False])
+    @parametrize("enable_nested_tensor", [True, False])
+    @parametrize("device", device_list)
+    def test_transformerencoder(self, batch_first, training, enable_nested_tensor, device):
+        def get_a_test_layer(activation, batch_first=False):
+            d_model = 4
+            nhead = 2
+            dim_feedforward = 16
+            dropout = 0.0
+
+            layer = nn.TransformerEncoderLayer(
+                d_model,
+                nhead,
+                dim_feedforward=dim_feedforward,
+                dropout=dropout,
+                activation=activation,
+                batch_first=batch_first,
+            ).to(device)
+
+            with torch.no_grad():
+                # set constant weights of the model
+                for idx, p in enumerate(layer.parameters()):
+                    x = p.data
+                    sz = x.view(-1).size(0)
+                    shape = x.shape
+                    x = torch.cos(torch.arange(0, sz).float().view(shape))
+                    p.data.copy_(x)
+
+            return layer
+
+        # this is a deterministic test for TransformerEncoder
+        activation = F.relu
+
+        def _test(batch_first, training, enable_nested_tensor):
+            def perm_fn(x):
+                return x.transpose(1, 0) if batch_first else x
+
+            encoder_layer = get_a_test_layer(activation=activation,
+                                             batch_first=batch_first)
+
+            model = nn.TransformerEncoder(encoder_layer, 1).to(device)
+            if not training:
+                model = model.eval()
+
+            # deterministic input
+            encoder_input = perm_fn(torch.tensor([[[0.7462, 0.6653, 0.5679, 0.4891],
+                                                   [0.5387, 0.1655, 0.3565, 0.0471]],
+                                                  [[0.8335, 0.2799, 0.5031, 0.2947],
+                                                   [0.1402, 0.0318, 0.7636, 0.1346]],
+                                                  [[0.6333, 0.9344, 0.1376, 0.9938],
+                                                   [0.8924, 0.2872, 0.6692, 0.2944]],
+                                                  [[0.9897, 0.6915, 0.3154, 0.1733],
+                                                   [0.8645, 0.3513, 0.3064, 0.0767]],
+                                                  [[0.8117, 0.2366, 0.4838, 0.7881],
+                                                   [0.3718, 0.4945, 0.9511, 0.0864]]]
+                                                 )).to(device)
+            result = model(encoder_input)
+            ref_output = perm_fn(torch.tensor([[[2.428589, 0.020835, -0.602055, -0.085249],
+                                                [2.427987, 0.021213, -0.602496, -0.084103]],
+                                               [[2.424689, 0.019155, -0.604793, -0.085672],
+                                                [2.413863, 0.022211, -0.612486, -0.072490]],
+                                               [[2.433774, 0.021598, -0.598343, -0.087548],
+                                                [2.425104, 0.019748, -0.604515, -0.084839]],
+                                               [[2.436185, 0.022682, -0.596625, -0.087261],
+                                                [2.433556, 0.021891, -0.598509, -0.086832]],
+                                               [[2.416246, 0.017512, -0.610712, -0.082961],
+                                                [2.422901, 0.024187, -0.606178, -0.074929]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            # all 0 src_mask
+            src_mask = torch.zeros([5, 5]).to(device) == 1
+            result = model(encoder_input, mask=src_mask)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            # all 0
+            mask = torch.zeros([2, 5]).to(device) == 1
+            result = model(encoder_input, src_key_padding_mask=mask)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            mask[0, 1] = 1
+            mask[1, 3] = 1
+            mask[1, 4] = 1
+            # If mask is not left aligned
+            # We disable nested tensor
+            model.enable_nested_tensor = enable_nested_tensor
+            result = model(encoder_input, src_key_padding_mask=mask)
+            ref_output = perm_fn(torch.tensor([[[2.429026, 0.020793, -0.601741, -0.085642],
+                                                [2.428811, 0.021445, -0.601912, -0.084252]],
+                                               [[2.425009, 0.019155, -0.604566, -0.085899],
+                                                [2.415408, 0.02249, -0.611415, -0.073]],
+                                               [[2.434199, 0.021682, -0.598039, -0.087699],
+                                                [2.42598, 0.019941, -0.603896, -0.085091]],
+                                               [[2.436457, 0.022736, -0.59643, -0.08736],
+                                                [2.434021, 0.022093, -0.598179, -0.08679]],
+                                               [[2.416531, 0.017498, -0.610513, -0.083181],
+                                                [2.4242, 0.024653, -0.605266, -0.074959]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            # test case 2, multiple layers no norm
+            model = nn.TransformerEncoder(encoder_layer, 2, enable_nested_tensor=enable_nested_tensor).to(device)
+            if not training:
+                model = model.eval()
+            result = model(encoder_input, src_key_padding_mask=mask)
+            ref_output = perm_fn(torch.tensor([[[2.419051, 0.017446, -0.608738, -0.085003],
+                                                [2.419102, 0.017452, -0.608703, -0.085026]],
+                                               [[2.419043, 0.017445, -0.608744, -0.084999],
+                                                [2.419052, 0.017446, -0.608738, -0.085004]],
+                                               [[2.419067, 0.017448, -0.608727, -0.085010],
+                                                [2.419098, 0.017452, -0.608706, -0.085024]],
+                                               [[2.419072, 0.017449, -0.608724, -0.085012],
+                                                [2.419119, 0.017455, -0.608691, -0.085034]],
+                                               [[2.419019, 0.017442, -0.608761, -0.084989],
+                                                [2.419075, 0.017449, -0.608722, -0.085014]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            model = nn.TransformerEncoder(encoder_layer, 6, enable_nested_tensor=enable_nested_tensor).to(device)
+            if not training:
+                model = model.eval()
+            result = model(encoder_input, src_key_padding_mask=mask)
+            ref_output = perm_fn(torch.tensor([[[2.419101, 0.017453, -0.608703, -0.085025],
+                                                [2.419101, 0.017453, -0.608704, -0.085025]],
+                                               [[2.419101, 0.017453, -0.608703, -0.085025],
+                                                [2.419101, 0.017453, -0.608704, -0.085025]],
+                                               [[2.419101, 0.017453, -0.608703, -0.085025],
+                                                [2.419101, 0.017453, -0.608704, -0.085025]],
+                                               [[2.419101, 0.017453, -0.608703, -0.085025],
+                                                [2.419101, 0.017453, -0.608704, -0.085025]],
+                                               [[2.419101, 0.017453, -0.608703, -0.085025],
+                                                [2.419101, 0.017453, -0.608704, -0.085025]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            # test case 3, multiple layers with norm
+            # d_model = 4
+            norm = nn.LayerNorm(4)
+            model = nn.TransformerEncoder(encoder_layer, 2, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
+            if not training:
+                model = model.eval()
+            result = model(encoder_input, src_key_padding_mask=mask)
+            ref_output = perm_fn(torch.tensor([[[1.695949, -0.357635, -0.893077, -0.445238],
+                                                [1.695955, -0.357639, -0.893050, -0.445266]],
+                                               [[1.695948, -0.357634, -0.893082, -0.445233],
+                                                [1.695950, -0.357635, -0.893077, -0.445238]],
+                                               [[1.695951, -0.357636, -0.893069, -0.445246],
+                                                [1.695955, -0.357639, -0.893052, -0.445264]],
+                                               [[1.695952, -0.357636, -0.893066, -0.445249],
+                                                [1.695957, -0.357641, -0.893041, -0.445276]],
+                                               [[1.695946, -0.357632, -0.893095, -0.445220],
+                                                [1.695952, -0.357637, -0.893065, -0.445251]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+            model = nn.TransformerEncoder(encoder_layer, 6, norm=norm, enable_nested_tensor=enable_nested_tensor).to(device)
+            if not training:
+                model = model.eval()
+            result = model(encoder_input, src_key_padding_mask=mask)
+            ref_output = perm_fn(torch.tensor([[[1.695955, -0.357639, -0.893051, -0.445265],
+                                                [1.695955, -0.357639, -0.893051, -0.445265]],
+                                               [[1.695955, -0.357639, -0.893051, -0.445265],
+                                                [1.695955, -0.357639, -0.893051, -0.445265]],
+                                               [[1.695955, -0.357639, -0.893051, -0.445265],
+                                                [1.695955, -0.357639, -0.893051, -0.445265]],
+                                               [[1.695955, -0.357639, -0.893051, -0.445265],
+                                                [1.695955, -0.357639, -0.893051, -0.445265]],
+                                               [[1.695955, -0.357639, -0.893051, -0.445265],
+                                                [1.695955, -0.357639, -0.893051, -0.445265]]]
+                                              )).to(device)
+            self.assertEqual(tuple(result.shape), tuple(ref_output.shape))
+            torch.testing.assert_close(result, ref_output, rtol=1e-7, atol=1e-5)
+
+        # TODO: remove set default dtype to double by making ref_output more precise.
+        # Added because this test was copied from test_nn.py, which has default
+        # dtype double. If default dtype is float, tests will say tensors not close because
+        # ref output precision too low
+        with set_default_dtype(torch.double):
+            if training:
+                cm = contextlib.nullcontext()
+            else:
+                cm = torch.no_grad()  # transformer fast path requires no grad
+            with cm:
+                _test(batch_first, training, enable_nested_tensor)
+
+    @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
     @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
     def test_decoder_only_layer(self):
         DEFAULT_PADDING_IDX = 0
@@ -347,3 +651,6 @@
         torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2)
 
 instantiate_parametrized_tests(TestTransformers)
+
+if __name__ == '__main__':
+    run_tests()
diff --git a/torch/nn/modules/activation.py b/torch/nn/modules/activation.py
index 3aae87b..dd52f65 100644
--- a/torch/nn/modules/activation.py
+++ b/torch/nn/modules/activation.py
@@ -1087,10 +1087,10 @@
             why_not_fast_path = "add_zero_attn was enabled"
         elif not self._qkv_same_embed_dim:
             why_not_fast_path = "_qkv_same_embed_dim was not True"
-        elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
-            why_not_fast_path = "key_padding_mask and attn_mask are not supported with NestedTensor input"
-        elif not query.is_nested and key_padding_mask is not None and attn_mask is not None:
-            why_not_fast_path = "key_padding_mask and attn_mask were both supplied"
+        elif attn_mask is not None:
+            why_not_fast_path = "attn_mask was not None"
+        elif query.is_nested and key_padding_mask is not None:
+            why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
 
         if not why_not_fast_path:
             tensor_args = (
diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py
index 57ad916..728c7d8 100644
--- a/torch/nn/modules/transformer.py
+++ b/torch/nn/modules/transformer.py
@@ -451,10 +451,10 @@
             why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
         elif not (self.norm1.eps == self.norm2.eps):
             why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
-        elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):
-            why_not_sparsity_fast_path = "src_key_padding_mask and src_mask are not supported with NestedTensor input"
-        elif (not src.is_nested) and (src_key_padding_mask is not None and src_mask is not None):
-            why_not_sparsity_fast_path = "src_key_padding_mask and src_mask were both supplied"
+        elif src_mask is not None:
+            why_not_sparsity_fast_path = "src_mask is not supported for fastpath"
+        elif src.is_nested and src_key_padding_mask is not None:
+            why_not_sparsity_fast_path = "src_key_padding_mask is not supported with NestedTensor input for fastpath"
 
         if not why_not_sparsity_fast_path:
             tensor_args = (
@@ -503,7 +503,7 @@
                     self.linear1.bias,
                     self.linear2.weight,
                     self.linear2.bias,
-                    src_mask if src_mask is not None else src_key_padding_mask,
+                    src_mask if src_mask is not None else src_key_padding_mask,  # TODO: split into two args
                 )
 
         x = src