| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import torch |
| from executorch.exir.backend.compile_spec_schema import CompileSpec |
| |
| from ..model_base import EagerModelBase |
| |
| |
| class MulModule(torch.nn.Module, EagerModelBase): |
| def __init__(self) -> None: |
| super().__init__() |
| |
| def forward(self, input, other): |
| return input * other |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| return self |
| |
| def get_example_inputs(self): |
| return (torch.randn(3, 2), torch.randn(3, 2)) |
| |
| |
| class LinearModule(torch.nn.Module, EagerModelBase): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(3, 3) |
| |
| def forward(self, arg): |
| return self.linear(arg) |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| return self |
| |
| def get_example_inputs(self): |
| return (torch.randn(3, 3),) |
| |
| |
| class AddModule(torch.nn.Module, EagerModelBase): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, x, y): |
| z = x + y |
| return z |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| return self |
| |
| def get_example_inputs(self): |
| return (torch.ones(1), torch.ones(1)) |
| |
| |
| class AddMulModule(torch.nn.Module, EagerModelBase): |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, a, x, b): |
| y = torch.mm(a, x) |
| z = torch.add(y, b) |
| return z |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| return self |
| |
| def get_example_inputs(self): |
| return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) |
| |
| def get_compile_spec(self): |
| max_value = self.get_example_inputs()[0].shape[0] |
| return [CompileSpec("max_value", bytes([max_value]))] |
| |
| |
| class SoftmaxModule(torch.nn.Module, EagerModelBase): |
| def __init__(self): |
| super().__init__() |
| self.softmax = torch.nn.Softmax() |
| |
| def forward(self, x): |
| z = self.softmax(x) |
| return z |
| |
| def get_eager_model(self) -> torch.nn.Module: |
| return self |
| |
| def get_example_inputs(self): |
| return (torch.ones(2, 2),) |