blob: 6c709ee6b567da613df7f7102e8a31c89faa5c0a [file] [log] [blame]
from typing import Callable
import torch
__all__ = ["WrapperModule"]
class WrapperModule(torch.nn.Module):
"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you
are trying to export a callable.
"""
def __init__(self, fn: Callable):
super().__init__()
self.fn = fn
def forward(self, *args, **kwargs):
"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`."""
return self.fn(*args, **kwargs)