Add fastpath test for mask check flag (#82999)
Summary: Check that fastpath is taken, which type (sparsity fastpath or normal) for mask that is aligned and one that is not.
Test Plan: buck test caffe2/test:test_transformers
Differential Revision: D38259928
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82999
Approved by: https://github.com/jbschlosser
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 68fb896..6066650 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -5,10 +5,17 @@
import torch.nn as nn
import torch.nn.functional as F
import unittest
+from unittest.mock import patch
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
- TEST_FAIRSEQ, run_tests, parametrize, instantiate_parametrized_tests, freeze_rng_state)
+ TEST_FAIRSEQ,
+ run_tests,
+ parametrize,
+ instantiate_parametrized_tests,
+ freeze_rng_state,
+ TEST_WITH_CROSSREF
+)
from torch.testing._internal.common_cuda import TEST_CUDA
if TEST_FAIRSEQ:
@@ -724,6 +731,59 @@
if dropout_p == 0.0 or device == 'cpu':
self.assertEqual(actual, expected)
+ @unittest.skipIf(TEST_WITH_CROSSREF, 'Fastpath not available with crossref')
+ @torch.no_grad()
+ def test_mask_check_fastpath(self):
+ """
+ Test that fastpath is executed independently of the mask that is passed.
+ If the passed mask is left aligned or mask_check=False, test that nested tensors are used (sparsity fastpath),
+ otherwise use fastpath with traditional tensors.
+ """
+
+ x = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]).to(torch.float)
+
+ def _test_fastpath(model, mask, mock_return_value, nested_tensors=True):
+ with patch('torch._transformer_encoder_layer_fwd') as fastpath_mock:
+ fastpath_mock.return_value = mock_return_value
+ model(x, src_key_padding_mask=mask)
+
+ # If mock was called, fastpath was taken
+ self.assertTrue(fastpath_mock.called)
+
+ # If mock was called with nested tensors, sparsity fastpath was taken
+ for call_args, _ in fastpath_mock.call_args_list:
+ self.assertEqual(call_args[0].is_nested, nested_tensors)
+
+ encoder_layer = torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True)
+
+ model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=True)
+ model.eval()
+
+ aligned_mask = torch.Tensor([[0, 0, 1]]).to(torch.bool)
+ not_aligned_mask = torch.Tensor([[1, 0, 1]]).to(torch.bool)
+ nested_tensor_return_value = torch.nested_tensor([torch.ones((2, 2), dtype=torch.float)])
+ tensor_return_value = torch.ones((1, 3, 2), dtype=torch.float)
+
+ # Left aligned mask results in sparsity fastpath
+ _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
+
+ # Not aligned mask results in fastpath
+ _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
+
+ model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False, mask_check=True)
+ model.eval()
+
+ # If nested tensor disabled, fastpath is always taken
+ _test_fastpath(model, aligned_mask, tensor_return_value, nested_tensors=False)
+ _test_fastpath(model, not_aligned_mask, tensor_return_value, nested_tensors=False)
+
+
+ model = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=True, mask_check=False)
+ model.eval()
+
+ # Mask check disabled results in sparisty fastpath, independently of the mask
+ _test_fastpath(model, aligned_mask, nested_tensor_return_value, nested_tensors=True)
+ _test_fastpath(model, not_aligned_mask, nested_tensor_return_value, nested_tensors=True)
# TODO: Replace this with instantiate_device_type_tests() to take advantage of test framework support for
# cross device / dtype testing.