# Usage: python create_dummy_model.py <name_of_the_file> | |
import sys | |
import torch | |
from torch import nn | |
class NeuralNetwork(nn.Module): | |
def __init__(self): | |
super(NeuralNetwork, self).__init__() | |
self.flatten = nn.Flatten() | |
self.linear_relu_stack = nn.Sequential( | |
nn.Linear(28 * 28, 512), | |
nn.ReLU(), | |
nn.Linear(512, 512), | |
nn.ReLU(), | |
nn.Linear(512, 10), | |
) | |
def forward(self, x): | |
x = self.flatten(x) | |
logits = self.linear_relu_stack(x) | |
return logits | |
if __name__ == '__main__': | |
jit_module = torch.jit.script(NeuralNetwork()) | |
torch.jit.save(jit_module, sys.argv[1]) |