blob: 55a9a4c0a7171895d347159273a0b1a887f62f10 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import torch
import torchaudio
from ..model_base import EagerModelBase
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(format=FORMAT)
__all__ = [
"EmformerRnntTranscriberModel",
"EmformerRnntPredictorModel",
"EmformerRnntJoinerModel",
]
class EmformerRnntTranscriberExample(torch.nn.Module):
"""
This is a wrapper for validating transcriber for the Emformer RNN-T architecture.
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
"""
def __init__(self) -> None:
super().__init__()
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
decoder = bundle.get_decoder()
m = decoder.model
self.rnnt = m
def forward(self, transcribe_inputs):
return self.rnnt.transcribe(*transcribe_inputs)
class EmformerRnntTranscriberModel(EagerModelBase):
def __init__(self):
pass
def get_eager_model(self) -> torch.nn.Module:
logging.info("Loading emformer rnnt transcriber")
m = EmformerRnntTranscriberExample()
logging.info("Loaded emformer rnnt transcriber")
return m
def get_example_inputs(self):
transcribe_inputs = (
torch.randn(1, 128, 80),
torch.tensor([128]),
)
return (transcribe_inputs,)
class EmformerRnntPredictorExample(torch.nn.Module):
"""
This is a wrapper for validating predictor for the Emformer RNN-T architecture.
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
"""
def __init__(self) -> None:
super().__init__()
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
decoder = bundle.get_decoder()
m = decoder.model
self.rnnt = m
def forward(self, predict_inputs):
return self.rnnt.predict(*predict_inputs)
class EmformerRnntPredictorModel(EagerModelBase):
def __init__(self):
pass
def get_eager_model(self) -> torch.nn.Module:
logging.info("Loading emformer rnnt predictor")
m = EmformerRnntPredictorExample()
logging.info("Loaded emformer rnnt predictor")
return m
def get_example_inputs(self):
predict_inputs = (
torch.zeros([1, 128], dtype=int),
torch.tensor([128], dtype=int),
None,
)
return (predict_inputs,)
class EmformerRnntJoinerExample(torch.nn.Module):
"""
This is a wrapper for validating joiner for the Emformer RNN-T architecture.
It does not reflect the actual usage such as beam search, but rather an example for the export workflow.
"""
def __init__(self) -> None:
super().__init__()
bundle = torchaudio.pipelines.EMFORMER_RNNT_BASE_LIBRISPEECH
decoder = bundle.get_decoder()
m = decoder.model
self.rnnt = m
def forward(self, predict_inputs):
return self.rnnt.join(*predict_inputs)
class EmformerRnntJoinerModel(EagerModelBase):
def __init__(self):
pass
def get_eager_model(self) -> torch.nn.Module:
logging.info("Loading emformer rnnt joiner")
m = EmformerRnntJoinerExample()
logging.info("Loaded emformer rnnt joiner")
return m
def get_example_inputs(self):
join_inputs = (
torch.rand([1, 128, 1024]),
torch.tensor([128]),
torch.rand([1, 128, 1024]),
torch.tensor([128]),
)
return (join_inputs,)