blob: 641fa5e61218b2bdf60a4b4d380c97ccbf0095cb [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Example script for exporting simple models to flatbuffer
import logging
import torch
from executorch.backends.cadence.aot.ops_registrations import * # noqa
from typing import List, Optional, Tuple
from executorch.backends.cadence.aot.export_example import export_model
from torchaudio.prototype.models import ConvEmformer
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
if __name__ == "__main__":
class _TimeReduction(torch.nn.Module):
def __init__(self, stride: int) -> None:
super().__init__()
self.stride = stride
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
B, T, D = input.shape
num_frames = T - (T % self.stride)
input = input[:, :num_frames, :]
lengths = lengths.div(self.stride, rounding_mode="trunc")
T_max = num_frames // self.stride
output = input.reshape(B, T_max, D * self.stride)
output = output.contiguous()
return output, lengths
class ConvEmformerEncoder(torch.nn.Module):
def __init__(
self,
*,
input_dim: int,
output_dim: int,
segment_length: int,
kernel_size: int,
right_context_length: int,
time_reduction_stride: int,
transformer_input_dim: int,
transformer_num_heads: int,
transformer_ffn_dim: int,
transformer_num_layers: int,
transformer_left_context_length: int,
transformer_dropout: float = 0.0,
transformer_activation: str = "relu",
transformer_max_memory_size: int = 0,
transformer_weight_init_scale_strategy: str = "depthwise",
transformer_tanh_on_mem: bool = False,
) -> None:
super().__init__()
self.time_reduction = _TimeReduction(time_reduction_stride)
self.input_linear = torch.nn.Linear(
input_dim * time_reduction_stride,
transformer_input_dim,
bias=False,
)
self.transformer = ConvEmformer(
transformer_input_dim,
transformer_num_heads,
transformer_ffn_dim,
transformer_num_layers,
segment_length // time_reduction_stride,
kernel_size=kernel_size,
dropout=transformer_dropout,
ffn_activation=transformer_activation,
left_context_length=transformer_left_context_length,
right_context_length=right_context_length // time_reduction_stride,
max_memory_size=transformer_max_memory_size,
weight_init_scale_strategy=transformer_weight_init_scale_strategy,
tanh_on_mem=transformer_tanh_on_mem,
conv_activation="silu",
)
self.output_linear = torch.nn.Linear(transformer_input_dim, output_dim)
self.layer_norm = torch.nn.LayerNorm(output_dim)
def forward(
self, input: torch.Tensor, lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
time_reduction_out, time_reduction_lengths = self.time_reduction(
input, lengths
)
input_linear_out = self.input_linear(time_reduction_out)
transformer_out, transformer_lengths = self.transformer(
input_linear_out, time_reduction_lengths
)
output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths
@torch.jit.export
def infer(
self,
input: torch.Tensor,
lengths: torch.Tensor,
states: Optional[List[List[torch.Tensor]]],
) -> Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]]:
time_reduction_out, time_reduction_lengths = self.time_reduction(
input, lengths
)
input_linear_out = self.input_linear(time_reduction_out)
(
transformer_out,
transformer_lengths,
transformer_states,
) = self.transformer.infer(input_linear_out, time_reduction_lengths, states)
output_linear_out = self.output_linear(transformer_out)
layer_norm_out = self.layer_norm(output_linear_out)
return layer_norm_out, transformer_lengths, transformer_states
# Instantiate model
time_reduction_stride = 4
encoder = ConvEmformerEncoder(
input_dim=80,
output_dim=256,
segment_length=4 * time_reduction_stride,
kernel_size=7,
right_context_length=1 * time_reduction_stride,
time_reduction_stride=time_reduction_stride,
transformer_input_dim=128,
transformer_num_heads=4,
transformer_ffn_dim=512,
transformer_num_layers=1,
transformer_left_context_length=10,
transformer_tanh_on_mem=True,
)
# Batch size
batch_size = 1
max_input_length = 100
input_dim = 80
right_context_length = 4
# Dummy inputs
transcriber_input = torch.rand(
batch_size, max_input_length + right_context_length, input_dim
)
transcriber_lengths = torch.randint(1, max_input_length + 1, (batch_size,))
example_inputs = (
transcriber_input,
transcriber_lengths,
)
export_model(encoder, example_inputs)