Fix resolution callback for @script_method (#8912)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/8715. This was peeking too few frames up when we instantiate the callback
Closes https://github.com/pytorch/pytorch/pull/8912
Reviewed By: ezyang
Differential Revision: D8684972
Pulled By: jamesr66a
fbshipit-source-id: 11dbb919ae7273f92cbe25fe21f7946b9fa28aeb
diff --git a/test/expect/TestScript.test_call_python_fn_from_script_module.expect b/test/expect/TestScript.test_call_python_fn_from_script_module.expect
new file mode 100644
index 0000000..6688c82
--- /dev/null
+++ b/test/expect/TestScript.test_call_python_fn_from_script_module.expect
@@ -0,0 +1,6 @@
+graph(%x : Dynamic
+ %1 : Dynamic) {
+ %2 : Dynamic = aten::mm(%x, %1)
+ %3 : Dynamic = ^python_fn()(%2)
+ return (%3);
+}
diff --git a/test/expect/TestScript.test_call_script_fn_from_script_module.expect b/test/expect/TestScript.test_call_script_fn_from_script_module.expect
new file mode 100644
index 0000000..08ffa4b
--- /dev/null
+++ b/test/expect/TestScript.test_call_script_fn_from_script_module.expect
@@ -0,0 +1,6 @@
+graph(%x : Dynamic
+ %1 : Dynamic) {
+ %2 : Dynamic = aten::mm(%x, %1)
+ %3 : Dynamic = aten::neg(%2)
+ return (%3);
+}
diff --git a/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect b/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect
new file mode 100644
index 0000000..6194232
--- /dev/null
+++ b/test/expect/TestScript.test_call_tracing_fn_from_script_module.expect
@@ -0,0 +1,6 @@
+graph(%x : Dynamic
+ %1 : Dynamic) {
+ %2 : Dynamic = aten::mm(%x, %1)
+ %3 : Double(3, 3) = aten::neg(%2)
+ return (%3);
+}
diff --git a/test/test_jit.py b/test/test_jit.py
index 6516118..fd6d87b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -3920,7 +3920,6 @@
self.assertExpected(str(script_fn.graph))
- @unittest.skip('TODO: Python value resolution broken')
def test_call_python_fn_from_script_module(self):
def python_fn(x):
return torch.neg(x)
@@ -3935,14 +3934,7 @@
return python_fn(torch.mm(x, self.param))
sm = ScriptMod()
- # TODO: At the time of writing this test fails with:
- # RuntimeError
- # undefined value python_fn:
- # @torch.jit.script_method
- # def forward(self, x):
- # return python_fn(torch.mm(x, self.param))
- # ~~~~~~~~~ <--- HERE
- self.assertExpected(str(sm.graph))
+ self.assertExpected(str(sm.__getattr__('forward').graph))
def test_call_python_mod_from_script_module(self):
class PythonMod(torch.nn.Module):
@@ -3968,7 +3960,6 @@
# are NOT inlined
self.assertExpected(str(sm.graph))
- @unittest.skip('TODO: Python value resolution broken')
def test_call_tracing_fn_from_script_module(self):
@torch.jit.trace(torch.rand(3, 3))
def traced_fn(x):
@@ -3984,14 +3975,7 @@
return traced_fn(torch.mm(x, self.param))
sm = ScriptMod()
- # FIXME: at the time of writing we fail with the following:
- # RuntimeError:
- # undefined value traced_fn:
- # @torch.jit.script_method
- # def forward(self, x):
- # return traced_fn(torch.mm(x, self.param))
- # ~~~~~~~~~ <--- HERE
- self.assertExpected(str(sm.graph))
+ self.assertExpected(str(sm.__getattr__('forward').graph))
def test_call_tracing_mod_from_script_module(self):
class TracedMod(torch.nn.Module):
@@ -4018,7 +4002,6 @@
# inlined
self.assertExpected(str(sm.graph))
- @unittest.skip('TODO: Python value resolution broken')
def test_call_script_fn_from_script_module(self):
@torch.jit.script
def script_fn(x):
@@ -4034,14 +4017,7 @@
return script_fn(torch.mm(x, self.param))
sm = ScriptMod()
- # FIXME: at the time of writing, this failes with
- # RuntimeError:
- # undefined value traced_fn:
- # @torch.jit.script_method
- # def forward(self, x):
- # return traced_fn(torch.mm(x, self.param))
- # ~~~~~~~~~ <--- HERE
- self.assertExpected(str(sm.graph))
+ self.assertExpected(str(sm.__getattr__('forward').graph))
def test_call_script_mod_from_script_module(self):
class ScriptMod1(torch.jit.ScriptModule):
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 937dff5..9c70516 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -384,7 +384,19 @@
def script_method(fn):
- return ScriptMethodStub(createResolutionCallback(frames_up=1), get_jit_ast(fn), fn)
+ # NOTE: we need to traverse two frames here because the meta-class frame
+ # for ScriptModule will be present, as opposed to invoking @script on a
+ # a function or invoking define() on a CompilationUnit.
+ # The stack will look like:
+ #
+ # 0. createResolutionCallback()
+ # 1. script_method()
+ # 2. ScriptModule metaclass frame
+ # 3. Surrounding scope
+ #
+ # createResolutionCallback internally adds 1 to get us to the scope of this
+ # function (the calling function). Adding 2 gets us to the proper surrounding scope.
+ return ScriptMethodStub(createResolutionCallback(frames_up=2), get_jit_ast(fn), fn)
# These OrderedDictWrapper classes replace the actual OrderedDicts in
@@ -610,6 +622,14 @@
return sorted(Module.__dir__(self) + self._method_names())
def define(self, lang):
+ # We use frames_up=1 to get to the proper surrounding scope. The stack
+ # will look like:
+ # 0. createResolutionCallback
+ # 1. define()
+ # 2. surrounding scope.
+ #
+ # createResolutionCallback internally adds 1 to get us to our frame, then
+ # we add 1 to get to the proper surrounding scope.
rcb = createResolutionCallback(frames_up=1)
self._define(lang, rcb, True)