blob: 8dc587fae9360d8e3cbce916aa3b0e949f4ad785 [file] [log] [blame]
import torch
from torch._export import aot_compile, dynamic_dim
torch.manual_seed(1337)
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(64, 10)
def forward(self, x, y):
return self.fc(torch.sin(x) + torch.cos(y))
data = {}
for device in ["cpu", "cuda"]:
model = Net().to(device=device)
x = torch.randn((32, 64), device=device)
y = torch.randn((32, 64), device=device)
with torch.no_grad():
ref_output = model(x, y)
torch._dynamo.reset()
with torch.no_grad():
constraints = [
dynamic_dim(x, 0) >= 1,
dynamic_dim(x, 0) <= 1024,
dynamic_dim(x, 0) == dynamic_dim(y, 0),
]
model_so_path, _ = aot_compile(model, (x, y), constraints=constraints)
data.update({
f"model_so_path_{device}": model_so_path,
f"inputs_{device}": [x, y],
f"outputs_{device}": [ref_output],
})
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
torch.jit.script(Serializer(data)).save("data.pt")