commit | fedae41c57bade067bd875c6ba9e4bc90812c016 | [log] [tgz] |
---|---|---|
author | Animesh Jain <anijain@umich.edu> | Mon Jul 15 16:27:32 2024 -0700 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Tue Jul 16 06:55:46 2024 +0000 |
tree | 43dbf37a69a55f77f2f567a1d01a19b144f51256 | |
parent | 83eedf66b9e7f52323d9f45c5dfaa64472452595 [diff] |
[dynamo] Do not mark nn.module containers as BuiltinNNModuleVariable (#130773) Pull Request resolved: https://github.com/pytorch/pytorch/pull/130773 Approved by: https://github.com/williamwen42, https://github.com/mlazos
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index 9abd236..c11dfa0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py
@@ -1278,7 +1278,9 @@ # this will get cleaned up once compile ends self.tx.output.nn_modules[self.name] = value - if value.__module__.startswith(("torch.nn.", "torch.ao.")): + if value.__module__.startswith( + ("torch.nn.", "torch.ao.") + ) and not value.__module__.startswith("torch.nn.modules.container"): result = UnspecializedBuiltinNNModuleVariable(value, source=self.source) else: result = UnspecializedNNModuleVariable(value, source=self.source)