| # Owner(s): ["module: dynamo"] |
| |
| import types |
| from copy import deepcopy |
| from unittest.mock import patch |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.eval_frame import unsupported |
| from torch._dynamo.mutation_guard import GenerationTracker |
| from torch._dynamo.testing import same |
| from torch.nn import functional as F |
| from torch.nn.modules.lazy import LazyModuleMixin |
| from torch.nn.parameter import Parameter, UninitializedParameter |
| |
| try: |
| from . import test_functions |
| except ImportError: |
| import test_functions |
| |
| |
| class BasicModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.scale = torch.randn(1, 10) |
| |
| def forward(self, x): |
| return F.relu(self.linear1(x)) * self.scale |
| |
| |
| class FnMember(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.activation = F.relu |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| if self.activation: |
| x = self.activation(x) |
| return x |
| |
| |
| class FnMemberCmp(torch.nn.Module): |
| def __init__(self, activation): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.activation = activation |
| |
| def forward(self, x): |
| x = self.linear1(x) |
| if self.activation is not None: |
| x = self.activation(x) |
| if self.activation is None: |
| x = torch.sigmoid(x) |
| return x |
| |
| |
| class SubmoduleExample(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| def forward(self, x): |
| x = self.layer1(x) |
| x = self.layer2(x) |
| return x * self.scale |
| |
| |
| class IsTrainingCheck(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.linear2 = torch.nn.Linear(10, 10) |
| self.train(True) |
| |
| def forward(self, x): |
| if self.training: |
| mod = self.linear1 |
| else: |
| mod = self.linear2 |
| return F.relu(mod(x)) |
| |
| |
| class IsEvalCheck(IsTrainingCheck): |
| def __init__(self): |
| super().__init__() |
| self.train(False) |
| |
| |
| class ModuleMethodCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| def call_and_scale(self, mod, x): |
| x = mod(x) |
| return x * self.scale |
| |
| def forward(self, x): |
| x1 = self.call_and_scale(self.layer1, x) |
| x2 = self.call_and_scale(self.layer2, x) |
| return x1 + x2 |
| |
| |
| class UnsupportedMethodCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| def call_and_scale(self, mod, x): |
| x = mod(x) |
| x = x * self.scale |
| return unsupported(x, x) |
| |
| def forward(self, x): |
| x1 = self.call_and_scale(self.layer1, x) |
| return x + x1 |
| |
| |
| class UnsupportedModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| def forward(self, x): |
| x = self.layer1(x) * self.scale |
| return unsupported(x, x) |
| |
| |
| class UnsupportedModuleCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = UnsupportedModule() |
| |
| def forward(self, x): |
| return 1 + self.mod(x * 1.5) |
| |
| |
| class ModuleWithStaticForward(torch.nn.Module): |
| @staticmethod |
| def forward(x): |
| return x * torch.sigmoid(x) |
| |
| |
| class ModuleCallModuleWithStaticForward(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = ModuleWithStaticForward() |
| |
| def forward(self, x): |
| return self.mod(x) |
| |
| |
| class ModuleStaticMethodCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| @staticmethod |
| def call_and_scale(scale, mod, x): |
| x = mod(x) |
| return x * scale |
| |
| def forward(self, x): |
| x1 = self.call_and_scale(self.scale, self.layer1, x) |
| x2 = self.call_and_scale(self.scale, self.layer2, x) |
| return x1 + x2 |
| |
| |
| class ModuleClassMethodCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = BasicModule() |
| self.scale = torch.randn(1, 10) |
| |
| @classmethod |
| def call_and_scale(cls, scale, mod, x): |
| x = mod(x) |
| return x * scale |
| |
| def forward(self, x): |
| x1 = self.call_and_scale(self.scale, self.layer1, x) |
| x2 = self.call_and_scale(self.scale, self.layer2, x) |
| return x1 + x2 |
| |
| |
| class ModuleProperty(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.scale = torch.randn(1, 10) |
| |
| @property |
| def scale_alias(self): |
| return self.scale |
| |
| def forward(self, x): |
| return x * self.scale_alias |
| |
| |
| class ConstLoop(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.count = 3 |
| |
| def forward(self, x): |
| for i in range(self.count): |
| x = torch.sigmoid(self.linear1(x)) |
| return x |
| |
| |
| class ViaModuleCall(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| |
| def forward(self, x): |
| return test_functions.constant3(torch.sigmoid(self.linear1(x)), x) |
| |
| |
| class IsNoneLayer(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = torch.nn.Linear(10, 10) |
| self.layer2 = None |
| self.train(True) |
| |
| def forward(self, x): |
| if self.layer1 is not None: |
| x = self.layer1(x) |
| if self.layer2 is not None: |
| x = self.layer2(x) |
| return x |
| |
| |
| class LayerList(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = [ |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| ] |
| |
| def forward(self, x): |
| for layer in self.layers: |
| x = layer(x) |
| return x |
| |
| |
| class ModuleList(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleList( |
| [ |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ] |
| ) |
| |
| def forward(self, x): |
| for i in range(len(self.layers)): |
| x = self.layers[i](x) |
| |
| for layer in self.layers: |
| x = layer(x) |
| |
| for layer, val in zip(self.layers, (x, x, x, x)): |
| x = layer(x) + val |
| |
| for layer, val in zip(self.layers, (1, 2, 3, 4)): |
| x = layer(x) + val |
| |
| for idx, layer in enumerate(self.layers): |
| x = layer(x) * idx |
| |
| for idx, layer in enumerate(self.layers[::-1]): |
| x = layer(x) * idx |
| |
| return x |
| |
| |
| class ModuleDict(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.ModuleDict( |
| { |
| "0": torch.nn.Linear(10, 10), |
| } |
| ) |
| |
| def forward(self, x): |
| # TODO(future PR): handle more logic |
| x = self.layers["0"](x) |
| return x |
| |
| |
| class TensorList(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = ( |
| torch.randn((1, 10)), |
| torch.randn((10, 1)), |
| torch.randn((1, 10)), |
| torch.randn((10, 1)), |
| ) |
| |
| def forward(self, x): |
| for layer in self.layers: |
| x = x * layer |
| return x |
| |
| |
| class Children(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.l1 = torch.nn.Linear(10, 10) |
| self.l2 = torch.nn.ReLU() |
| self.l3 = torch.nn.Linear(10, 10) |
| self.l4 = torch.nn.ReLU() |
| |
| def forward(self, x): |
| for block in self.children(): |
| x = block(x) |
| return x |
| |
| |
| class IntArg(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = torch.nn.Linear(10, 10) |
| |
| def forward(self, x, offset=1): |
| x = F.relu(self.layer1(x)) + offset |
| return x |
| |
| |
| class Seq(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ) |
| |
| def forward(self, x): |
| return self.layers(x) |
| |
| |
| class Cfg: |
| def __init__(self): |
| self.val = 0.5 |
| self.count = 3 |
| |
| |
| class CfgModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.cfg = Cfg() |
| self.layer = torch.nn.Linear(10, 10) |
| |
| def forward(self, x): |
| for i in range(self.cfg.count): |
| x = self.layer(x + self.cfg.val) |
| return x |
| |
| |
| class StringMember(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.mode = "some_string" |
| |
| def forward(self, x): |
| if self.mode == "some_string": |
| return F.relu(self.linear1(x)) |
| |
| |
| class _Block(torch.nn.Module): |
| def forward(self, x): |
| return 1.5 * torch.cat(x, 1) |
| |
| |
| class _DenseBlock(torch.nn.ModuleDict): |
| _version = 2 |
| |
| def __init__( |
| self, |
| num_layers: int = 3, |
| ) -> None: |
| super().__init__() |
| for i in range(num_layers): |
| self.add_module("denselayer%d" % (i + 1), _Block()) |
| |
| def forward(self, init_features): |
| features = [init_features] |
| for name, layer in self.items(): |
| new_features = layer(features) |
| features.append(new_features) |
| return torch.cat(features, 1) |
| |
| |
| class DenseNetBlocks(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layers = _DenseBlock() |
| |
| def forward(self, x): |
| return self.layers(x) |
| |
| |
| class MaterializedModule(torch.nn.Module): |
| """Once the below lazy module is initialized with its first input, |
| it is transformed into this module.""" |
| |
| param: Parameter |
| |
| def __init__(self): |
| super().__init__() |
| self.register_parameter("param", None) |
| |
| def forward(self, x): |
| return x |
| |
| |
| class LazyModule(LazyModuleMixin, MaterializedModule): |
| param: UninitializedParameter |
| cls_to_become = MaterializedModule |
| |
| def __init__(self): |
| super().__init__() |
| self.param = UninitializedParameter() |
| |
| def initialize_parameters(self, x): |
| self.param.materialize(x.shape) |
| |
| |
| def requires_grad1(module: torch.nn.Module, recurse: bool = False) -> bool: |
| requires_grad = any([p.requires_grad for p in module.parameters(recurse)]) |
| return requires_grad |
| |
| |
| def requires_grad2(module: torch.nn.Module, recurse: bool = False) -> bool: |
| requires_grad = any(p.requires_grad for p in module.parameters(recurse)) |
| return requires_grad |
| |
| |
| class ParametersModule1(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| self.scale = torch.nn.Parameter(torch.randn(1, 10)) |
| |
| def forward(self, x): |
| if not requires_grad1(self): |
| return F.relu(self.linear1(x)) * self.scale |
| else: |
| return x + 1 |
| |
| |
| class ParametersModule2(ParametersModule1): |
| def forward(self, x): |
| if not requires_grad2(self): |
| return F.relu(self.linear1(x)) * self.scale |
| else: |
| return x + 1 |
| |
| |
| class ParametersModule3(ParametersModule1): |
| def forward(self, x): |
| ones = torch.ones(10, dtype=next(self.parameters()).dtype) |
| return F.relu(self.linear1(x)) * self.scale + ones |
| |
| |
| class SuperModule(BasicModule): |
| def forward(self, x): |
| x = super().forward(x) |
| return x + 10.0 |
| |
| |
| class SuperModule2(BasicModule): |
| def forward(self, x): |
| return BasicModule.forward(self, x) |
| |
| |
| class ComplicatedSuperParent(torch.nn.Module): |
| @classmethod |
| def custom_add(cls, x): |
| x = x + x |
| return x |
| |
| |
| class SuperChildCallsClassMethod(ComplicatedSuperParent): |
| @classmethod |
| def child_func(cls, x): |
| x = super().custom_add(x) |
| return x |
| |
| def forward(self, x): |
| x = self.child_func(x) |
| return x |
| |
| |
| class HasAttrModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.scale = torch.nn.Parameter(torch.randn(1, 10)) |
| |
| def forward(self, x): |
| x = F.relu(x) |
| if hasattr(self, "scale"): |
| x *= self.scale |
| if hasattr(self, "scale2"): |
| x *= self.scale2 |
| return x |
| |
| |
| class EnumValues(torch.nn.ModuleDict): |
| def __init__( |
| self, |
| num_layers: int = 3, |
| ) -> None: |
| super().__init__() |
| for i in range(num_layers): |
| self.add_module("denselayer%d" % (i + 1), _Block()) |
| |
| def forward(self, init_features): |
| features = [init_features] |
| for idx, layer in enumerate(self.values()): |
| new_features = layer(features) |
| features.append(new_features) |
| return torch.cat(features, 1) |
| |
| |
| class AccessByKeys(torch.nn.ModuleDict): |
| def __init__( |
| self, |
| num_layers: int = 3, |
| ) -> None: |
| super().__init__() |
| for i in range(num_layers): |
| self.add_module("denselayer%d" % (i + 1), _Block()) |
| |
| def forward(self, init_features): |
| features = [init_features] |
| for k in self.keys(): |
| new_features = self[k](features) |
| features.append(new_features) |
| return torch.cat(features, 1) |
| |
| |
| class CallForwardDirectly(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = torch.nn.Linear(10, 10) |
| |
| def forward(self, x): |
| x = self.layer1.forward(x) |
| x = self.layer2.forward(x) |
| return x |
| |
| |
| class ModuleNameString(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear1 = torch.nn.Linear(10, 10) |
| |
| def forward(self, x): |
| if self.__class__.__name__ == "ABC": |
| return 10 |
| if self.linear1.__class__.__name__ == "Linear": |
| return F.relu(self.linear1(x) + 10) |
| return 11 |
| |
| |
| class SelfMutatingModule(torch.nn.Module): |
| def __init__(self, layer): |
| super().__init__() |
| self.layer = layer |
| self.counter = 0 |
| |
| def forward(self, x): |
| result = self.layer(x) + self.counter |
| self.counter += 1 |
| return F.relu(result) |
| |
| |
| class ModuleAttributePrecedenceBase(torch.nn.Module): |
| def linear(self, x): |
| return x * 2.0 |
| |
| |
| class ModuleAttributePrecedence(ModuleAttributePrecedenceBase): |
| def __init__(self): |
| super().__init__() |
| self.activation = torch.nn.ReLU() |
| self.linear = torch.nn.Linear(10, 10) |
| self.initializer = torch.ones([10, 10]) |
| self.scale = 0.5 |
| |
| def activation(self, x): |
| return x * 1.2 |
| |
| def initializer(self): |
| return torch.zeros([10, 10]) |
| |
| def scale(self): |
| return 2.0 |
| |
| def forward(self, x): |
| # object attribute takes precedence unless it's a nn.Module |
| return self.activation(self.linear(self.initializer + x)) * self.scale |
| |
| |
| class ModuleForwardHasGraphBreak(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.layer1 = BasicModule() |
| self.layer2 = BasicModule() |
| self.layer3 = torch.nn.Sequential(BasicModule(), BasicModule()) |
| self.layer4 = torch.nn.ModuleList( |
| [ |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| ] |
| ) |
| self.layer5 = torch.nn.ModuleDict( |
| { |
| "0": torch.nn.Linear(10, 10), |
| } |
| ) |
| self.scale = torch.randn(1, 10) |
| |
| def forward(self, x): |
| """ |
| This is used to test if the results of functions like `named_parameters` |
| can be reconstructed correctly after graph break. |
| |
| https://github.com/pytorch/torchdynamo/issues/1931 |
| """ |
| x = self.layer1(x) |
| params1 = dict(self.named_parameters()) |
| params2 = list(self.parameters()) |
| buffers1 = dict(self.named_buffers()) |
| buffers2 = list(self.buffers()) |
| modules1 = dict(self.named_modules()) |
| modules2 = list(self.modules()) |
| torch._dynamo.graph_break() |
| y = modules2 |
| y = modules1 |
| y = buffers2 |
| y = buffers1 |
| y = params2 |
| y = params1 |
| x = ( |
| self.layer2(x) |
| + y["layer3.1.linear1.weight"] |
| + y["layer4.2.weight"] |
| + y["layer5.0.weight"] |
| ) |
| return x * self.scale |
| |
| |
| class ModulePatch1(torch.nn.Module): |
| pass |
| |
| |
| class ModulePatch2(torch.nn.Module): |
| def forward(self, x): |
| return x - 1 |
| |
| |
| def make_test(fn, expected_ops=None): |
| def test_fn(self): |
| return torch._dynamo.testing.standard_test( |
| self, fn=fn, nargs=1, expected_ops=expected_ops |
| ) |
| |
| fn.eval() |
| return test_fn |
| |
| |
| class NNModuleTests(torch._dynamo.test_case.TestCase): |
| test_seq = make_test(Seq()) |
| test_basicmodule1 = make_test(BasicModule()) |
| test_basicmodule2 = make_test(BasicModule()) |
| test_submodules1 = make_test(SubmoduleExample()) |
| test_submodules2 = make_test(SubmoduleExample()) |
| test_modulemethod1 = make_test(ModuleMethodCall()) |
| test_modulemethod2 = make_test(ModuleMethodCall()) |
| test_module_call_module_with_static_forward = make_test( |
| ModuleCallModuleWithStaticForward() |
| ) |
| test_module_static_method = make_test(ModuleStaticMethodCall()) |
| test_fnmember = make_test(FnMember()) |
| test_fnmembercmp1 = make_test(FnMemberCmp(F.relu)) |
| test_fnmembercmp2 = make_test(FnMemberCmp(None)) |
| test_constloop = make_test(ConstLoop()) |
| test_istraining1 = make_test(IsTrainingCheck()) |
| test_istraining2 = make_test(IsTrainingCheck()) |
| test_iseval1 = make_test(IsEvalCheck()) |
| test_iseval2 = make_test(IsEvalCheck()) |
| test_viamodulecall = make_test(ViaModuleCall()) |
| test_isnonelayer = make_test(IsNoneLayer()) |
| test_layerlist = make_test(LayerList()) |
| test_tensorlist = make_test(TensorList()) |
| test_intarg = make_test(IntArg()) |
| test_cfgmod = make_test(CfgModule()) |
| test_stringmember = make_test(StringMember()) |
| test_modulelist = make_test(ModuleList()) |
| test_moduledict = make_test(ModuleDict()) |
| test_super1 = make_test(SuperModule()) |
| test_super2 = make_test(SuperModule2()) |
| test_super_class_method = make_test(SuperChildCallsClassMethod()) |
| test_children = make_test(Children()) |
| test_densenet = make_test(DenseNetBlocks()) |
| test_parameters1 = make_test(ParametersModule1()) |
| test_parameters2 = make_test(ParametersModule2()) |
| test_parameters3 = make_test(ParametersModule3(), expected_ops=5) |
| test_hasattr = make_test(HasAttrModule()) |
| test_enumvalues = make_test(EnumValues()) |
| test_access_by_keys = make_test(AccessByKeys()) |
| test_module_class_method = make_test(ModuleClassMethodCall()) |
| test_module_property = make_test(ModuleProperty()) |
| test_forward_directly = make_test(CallForwardDirectly()) |
| test_module_name_string = make_test(ModuleNameString()) |
| test_module_attribute_precedence = make_test(ModuleAttributePrecedence()) |
| |
| def test_module_forward_has_graph_break(self): |
| m = ModuleForwardHasGraphBreak() |
| x = torch.rand([10, 10]) |
| ref = m(x) |
| opt_m = torch._dynamo.optimize("eager")(m) |
| res = opt_m(x) |
| self.assertTrue(torch.allclose(ref, res)) |
| |
| def test_unsupportedmethod(self): |
| m = UnsupportedMethodCall() |
| i = torch.randn(10) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_m = torch._dynamo.optimize(cnt)(m) |
| r = opt_m(i) |
| self.assertTrue(torch._dynamo.testing.same(r, m(i))) |
| self.assertEqual(cnt.op_count, 5) |
| |
| def test_unsupportedmodule(self): |
| m = UnsupportedModuleCall() |
| i = torch.randn(10) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_m = torch._dynamo.optimize(cnt)(m) |
| r = opt_m(i) |
| self.assertTrue(torch._dynamo.testing.same(r, m(i))) |
| self.assertEqual(cnt.op_count, 6) |
| |
| def test_self_mutating1(self): |
| m1 = torch.nn.Linear(10, 10) |
| m2 = SelfMutatingModule(m1) |
| m3 = SelfMutatingModule(m1) |
| m4 = SelfMutatingModule(m1) |
| i = torch.randn(10) |
| out2 = [m2(i), m2(i), m2(i)] |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_m3 = torch._dynamo.optimize_assert(cnt)(m3) |
| opt_m4 = torch._dynamo.optimize_assert(cnt)(m4) |
| out3 = [opt_m3(i), opt_m3(i), opt_m3(i)] |
| out4 = [opt_m4(i), opt_m4(i), opt_m4(i)] |
| self.assertTrue(torch._dynamo.testing.same(out2, out3)) |
| self.assertTrue(torch._dynamo.testing.same(out2, out4)) |
| self.assertEqual(cnt.frame_count, 3) |
| |
| @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) |
| def test_generation_tag(self): |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| # guarantee that we have installed |
| # the generation tagging function |
| with torch._dynamo.optimize_assert(cnt): |
| pass |
| |
| m1 = torch.nn.Linear(10, 10) |
| prev_generation = GenerationTracker.get_generation_value(m1) |
| cur_generation = prev_generation + 1 |
| |
| with torch._dynamo.optimize_assert(cnt): |
| m2 = torch.nn.Linear(10, 10) |
| |
| self.assertEqual(GenerationTracker.get_generation_value(m1), prev_generation) |
| self.assertEqual(GenerationTracker.get_generation_value(m2), cur_generation) |
| # check that newly constructed instances |
| # also have the same generation (even if copied from an old instance) |
| m3 = deepcopy(m1) |
| self.assertEqual(GenerationTracker.get_generation_value(m3), cur_generation) |
| |
| def test_simple_torch_function(self): |
| def foo(x): |
| # function call, twice to test wrapping |
| x = F.sigmoid(x) |
| x = F.sigmoid(x) |
| # method call, twice to test wrapping |
| x = x.sigmoid() |
| x = x.sigmoid() |
| return x |
| |
| class TensorProxy(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) |
| |
| try: |
| |
| x = torch.randn(1).as_subclass(TensorProxy) |
| cnt = torch._dynamo.testing.CompileCounter() |
| out1 = foo(x) |
| opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) |
| out2 = opt_foo(x) |
| |
| self.assertEqual(cnt.op_count, 4) |
| self.assertTrue(torch._dynamo.testing.same(out1, out2)) |
| |
| finally: |
| torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) |
| |
| def test_torch_function_with_closure(self): |
| def run(): |
| |
| counter = 0 |
| |
| def foo(x): |
| # function call, twice to test wrapping |
| x = F.sigmoid(x) |
| x = F.sigmoid(x) |
| # method call, twice to test wrapping |
| x = x.sigmoid() |
| x = x.sigmoid() |
| return x |
| |
| class TensorProxy(torch.Tensor): |
| @classmethod |
| def __torch_function__(cls, func, types, args=(), kwargs=None): |
| nonlocal counter |
| # for now, only support reads from closure cells |
| # TODO(future PR): support writes as well |
| counter + 1 |
| return super().__torch_function__(func, types, args, kwargs) |
| |
| torch._dynamo.config.traceable_tensor_subclasses.add(TensorProxy) |
| |
| try: |
| x = torch.randn(1).as_subclass(TensorProxy) |
| x = torch.randn(1) |
| cnt = torch._dynamo.testing.CompileCounter() |
| out1 = foo(x) |
| opt_foo = torch._dynamo.optimize(cnt, nopython=True)(foo) |
| out2 = opt_foo(x) |
| |
| self.assertEqual(cnt.op_count, 4) |
| self.assertTrue(torch._dynamo.testing.same(out1, out2)) |
| finally: |
| torch._dynamo.config.traceable_tensor_subclasses.remove(TensorProxy) |
| |
| run() |
| |
| @patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False) |
| def test_nn_moduledict_contains(self): |
| class M(torch.nn.Module): |
| def __init__(self, module_dict): |
| super().__init__() |
| self.module_dict = module_dict |
| |
| def forward(self, x): |
| if "foo" in self.module_dict: |
| x = torch.mul(x, 1.0) |
| x = torch.add(x, 1.0) |
| return x |
| |
| module_dict = torch.nn.ModuleDict({"foo": torch.nn.Conv2d(1, 1, 1)}) |
| m = M(module_dict) |
| data = torch.randn(1) |
| out1 = m(data) |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) |
| out2 = opt_m(data) |
| self.assertEqual(cnt.op_count, 2) |
| self.assertTrue(torch._dynamo.testing.same(out1, out2)) |
| |
| module_dict = torch.nn.ModuleDict({"bar": torch.nn.Conv2d(1, 1, 1)}) |
| m = M(module_dict) |
| data = torch.randn(1) |
| out1 = m(data) |
| cnt = torch._dynamo.testing.CompileCounter() |
| torch._dynamo.reset() |
| opt_m = torch._dynamo.optimize(cnt, nopython=True)(m) |
| out2 = opt_m(data) |
| |
| self.assertEqual(cnt.op_count, 1) |
| self.assertTrue(torch._dynamo.testing.same(out1, out2)) |
| |
| module_dict = torch.nn.ModuleDict({"cat": torch.nn.Conv2d(1, 1, 1)}) |
| pre = m(data) |
| cnt.clear() |
| |
| with torch._dynamo.optimize(cnt, nopython=False): |
| opt_pre = m(data) |
| m = M(module_dict) |
| data = torch.randn(1) |
| out1 = m(data) |
| |
| out_post = m(data) |
| self.assertEqual(cnt.frame_count, 1) |
| self.assertEqual(cnt.op_count, 1) |
| self.assertTrue(torch._dynamo.testing.same(pre, opt_pre)) |
| self.assertTrue(torch._dynamo.testing.same(out1, out_post)) |
| |
| def test_lazy_module(self): |
| input_shape = (16, 3, 6, 7, 8) |
| |
| cnt = torch._dynamo.testing.CompileCounter() |
| module = LazyModule() |
| |
| def test_static_module(): |
| input = torch.ones(*input_shape) |
| module(input) |
| |
| opt_test_static_module = torch._dynamo.optimize(cnt)(test_static_module) |
| opt_test_static_module() |
| |
| self.assertTrue( |
| isinstance(module, MaterializedModule), |
| "Module should be transformed to an instance of MaterializedModule.", |
| ) |
| self.assertEqual(module.param.shape, input_shape) |
| |
| # test when mapped to UnspecializedNNModule |
| module = LazyModule() |
| |
| def test_unspecialized(): |
| nonlocal module |
| module = LazyModule() |
| input = torch.ones(*input_shape) |
| module(input) |
| |
| opt_test_unspecialized = torch._dynamo.optimize(cnt)(test_unspecialized) |
| opt_test_unspecialized() |
| |
| self.assertTrue( |
| isinstance(module, MaterializedModule), |
| "Module should be transformed to an instance of MaterializedModule.", |
| ) |
| self.assertEqual(module.param.shape, input_shape) |
| |
| # test with a static module in torch.* |
| module = torch.nn.modules.LazyBatchNorm3d( |
| affine=False, track_running_stats=False |
| ) |
| |
| cnt = torch._dynamo.testing.CompileCounter() |
| |
| torch._dynamo.reset() |
| |
| def test_torch_static(): |
| input = torch.ones(*input_shape) |
| return module(input) # fully materialized |
| |
| opt_test_torch_static = torch._dynamo.optimize(cnt)(test_torch_static) |
| opt_test_torch_static() |
| out = opt_test_torch_static() |
| |
| self.assertTrue(same(out, module(torch.ones(*input_shape)))) |
| |
| self.assertTrue( |
| isinstance(module, torch.nn.modules.batchnorm.BatchNorm3d), |
| "Module should be transformed to an instance of BatchNorm3d.", |
| ) |
| self.assertEqual(cnt.frame_count, 1, "No guards should have triggered.") |
| |
| def test_call_fn_with_non_const_inputs_safe(self): |
| class ModuleSpecialFwd(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv = torch.nn.Conv2d( |
| in_channels=3, out_channels=20, kernel_size=(5, 5) |
| ) |
| |
| def _conv_forward(self, x): |
| return self.conv._conv_forward(x, self.conv.weight, self.conv.bias) |
| |
| def forward(self, x): |
| return self._conv_forward(x) |
| |
| mod = ModuleSpecialFwd() |
| rx = torch.randn([3, 10, 10]) |
| real = mod(rx) |
| graph, _ = torch._dynamo.export(mod, rx) |
| self.assertTrue(torch._dynamo.testing.same(real, graph(rx))) |
| |
| |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| self.linear = torch.nn.Linear(10, 10) |
| self.register_buffer("buf0", torch.randn(10, 10)) |
| |
| def forward(self, x): |
| return self.relu(self.linear(x) + self.buf0) |
| |
| |
| class OptimizedModuleTest(torch._dynamo.test_case.TestCase): |
| def test_nn_module(self): |
| mod = MockModule() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_mod = torch._dynamo.optimize(cnt)(mod) |
| self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) |
| |
| x = torch.randn(10, 10) |
| self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) |
| self.assertEqual(cnt.frame_count, 1) |
| |
| def test_to(self): |
| mod = MockModule() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_mod = torch._dynamo.optimize(cnt)(mod) |
| x = torch.randn(10, 10) |
| self.assertTrue(torch._dynamo.testing.same(mod(x), opt_mod(x))) |
| self.assertEqual(cnt.frame_count, 1) |
| |
| # Ensure that there is no recompilation |
| opt_mod(x) |
| self.assertEqual(cnt.frame_count, 1) |
| |
| opt_mod = opt_mod.to(device="cpu").to(dtype=torch.float64) |
| self.assertIsInstance(opt_mod, torch._dynamo.OptimizedModule) |
| x = torch.randn(10, 10).to(dtype=torch.float64) |
| opt_mod(x) |
| # Ensure that there is a recompilation |
| self.assertEqual(cnt.frame_count, 2) |
| |
| # Ensure that there is no recompilation |
| opt_mod(x) |
| self.assertEqual(cnt.frame_count, 2) |
| |
| torch._dynamo.reset() |
| opt_mod(x) |
| self.assertEqual(cnt.frame_count, 3) |
| |
| def test_attr(self): |
| class MockModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(10, 10) |
| self.register_buffer("buf0", torch.randn(10, 10)) |
| |
| def forward(self, x): |
| return self.r(torch.sin(x)) + self.buf0 |
| |
| mod = MockModule() |
| opt_mod = torch._dynamo.optimize("eager")(mod) |
| |
| # Check parameteres and buffers |
| for (p1, p2) in zip(mod.parameters(), opt_mod.parameters()): |
| self.assertTrue(id(p1) == id(p2)) |
| |
| def test_recursion(self): |
| mod = MockModule() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_mod = torch._dynamo.optimize(cnt)(mod) |
| |
| for _ in range(5): |
| opt_mod = torch._dynamo.optimize(cnt)(opt_mod) |
| opt_mod(torch.randn(10, 10)) |
| self.assertEqual(cnt.frame_count, 1) |
| |
| def test_composition(self): |
| class InnerModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(torch.sin(x)) |
| |
| opt_inner_mod = InnerModule() |
| |
| class OuterModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = opt_inner_mod |
| |
| def forward(self, x): |
| return self.mod(torch.cos(x)) |
| |
| outer_mod = OuterModule() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) |
| |
| x = torch.randn(4) |
| self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) |
| self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) |
| self.assertEqual(cnt.frame_count, 1) |
| |
| def test_composition_with_opt_mod(self): |
| class InnerModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(torch.sin(x)) |
| |
| inner_mod = InnerModule() |
| cnt = torch._dynamo.testing.CompileCounter() |
| opt_inner_mod = torch._dynamo.optimize(cnt)(inner_mod) |
| |
| class OuterModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mod = opt_inner_mod |
| |
| def forward(self, x): |
| return self.mod(torch.cos(x)) |
| |
| outer_mod = OuterModule() |
| opt_outer_mod = torch._dynamo.optimize(cnt)(outer_mod) |
| |
| x = torch.randn(4) |
| self.assertIsInstance(opt_outer_mod, torch._dynamo.OptimizedModule) |
| self.assertTrue(torch._dynamo.testing.same(outer_mod(x), opt_outer_mod(x))) |
| # There will be a graph break for the inner mod being OptimizedModule |
| self.assertEqual(cnt.frame_count, 2) |
| |
| def test_module_patch(self): |
| mod = ModulePatch1() |
| mod.forward = types.MethodType(ModulePatch2.forward, mod) |
| |
| def fn(x): |
| return mod(x) |
| |
| self.assertTrue( |
| torch.allclose( |
| torch._dynamo.optimize("eager", nopython=True)(fn)(torch.ones(10)), |
| torch.zeros(1), |
| ) |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |