from typing import Callable | |
import torch | |
__all__ = ["WrapperModule"] | |
class WrapperModule(torch.nn.Module): | |
"""Class to wrap a callable in an :class:`torch.nn.Module`. Use this if you | |
are trying to export a callable. | |
""" | |
def __init__(self, fn: Callable): | |
super().__init__() | |
self.fn = fn | |
def forward(self, *args, **kwargs): | |
"""Simple forward that just calls the ``fn`` provided to :meth:`WrapperModule.__init__`.""" | |
return self.fn(*args, **kwargs) |