Revert D17199043: [JIT] preserve ignored function return value type
Test Plan: revert-hammer
Differential Revision:
D17199043
Original commit changeset: 1196fd94c207
fbshipit-source-id: 49789ae1f128262bc40a9d5b0d2b7bfbbf0b7e1e
diff --git a/docs/source/jit.rst b/docs/source/jit.rst
index 72de09f..7d2b7e1 100644
--- a/docs/source/jit.rst
+++ b/docs/source/jit.rst
@@ -202,8 +202,8 @@
The :func:`@torch.jit.ignore <torch.jit.ignore>` annotation's behavior changes in
PyTorch 1.2. Before PyTorch 1.2 the @ignore decorator was used to make a function
or method callable from code that is exported. To get this functionality back,
- use ``@torch.jit.ignore(drop=True)``. ``@torch.jit.ignore`` is now equivalent
- to ``@torch.jit.ignore(drop=False)``. See :func:`@torch.jit.ignore <torch.jit.ignore>`
+ use ``@torch.jit.ignore(drop_on_export=True)``. ``@torch.jit.ignore`` is now equivalent
+ to ``@torch.jit.ignore(drop_on_export=False)``. See :func:`@torch.jit.ignore <torch.jit.ignore>`
for details.
When passed to the :func:`torch.jit.script <torch.jit.script>` function, a ``torch.nn.Module``\'s data is
diff --git a/test/test_jit.py b/test/test_jit.py
index 13e47dc..325e871 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -14007,70 +14007,37 @@
out = m(torch.ones(5, 5, 5).cuda())
self.assertTrue(out[0].is_cuda)
+
def test_ignore_decorator(self):
- with warnings.catch_warnings(record=True) as warns:
- class M(torch.jit.ScriptModule):
- def __init__(self):
- super(M, self).__init__()
- tensor = torch.zeros(1, requires_grad=False)
- self.register_buffer('some_state', torch.nn.Parameter(tensor))
+ class M(torch.jit.ScriptModule):
+ def __init__(self):
+ super(M, self).__init__()
+ tensor = torch.zeros(1, requires_grad=False)
+ self.register_buffer('some_state', torch.nn.Parameter(tensor))
- @torch.jit.script_method
- def forward(self, x):
- self.ignored_code(x)
- return x
+ @torch.jit.script_method
+ def forward(self, x):
+ self.ignored_code(x)
+ return x
- @torch.jit.ignore(drop_on_export=True)
- def ignored_code(self, x):
- self.some_state = torch.tensor((100,))
-
- FileCheck().check("TorchScript will now drop the drop call on compilation.").run(str(warns[0]))
+ @torch.jit.ignore(drop_on_export=True)
+ def ignored_code(self, x):
+ self.some_state = torch.tensor((100,))
# Assert ignored code is run
m = M()
+ self.assertEqual(m.some_state, torch.zeros(1))
+ m(torch.ones(1))
+ self.assertEqual(m.some_state, torch.zeros(1) + 100)
m2 = self.getExportImportCopy(m)
pp = str(m2.forward.code)
+ self.assertIn('IgnoredPythonOp', pp)
self.assertNotIn('ignored_code', pp)
- with self.assertRaisesRegex(torch.jit.Error, "annotated to be ignored and cannot be run"):
+ with self.assertRaisesRegex(torch.jit.Error, "This Python function is annotated to be ignored"):
m2.forward(torch.ones(1))
- def test_ignored_as_value(self):
- class Model(nn.Module):
- def __init__(self):
- super(Model, self).__init__()
-
- @torch.jit.unused
- def tuple_ignored(self, x):
- # type: (Tensor) -> Tuple[Tensor, Tensor]
- return x, x
-
- @torch.jit.unused
- def single_val_ignored(self, x, y):
- # type: (Tensor, Tensor) -> Tensor
- return x
-
- def forward(self, x, use_ignore_path):
- # type: (Tensor, bool) -> Tuple[Tensor, Tensor]
- if False:
- return self.tuple_ignored(x)
- if use_ignore_path:
- return self.single_val_ignored(x, x), self.single_val_ignored(x, x)
- return x, x
-
- original = Model()
- scripted = torch.jit.script(original)
- self.assertEqual(scripted(torch.tensor(.5), False), (torch.tensor(.5), torch.tensor(.5)))
-
- buffer = io.BytesIO()
- torch.jit.save(scripted, buffer)
- buffer.seek(0)
- loaded = torch.jit.load(buffer)
-
- with self.assertRaisesRegex(torch._C.JITException, "annotated to be ignored and cannot be run"):
- loaded(torch.tensor(.5), True)
-
def test_module_error(self):
class MyModule(torch.nn.Module):
def __init__(self):
diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py
index 80193cf..7327689 100644
--- a/torch/_jit_internal.py
+++ b/torch/_jit_internal.py
@@ -6,7 +6,6 @@
import inspect
import weakref
-import warnings
import torch._C
from torch._six import builtins
@@ -169,7 +168,7 @@
Used to denote the behavior of a function in TorchScript. See export() and
ignore() for details.
"""
- UNUSED = "unused (ignored and replaced with raising of an exception)"
+ IGNORE_AND_DROP = "ignore (leave as a call to Python, replace with a 'raise' on torch.jit.save)"
IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)"
EXPORT = "export (compile this function even if nothing calls it)"
DEFAULT = "default (compile if called from a exported function / forward)"
@@ -220,52 +219,23 @@
return fn
-def unused(fn):
+def ignore(drop_on_export=False):
"""
This decorator indicates to the compiler that a function or method should
- be ignored and replaced with the raising of an exception. This allows you
- to leave code in your model that is not yet TorchScript compatible and still
- export your model.
+ be ignored and left as a Python function.
- Example (using ``@torch.jit.unused`` on a method)::
+ Arguments:
- import torch
- import torch.nn as nn
-
- class MyModule(nn.Module):
- def __init__(self, use_memory_efficent):
- super(MyModule, self).__init__()
- self.use_memory_efficent = use_memory_efficent
-
- @torch.jit.unused
- def memory_efficient(self, x):
- import pdb
- pdb.set_trace()
- return x + 10
-
- def forward(self, x):
- # Use not-yet-scriptable memory efficient mode
- if self.use_memory_efficient:
- return self.memory_efficient(x)
- else:
- return x + 10
-
- m = torch.jit.script(MyModule(use_memory_efficent=False))
- m.save("m.pt")
-
- m = torch.jit.script(MyModule(use_memory_efficient=True))
- # exception raised
- m(torch.rand(100))
- """
- fn._torchscript_modifier = FunctionModifiers.UNUSED
- return fn
-
-def ignore(drop=False, **kwargs):
- """
- This decorator indicates to the compiler that a function or method should
- be ignored and left as a Python function. This allows you to leave code in
- your model that is not yet TorchScript compatible. Models with ignored
- functions cannot be exported; use torch.jit.unused instead.
+ drop_on_export (bool): When ``False``, calls to this function will
+ that will be run with ``example_inputs``.
+ arguments and returns to ``func`` must be tensors
+ or (possibly nested) tuples that
+ contain tensors. When ``True``, any calls to
+ this function from other TorchScript code will be replaced
+ with a `raise` when the model is saved.
+ This allows you to leave code in your TorchScript model that is only ever
+ run when the Python interpreter is present, but not run after you save
+ and load your model.
Example (using ``@torch.jit.ignore`` on a method)::
@@ -291,7 +261,7 @@
# Error! The call `debugger` cannot be saved since it calls into Python
m.save("m.pt")
- Example (using ``@torch.jit.ignore(drop=True)`` on a method):
+ Example (using ``@torch.jit.ignore(drop_on_export=True)`` on a method):
.. testcode::
@@ -299,7 +269,7 @@
import torch.nn as nn
class MyModule(nn.Module):
- @torch.jit.ignore(drop=True)
+ @torch.jit.ignore(drop_on_export=True)
def training_method(self, x):
import pdb
pdb.set_trace()
@@ -320,37 +290,24 @@
import os
os.remove('m.pt')
"""
-
- if callable(drop):
- # used without any args, so drop is actually a function
+ if callable(drop_on_export):
+ # used without any args, so drop_on_export is actually a function
# @torch.jit.ignore
# def fn(...):
- fn = drop
+ fn = drop_on_export
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
- if not isinstance(drop, bool):
- raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
- "a function but got {}".format(drop))
-
- # for backwards compat
- drop_on_export = kwargs.pop("drop_on_export", None)
- if drop_on_export:
- warnings.warn("ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the drop "
- "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
-
- drop = drop_on_export
- elif drop:
- warnings.warn("ignore(True) has been deprecated. TorchScript will now drop the drop "
- "call on compilation. Use torch.jit.unused now. {}", category=DeprecationWarning)
-
- def decorator(fn):
- if drop:
- fn._torchscript_modifier = FunctionModifiers.UNUSED
- else:
- fn._torchscript_modifier = FunctionModifiers.IGNORE
- return fn
- return decorator
+ if isinstance(drop_on_export, bool):
+ def decorator(fn):
+ if drop_on_export:
+ fn._torchscript_modifier = FunctionModifiers.IGNORE_AND_DROP
+ else:
+ fn._torchscript_modifier = FunctionModifiers.IGNORE
+ return fn
+ return decorator
+ raise RuntimeError("Argument to @torch.jit.ignore must be a bool or "
+ "a function but got {}".format(drop_on_export))
def module_has_exports(mod):
@@ -361,16 +318,16 @@
return True
return False
-def should_drop(fn):
+def should_drop_on_export(fn):
attr = get_torchscript_modifier(fn)
if attr is None:
return False
- return attr is FunctionModifiers.UNUSED
+ return attr is FunctionModifiers.IGNORE_AND_DROP
def is_ignored_fn(fn):
mod = get_torchscript_modifier(fn)
- return mod is FunctionModifiers.UNUSED or mod is FunctionModifiers.IGNORE
+ return mod is FunctionModifiers.IGNORE_AND_DROP or mod is FunctionModifiers.IGNORE
def get_torchscript_modifier(fn):
diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h
index 6a962ee..83f91e4 100644
--- a/torch/csrc/jit/ir.h
+++ b/torch/csrc/jit/ir.h
@@ -1293,6 +1293,10 @@
struct TORCH_API PythonOp : public Node {
using Node::Node;
+ // should this Python function be skipped over when exported (i.e. for
+ // debugging functions that only run in Python)
+ bool ignore_on_export = false;
+
virtual std::string name() const = 0;
virtual void writeScalars(std::ostream& out) const = 0;
void cloneFrom(Node* other_) override = 0;
diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp
index 99f2e1b..6fc71a0 100644
--- a/torch/csrc/jit/passes/python_print.cpp
+++ b/torch/csrc/jit/passes/python_print.cpp
@@ -950,17 +950,22 @@
switch (node->kind()) {
case prim::PythonOp: {
auto value = static_cast<const PythonOp*>(node);
- if (enforce_importable_) {
+ if (enforce_importable_ && !value->ignore_on_export) {
throw script::ErrorReport(node->sourceRange())
<< "Could not export Python function call '" << value->name()
<< "'. Remove calls to Python functions before export. "
<< "Did you forget add @script or @script_method annotation? "
<< "If this is a nn.ModuleList, add it to __constants__";
}
- std::stringstream scalars_stream;
- stmt << "^" << value->name();
- value->writeScalars(scalars_stream);
- stmt << scalars_stream.str();
+
+ if (value->ignore_on_export) {
+ stmt << "ops.prim.IgnoredPythonOp";
+ } else {
+ std::stringstream scalars_stream;
+ stmt << "^" << value->name();
+ value->writeScalars(scalars_stream);
+ stmt << scalars_stream.str();
+ }
printValueList(stmt, node->inputs(), "(", ")");
} break;
case prim::Uninitialized: {
diff --git a/torch/csrc/jit/python_ir.cpp b/torch/csrc/jit/python_ir.cpp
index 201ec0f..3f20c93 100644
--- a/torch/csrc/jit/python_ir.cpp
+++ b/torch/csrc/jit/python_ir.cpp
@@ -139,6 +139,7 @@
this->cconv = other->cconv;
Py_INCREF(other->pyobj.get());
this->pyobj = THPObjectPtr(other->pyobj.get());
+ this->ignore_on_export = other->ignore_on_export;
for (auto& sa : other->scalar_args) {
Py_INCREF(sa.get());
this->scalar_args.emplace_back(sa.get());
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index fb5bb32..246497b 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -108,28 +108,18 @@
if (!matched_schema)
throw ErrorReport(loc) << failure_messages.str();
- // If if a function is marked as dropped,
- // we throw an exception if it is invoked.
- if (py::cast<bool>(py::module::import("torch._jit_internal")
- .attr("should_drop")(self))) {
- auto g = m.graph();
- auto err_msg = insertConstant(
- *g,
- IValue(
- "This Python function is annotated to be ignored and cannot be run"));
- g->insert(prim::RaiseException, {err_msg}, {}, loc);
- return std::make_shared<SimpleValue>(
- g->insertNode(
- g->createUninitialized(matched_schema->return_types.at(0)))
- ->output());
- }
-
// Release the function object so we can wrap it in a PythonOp
py::object func = self;
std::string cconv(inputs.size(), 'd');
Node* new_node = m.graph()->insertNode(
m.graph()->createPythonOp(THPObjectPtr(func.release().ptr()), cconv, {}));
+ // Mark if function is ignored on export
+ if (py::cast<bool>(py::module::import("torch._jit_internal")
+ .attr("should_drop_on_export")(self))) {
+ auto python_op = static_cast<PythonOp*>(new_node);
+ python_op->ignore_on_export = true;
+ }
new_node->setSourceRange(loc);
for (auto& i : matched_schema->inputs)
new_node->addInput(i);
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index f651a5b..4edef27 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -32,7 +32,7 @@
# These are imported so users can access them from the `torch.jit` module
from torch._jit_internal import Final, _overload, _overload_method # noqa: F401
-from torch._jit_internal import ignore, export, unused # noqa: F401
+from torch._jit_internal import ignore, export # noqa: F401
if sys.version_info[0] > 2:
import pathlib