Use real argument names for Python functions (#29300)
Summary:
This hooks up `inspect` so that Python functions get their parameters
names attached instead of naming them `0, 1, 2, ...`. This also fixes
issue #28537 where `ignore` functions were improperly typing `self`.
](https://our.intern.facebook.com/intern/diff/19256434/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29300
Pulled By: driazati
Differential Revision: D19256434
fbshipit-source-id: 6a1fe7bd0afab708b8439517798955d0abfeb44c
diff --git a/test/test_jit.py b/test/test_jit.py
index e2b4be0..db77164 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10526,7 +10526,7 @@
fn = torch.jit.ignore(fn)
with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
- r" '0' but instead found type 'Tuple\[Tensor,"):
+ r" 'x' but instead found type 'Tuple\[Tensor,"):
@torch.jit.script
def bad_fn(x):
x, y = fn((x, x), x, x)
@@ -10605,11 +10605,11 @@
with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"):
ModuleTooMany()
- with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided"):
+ with self.assertRaisesRegex(RuntimeError, "Argument y not provided"):
ModuleTooFew()
with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
ModuleTooManyAssign()
- with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided."):
+ with self.assertRaisesRegex(RuntimeError, "Argument y not provided."):
ModuleDefault()
def test_script_define_order(self):
diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py
index 94ab51e..f096489 100644
--- a/test/test_jit_py3.py
+++ b/test/test_jit_py3.py
@@ -1,7 +1,7 @@
from common_utils import run_tests
from jit_utils import JitTestCase
from torch.testing import FileCheck
-from typing import NamedTuple, List, Optional, Any
+from typing import NamedTuple, List, Optional, Any, Dict
from jit.test_module_interface import TestModuleInterface # noqa: F401
import unittest
import sys
@@ -89,6 +89,53 @@
self.assertEqual(out.sequence_features, [3.0])
self.assertEqual(out.time_since_first, 3.0)
+ def test_ignore_with_types(self):
+ @torch.jit.ignore
+ def fn(x: Dict[str, Optional[torch.Tensor]]):
+ return x + 10
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+
+ def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
+ self.dropout_modality(in_batch)
+ fn(in_batch)
+ return torch.tensor(1)
+
+ @torch.jit.ignore
+ def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
+ return in_batch
+
+ sm = torch.jit.script(M())
+ FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
+
+ def test_python_callable(self):
+ class MyPythonClass(object):
+ @torch.jit.ignore
+ def __call__(self, *args) -> str:
+ return str(type(args[0]))
+
+ the_class = MyPythonClass()
+ @torch.jit.script
+ def fn(x):
+ return the_class(x)
+
+ # This doesn't involve the string frontend, so don't use checkScript
+ x = torch.ones(2)
+ self.assertEqual(fn(x), the_class(x))
+
+ def test_bad_types(self):
+ @torch.jit.ignore
+ def fn(my_arg):
+ return my_arg + 10
+
+ with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
+ @torch.jit.script
+ def other_fn(x):
+ return fn('2')
+
+
def test_named_tuple_slice_unpack(self):
class MyCoolNamedTuple(NamedTuple):
a : int
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 0b3d6dd..c17b167 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -32,54 +32,37 @@
const size_t n_binders,
const SourceRange& loc) {
auto annotations = py::module::import("torch.jit.annotations");
- const auto fn_to_get_signature =
- moduleSelf_ ? py::getattr(self, "original_fn") : self;
+ const auto callable = moduleSelf_ ? py::getattr(self, "original_fn") : self;
+
+ // Make sure the function is not a class instantiation (e.g. `Exception()`)
+ annotations.attr("check_fn")(callable, loc);
+ auto is_vararg = py::cast<bool>(annotations.attr("is_vararg")(callable));
+
auto signature = annotations.attr("get_signature")(
- fn_to_get_signature, rcb ? *rcb : py::none(), loc);
+ callable, rcb ? *rcb : py::none(), loc, bool(moduleSelf_));
std::vector<Argument> args, rets;
+ auto py_param_names = annotations.attr("get_param_names")(callable, n_args);
+ auto param_names = py::cast<std::vector<std::string>>(py_param_names);
+ auto names_it = param_names.begin();
if (moduleSelf_) {
- args.push_back(Argument("self", moduleSelf_->type(), {}, {}, false));
+ // If there is a `self` parameter on the callable, skip it on the names list
+ args.emplace_back(Argument(*names_it, moduleSelf_->type(), {}, {}, false));
+ ++names_it;
}
- // We may mutate this if we can determine the number of args from Python
- // introspection.
- size_t actual_n_args = moduleSelf_ ? n_args + 1 : n_args;
- if (!signature.is_none()) {
- std::vector<TypePtr> arg_types;
- TypePtr ret_type;
- std::tie(arg_types, ret_type) =
- py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
- args.reserve(arg_types.size());
- size_t idx = 0; // Fake argument names by putting in the index
- for (auto& arg_type : arg_types) {
- args.push_back(
- Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
+ if (signature.is_none()) {
+ // No type signature was provided on the callable, so make a default
+ // signature where each argument is typed as a Tensor
+ for (; names_it != param_names.end(); ++names_it) {
+ args.emplace_back(Argument(
+ /*name=*/*names_it,
+ /*type=*/TensorType::get(),
+ /*N=*/c10::nullopt,
+ /*default_value=*/c10::nullopt,
+ /*kwarg_only=*/false));
}
- rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
- } else {
- // Create a default signature using what information we have
- // First see if we can introspect the number of function parameters
- // irrespective of the presence of explicit type annotations
- auto num_params =
- annotations.attr("get_num_params")(fn_to_get_signature, loc);
- if (!num_params.is_none()) {
- // Return a signature with the correct number of params according to the
- // Python function. The error handling in call() will catch any mismatch
- // later.
- actual_n_args = py::cast<size_t>(num_params);
- if (moduleSelf_) {
- TORCH_INTERNAL_ASSERT(actual_n_args > 0);
- --actual_n_args;
- }
- }
- // Construct the default signature: all arguments and returns will be
- // DynamicType
- args.reserve(actual_n_args);
- for (size_t i = 0; i < actual_n_args; ++i) {
- args.push_back(
- Argument(std::to_string(i), TensorType::get(), {}, {}, false));
- }
+ // Use as many outputs as are requested to make the return type
TypePtr ret_type = TensorType::get();
if (n_binders == 0) {
ret_type = NoneType::get();
@@ -87,16 +70,38 @@
std::vector<TypePtr> tuple_values(n_binders, ret_type);
ret_type = TupleType::create(std::move(tuple_values));
}
- rets.push_back(Argument("0", ret_type, {}, {}, false));
+ rets.emplace_back(Argument("0", ret_type, {}, {}, false));
+ } else {
+ // Use the provided type signature
+ std::vector<TypePtr> arg_types;
+ TypePtr ret_type;
+ std::tie(arg_types, ret_type) =
+ py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
+
+ // arg_types does not include self but param_names does, so adjust for that
+ // if needed
+ TORCH_INTERNAL_ASSERT(arg_types.size() == param_names.size() - (moduleSelf_ ? 1 : 0));
+
+ auto types_it = arg_types.begin();
+ for (; types_it != arg_types.end(); ++types_it, ++names_it) {
+ args.push_back(Argument(
+ /*name=*/*names_it,
+ /*type=*/std::move(*types_it),
+ /*N=*/c10::nullopt,
+ /*default_value=*/c10::nullopt,
+ /*kwarg_only=*/false));
+ }
+ rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
}
- std::string name("");
- // Use the qualified name if possible
+
+ std::string name;
if (py::hasattr(self, "__qualname__")) {
+ // Use the qualified name if possible
name = py::str(py::getattr(self, "__qualname__"));
} else if (py::hasattr(self, "__name__")) {
name = py::str(py::getattr(self, "__name__"));
}
- return FunctionSchema("", "", std::move(args), std::move(rets));
+ return FunctionSchema(name, "", std::move(args), std::move(rets), is_vararg);
}
std::shared_ptr<SugaredValue> PythonValue::call(
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index a3e8fb5..9084590 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1990,7 +1990,7 @@
def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
overload_decl = torch.jit.get_jit_def(overload_fn).decl()
- overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None)
+ overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None, inspect.ismethod(overload_fn))
impl_ast = torch.jit.get_jit_def(impl_fn)
overload_defaults = get_default_args(overload_fn)
implementation_defaults = get_default_args(impl_fn)
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index ce04975..733614b 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -425,7 +425,7 @@
return overload_name_mappings
def _check_no_signature(func):
- signature = torch.jit.annotations.get_signature(func, None, None)
+ signature = torch.jit.annotations.get_signature(func, None, None, inspect.ismethod(func))
if signature is None:
qual_name = torch.jit._qualified_name(func)
raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index 89e516d..ed18cf9 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -11,7 +11,7 @@
DeviceObjType
from textwrap import dedent
-from torch._six import builtins
+from torch._six import builtins, PY2
from torch._utils_internal import get_source_lines_and_file
@@ -51,50 +51,92 @@
return self.rcb(name)
return getattr(builtins, name, None)
-def get_signature(fn, rcb, loc):
+def get_signature(fn, rcb, loc, is_method):
# Python 3.5 adds support for the nice annotation syntax, so try that first.
+ signature = None
if PY35:
- sig = try_real_annotations(fn)
- if sig is not None:
- return sig
+ signature = try_real_annotations(fn)
+ if signature is not None and is_method:
+ # If this is a method, then the signaure will include a type for
+ # `self`, but type comments do not contain a `self`. So strip it
+ # away here so everything is consistent (`inspect.ismethod` does
+ # not work here since `fn` is unbound at this point)
+ param_types, return_type = signature
+ param_types = param_types[1:]
+ signature = (param_types, return_type)
- type_line, source = None, None
- try:
- source = dedent(''.join(get_source_lines_and_file(fn)[0]))
- type_line = get_type_line(source)
- except TypeError:
- pass
- # This might happen both because we failed to get the source of fn, or
- # because it didn't have any annotations.
- if type_line is None:
- return None
+ if signature is None:
+ type_line, source = None, None
+ try:
+ source = dedent(''.join(get_source_lines_and_file(fn)[0]))
+ type_line = get_type_line(source)
+ except TypeError:
+ pass
+ # This might happen both because we failed to get the source of fn, or
+ # because it didn't have any annotations.
+ if type_line is not None:
+ signature = parse_type_line(type_line, rcb, loc)
- return parse_type_line(type_line, rcb, loc)
+ return signature
-# This is essentially a weaker form of get_signature(), where we don't care if
-# we have the types, we just care that we can figure out how many parameters
-# a function takes.
-def get_num_params(fn, loc):
+def is_function_or_method(the_callable):
+ # A stricter version of `inspect.isroutine` that does not pass for built-in
+ # functions
+ return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
+
+
+def is_vararg(the_callable):
+ if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'): # noqa: B004
+ # If `the_callable` is a class, de-sugar the call so we can still get
+ # the signature
+ the_callable = the_callable.__call__
+
+ if is_function_or_method(the_callable):
+ if PY2:
+ # [inspect args]
+ # `inspect.getfullargspec` is not available in Python 2 but
+ # `inspect.getargspec` is deprecated in Python 3, so we have to
+ # switch over them
+ return inspect.getargspec(the_callable).varargs is not None
+ else:
+ return inspect.getfullargspec(the_callable).varargs is not None
+ else:
+ return False
+
+
+def get_param_names(fn, n_args):
+ if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
+ # De-sugar calls to classes
+ fn = fn.__call__
+
+ if is_function_or_method(fn):
+ if PY2:
+ # see [inspect args]
+ return inspect.getargspec(fn).args
+ else:
+ return inspect.getfullargspec(fn).args
+ else:
+ # The `fn` was not a method or function (maybe a class with a __call__
+ # method, so use a default param name list)
+ return [str(i) for i in range(n_args)]
+
+
+def check_fn(fn, loc):
+ # Make sure the function definition is not a class instantiation
try:
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
except (TypeError, IOError):
- return None
+ return
if source is None:
- return None
+ return
+
py_ast = ast.parse(source)
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
raise torch.jit.frontend.FrontendError(
loc, "Cannot instantiate class '{}' in a script function".format(py_ast.body[0].name))
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
- py_def = py_ast.body[0]
- if py_def.args.vararg is not None:
- return None
- elif hasattr(py_def.args, 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0:
- return None
- else:
- return len(py_def.args.args)
def parse_type_line(type_line, rcb, loc):
@@ -279,7 +321,8 @@
# TODO: Consider not exporting these during wildcard import (reserve
# that for the types; for idiomatic typing code.)
'get_signature',
- 'get_num_params',
+ 'check_fn',
+ 'get_param_names',
'parse_type_line',
'get_type_line',
'split_type_line',