Add fastpath test for mask check flag (#82999)

Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.

Test Plan: buck test caffe2/test:test_transformers

Differential Revision: D38259928

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999
Approved by: https://github.com/jbschlosser
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 68fb896..6066650 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -5,10 +5,17 @@
 import torch.nn as nn
 import torch.nn.functional as F
 import unittest
+from unittest.mock import patch
 
 from torch.testing._internal.common_nn import NNTestCase
 from torch.testing._internal.common_utils import (
-    TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
+    TEST_FAIRSEQ,
+    run_tests,
+    parametrize,
+    instantiate_parametrized_tests,
+    freeze_rng_state,
+    TEST_WITH_CROSSREF
+)
 from torch.testing._internal.common_cuda import TEST_CUDA
 
 if TEST_FAIRSEQ:
@@ -724,6 +731,59 @@
             if dropout_p == 0.0 or device == 'cpu':
                 self.assertEqual(actual, expected)
 
+    @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
+    @torch.no_grad()
+    def test_mask_check_fastpath(self):
+        """
+        Test that fastpath is executed independently of the mask that is passed.
+        If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath),
+        otherwise use fastpath with traditional tensors.
+        """
+
+        x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
+
+        def _test_fastpath(model, mask, mock_return_value, nested_tensors=True):
+            with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
+                fastpath_mock.return_value = mock_return_value
+                model(x, src_key_padding_mask=mask)
+
+                # If mock was called, fastpath was taken
+                self.assertTrue(fastpath_mock.called)
+
+                # If mock was called with nested tensors, sparsity fastpath was taken
+                for call_args, _ in fastpath_mock.call_args_list:
+                    self.assertEqual(call_args[0].is_nested, nested_tensors)
+
+        encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
+
+        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
+        model.eval()
+
+        aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
+        not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
+        nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
+        tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
+
+        # Left aligned mask results in sparsity fastpath
+        _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
+
+        # Not aligned mask results in fastpath
+        _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
+
+        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
+        model.eval()
+
+        # If nested tensor disabled, fastpath is always taken
+        _test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False)
+        _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
+
+
+        model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
+        model.eval()
+
+        # Mask check disabled results in sparisty fastpath, independently of the mask
+        _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
+        _test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True)
 
 # TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
 # cross device / dtype testing.