blob: 62d9b79b8cba632928f827db2104d0830cfe3776 [file] [log] [blame]
import torch
from torch._export import aot_compile
from torch.export import Dim
torch.manual_seed(1337)
class Net(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.w_pre = torch.randn(4, 4, device=device)
self.w_add = torch.randn(4, 4, device=device)
def forward(self, x):
w_transpose = torch.transpose(self.w_pre, 0, 1)
w_relu = torch.nn.functional.relu(w_transpose)
w = w_relu + self.w_add
return torch.matmul(x, w)
data = {}
for device in ["cpu", "cuda"]:
for use_runtime_constant_folding in [True, False]:
if device == "cpu" and use_runtime_constant_folding:
# We do not test runtime const folding for cpu mode.
continue
model = Net(device).to(device=device)
x = torch.randn((4, 4), device=device)
with torch.no_grad():
ref_output = model(x)
torch._dynamo.reset()
with torch.no_grad():
dim0_x = Dim("dim0_x", min=1, max=1024)
dynamic_shapes = {"x": {0: dim0_x}}
model_so_path = aot_compile(
model,
(x,),
dynamic_shapes=dynamic_shapes,
options={"aot_inductor.use_runtime_constant_folding": use_runtime_constant_folding})
suffix = f"{device}"
if use_runtime_constant_folding:
suffix += "_use_runtime_constant_folding"
data.update({
f"model_so_path_{suffix}": model_so_path,
f"inputs_{suffix}": [x],
f"outputs_{suffix}": [ref_output],
f"w_pre_{suffix}": model.w_pre,
f"w_add_{suffix}": model.w_add,
})
# Use this to communicate tensors to the cpp code
class Serializer(torch.nn.Module):
def __init__(self, data):
super().__init__()
for key in data:
setattr(self, key, data[key])
torch.jit.script(Serializer(data)).save("data.pt")