Handle trailing masked column behavior for nested tensor (#100113)
Summary:
Handle trailing masked column behavior for nested tensor by padding during to_padded, to original tensor size
https://github.com/pytorch/pytorch/issues/97111
Test Plan: sandcastle & github
Differential Revision: D45167874
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100113
Approved by: https://github.com/bertmaher, https://github.com/cpuhrsch, https://github.com/drisspg
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 20c4b65..b480910 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -688,6 +688,25 @@
mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
+ def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=256,
+ nhead=4,
+ dim_feedforward=512,
+ activation='gelu',
+ norm_first=False,
+ batch_first=False,
+ )
+ transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
+
+ x = torch.randn(10, 6, 256).to(device)
+ mask = torch.ones(6, 10)
+ mask[0, :] = 0 # here I masked 5 columns instead of just one
+ mask = mask.bool().to(device)
+ out = transformer_encoder(src=x, src_key_padding_mask=mask)
+ self.assertEqual(out.shape[1], 6)
+
+
@onlyCUDA
@unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
def test_decoder_only_layer(self):
diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py
index 4a08c12..6eda9a3 100644
--- a/torch/nn/modules/transformer.py
+++ b/torch/nn/modules/transformer.py
@@ -229,6 +229,7 @@
)
output = src
+ output_size = output.size()
convert_to_nested = False
first_layer = self.layers[0]
src_key_padding_mask_for_layers = src_key_padding_mask
@@ -315,7 +316,7 @@
output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
if convert_to_nested:
- output = output.to_padded_tensor(0.)
+ output = output.to_padded_tensor(0., output_size)
if self.norm is not None:
output = self.norm(output)