Use real argument names for Python functions (#29300)

Summary:
This hooks up `inspect` so that Python functions get their parameters
names attached instead of naming them `0, 1, 2, ...`. This also fixes
issue #28537 where `ignore` functions were improperly typing `self`.
](https://our.intern.facebook.com/intern/diff/19256434/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29300

Pulled By: driazati

Differential Revision: D19256434

fbshipit-source-id: 6a1fe7bd0afab708b8439517798955d0abfeb44c
diff --git a/test/test_jit.py b/test/test_jit.py
index e2b4be0..db77164 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -10526,7 +10526,7 @@
             fn = torch.jit.ignore(fn)
 
             with self.assertRaisesRegex(RuntimeError, r"Expected a value of type 'Tensor' for argument"
-                                                      r" '0' but instead found type 'Tuple\[Tensor,"):
+                                                      r" 'x' but instead found type 'Tuple\[Tensor,"):
                 @torch.jit.script
                 def bad_fn(x):
                     x, y = fn((x, x), x, x)
@@ -10605,11 +10605,11 @@
 
         with self.assertRaisesRegex(RuntimeError, "Expected at most 2 arguments but found 3"):
             ModuleTooMany()
-        with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided"):
+        with self.assertRaisesRegex(RuntimeError, "Argument y not provided"):
             ModuleTooFew()
         with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
             ModuleTooManyAssign()
-        with self.assertRaisesRegex(RuntimeError, "Argument 1 not provided."):
+        with self.assertRaisesRegex(RuntimeError, "Argument y not provided."):
             ModuleDefault()
 
     def test_script_define_order(self):
diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py
index 94ab51e..f096489 100644
--- a/test/test_jit_py3.py
+++ b/test/test_jit_py3.py
@@ -1,7 +1,7 @@
 from common_utils import run_tests
 from jit_utils import JitTestCase
 from torch.testing import FileCheck
-from typing import NamedTuple, List, Optional, Any
+from typing import NamedTuple, List, Optional, Any, Dict
 from jit.test_module_interface import TestModuleInterface  # noqa: F401
 import unittest
 import sys
@@ -89,6 +89,53 @@
         self.assertEqual(out.sequence_features, [3.0])
         self.assertEqual(out.time_since_first, 3.0)
 
+    def test_ignore_with_types(self):
+        @torch.jit.ignore
+        def fn(x: Dict[str, Optional[torch.Tensor]]):
+            return x + 10
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+
+            def forward(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> torch.Tensor:
+                self.dropout_modality(in_batch)
+                fn(in_batch)
+                return torch.tensor(1)
+
+            @torch.jit.ignore
+            def dropout_modality(self, in_batch: Dict[str, Optional[torch.Tensor]]) -> Dict[str, Optional[torch.Tensor]]:
+                return in_batch
+
+        sm = torch.jit.script(M())
+        FileCheck().check("dropout_modality").check("in_batch").run(str(sm.graph))
+
+    def test_python_callable(self):
+        class MyPythonClass(object):
+            @torch.jit.ignore
+            def __call__(self, *args) -> str:
+                return str(type(args[0]))
+
+        the_class = MyPythonClass()
+        @torch.jit.script
+        def fn(x):
+            return the_class(x)
+
+        # This doesn't involve the string frontend, so don't use checkScript
+        x = torch.ones(2)
+        self.assertEqual(fn(x), the_class(x))
+
+    def test_bad_types(self):
+        @torch.jit.ignore
+        def fn(my_arg):
+            return my_arg + 10
+
+        with self.assertRaisesRegex(RuntimeError, "argument 'my_arg'"):
+            @torch.jit.script
+            def other_fn(x):
+                return fn('2')
+
+
     def test_named_tuple_slice_unpack(self):
         class MyCoolNamedTuple(NamedTuple):
             a : int
diff --git a/torch/csrc/jit/script/python_sugared_value.cpp b/torch/csrc/jit/script/python_sugared_value.cpp
index 0b3d6dd..c17b167 100644
--- a/torch/csrc/jit/script/python_sugared_value.cpp
+++ b/torch/csrc/jit/script/python_sugared_value.cpp
@@ -32,54 +32,37 @@
     const size_t n_binders,
     const SourceRange& loc) {
   auto annotations = py::module::import("torch.jit.annotations");
-  const auto fn_to_get_signature =
-      moduleSelf_ ? py::getattr(self, "original_fn") : self;
+  const auto callable = moduleSelf_ ? py::getattr(self, "original_fn") : self;
+
+  // Make sure the function is not a class instantiation (e.g. `Exception()`)
+  annotations.attr("check_fn")(callable, loc);
+  auto is_vararg = py::cast<bool>(annotations.attr("is_vararg")(callable));
+
   auto signature = annotations.attr("get_signature")(
-      fn_to_get_signature, rcb ? *rcb : py::none(), loc);
+      callable, rcb ? *rcb : py::none(), loc, bool(moduleSelf_));
   std::vector<Argument> args, rets;
 
+  auto py_param_names = annotations.attr("get_param_names")(callable, n_args);
+  auto param_names = py::cast<std::vector<std::string>>(py_param_names);
+  auto names_it = param_names.begin();
   if (moduleSelf_) {
-    args.push_back(Argument("self", moduleSelf_->type(), {}, {}, false));
+    // If there is a `self` parameter on the callable, skip it on the names list
+    args.emplace_back(Argument(*names_it, moduleSelf_->type(), {}, {}, false));
+    ++names_it;
   }
-  // We may mutate this if we can determine the number of args from Python
-  // introspection.
-  size_t actual_n_args = moduleSelf_ ? n_args + 1 : n_args;
-  if (!signature.is_none()) {
-    std::vector<TypePtr> arg_types;
-    TypePtr ret_type;
-    std::tie(arg_types, ret_type) =
-        py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
-    args.reserve(arg_types.size());
-    size_t idx = 0; // Fake argument names by putting in the index
-    for (auto& arg_type : arg_types) {
-      args.push_back(
-          Argument(std::to_string(idx++), std::move(arg_type), {}, {}, false));
+  if (signature.is_none()) {
+    // No type signature was provided on the callable, so make a default
+    // signature where each argument is typed as a Tensor
+    for (; names_it != param_names.end(); ++names_it) {
+      args.emplace_back(Argument(
+          /*name=*/*names_it,
+          /*type=*/TensorType::get(),
+          /*N=*/c10::nullopt,
+          /*default_value=*/c10::nullopt,
+          /*kwarg_only=*/false));
     }
-    rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
-  } else {
-    // Create a default signature using what information we have
 
-    // First see if we can introspect the number of function parameters
-    // irrespective of the presence of explicit type annotations
-    auto num_params =
-        annotations.attr("get_num_params")(fn_to_get_signature, loc);
-    if (!num_params.is_none()) {
-      // Return a signature with the correct number of params according to the
-      // Python function. The error handling in call() will catch any mismatch
-      // later.
-      actual_n_args = py::cast<size_t>(num_params);
-      if (moduleSelf_) {
-        TORCH_INTERNAL_ASSERT(actual_n_args > 0);
-        --actual_n_args;
-      }
-    }
-    // Construct the default signature: all arguments and returns will be
-    // DynamicType
-    args.reserve(actual_n_args);
-    for (size_t i = 0; i < actual_n_args; ++i) {
-      args.push_back(
-          Argument(std::to_string(i), TensorType::get(), {}, {}, false));
-    }
+    // Use as many outputs as are requested to make the return type
     TypePtr ret_type = TensorType::get();
     if (n_binders == 0) {
       ret_type = NoneType::get();
@@ -87,16 +70,38 @@
       std::vector<TypePtr> tuple_values(n_binders, ret_type);
       ret_type = TupleType::create(std::move(tuple_values));
     }
-    rets.push_back(Argument("0", ret_type, {}, {}, false));
+    rets.emplace_back(Argument("0", ret_type, {}, {}, false));
+  } else {
+    // Use the provided type signature
+    std::vector<TypePtr> arg_types;
+    TypePtr ret_type;
+    std::tie(arg_types, ret_type) =
+        py::cast<std::pair<std::vector<TypePtr>, TypePtr>>(signature);
+
+    // arg_types does not include self but param_names does, so adjust for that
+    // if needed
+    TORCH_INTERNAL_ASSERT(arg_types.size() == param_names.size() - (moduleSelf_ ? 1 : 0));
+
+    auto types_it = arg_types.begin();
+    for (; types_it != arg_types.end(); ++types_it, ++names_it) {
+      args.push_back(Argument(
+          /*name=*/*names_it,
+          /*type=*/std::move(*types_it),
+          /*N=*/c10::nullopt,
+          /*default_value=*/c10::nullopt,
+          /*kwarg_only=*/false));
+    }
+    rets.push_back(Argument("0", std::move(ret_type), {}, {}, false));
   }
-  std::string name("");
-  // Use the qualified name if possible
+
+  std::string name;
   if (py::hasattr(self, "__qualname__")) {
+    // Use the qualified name if possible
     name = py::str(py::getattr(self, "__qualname__"));
   } else if (py::hasattr(self, "__name__")) {
     name = py::str(py::getattr(self, "__name__"));
   }
-  return FunctionSchema("", "", std::move(args), std::move(rets));
+  return FunctionSchema(name, "", std::move(args), std::move(rets), is_vararg);
 }
 
 std::shared_ptr<SugaredValue> PythonValue::call(
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index a3e8fb5..9084590 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -1990,7 +1990,7 @@
 
 def _compile_function_with_overload(overload_fn, qual_name, impl_fn):
     overload_decl = torch.jit.get_jit_def(overload_fn).decl()
-    overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None)
+    overload_signature = torch.jit.annotations.get_signature(overload_fn, None, None, inspect.ismethod(overload_fn))
     impl_ast = torch.jit.get_jit_def(impl_fn)
     overload_defaults = get_default_args(overload_fn)
     implementation_defaults = get_default_args(impl_fn)
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index ce04975..733614b 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -425,7 +425,7 @@
     return overload_name_mappings
 
 def _check_no_signature(func):
-    signature = torch.jit.annotations.get_signature(func, None, None)
+    signature = torch.jit.annotations.get_signature(func, None, None, inspect.ismethod(func))
     if signature is None:
         qual_name = torch.jit._qualified_name(func)
         raise RuntimeError("Must explicitly add type annotations to overloaded functions: {}".format(qual_name))
diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py
index 89e516d..ed18cf9 100644
--- a/torch/jit/annotations.py
+++ b/torch/jit/annotations.py
@@ -11,7 +11,7 @@
     DeviceObjType
 
 from textwrap import dedent
-from torch._six import builtins
+from torch._six import builtins, PY2
 from torch._utils_internal import get_source_lines_and_file
 
 
@@ -51,50 +51,92 @@
             return self.rcb(name)
         return getattr(builtins, name, None)
 
-def get_signature(fn, rcb, loc):
+def get_signature(fn, rcb, loc, is_method):
     # Python 3.5 adds support for the nice annotation syntax, so try that first.
+    signature = None
     if PY35:
-        sig = try_real_annotations(fn)
-        if sig is not None:
-            return sig
+        signature = try_real_annotations(fn)
+        if signature is not None and is_method:
+            # If this is a method, then the signaure will include a type for
+            # `self`, but type comments do not contain a `self`. So strip it
+            # away here so everything is consistent (`inspect.ismethod` does
+            # not work here since `fn` is unbound at this point)
+            param_types, return_type = signature
+            param_types = param_types[1:]
+            signature = (param_types, return_type)
 
-    type_line, source = None, None
-    try:
-        source = dedent(''.join(get_source_lines_and_file(fn)[0]))
-        type_line = get_type_line(source)
-    except TypeError:
-        pass
-    # This might happen both because we failed to get the source of fn, or
-    # because it didn't have any annotations.
-    if type_line is None:
-        return None
+    if signature is None:
+        type_line, source = None, None
+        try:
+            source = dedent(''.join(get_source_lines_and_file(fn)[0]))
+            type_line = get_type_line(source)
+        except TypeError:
+            pass
+        # This might happen both because we failed to get the source of fn, or
+        # because it didn't have any annotations.
+        if type_line is not None:
+            signature = parse_type_line(type_line, rcb, loc)
 
-    return parse_type_line(type_line, rcb, loc)
+    return signature
 
 
-# This is essentially a weaker form of get_signature(), where we don't care if
-# we have the types, we just care that we can figure out how many parameters
-# a function takes.
-def get_num_params(fn, loc):
+def is_function_or_method(the_callable):
+    # A stricter version of `inspect.isroutine` that does not pass for built-in
+    # functions
+    return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
+
+
+def is_vararg(the_callable):
+    if not is_function_or_method(the_callable) and hasattr(the_callable, '__call__'):  # noqa: B004
+        # If `the_callable` is a class, de-sugar the call so we can still get
+        # the signature
+        the_callable = the_callable.__call__
+
+    if is_function_or_method(the_callable):
+        if PY2:
+            # [inspect args]
+            # `inspect.getfullargspec` is not available in Python 2 but
+            # `inspect.getargspec` is deprecated in Python 3, so we have to
+            # switch over them
+            return inspect.getargspec(the_callable).varargs is not None
+        else:
+            return inspect.getfullargspec(the_callable).varargs is not None
+    else:
+        return False
+
+
+def get_param_names(fn, n_args):
+    if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__):  # noqa: B004
+        # De-sugar calls to classes
+        fn = fn.__call__
+
+    if is_function_or_method(fn):
+        if PY2:
+            # see [inspect args]
+            return inspect.getargspec(fn).args
+        else:
+            return inspect.getfullargspec(fn).args
+    else:
+        # The `fn` was not a method or function (maybe a class with a __call__
+        # method, so use a default param name list)
+        return [str(i) for i in range(n_args)]
+
+
+def check_fn(fn, loc):
+    # Make sure the function definition is not a class instantiation
     try:
         source = dedent(''.join(get_source_lines_and_file(fn)[0]))
     except (TypeError, IOError):
-        return None
+        return
     if source is None:
-        return None
+        return
+
     py_ast = ast.parse(source)
     if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
         raise torch.jit.frontend.FrontendError(
             loc, "Cannot instantiate class '{}' in a script function".format(py_ast.body[0].name))
     if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
         raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
-    py_def = py_ast.body[0]
-    if py_def.args.vararg is not None:
-        return None
-    elif hasattr(py_def.args, 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0:
-        return None
-    else:
-        return len(py_def.args.args)
 
 
 def parse_type_line(type_line, rcb, loc):
@@ -279,7 +321,8 @@
     # TODO: Consider not exporting these during wildcard import (reserve
     # that for the types; for idiomatic typing code.)
     'get_signature',
-    'get_num_params',
+    'check_fn',
+    'get_param_names',
     'parse_type_line',
     'get_type_line',
     'split_type_line',