| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
| |
| |
| class ASRJoiner(nn.Module): |
| """ |
| ASR joiner implementation following the code in https://fburl.com/code/ierfau7c |
| Have a local implementation has the benefit that we don't need pull in the |
| heavy dependencies and wait for a few minutes to run tests. |
| """ |
| |
| def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None: |
| """ |
| B: source batch size |
| H: number of hypotheses for beam search |
| T: source sequence length |
| U: target sequence length |
| D: encoding (some sort of embedding?) dimension |
| """ |
| super().__init__() |
| self.B, self.H, self.T, self.U, self.D = B, H, T, U, D |
| # The module looks like: |
| # SequentialContainer( |
| # (module_list): ModuleList( |
| # (0): ReLULayer(inplace=False) |
| # (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None) |
| # ) |
| # ) |
| self.module = nn.Sequential( |
| nn.ReLU(), |
| nn.Linear(D, 4096), |
| ) |
| |
| def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths): |
| """ |
| One simplification we make here is we assume src_encodings and tgt_encodings |
| are not None. In the originally implementation, either can be None. |
| """ |
| H = tgt_encodings.shape[0] // src_encodings.shape[0] |
| B = src_encodings.shape[0] |
| new_order = ( |
| (torch.arange(B).view(-1, 1).repeat(1, H).view(-1)) |
| .long() |
| .to(device=src_encodings.device) |
| ) |
| # src_encodings: (B, T, D) -> (B*H, T, D) |
| src_encodings = torch.index_select( |
| src_encodings, dim=0, index=new_order |
| ).contiguous() |
| |
| # src_lengths: (B,) -> (B*H,) |
| src_lengths = torch.index_select( |
| src_lengths, dim=0, index=new_order.to(device=src_lengths.device) |
| ) |
| |
| # src_encodings: (B*H, T, D) -> (B*H, T, 1, D) |
| src_encodings = src_encodings.unsqueeze(dim=2).contiguous() |
| |
| # tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D) |
| tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous() |
| |
| # joint_encodings: (B*H, T, U, D) |
| joint_encodings = src_encodings + tgt_encodings |
| |
| output = F.log_softmax(self.module(joint_encodings), dim=-1) |
| |
| return output, src_lengths, tgt_lengths |
| |
| def get_random_inputs(self): |
| return ( |
| torch.rand(self.B, self.T, self.D), |
| torch.randint(0, 10, (self.B,)), |
| torch.rand(self.B * self.H, self.U, self.D), |
| torch.randint(0, 10, (self.B,)), |
| ) |