preserve ignored function return value type (#25262)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25262
Preserve the type of ignore'd functions on serialization. Currently we first compile an ignore'd function with it's annotated type when first compiling, but do not preserve it. This is important for being able to compile models with not-yet-supported features in JIT.
```
torch.jit.ignore
def unsupported(x):
return x
def foo():
if not torch.jit._is_scripting():
return torch.linear(...)
else:
return unsupported(...)
```
Test Plan: Imported from OSS
Reviewed By: driazati
Differential Revision: D17199043
Pulled By: eellison
fbshipit-source-id: 1196fd94c207b9fbee1087e4b2ef7d4656a6647f
diff --git a/docs/source/jit.rst b/docs/source/jit.rst
index 7d2b7e1..72de09f 100644
--- a/docs/source/jit.rst
+++ b/docs/source/jit.rst
@@ -202,8 +202,8 @@
The :func:`@torch.jit.ignore <torch.jit.ignore>` annotation's behavior changes in
PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function
or method callable from code that is exported. To get this functionality back,
- use ``@torch.jit.ignore(drop_on_export=True)``. ``@torch.jit.ignore`` is now equivalent
- to ``@torch.jit.ignore(drop_on_export=False)``. See :func:`@torch.jit.ignore <torch.jit.ignore>`
+ use ``@torch.jit.ignore(drop=True)``. ``@torch.jit.ignore`` is now equivalent
+ to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore <torch.jit.ignore>`
for details.
When passed to the :func:`torch.jit.script <torch.jit.script>` function, a ``torch.nn.Module``\'s data is
diff --git a/test/test_jit.py b/test/test_jit.py
index e1cb821..47f2528 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -14003,37 +14003,70 @@
out = m(torch.ones(5, 5, 5).cuda())
self.assertTrue(out[0].is_cuda)
-
def test_ignore_decorator(self):
- class M(torch.jit.ScriptModule):
- def __init__(self):
- super(M, self).__init__()
- tensor = torch.zeros(1, requires_grad=False)
- self.register_buffer('some_state', torch.nn.Parameter(tensor))
+ with warnings.catch_warnings(record=True) as warns:
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ tensor = torch.zeros(1, requires_grad=False)
+ self.register_buffer('some_state', torch.nn.Parameter(tensor))
- @torch.jit.script_method
- def forward(self, x):
- self.ignored_code(x)
- return x
+ @torch.jit.script_method
+ def forward(self, x):
+ self.ignored_code(x)
+ return x
- @torch.jit.ignore(drop_on_export=True)
- def ignored_code(self, x):
- self.some_state = torch.tensor((100,))
+ @torch.jit.ignore(drop_on_export=True)
+ def ignored_code(self, x):
+ self.some_state = torch.tensor((100,))
+
+ FileCheck().check("TorchScript will now drop the drop call on compilation.").run(str(warns[0]))
# Assert ignored code is run
m = M()
- self.assertEqual(m.some_state, torch.zeros(1))
- m(torch.ones(1))
- self.assertEqual(m.some_state, torch.zeros(1) + 100)
m2 = self.getExportImportCopy(m)
pp = str(m2.forward.code)
- self.assertIn('IgnoredPythonOp', pp)
self.assertNotIn('ignored_code', pp)
- with self.assertRaisesRegex(torch.jit.Error, "This Python function is annotated to be ignored"):
+ with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
m2.forward(torch.ones(1))
+ def test_ignored_as_value(self):
+ class Model(nn.Module):
+ def __init__(self):
+ super(Model, self).__init__()
+
+ @torch.jit.unused
+ def tuple_ignored(self, x):
+ # type: (Tensor) -> Tuple[Tensor, Tensor]
+ return x, x
+
+ @torch.jit.unused
+ def single_val_ignored(self, x, y):
+ # type: (Tensor, Tensor) -> Tensor
+ return x
+
+ def forward(self, x, use_ignore_path):
+ # type: (Tensor, bool) -> Tuple[Tensor, Tensor]
+ if False:
+ return self.tuple_ignored(x)
+ if use_ignore_path:
+ return self.single_val_ignored(x, x), self.single_val_ignored(x, x)
+ return x, x
+
+ original = Model()
+ scripted = torch.jit.script(original)
+ self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5)))
+
+ buffer = io.BytesIO()
+ torch.jit.save(scripted, buffer)
+ buffer.seek(0)
+ loaded = torch.jit.load(buffer)
+
+ with self.assertRaisesRegex(torch._C.JITException, "annotated to be ignored and cannot be run"):
+ loaded(torch.tensor(.5), True)
+
def test_module_error(self):
class MyModule(torch.nn.Module):
def __init__(self):
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 7327689..80193cf 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -6,6 +6,7 @@
import inspect
import weakref
+import warnings
import torch._C
from torch._six import builtins
@@ -168,7 +169,7 @@
Used to denote the behavior of a function in TorchScript. See export() and
ignore() for details.
"""
- IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)"
+ UNUSED = "unused (ignored and replaced with raising of an exception)"
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
EXPORT = "export (compile this function even if nothing calls it)"
DEFAULT = "default (compile if called from a exported function / forward)"
@@ -219,23 +220,52 @@
return fn
-def ignore(drop_on_export=False):
+def unused(fn):
"""
This decorator indicates to the compiler that a function or method should
- be ignored and left as a Python function.
+ be ignored and replaced with the raising of an exception. This allows you
+ to leave code in your model that is not yet TorchScript compatible and still
+ export your model.
- Arguments:
+ Example (using ``@torch.jit.unused`` on a method)::
- drop_on_export (bool): When ``False``, calls to this function will
- that will be run with ``example_inputs``.
- arguments and returns to ``func`` must be tensors
- or (possibly nested) tuples that
- contain tensors. When ``True``, any calls to
- this function from other TorchScript code will be replaced
- with a `raise` when the model is saved.
- This allows you to leave code in your TorchScript model that is only ever
- run when the Python interpreter is present, but not run after you save
- and load your model.
+ import torch
+ import torch.nn as nn
+
+ class MyModule(nn.Module):
+ def __init__(self, use_memory_efficent):
+ super(MyModule, self).__init__()
+ self.use_memory_efficent = use_memory_efficent
+
+ @torch.jit.unused
+ def memory_efficient(self, x):
+ import pdb
+ pdb.set_trace()
+ return x + 10
+
+ def forward(self, x):
+ # Use not-yet-scriptable memory efficient mode
+ if self.use_memory_efficient:
+ return self.memory_efficient(x)
+ else:
+ return x + 10
+
+ m = torch.jit.script(MyModule(use_memory_efficent=False))
+ m.save("m.pt")
+
+ m = torch.jit.script(MyModule(use_memory_efficient=True))
+ # exception raised
+ m(torch.rand(100))
+ """
+ fn._torchscript_modifier = FunctionModifiers.UNUSED
+ return fn
+
+def ignore(drop=False, **kwargs):
+ """
+ This decorator indicates to the compiler that a function or method should
+ be ignored and left as a Python function. This allows you to leave code in
+ your model that is not yet TorchScript compatible. Models with ignored
+ functions cannot be exported; use torch.jit.unused instead.
Example (using ``@torch.jit.ignore`` on a method)::
@@ -261,7 +291,7 @@
# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
- Example (using ``@torch.jit.ignore(drop_on_export=True)`` on a method):
+ Example (using ``@torch.jit.ignore(drop=True)`` on a method):
.. testcode::
@@ -269,7 +299,7 @@
import torch.nn as nn
class MyModule(nn.Module):
- @torch.jit.ignore(drop_on_export=True)
+ @torch.jit.ignore(drop=True)
def training_method(self, x):
import pdb
pdb.set_trace()
@@ -290,24 +320,37 @@
import os
os.remove('m.pt')
"""
- if callable(drop_on_export):
- # used without any args, so drop_on_export is actually a function
+
+ if callable(drop):
+ # used without any args, so drop is actually a function
# @torch.jit.ignore
# def fn(...):
- fn = drop_on_export
+ fn = drop
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
- if isinstance(drop_on_export, bool):
- def decorator(fn):
- if drop_on_export:
- fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP
- else:
- fn._torchscript_modifier = FunctionModifiers.IGNORE
- return fn
- return decorator
- raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
- "a function but got {}".format(drop_on_export))
+ if not isinstance(drop, bool):
+ raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
+ "a function but got {}".format(drop))
+
+ # for backwards compat
+ drop_on_export = kwargs.pop("drop_on_export", None)
+ if drop_on_export:
+ warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the drop "
+ "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
+
+ drop = drop_on_export
+ elif drop:
+ warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the drop "
+ "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
+
+ def decorator(fn):
+ if drop:
+ fn._torchscript_modifier = FunctionModifiers.UNUSED
+ else:
+ fn._torchscript_modifier = FunctionModifiers.IGNORE
+ return fn
+ return decorator
def module_has_exports(mod):
@@ -318,16 +361,16 @@
return True
return False
-def should_drop_on_export(fn):
+def should_drop(fn):
attr = get_torchscript_modifier(fn)
if attr is None:
return False
- return attr is FunctionModifiers.IGNORE_AND_DROP
+ return attr is FunctionModifiers.UNUSED
def is_ignored_fn(fn):
mod = get_torchscript_modifier(fn)
- return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE
+ return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
def get_torchscript_modifier(fn):
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 83f91e4..6a962ee 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1293,10 +1293,6 @@
struct TORCH_API PythonOp : public Node {
using Node::Node;
- // should this Python function be skipped over when exported (i.e. for
- // debugging functions that only run in Python)
- bool ignore_on_export = false;
-
virtual std::string name() const = 0;
virtual void writeScalars(std::ostream& out) const = 0;
void cloneFrom(Node* other_) override = 0;
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 6fc71a0..99f2e1b 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -950,22 +950,17 @@
switch (node->kind()) {
case prim::PythonOp: {
auto value = static_cast<const PythonOp*>(node);
- if (enforce_importable_ && !value->ignore_on_export) {
+ if (enforce_importable_) {
throw script::ErrorReport(node->sourceRange())
<< "Could not export Python function call '" << value->name()
<< "'. Remove calls to Python functions before export. "
<< "Did you forget add @script or @script_method annotation? "
<< "If this is a nn.ModuleList, add it to __constants__";
}
-
- if (value->ignore_on_export) {
- stmt << "ops.prim.IgnoredPythonOp";
- } else {
- std::stringstream scalars_stream;
- stmt << "^" << value->name();
- value->writeScalars(scalars_stream);
- stmt << scalars_stream.str();
- }
+ std::stringstream scalars_stream;
+ stmt << "^" << value->name();
+ value->writeScalars(scalars_stream);
+ stmt << scalars_stream.str();
printValueList(stmt, node->inputs(), "(", ")");
} break;
case prim::Uninitialized: {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 3f20c93..201ec0f 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -139,7 +139,6 @@
this->cconv = other->cconv;
Py_INCREF(other->pyobj.get());
this->pyobj = THPObjectPtr(other->pyobj.get());
- this->ignore_on_export = other->ignore_on_export;
for (auto& sa : other->scalar_args) {
Py_INCREF(sa.get());
this->scalar_args.emplace_back(sa.get());
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 246497b..fb5bb32 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -108,18 +108,28 @@
if (!matched_schema)
throw ErrorReport(loc) << failure_messages.str();
+ // If if a function is marked as dropped,
+ // we throw an exception if it is invoked.
+ if (py::cast<bool>(py::module::import("torch._jit_internal")
+ .attr("should_drop")(self))) {
+ auto g = m.graph();
+ auto err_msg = insertConstant(
+ *g,
+ IValue(
+ "This Python function is annotated to be ignored and cannot be run"));
+ g->insert(prim::RaiseException, {err_msg}, {}, loc);
+ return std::make_shared<SimpleValue>(
+ g->insertNode(
+ g->createUninitialized(matched_schema->return_types.at(0)))
+ ->output());
+ }
+
// Release the function object so we can wrap it in a PythonOp
py::object func = self;
std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(
m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
- // Mark if function is ignored on export
- if (py::cast<bool>(py::module::import("torch._jit_internal")
- .attr("should_drop_on_export")(self))) {
- auto python_op = static_cast<PythonOp*>(new_node);
- python_op->ignore_on_export = true;
- }
new_node->setSourceRange(loc);
for (auto& i : matched_schema->inputs)
new_node->addInput(i);
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 4edef27..f651a5b 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -32,7 +32,7 @@
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method # noqa: F401
-from torch._jit_internal import ignore, export # noqa: F401
+from torch._jit_internal import ignore, export, unused # noqa: F401
if sys.version_info[0] > 2:
import pathlib