blob: 0a892367eec31f0e818792902e06dc26e7711b8c [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from torch import nn
class ASRJoiner(nn.Module):
"""
ASR joiner implementation following the code in https://fburl.com/code/ierfau7c
Have a local implementation has the benefit that we don't need pull in the
heavy dependencies and wait for a few minutes to run tests.
"""
def __init__(self, B=1, H=10, T=1, U=1, D=768) -> None:
"""
B: source batch size
H: number of hypotheses for beam search
T: source sequence length
U: target sequence length
D: encoding (some sort of embedding?) dimension
"""
super().__init__()
self.B, self.H, self.T, self.U, self.D = B, H, T, U, D
# The module looks like:
# SequentialContainer(
# (module_list): ModuleList(
# (0): ReLULayer(inplace=False)
# (1): LinearLayer(input_dim=768, output_dim=4096, bias=True, context_dim=0, pruning_aware_training=False, parameter_noise=0.1, qat_qconfig=None, freeze_rex_pattern=None)
# )
# )
self.module = nn.Sequential(
nn.ReLU(),
nn.Linear(D, 4096),
)
def forward(self, src_encodings, src_lengths, tgt_encodings, tgt_lengths):
"""
One simplification we make here is we assume src_encodings and tgt_encodings
are not None. In the originally implementation, either can be None.
"""
H = tgt_encodings.shape[0] // src_encodings.shape[0]
B = src_encodings.shape[0]
new_order = (
(torch.arange(B).view(-1, 1).repeat(1, H).view(-1))
.long()
.to(device=src_encodings.device)
)
# src_encodings: (B, T, D) -> (B*H, T, D)
src_encodings = torch.index_select(
src_encodings, dim=0, index=new_order
).contiguous()
# src_lengths: (B,) -> (B*H,)
src_lengths = torch.index_select(
src_lengths, dim=0, index=new_order.to(device=src_lengths.device)
)
# src_encodings: (B*H, T, D) -> (B*H, T, 1, D)
src_encodings = src_encodings.unsqueeze(dim=2).contiguous()
# tgt_encodings: (B*H, U, D) -> (B*H, 1, U, D)
tgt_encodings = tgt_encodings.unsqueeze(dim=1).contiguous()
# joint_encodings: (B*H, T, U, D)
joint_encodings = src_encodings + tgt_encodings
output = F.log_softmax(self.module(joint_encodings), dim=-1)
return output, src_lengths, tgt_lengths
def get_random_inputs(self):
return (
torch.rand(self.B, self.T, self.D),
torch.randint(0, 10, (self.B,)),
torch.rand(self.B * self.H, self.U, self.D),
torch.randint(0, 10, (self.B,)),
)