| # Owner(s): ["module: nn"] |
| |
| import torch |
| import unittest |
| |
| from torch.testing._internal.common_nn import NNTestCase |
| from torch.testing._internal.common_utils import TEST_FAIRSEQ, parametrize, instantiate_parametrized_tests |
| from torch.testing._internal.common_cuda import TEST_CUDA |
| |
| if TEST_FAIRSEQ: |
| import fairseq.models.transformer as fairseq_transformer |
| |
| class TestTransformers(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| @parametrize("use_torchscript", [True, False]) |
| @parametrize("with_no_grad", [True, False]) |
| @parametrize("training", [True, False]) |
| def test_transformerencoder_fastpath_torchscript(self, use_torchscript, with_no_grad, training): |
| """ |
| Test TransformerEncoder does not crash |
| """ |
| model = torch.nn.TransformerEncoder( |
| torch.nn.TransformerEncoderLayer(d_model=2, nhead=2, dim_feedforward=8, batch_first=True), |
| num_layers=2, |
| enable_nested_tensor=True |
| ) |
| |
| if training: |
| model = model.train() |
| else: |
| model = model.eval() |
| |
| if use_torchscript: |
| model = torch.jit.script(model) |
| |
| x = torch.Tensor([[[1, 2], [3, 4]]]).to(torch.float) |
| mask = torch.Tensor([[0, 1]]).to(torch.bool) |
| |
| if with_no_grad: |
| with torch.no_grad(): |
| model(x, src_key_padding_mask=mask) |
| else: |
| model(x, src_key_padding_mask=mask) |
| |
| @unittest.skipIf(not TEST_FAIRSEQ, "numpy not found") |
| @unittest.skipIf(not TEST_CUDA, 'CUDA not available') |
| def test_decoder_only_layer(self): |
| DEFAULT_PADDING_IDX = 0 |
| |
| class FairseqDecoder(torch.nn.Module): |
| def __init__( |
| self, |
| embed_dim, |
| attention_heads, |
| ffn_embed_dim, |
| num_layers, |
| embedding_layer, # torch.nn.Embedding. Must have a padding_idx field |
| dropout=0, |
| normalize_before=False, |
| torch_encoder=None, # torch encoder that you can map weights from |
| activation="relu", |
| ): |
| super().__init__() |
| |
| cfg = fairseq_transformer.TransformerConfig() |
| cfg.decoder.embed_dim = embed_dim |
| cfg.decoder.output_dim = embed_dim |
| cfg.decoder.attention_heads = attention_heads |
| cfg.decoder.ffn_embed_dim = ffn_embed_dim |
| cfg.dropout = dropout |
| cfg.decoder.normalize_before = normalize_before |
| cfg.decoder.layers = num_layers |
| # make embedding behavior same as other encoders |
| cfg.no_token_positional_embeddings = True |
| cfg.no_scale_embedding = True |
| cfg.activation_fn = activation |
| |
| dictionary = {} # TODO: verify what this is |
| |
| self.decoder = fairseq_transformer.TransformerDecoder( |
| cfg, |
| dictionary, |
| embedding_layer, |
| no_encoder_attn=True, |
| output_projection=None, |
| ) |
| |
| if torch_encoder is not None: |
| self.decoder = torch_to_fairseq(torch_encoder, self.decoder) |
| self.decoder = self.decoder.eval().cuda().half() |
| |
| def forward( |
| self, |
| tokens, |
| src_lengths=None, |
| with_triangle_mask=False, |
| incremental_state=None, |
| ): |
| return self.decoder( |
| prev_output_tokens=tokens, |
| encoder_out=None, |
| incremental_state=incremental_state, |
| features_only=True, |
| full_context_alignment=not with_triangle_mask, |
| alignment_layer=None, |
| alignment_heads=None, |
| src_lengths=src_lengths, |
| return_all_hiddens=False, |
| )[0] |
| |
| class BetterDecoder(torch.nn.Module): |
| """ |
| Only incremental decoder for now |
| """ |
| |
| def __init__(self, transformer, embedding, pad_idx): |
| super().__init__() |
| self.transformer = transformer |
| self.embedding = embedding |
| self.padding_idx = pad_idx |
| |
| def forward( |
| self, |
| x, |
| src_mask=None, |
| include_padding_mask=True, |
| incr_key_lst=None, |
| incr_value_lst=None, |
| is_incremental_decoding=False, |
| ): |
| padding_mask = None |
| if not x.is_nested and include_padding_mask: |
| padding_mask = x.eq(self.padding_idx) |
| if(is_incremental_decoding): |
| x = x[:, -1:] # only take the last token |
| x = self.embedding(x) |
| |
| one_encoder_layer = self.transformer.layers[0] |
| self_attn = one_encoder_layer.self_attn |
| embed_dim = self_attn.embed_dim |
| num_heads = self_attn.num_heads |
| |
| use_gelu = ( |
| one_encoder_layer.activation_relu_or_gelu == 2 |
| ) # see torch/nn/modules/activation attention impl. 1 == relu, 2 == gelu |
| assert ( |
| one_encoder_layer.activation_relu_or_gelu != 0 |
| ) # 0 == not relu or gelu |
| |
| norm_first = one_encoder_layer.norm_first |
| |
| |
| # TODO: make this a bit less janky. but for now we initialize with an empty tensor. |
| if(not is_incremental_decoding): |
| assert len(incr_key_lst) == 0 or incr_key_lst[0] is None |
| assert len(incr_value_lst) == 0 or incr_value_lst[0] is None |
| while len(incr_key_lst) <= len(self.transformer.layers): |
| if(is_incremental_decoding): |
| incr_key_lst.append(torch.Tensor([]).cuda().half()) |
| incr_value_lst.append(torch.Tensor([]).cuda().half()) |
| else: |
| incr_key_lst.append(None) |
| incr_value_lst.append(None) |
| |
| for i, layer in enumerate(self.transformer.layers): |
| incr_key = incr_key_lst[i] |
| incr_value = incr_value_lst[i] |
| |
| x, incr_key, incr_value = torch._transformer_decoder_only_layer_fwd( |
| src=x, |
| embed_dim=embed_dim, |
| num_heads=num_heads, |
| qkv_weight=layer.self_attn.in_proj_weight, |
| qkv_bias=layer.self_attn.in_proj_bias, |
| proj_weight=layer.self_attn.out_proj.weight, |
| proj_bias=layer.self_attn.out_proj.bias, |
| use_gelu=use_gelu, |
| norm_first=norm_first, |
| # TODO: layer_norm_eps hardcoded to be same as nn.TransformerEncoder default. |
| # fix by pulling from self_attn.norm1 |
| eps=1e-5, |
| norm_weight_1=layer.norm1.weight, |
| norm_bias_1=layer.norm1.bias, |
| norm_weight_2=layer.norm2.weight, |
| norm_bias_2=layer.norm2.bias, |
| ffn_weight_1=layer.linear1.weight, |
| ffn_bias_1=layer.linear1.bias, |
| ffn_weight_2=layer.linear2.weight, |
| ffn_bias_2=layer.linear2.bias, |
| mask=src_mask, |
| incr_key=incr_key, # altered in place |
| incr_value=incr_value, |
| ) |
| |
| # not in-place |
| if(not is_incremental_decoding): |
| incr_key = None |
| incr_value = None |
| incr_key_lst[i] = incr_key |
| incr_value_lst[i] = incr_value |
| |
| return x, incr_key_lst, incr_value_lst |
| |
| def torch_to_fairseq(torch_encoder, fairseq_encoder): |
| for src_layer, dst_layer in zip(torch_encoder.layers, fairseq_encoder.layers): |
| w_q, w_k, w_v = src_layer.self_attn.in_proj_weight.chunk(3, dim=0) |
| b_q, b_k, b_v = src_layer.self_attn.in_proj_bias.chunk(3, dim=0) |
| |
| dst_layer.self_attn.q_proj.weight = torch.nn.Parameter(w_q) |
| dst_layer.self_attn.q_proj.bias = torch.nn.Parameter(b_q) |
| dst_layer.self_attn.k_proj.weight = torch.nn.Parameter(w_k) |
| dst_layer.self_attn.k_proj.bias = torch.nn.Parameter(b_k) |
| dst_layer.self_attn.v_proj.weight = torch.nn.Parameter(w_v) |
| dst_layer.self_attn.v_proj.bias = torch.nn.Parameter(b_v) |
| |
| dst_layer.self_attn.out_proj.weight = src_layer.self_attn.out_proj.weight |
| dst_layer.self_attn.out_proj.bias = src_layer.self_attn.out_proj.bias |
| |
| dst_layer.fc1.weight = src_layer.linear1.weight |
| dst_layer.fc1.bias = src_layer.linear1.bias |
| |
| # fairseq may use fusedlayernorm from nvidia apex - diff properties |
| dst_layer.self_attn_layer_norm.load_state_dict(src_layer.norm1.state_dict()) |
| |
| dst_layer.fc2.weight = src_layer.linear2.weight |
| dst_layer.fc2.bias = src_layer.linear2.bias |
| |
| dst_layer.final_layer_norm.load_state_dict(src_layer.norm2.state_dict()) |
| |
| return fairseq_encoder |
| |
| def set_weights_deterministic(model): |
| for idx, p in enumerate(model.parameters()): |
| x = p.data |
| sz = x.view(-1).size(0) |
| shape = x.shape |
| x = torch.cos(torch.arange(0, sz).float().view(shape)) |
| p.data.copy_(x) |
| |
| D = 4 # d_model |
| H = 2 # nhead |
| FD = 16 # dim_feedforward |
| V = 100 # vocab size |
| L = 2 # num layers |
| |
| embedding_layer = torch.nn.Embedding(V, D, DEFAULT_PADDING_IDX) |
| layer = torch.nn.TransformerEncoderLayer( |
| d_model=D, |
| nhead=H, |
| dim_feedforward=FD, |
| batch_first=True, |
| activation="gelu", |
| ) |
| transformer = torch.nn.TransformerEncoder( |
| layer, |
| num_layers=L, |
| ).eval().cuda().half() |
| |
| set_weights_deterministic(embedding_layer) |
| set_weights_deterministic(transformer) |
| |
| better_decoder = ( |
| BetterDecoder(transformer, embedding_layer, DEFAULT_PADDING_IDX) |
| .eval() |
| .cuda() |
| .half() |
| ) |
| fairseq_decoder = ( |
| FairseqDecoder( |
| D, |
| H, |
| FD, |
| L, |
| embedding_layer, |
| dropout=0, |
| normalize_before=False, |
| torch_encoder=transformer, |
| activation="gelu", |
| ) |
| .eval() |
| .cuda() |
| .half() |
| ) |
| |
| tokens = torch.Tensor([ |
| [5, 6, 7, 8], |
| [9, 10, 11, 12] |
| ]).to(torch.int).cuda() |
| lengths_tensor = torch.Tensor([2, 2]).to(torch.int).cuda() |
| # bs = 2, seqlen = 4 |
| bs, seqlen = tokens.shape |
| |
| upper_triangle = torch.zeros(seqlen, seqlen) |
| upper_triangle.fill_(-100000000) |
| upper_triangle = torch.triu(upper_triangle, 1) |
| upper_triangle = upper_triangle.cuda().half() |
| upper_triangle_expanded = upper_triangle.unsqueeze(0).unsqueeze(0) |
| upper_triangle_expanded = upper_triangle_expanded.expand( |
| bs, H, -1, -1 |
| ) |
| |
| # test forced decoding |
| with torch.no_grad(): |
| result, _, _ = better_decoder( |
| tokens, |
| src_mask=upper_triangle_expanded, |
| include_padding_mask=False, |
| incr_key_lst=[], |
| incr_value_lst=[], |
| is_incremental_decoding=False, |
| ) |
| ref_output = fairseq_decoder(tokens, lengths_tensor, with_triangle_mask=True) |
| |
| self.assertEqual(result.shape, ref_output.shape) |
| torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) |
| |
| # test incremental decoding |
| bs, seqlen = tokens.shape |
| |
| incr_state = {} |
| ref_outputs = [fairseq_decoder( |
| tokens[:, :i], |
| src_lengths=None, |
| with_triangle_mask=False, |
| incremental_state=incr_state, |
| ) for i in range(1, seqlen + 1)] |
| ref_output = torch.stack(ref_outputs) |
| |
| incr_key_lst = [] |
| incr_value_lst = [] |
| results = [] |
| for i in range(1, seqlen + 1): |
| res, incr_key_lst, incr_value_lst = better_decoder( |
| tokens[:, :i], |
| src_mask=None, |
| include_padding_mask=False, |
| incr_key_lst=incr_key_lst, |
| incr_value_lst=incr_value_lst, |
| is_incremental_decoding=True, |
| ) |
| results.append(res) |
| result = torch.stack(results) |
| |
| self.assertEqual(result.shape, ref_output.shape) |
| torch.testing.assert_close(result, ref_output, atol=1e-3, rtol=1e-2) |
| |
| instantiate_parametrized_tests(TestTransformers) |