blob: 32fbd3084d213be32b55319c3169b4cadfbf4283 [file] [log] [blame]
import sys
import ast
import inspect
import re
import torch
from .._jit_internal import List, BroadcastingList1, BroadcastingList2, \
BroadcastingList3, Tuple, is_tuple, is_list, Dict, is_dict, Optional, \
is_optional, _qualified_name, Any
from torch._C import TensorType, TupleType, FloatType, IntType, \
ListType, StringType, DictType, BoolType, OptionalType, ClassType, InterfaceType, AnyType, NoneType
from textwrap import dedent
from torch._six import builtins
from torch._utils_internal import get_source_lines_and_file
PY35 = sys.version_info >= (3, 5)
class Module(object):
def __init__(self, name, members):
self.name = name
self.members = members
def __getattr__(self, name):
try:
return self.members[name]
except KeyError:
raise RuntimeError("Module {} has no member called {}".format(self.name, name))
class EvalEnv(object):
env = {
'torch': Module('torch', {'Tensor': torch.Tensor}),
'Tensor': torch.Tensor,
'typing': Module('typing', {'Tuple': Tuple}),
'Tuple': Tuple,
'List': List,
'Dict': Dict,
'Optional': Optional,
}
def __init__(self, rcb):
self.rcb = rcb
def __getitem__(self, name):
if name in self.env:
return self.env[name]
if self.rcb is not None:
return self.rcb(name)
return getattr(builtins, name, None)
def get_signature(fn, rcb, loc):
# Python 3.5 adds support for the nice annotation syntax, so try that first.
if PY35:
sig = try_real_annotations(fn)
if sig is not None:
return sig
type_line, source = None, None
try:
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
type_line = get_type_line(source)
except TypeError:
pass
# This might happen both because we failed to get the source of fn, or
# because it didn't have any annotations.
if type_line is None:
return None
return parse_type_line(type_line, rcb, loc)
# This is essentially a weaker form of get_signature(), where we don't care if
# we have the types, we just care that we can figure out how many parameters
# a function takes.
def get_num_params(fn, loc):
try:
source = dedent(''.join(get_source_lines_and_file(fn)[0]))
except (TypeError, IOError):
return None
if source is None:
return None
py_ast = ast.parse(source)
if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
raise torch.jit.frontend.FrontendError(
loc, "Cannot instantiate class '{}' in a script function".format(py_ast.body[0].name))
if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
raise torch.jit.frontend.FrontendError(loc, "Expected a single top-level function")
py_def = py_ast.body[0]
if py_def.args.vararg is not None:
return None
elif hasattr(py_def.args, 'kwonlyargs') and len(py_def.args.kwonlyargs) > 0:
return None
else:
return len(py_def.args.args)
def parse_type_line(type_line, rcb, loc):
"""Parses a type annotation specified as a comment.
Example inputs:
# type: (Tensor, torch.Tensor) -> Tuple[Tensor]
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
"""
arg_ann_str, ret_ann_str = split_type_line(type_line)
try:
arg_ann = eval(arg_ann_str, {}, EvalEnv(rcb)) # noqa: P204
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the argument list of a type annotation: {}".format(str(e)))
if not isinstance(arg_ann, tuple):
arg_ann = (arg_ann,)
try:
ret_ann = eval(ret_ann_str, {}, EvalEnv(rcb)) # noqa: P204
except (NameError, SyntaxError) as e:
raise RuntimeError("Failed to parse the return type of a type annotation: {}".format(str(e)))
resolver = (rcb, loc)
arg_types = [ann_to_type(ann, resolver) for ann in arg_ann]
return arg_types, ann_to_type(ret_ann, resolver)
def get_type_line(source):
"""Tries to find the line containing a comment with the type annotation."""
type_comment = '# type:'
lines = source.split('\n')
lines = [(line_num, line) for line_num, line in enumerate(lines)]
type_lines = list(filter(lambda line: type_comment in line[1], lines))
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
if len(type_lines) == 0:
type_pattern = re.compile('#[\t ]*type[\t ]*:')
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
if len(wrong_type_lines) > 0:
raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
+ " is probably invalid.\nIt must be '# type:'"
+ "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa
+ "\nfor examples")
return None
elif len(type_lines) == 1:
# Only 1 type line, quit now
return type_lines[0][1].strip()
# Parse split up argument types according to PEP 484
# https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
return_line = None
parameter_type_lines = []
for line_num, line in reversed(type_lines):
if '# type: (...) -> ' in line:
return_line = (line_num, line)
elif type_comment in line:
if return_line is None:
raise RuntimeError("Return type line '# type: (...) -> ...' not found on multiline "
"type annotation\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)") # noqa
if line_num < return_line[0]:
parameter_type_lines.insert(0, line)
def get_parameter_type(line):
item_type = line[line.find(type_comment) + len(type_comment):]
return item_type.strip()
types = map(get_parameter_type, parameter_type_lines)
parameter_types = ", ".join(types)
return return_line[1].replace("...", parameter_types)
def split_type_line(type_line):
"""Splits the comment with the type annotation into parts for argument and return types.
For example, for an input of:
# type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
This function will return:
("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
"""
start_offset = len('# type:')
try:
arrow_pos = type_line.index('->')
except ValueError:
raise RuntimeError("Syntax error in type annotation (cound't find `->`)")
return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2:].strip()
def try_real_annotations(fn):
"""Tries to use the Py3.5+ annotation syntax to get the type."""
try:
sig = inspect.signature(fn)
except ValueError:
return None
all_annots = [sig.return_annotation] + [p.annotation for p in sig.parameters.values()]
if all(ann is sig.empty for ann in all_annots):
return None
def as_ann(ann):
# sig.empty is really annoying so convert it to None
return ann if ann is not sig.empty else None
arg_types = [ann_to_type(as_ann(p.annotation))
for p in sig.parameters.values()]
return_type = ann_to_type(as_ann(sig.return_annotation))
return arg_types, return_type
def ann_to_type(ann, resolver=None):
# resolver should be a Tuple[Callable, SourceRange] where the Callable
# is a resolutionCallback
if ann is None:
return TensorType.get()
elif ann is torch.Tensor:
return TensorType.get()
elif is_tuple(ann):
return TupleType([ann_to_type(a) for a in ann.__args__])
elif is_list(ann):
return ListType(ann_to_type(ann.__args__[0]))
elif is_dict(ann):
key = ann_to_type(ann.__args__[0])
value = ann_to_type(ann.__args__[1])
return DictType(key, value)
elif is_optional(ann):
if issubclass(ann.__args__[1], type(None)):
return OptionalType(ann_to_type(ann.__args__[0]))
else:
return OptionalType(ann_to_type(ann.__args__[1]))
elif ann is float:
return FloatType.get()
elif ann is int:
return IntType.get()
elif ann is str:
return StringType.get()
elif ann is bool:
return BoolType.get()
elif ann is Any:
return AnyType.get()
elif ann is type(None):
return NoneType.get()
elif hasattr(ann, "__torch_script_class__"):
return ClassType(_qualified_name(ann))
elif hasattr(ann, "__torch_script_interface__"):
return InterfaceType(_qualified_name(ann))
elif resolver is not None:
# Maybe resolve a NamedTuple to a Tuple Type
rcb, loc = resolver
the_type = torch._C._resolve_type(ann.__name__, loc, rcb)
if the_type is not None:
return the_type
raise ValueError("Unknown type annotation: '{}'".format(ann))
__all__ = [
'Any',
'List',
'BroadcastingList1',
'BroadcastingList2',
'BroadcastingList3',
'Tuple',
'is_tuple',
'is_list',
'Dict',
'is_dict',
'TensorType',
'TupleType',
'FloatType',
'IntType',
'ListType',
'StringType',
'DictType',
'AnyType',
'Module',
# TODO: Consider not exporting these during wildcard import (reserve
# that for the types; for idiomatic typing code.)
'get_signature',
'get_num_params',
'parse_type_line',
'get_type_line',
'split_type_line',
'try_real_annotations',
'ann_to_type',
]