| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| class LinearModel(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.a = 3 * torch.ones(2, 2, dtype=torch.float) | |
| self.b = 2 * torch.ones(2, 2, dtype=torch.float) | |
| def forward(self, x: torch.Tensor): | |
| out_1 = torch.mul(self.a, x) | |
| out_2 = torch.add(out_1, self.b) | |
| return out_2 |