Handle trailing masked column behavior for nested tensor (#100113)

Summary:
Handle trailing masked column behavior for nested tensor by padding during to_padded, to original tensor size

https://github.com/pytorch/pytorch/issues/97111

Test Plan: sandcastle & github

Differential Revision: D45167874

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100113
Approved by: https://github.com/bertmaher, https://github.com/cpuhrsch, https://github.com/drisspg
diff --git a/test/test_transformers.py b/test/test_transformers.py
index 20c4b65..b480910 100644
--- a/test/test_transformers.py
+++ b/test/test_transformers.py
@@ -688,6 +688,25 @@
 
             mha(query=x, key=x, value=x, key_padding_mask=pad_mask)
 
+    def test_kpm_mask_trailing_column_with_nested_tensor(self, device):
+        encoder_layer = nn.TransformerEncoderLayer(
+            d_model=256,
+            nhead=4,
+            dim_feedforward=512,
+            activation='gelu',
+            norm_first=False,
+            batch_first=False,
+        )
+        transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=3, enable_nested_tensor=True).to(device)
+
+        x = torch.randn(10, 6, 256).to(device)
+        mask = torch.ones(6, 10)
+        mask[0, :] = 0  # here I masked 5 columns instead of just one
+        mask = mask.bool().to(device)
+        out = transformer_encoder(src=x, src_key_padding_mask=mask)
+        self.assertEqual(out.shape[1], 6)
+
+
     @onlyCUDA
     @unittest.skipIf(not TEST_FAIRSEQ, "Fairseq not found")
     def test_decoder_only_layer(self):
diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py
index 4a08c12..6eda9a3 100644
--- a/torch/nn/modules/transformer.py
+++ b/torch/nn/modules/transformer.py
@@ -229,6 +229,7 @@
         )
 
         output = src
+        output_size = output.size()
         convert_to_nested = False
         first_layer = self.layers[0]
         src_key_padding_mask_for_layers = src_key_padding_mask
@@ -315,7 +316,7 @@
             output = mod(output, src_mask=mask, is_causal=is_causal, src_key_padding_mask=src_key_padding_mask_for_layers)
 
         if convert_to_nested:
-            output = output.to_padded_tensor(0.)
+            output = output.to_padded_tensor(0., output_size)
 
         if self.norm is not None:
             output = self.norm(output)