blob: d9f25c5954995c50a5f3117be7bf04f75cffdd07 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import importlib
import json
import os
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.optimizations import backends
from torch._dynamo.optimizations.analysis import has_mutation
from torch._dynamo.optimizations.log_args import conv_args_analysis
from torch._dynamo.optimizations.normalize import Inplacifier, normalize
from torch._dynamo.testing import same
def has_onnxruntime():
try:
importlib.import_module("onnxruntime")
return True
except ImportError:
return False
def has_ipex():
try:
importlib.import_module("intel_extension_for_pytorch")
return True
except ImportError:
return False
def has_functorch():
try:
importlib.import_module("functorch")
return True
except ImportError:
return False
class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.Sigmoid(),
)
def forward(self, x):
return self.layers(x)
class Conv_Bn_Relu(torch.nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(Conv_Bn_Relu, self).__init__()
self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class TestOptimizations(torch._dynamo.test_case.TestCase):
def test_inplacifier(self):
gm = torch.fx.symbolic_trace(Seq())
normalize(gm)
Inplacifier(gm).inplacify()
gm.recompile()
code = gm.code.replace(" ", "")
self.assertIn("inplace=True", code)
self.assertIn("out=linear_1", code)
def test_has_mutation(self):
gm = torch.fx.symbolic_trace(Seq())
self.assertFalse(has_mutation(gm, torch.rand([10, 10])))
class Mutating(torch.nn.Module):
def __init__(self):
super(Mutating, self).__init__()
def forward(self, arg):
return arg.add_(1)
gm = torch.fx.symbolic_trace(Mutating())
self.assertTrue(has_mutation(gm, torch.rand([10, 1, 1, 1])))
def test_has_mutation_factory(self):
def fn():
x = torch.empty(2)
x.fill_(2)
return x
def compiler_fn(graph, example_inputs):
self.assertTrue(has_mutation(graph, example_inputs))
return graph
opt_fn = torch._dynamo.optimize(compiler_fn)(fn)
opt_fn()
def test_example_inputs(self):
def fn(a, bc, d):
b, c = bc
return a / d - b / c
def compiler_fn(graph, example_inputs):
nonlocal r1
r1 = graph(*example_inputs)[0]
return graph.forward
a = torch.empty(2).fill_(1)
b = torch.empty(2).fill_(2)
c = torch.empty(2).fill_(3)
d = 4
r1 = None
r2 = fn(a, (b, c), d)
opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn)
r3 = opt_fn(a, (b, c), d)
self.assertIsNotNone(r1)
self.assertTrue(same(r1, r2))
self.assertTrue(same(r1, r3))
@patch.object(torch._dynamo.config, "fake_tensor_propagation", False)
@unittest.skipIf(not has_functorch(), "requires functorch")
def test_log_conv_args(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
# check tmp/conv_args.json exists and has keys as arg names
filename = "tmp/conv_args.json"
if os.path.exists(filename):
os.remove(filename)
opt_model = torch._dynamo.optimize(conv_args_analysis)(model)
with torch.no_grad():
r2 = opt_model(input)
self.assertTrue(same(r1, r2.float(), tol=0.1))
self.assertTrue(os.path.exists(filename))
with open(filename) as f:
args_dict = json.load(f)
self.assertIn("convolution", args_dict.keys())
conv_args_dict = args_dict["convolution"]
self.assertIn("input", conv_args_dict.keys())
self.assertIn("weight", conv_args_dict.keys())
self.assertIn("bias", conv_args_dict.keys())
self.assertIn("stride", conv_args_dict.keys())
self.assertIn("padding", conv_args_dict.keys())
self.assertIn("dilation", conv_args_dict.keys())
self.assertIn("transposed", conv_args_dict.keys())
self.assertIn("output_padding", conv_args_dict.keys())
self.assertIn("groups", conv_args_dict.keys())
os.remove(filename)
@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_fp32(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
opt_model = torch._dynamo.optimize(backends.ipex_fp32)(model)
with torch.no_grad():
r2 = opt_model(input)
self.assertTrue(same(r1, r2))
self.assertEqual(r2.dtype, torch.float32)
@unittest.skipIf(not has_ipex(), "requires ipex")
def test_ipex_bf16(self):
model = Conv_Bn_Relu(3, 32, kernel_size=3, stride=1)
model = model.to(memory_format=torch.channels_last)
model = model.eval()
input = torch.randn(8, 3, 64, 64).contiguous(memory_format=torch.channels_last)
r1 = model(input)
opt_model = torch._dynamo.optimize(backends.ipex_bf16)(model)
with torch.no_grad(), torch.cpu.amp.autocast():
r2 = opt_model(input)
self.assertTrue(same(r1, r2.float(), tol=0.1))
self.assertEqual(r2.dtype, torch.bfloat16)
class NormalizeIRTests(torch._dynamo.test_case.TestCase):
@unittest.skipIf(not has_functorch(), "requires functorch")
def test_inplace_normalize(self):
def fn(a, b):
x = torch.cos(a)
x += b
return torch.sin(x)
a = torch.randn(10)
b = torch.randn(10).to(torch.float64)
ref = fn(a, b)
optimized_fn = torch._dynamo.optimize("aot_eager")(fn)
res = optimized_fn(a, b)
self.assertTrue(same(ref, res))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()