[dynamo] Remove incorrect sources (#112961)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112961
Approved by: https://github.com/voznesenskym, https://github.com/Skylion007
ghstack dependencies: #111306, #111415, #111725, #111726, #112962
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index eb48354..a2a2613 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -422,7 +422,10 @@
fwd_bwd_tracer=None,
).call_function(tx, args, kwargs)
- source = AttrSource(AttrSource(self.source, "__class__"), "forward")
+ if self.source:
+ source = AttrSource(AttrSource(self.source, "__class__"), "forward")
+ else:
+ source = None
fn = self.fn_cls.forward
if isinstance(fn, types.FunctionType):
return variables.UserFunctionVariable(fn, source=source).call_function(
@@ -440,8 +443,7 @@
)
def call_function(self, tx, args, kwargs):
- # TODO(jansel): BUG! the source here seems wrong, I believe it should just be None
- return AutogradFunctionVariable(self.fn_cls, source=self.source)
+ return AutogradFunctionVariable(self.fn_cls)
def call_method(
self,
diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py
index 52011de..68c6de8 100644
--- a/torch/_dynamo/variables/user_defined.py
+++ b/torch/_dynamo/variables/user_defined.py
@@ -382,8 +382,7 @@
}
partial_kwargs.update(kwargs)
if is_utils_checkpoint(self.value.func):
- # TODO(jansel): BUG? passing self.source here is a bit suss, expect None to be better
- return build_checkpoint_variable(source=self.source).call_function(
+ return build_checkpoint_variable().call_function(
tx, partial_args, partial_kwargs
)
return variables.TorchVariable(self.value.func).call_function(