blob: 2fb83b3add6cfbc9fbcb053e0184d52bb6c92a2e [file] [log] [blame]
# Owner(s): ["module: dynamo"]
from copy import deepcopy
from unittest.mock import patch
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.eval_frame import unsupported
from torch._dynamo.mutation_guard import GenerationTracker
from torch._dynamo.testing import same
from torch.nn import functional as F
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.parameter import Parameter, UninitializedParameter
try:
from . import test_functions
except ImportError:
import test_functions
class BasicModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.scale = torch.randn(1, 10)
def forward(self, x):
return F.relu(self.linear1(x)) * self.scale
class FnMember(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.activation = F.relu
def forward(self, x):
x = self.linear1(x)
if self.activation:
x = self.activation(x)
return x
class FnMemberCmp(torch.nn.Module):
def __init__(self, activation):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.activation = activation
def forward(self, x):
x = self.linear1(x)
if self.activation is not None:
x = self.activation(x)
if self.activation is None:
x = torch.sigmoid(x)
return x
class SubmoduleExample(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x * self.scale
class IsTrainingCheck(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.linear2 = torch.nn.Linear(10, 10)
self.train(True)
def forward(self, x):
if self.training:
mod = self.linear1
else:
mod = self.linear2
return F.relu(mod(x))
class IsEvalCheck(IsTrainingCheck):
def __init__(self):
super().__init__()
self.train(False)
class ModuleMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
def call_and_scale(self, mod, x):
x = mod(x)
return x * self.scale
def forward(self, x):
x1 = self.call_and_scale(self.layer1, x)
x2 = self.call_and_scale(self.layer2, x)
return x1 + x2
class UnsupportedMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.scale = torch.randn(1, 10)
def call_and_scale(self, mod, x):
x = mod(x)
x = x * self.scale
return unsupported(x, x)
def forward(self, x):
x1 = self.call_and_scale(self.layer1, x)
return x + x1
class UnsupportedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.scale = torch.randn(1, 10)
def forward(self, x):
x = self.layer1(x) * self.scale
return unsupported(x, x)
class UnsupportedModuleCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.mod = UnsupportedModule()
def forward(self, x):
return 1 + self.mod(x * 1.5)
class ModuleStaticMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
@staticmethod
def call_and_scale(scale, mod, x):
x = mod(x)
return x * scale
def forward(self, x):
x1 = self.call_and_scale(self.scale, self.layer1, x)
x2 = self.call_and_scale(self.scale, self.layer2, x)
return x1 + x2
class ModuleClassMethodCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = BasicModule()
self.scale = torch.randn(1, 10)
@classmethod
def call_and_scale(cls, scale, mod, x):
x = mod(x)
return x * scale
def forward(self, x):
x1 = self.call_and_scale(self.scale, self.layer1, x)
x2 = self.call_and_scale(self.scale, self.layer2, x)
return x1 + x2
class ModuleProperty(torch.nn.Module):
def __init__(self):
super().__init__()
self.scale = torch.randn(1, 10)
@property
def scale_alias(self):
return self.scale
def forward(self, x):
return x * self.scale_alias
class ConstLoop(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.count = 3
def forward(self, x):
for i in range(self.count):
x = torch.sigmoid(self.linear1(x))
return x
class ViaModuleCall(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
def forward(self, x):
return test_functions.constant3(torch.sigmoid(self.linear1(x)), x)
class IsNoneLayer(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
self.layer2 = None
self.train(True)
def forward(self, x):
if self.layer1 is not None:
x = self.layer1(x)
if self.layer2 is not None:
x = self.layer2(x)
return x
class LayerList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = [
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class ModuleList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleList(
[
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
]
)
def forward(self, x):
for i in range(len(self.layers)):
x = self.layers[i](x)
for layer in self.layers:
x = layer(x)
for layer, val in zip(self.layers, (x, x, x, x)):
x = layer(x) + val
for layer, val in zip(self.layers, (1, 2, 3, 4)):
x = layer(x) + val
for idx, layer in enumerate(self.layers):
x = layer(x) * idx
for idx, layer in enumerate(self.layers[::-1]):
x = layer(x) * idx
return x
class ModuleDict(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.ModuleDict(
{
"0": torch.nn.Linear(10, 10),
}
)
def forward(self, x):
# TODO(future PR): handle more logic
x = self.layers["0"](x)
return x
class TensorList(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = (
torch.randn((1, 10)),
torch.randn((10, 1)),
torch.randn((1, 10)),
torch.randn((10, 1)),
)
def forward(self, x):
for layer in self.layers:
x = x * layer
return x
class Children(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(10, 10)
self.l2 = torch.nn.ReLU()
self.l3 = torch.nn.Linear(10, 10)
self.l4 = torch.nn.ReLU()
def forward(self, x):
for block in self.children():
x = block(x)
return x
class IntArg(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(10, 10)
def forward(self, x, offset=1):
x = F.relu(self.layer1(x)) + offset
return x
class Seq(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 10),
torch.nn.ReLU(),
)
def forward(self, x):
return self.layers(x)
class Cfg:
def __init__(self):
self.val = 0.5
self.count = 3
class CfgModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.cfg = Cfg()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
for i in range(self.cfg.count):
x = self.layer(x + self.cfg.val)
return x
class StringMember(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.mode = "some_string"
def forward(self, x):
if self.mode == "some_string":
return F.relu(self.linear1(x))
class _Block(torch.nn.Module):
def forward(self, x):
return 1.5 * torch.cat(x, 1)
class _DenseBlock(torch.nn.ModuleDict):
_version = 2
def __init__(
self,
num_layers: int = 3,
) -> None:
super().__init__()
for i in range(num_layers):
self.add_module("denselayer%d" % (i + 1), _Block())
def forward(self, init_features):
features = [init_features]
for name, layer in self.items():
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class DenseNetBlocks(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = _DenseBlock()
def forward(self, x):
return self.layers(x)
class MaterializedModule(torch.nn.Module):
"""Once the below lazy module is initialized with its first input,
it is transformed into this module."""
param: Parameter
def __init__(self):
super().__init__()
self.register_parameter("param", None)
def forward(self, x):
return x
class LazyModule(LazyModuleMixin, MaterializedModule):
param: UninitializedParameter
cls_to_become = MaterializedModule
def __init__(self):
super().__init__()
self.param = UninitializedParameter()
def initialize_parameters(self, x):
self.param.materialize(x.shape)
def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool:
requires_grad = any([p.requires_grad for p in module.parameters(recurse)])
return requires_grad
def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool:
requires_grad = any(p.requires_grad for p in module.parameters(recurse))
return requires_grad
class ParametersModule1(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
self.scale = torch.nn.Parameter(torch.randn(1, 10))
def forward(self, x):
if not requires_grad1(self):
return F.relu(self.linear1(x)) * self.scale
else:
return x + 1
class ParametersModule2(ParametersModule1):
def forward(self, x):
if not requires_grad2(self):
return F.relu(self.linear1(x)) * self.scale
else:
return x + 1
class ParametersModule3(ParametersModule1):
def forward(self, x):
ones = torch.ones(10, dtype=next(self.parameters()).dtype)
return F.relu(self.linear1(x)) * self.scale + ones
class SuperModule(BasicModule):
def forward(self, x):
x = super().forward(x)
return x + 10.0
class ComplicatedSuperParent(torch.nn.Module):
@classmethod
def custom_add(cls, x):
x = x + x
return x
class SuperChildCallsClassMethod(ComplicatedSuperParent):
@classmethod
def child_func(cls, x):
x = super().custom_add(x)
return x
def forward(self, x):
x = self.child_func(x)
return x
class HasAttrModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.scale = torch.nn.Parameter(torch.randn(1, 10))
def forward(self, x):
x = F.relu(x)
if hasattr(self, "scale"):
x *= self.scale
if hasattr(self, "scale2"):
x *= self.scale2
return x
class EnumValues(torch.nn.ModuleDict):
def __init__(
self,
num_layers: int = 3,
) -> None:
super().__init__()
for i in range(num_layers):
self.add_module("denselayer%d" % (i + 1), _Block())
def forward(self, init_features):
features = [init_features]
for idx, layer in enumerate(self.values()):
new_features = layer(features)
features.append(new_features)
return torch.cat(features, 1)
class CallForwardDirectly(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = BasicModule()
self.layer2 = torch.nn.Linear(10, 10)
def forward(self, x):
x = self.layer1.forward(x)
x = self.layer2.forward(x)
return x
class ModuleNameString(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 10)
def forward(self, x):
if self.__class__.__name__ == "ABC":
return 10
if self.linear1.__class__.__name__ == "Linear":
return F.relu(self.linear1(x) + 10)
return 11
class SelfMutatingModule(torch.nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.counter = 0
def forward(self, x):
result = self.layer(x) + self.counter
self.counter += 1
return F.relu(result)
class ModuleAttributePrecedenceBase(torch.nn.Module):
def __init__(self):
super().__init__()
def linear(self, x):
return x * 2.0
class ModuleAttributePrecedence(ModuleAttributePrecedenceBase):
def __init__(self):
super().__init__()
self.activation = torch.nn.ReLU()
self.linear = torch.nn.Linear(10, 10)
self.initializer = torch.ones([10, 10])
self.scale = 0.5
def activation(self, x):
return x * 1.2
def initializer(self):
return torch.zeros([10, 10])
def scale(self):
return 2.0
def forward(self, x):
# object attribute takes precedence unless it's a nn.Module
return self.activation(self.linear(self.initializer + x)) * self.scale
def make_test(fn, expected_ops=None):
def test_fn(self):
return torch._dynamo.testing.standard_test(
self, fn=fn, nargs=1, expected_ops=expected_ops
)
fn.eval()
return test_fn
class NNModuleTests(torch._dynamo.test_case.TestCase):
test_seq = make_test(Seq())
test_basicmodule1 = make_test(BasicModule())
test_basicmodule2 = make_test(BasicModule())
test_submodules1 = make_test(SubmoduleExample())
test_submodules2 = make_test(SubmoduleExample())
test_modulemethod1 = make_test(ModuleMethodCall())
test_modulemethod2 = make_test(ModuleMethodCall())
test_module_static_method = make_test(ModuleStaticMethodCall())
test_fnmember = make_test(FnMember())
test_fnmembercmp1 = make_test(FnMemberCmp(F.relu))
test_fnmembercmp2 = make_test(FnMemberCmp(None))
test_constloop = make_test(ConstLoop())
test_istraining1 = make_test(IsTrainingCheck())
test_istraining2 = make_test(IsTrainingCheck())
test_iseval1 = make_test(IsEvalCheck())
test_iseval2 = make_test(IsEvalCheck())
test_viamodulecall = make_test(ViaModuleCall())
test_isnonelayer = make_test(IsNoneLayer())
test_layerlist = make_test(LayerList())
test_tensorlist = make_test(TensorList())
test_intarg = make_test(IntArg())
test_cfgmod = make_test(CfgModule())
test_stringmember = make_test(StringMember())
test_modulelist = make_test(ModuleList())
test_moduledict = make_test(ModuleDict())
test_super1 = make_test(SuperModule())
test_super_class_method = make_test(SuperChildCallsClassMethod())
test_children = make_test(Children())
test_densenet = make_test(DenseNetBlocks())
test_parameters1 = make_test(ParametersModule1())
test_parameters2 = make_test(ParametersModule2())
test_parameters3 = make_test(ParametersModule3(), expected_ops=5)
test_hasattr = make_test(HasAttrModule())
test_enumvalues = make_test(EnumValues())
test_module_class_method = make_test(ModuleClassMethodCall())
test_module_property = make_test(ModuleProperty())
test_forward_directly = make_test(CallForwardDirectly())
test_module_name_string = make_test(ModuleNameString())
test_module_attribute_precedence = make_test(ModuleAttributePrecedence())
def test_unsupportedmethod(self):
m = UnsupportedMethodCall()
i = torch.randn(10)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt)(m)
r = opt_m(i)
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 5)
def test_unsupportedmodule(self):
m = UnsupportedModuleCall()
i = torch.randn(10)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt)(m)
r = opt_m(i)
self.assertTrue(torch._dynamo.testing.same(r, m(i)))
self.assertEqual(cnt.op_count, 6)
def test_self_mutating1(self):
m1 = torch.nn.Linear(10, 10)
m2 = SelfMutatingModule(m1)
m3 = SelfMutatingModule(m1)
m4 = SelfMutatingModule(m1)
i = torch.randn(10)
out2 = [m2(i), m2(i), m2(i)]
cnt = torch._dynamo.testing.CompileCounter()
opt_m3 = torch._dynamo.optimize_assert(cnt)(m3)
opt_m4 = torch._dynamo.optimize_assert(cnt)(m4)
out3 = [opt_m3(i), opt_m3(i), opt_m3(i)]
out4 = [opt_m4(i), opt_m4(i), opt_m4(i)]
self.assertTrue(torch._dynamo.testing.same(out2, out3))
self.assertTrue(torch._dynamo.testing.same(out2, out4))
self.assertEqual(cnt.frame_count, 3)
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_generation_tag(self):
cnt = torch._dynamo.testing.CompileCounter()
# guarantee that we have installed
# the generation tagging function
with torch._dynamo.optimize_assert(cnt):
pass
m1 = torch.nn.Linear(10, 10)
prev_generation = GenerationTracker.get_generation_value(m1)
cur_generation = prev_generation + 1
with torch._dynamo.optimize_assert(cnt):
m2 = torch.nn.Linear(10, 10)
self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation)
self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation)
# check that newly constructed instances
# also have the same generation (even if copied from an old instance)
m3 = deepcopy(m1)
self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation)
def test_simple_torch_function(self):
def foo(x):
# function call, twice to test wrapping
x = F.sigmoid(x)
x = F.sigmoid(x)
# method call, twice to test wrapping
x = x.sigmoid()
x = x.sigmoid()
return x
class TensorProxy(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
x = torch.randn(1).as_subclass(TensorProxy)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
def test_torch_function_with_closure(self):
def run():
counter = 0
def foo(x):
# function call, twice to test wrapping
x = F.sigmoid(x)
x = F.sigmoid(x)
# method call, twice to test wrapping
x = x.sigmoid()
x = x.sigmoid()
return x
class TensorProxy(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
nonlocal counter
# for now, only support reads from closure cells
# TODO(future PR): support writes as well
counter + 1
return super().__torch_function__(func, types, args, kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy)
x = torch.randn(1).as_subclass(TensorProxy)
x = torch.randn(1)
cnt = torch._dynamo.testing.CompileCounter()
out1 = foo(x)
opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo)
out2 = opt_foo(x)
self.assertEqual(cnt.op_count, 4)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy)
run()
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
def test_nn_moduledict_contains(self):
class M(torch.nn.Module):
def __init__(self, module_dict):
super().__init__()
self.module_dict = module_dict
def forward(self, x):
if "foo" in self.module_dict:
x = torch.mul(x, 1.0)
x = torch.add(x, 1.0)
return x
module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)})
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
cnt = torch._dynamo.testing.CompileCounter()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
out2 = opt_m(data)
self.assertEqual(cnt.op_count, 2)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)})
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.reset()
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
out2 = opt_m(data)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(out1, out2))
module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)})
pre = m(data)
cnt.clear()
with torch._dynamo.optimize(cnt, nopython=False):
opt_pre = m(data)
m = M(module_dict)
data = torch.randn(1)
out1 = m(data)
out_post = m(data)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
self.assertTrue(torch._dynamo.testing.same(pre, opt_pre))
self.assertTrue(torch._dynamo.testing.same(out1, out_post))
def test_lazy_module(self):
input_shape = (16, 3, 6, 7, 8)
cnt = torch._dynamo.testing.CompileCounter()
module = LazyModule()
def test_static_module():
input = torch.ones(*input_shape)
module(input)
opt_test_static_module = torch._dynamo.optimize(cnt)(test_static_module)
opt_test_static_module()
self.assertTrue(
isinstance(module, MaterializedModule),
"Module should be transformed to an instance of MaterializedModule.",
)
self.assertEqual(module.param.shape, input_shape)
# test when mapped to UnspecializedNNModule
module = LazyModule()
def test_unspecialized():
nonlocal module
module = LazyModule()
input = torch.ones(*input_shape)
module(input)
opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized)
opt_test_unspecialized()
self.assertTrue(
isinstance(module, MaterializedModule),
"Module should be transformed to an instance of MaterializedModule.",
)
self.assertEqual(module.param.shape, input_shape)
# test with a static module in torch.*
module = torch.nn.modules.LazyBatchNorm3d(
affine=False, track_running_stats=False
)
cnt = torch._dynamo.testing.CompileCounter()
torch._dynamo.reset()
def test_torch_static():
input = torch.ones(*input_shape)
return module(input) # fully materialized
opt_test_torch_static = torch._dynamo.optimize(cnt)(test_torch_static)
opt_test_torch_static()
out = opt_test_torch_static()
self.assertTrue(same(out, module(torch.ones(*input_shape))))
self.assertTrue(
isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d),
"Module should be transformed to an instance of BatchNorm3d.",
)
self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.")
def test_call_fn_with_non_const_inputs_safe(self):
class ModuleSpecialFwd(torch.nn.Module):
def __init__(self):
super(ModuleSpecialFwd, self).__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=20, kernel_size=(5, 5)
)
def _conv_forward(self, x):
return self.conv._conv_forward(x, self.conv.weight, self.conv.bias)
def forward(self, x):
return self._conv_forward(x)
mod = ModuleSpecialFwd()
rx = torch.randn([3, 10, 10])
real = mod(rx)
graph, _ = torch._dynamo.export(mod, rx)
self.assertTrue(torch._dynamo.testing.same(real, graph(rx)))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()