[jit] Allow instance overrides of ignored methods (#61076)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61076
Previously we would always retrieve ignored methods from the
type, which doesn't work when the user has overriden the ignored method
for a specific instance.
This PR changes things up so we retrieve the ignored method as a bound
method from the object being scripted, unwrap it, then re-bind it to the
scriptmodule.
Test Plan: Imported from OSS
Differential Revision: D29504421
Pulled By: suo
fbshipit-source-id: 14649863ea69a8d2180dd2c4341ec9a826039de1
diff --git a/test/jit/test_recursive_script.py b/test/jit/test_recursive_script.py
index 7baf955..0f04b0b 100644
--- a/test/jit/test_recursive_script.py
+++ b/test/jit/test_recursive_script.py
@@ -1,5 +1,6 @@
import os
import sys
+import types
import typing
import typing_extensions
from typing import List, Dict, Optional, Tuple
@@ -729,3 +730,23 @@
self.checkModule(mod, (torch.rand(2, 2),))
mod.foo = None
self.checkModule(mod, (torch.rand(2, 2),))
+
+ def test_override_instance_method_ignore(self):
+ class M(torch.nn.Module):
+ @torch.jit.ignore
+ def i_am_ignored(self):
+ return "old"
+
+ m = M()
+
+ # Override the ignored method by binding a new method to this instance.
+ @torch.jit.ignore
+ def i_am_ignored(self):
+ return "new"
+
+ m.i_am_ignored = types.MethodType(i_am_ignored, m)
+ self.assertEqual(m.i_am_ignored(), "new")
+
+ # ScriptModule should correctly reflect the override.
+ s = torch.jit.script(m)
+ self.assertEqual(s.i_am_ignored(), "new")
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index d88a985..2df0fbf 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -499,7 +499,7 @@
continue
item = getattr(nn_module, name, None)
if inspect.ismethod(item) and _jit_internal.is_ignored_fn(item):
- unbound_function = getattr(type(nn_module), name)
+ unbound_function = getattr(nn_module, name).__func__
bound_method = unbound_function.__get__(script_module)
setattr(script_module, name, bound_method)
elif concrete_type.is_ignored_attribute(name):