Add support for resolvers to model side effects of functions on local variables. Add support for static value annotations, which are needed to properly resolve names in modules. Add support for attributes, based on the static value if available. Clean up the resolver interface to reduce the amount of indirection.
PiperOrigin-RevId: 323562369
Change-Id: I934311ca47edf5d175e95940994e76b57d8a2c9c
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index e6c40fc..90535ff 100644
--- a/tensorflow/python/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -36,8 +36,14 @@
class NoValue(enum.Enum):
+ def of(self, node, default=None):
+ return getanno(node, self, default=default)
+
+ def exists(self, node):
+ return hasanno(node, self)
+
def __repr__(self):
- return self.name
+ return str(self.name)
class Basic(NoValue):
@@ -102,6 +108,7 @@
LIVE_VARS_IN = ('Symbols live when entering the node. See liveness.py.')
TYPES = 'Static type information. See type_inference.py.'
CLOSURE_TYPES = 'Types of closure symbols at each detected call site.'
+ VALUE = 'Static value information. See type_inference.py.'
FAIL = object()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
index 4e8a9a9..cf866ad 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference.py
@@ -20,6 +20,10 @@
* global and local symbols visible to the function at analysis time
* literals
+Important: This analysis is static, and does not detect dynamic type changes.
+The analysis attempts to use the values of external symbols, if available. These
+values are also considered static for the purpose of analysis.
+
Requires reaching function definitions analysis.
"""
@@ -27,7 +31,7 @@
from __future__ import division
from __future__ import print_function
-from typing import Any, Tuple
+from typing import Tuple
import gast
@@ -41,42 +45,68 @@
class Resolver(object):
"""Resolver objects handle the process of looking up actual names and types.
- All resolve_* methods:
+ Unless noted otherwise, all resolve_* methods:
* have a first namespace argument, mapping string to actual values
+ * have a second types_namespace argument, mapping string to actual inferred
+ types
* specify names as QN objects
* specify types as a Set of inferred types
- All resolve_* methods must return either:
+ Unless noted otherwise, all resolve_* methods must return either:
* a set of `type` objects
* None
"""
- def res_name(self, ns, name):
- """Resolves the type an external (e.g. closure, global) variable."""
+ def res_name(self, ns, types_ns, name):
+ """Resolves the type/value an external (e.g. closure, global) variable.
+
+ Args:
+ ns: namespace
+ types_ns: types namespace
+ name: symbol name
+ Returns:
+ Tuple (type, static_value). The first element is the type to use for
+ inferrence. The second is the static value to use. Return None to treat it
+ as unknown.
+ """
raise NotImplementedError('subclasses must implement')
def res_value(self, ns, value):
- """Resolves the type a literal value."""
+ """Resolves the type a literal or static value."""
raise NotImplementedError('subclasses must implement')
- # TODO(mdan): Allow caller to model side effects.
- def res_call(self, ns, name, target, args, keywords, starargs, kwargs):
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ """Resolves the type of a (possibly annotated) function argument."""
+ raise NotImplementedError('subclasses must implement')
+
+ def res_call(self, ns, types_ns, node, args, keywords):
"""Resolves the return type an external function or method call.
Args:
ns: namespace
- name: str, the function name
- target: if this is a method call, the types of the method target, None
- otherwise
- args: list or argument types
- keywords: dict of name to argument types
- starargs: list of types of the *args arguments (should be at most one)
- kwargs: list of types of the **kwargs arguments (in order of appearance)
+ types_ns: types namespace
+ node: str, the function name
+ args: types of each respective argument in node.args
+ keywords: types of each respective argument in node.keywords
+
+ Returns:
+ Tuple (return_type, side_effect_types). The first element is just the
+ return types of the function. The second element is a map from
+ argument names to sets of types, and allow modelling side effects of
+ functions (for example via global or nonlocal).
"""
raise NotImplementedError('subclasses must implement')
- def res_arg(self, ns, f_name, arg_name, type_anno):
- """Resolves the type of a (possibly annotated) function argument."""
+ def res_subscript(self, ns, types_ns, node, value, slice_):
+ """Resolves the return type of a unary operation."""
+ raise NotImplementedError('subclasses must implement')
+
+ def res_compare(self, ns, types_ns, node, left, right):
+ """Resolves the return type of a unary operation."""
+ raise NotImplementedError('subclasses must implement')
+
+ def res_binop(self, ns, types_ns, node, left, right):
+ """Resolves the return type of a unary operation."""
raise NotImplementedError('subclasses must implement')
@@ -86,23 +116,23 @@
This is a value type. Only implements the strictly necessary operators.
Attributes:
- value: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
+ types: Dict[qual_names.QN, Set[Type]], mapping symbols to the set of
possible types.
"""
def __init__(self, init_from=None):
if init_from:
assert isinstance(init_from, _SymbolTable)
- self.value = {
- s: set(other_types) for s, other_types in init_from.value.items()
+ self.types = {
+ s: set(other_types) for s, other_types in init_from.types.items()
}
else:
- self.value = {}
+ self.types = {}
def __eq__(self, other):
- if frozenset(self.value.keys()) != frozenset(other.value.keys()):
+ if frozenset(self.types.keys()) != frozenset(other.types.keys()):
return False
- ret = all(self.value[s] == other.value[s] for s in self.value)
+ ret = all(self.types[s] == other.types[s] for s in self.types)
return ret
def __ne__(self, other):
@@ -111,52 +141,17 @@
def __or__(self, other):
assert isinstance(other, _SymbolTable)
result = _SymbolTable(self)
- for s, other_types in other.value.items():
- if s not in result.value:
+ for s, other_types in other.types.items():
+ if s not in result.types:
self_types = set()
- result.value[s] = self_types
+ result.types[s] = self_types
else:
- self_types = result.value[s]
+ self_types = result.types[s]
self_types.update(other_types)
return result
def __repr__(self):
- return 'SymbolTable {}'.format(self.value)
-
-
-_GETITEM = qual_names.QN('__getitem__')
-
-_HANDLERS = {
- gast.Eq: qual_names.QN('__eq__'),
- gast.NotEq: qual_names.QN('__ne__'),
- gast.Lt: qual_names.QN('__lt__'),
- gast.LtE: qual_names.QN('__le__'),
- gast.Gt: qual_names.QN('__gt__'),
- gast.GtE: qual_names.QN('__ge__'),
- gast.In: qual_names.QN('__contains__'),
- # TODO(mdan): Is this actually correct?
- # NotIn(*) = Not(In(*))
- gast.NotIn: qual_names.QN('__not__'),
-
- gast.Add: qual_names.QN('__add__'),
- gast.Sub: qual_names.QN('__sub__'),
- gast.Mult: qual_names.QN('__mul__'),
- gast.Div: qual_names.QN('__div__'),
- gast.FloorDiv: qual_names.QN('__floordiv__'),
- gast.Mod: qual_names.QN('__mod__'),
- gast.Pow: qual_names.QN('__pow__'),
- gast.LShift: qual_names.QN('__lshift__'),
- gast.RShift: qual_names.QN('__rshift__'),
- gast.BitOr: qual_names.QN('__or__'),
- gast.BitXor: qual_names.QN('__xor__'),
- gast.BitAnd: qual_names.QN('__and__'),
- gast.MatMult: qual_names.QN('__matmul__'),
-}
-
-_FIXED_RETTYPES = {
- gast.Is: bool,
- gast.IsNot: bool,
-}
+ return 'SymbolTable {}'.format(self.types)
class StmtInferrer(gast.NodeVisitor):
@@ -164,6 +159,21 @@
This visitor annotates most nodes with type information. It also sets types
for the symbols modified by this statement in its types_out property.
+
+ Note: this inferrer is able to capture side effects of functions, however,
+ these side effects will not be applied to the current expression. Doing so
+ would create too much of a dependence on the runtime's internal rules about
+ execution order.
+ Example:
+
+ def f():
+ nonlocal a
+ a = 1
+ return a
+
+ a = 0.0
+ b = f() + a # a = float; side effect of f() ignored
+ print(a) # a = int; side effect of f() accounted for
"""
def __init__(self, resolver, scope, namespace, closure_types, types_in):
@@ -173,7 +183,7 @@
self.closure_types = closure_types
self.types_in = types_in
self.new_symbols = {}
- self.rvalue = None
+ self.rtype = None
def visit(self, node):
types = super().visit(node)
@@ -184,10 +194,19 @@
def visit_FunctionDef(self, node):
# Skip local function definitions. They are analyzed separately.
+ # TODO(mdan): Don't skip. Analyze side effects instead.
return None
+ def _check_set(self, value):
+ if value is not None and not isinstance(value, set):
+ raise ValueError('{} method expected to return set, got {}'.format(
+ self.resolver, value))
+
def visit_Constant(self, node):
- return self.resolver.res_value(self.namespace, node.value)
+ types = self.resolver.res_value(self.namespace, node.value)
+ if __debug__:
+ self._check_set(types)
+ return types
def visit_Tuple(self, node):
if isinstance(node.ctx, gast.Load):
@@ -214,116 +233,156 @@
def visit_Name(self, node):
name = anno.getanno(node, anno.Basic.QN)
+
if isinstance(node.ctx, gast.Load):
- types = self.types_in.value.get(name, None)
+ types = self.types_in.types.get(name, None)
if (types is None) and (name not in self.scope.bound):
if name in self.closure_types:
types = self.closure_types[name]
else:
- types = self.resolver.res_name(self.namespace, name)
- return types
+ types, value = self.resolver.res_name(
+ self.namespace, self.types_in.types, name)
+ if value is not None:
+ anno.setanno(node, anno.Static.VALUE, value)
elif isinstance(node.ctx, gast.Param):
type_name = anno.getanno(node.annotation, anno.Basic.QN, None)
- types = self.resolver.res_arg(self.namespace, self.scope.function_name,
- name, type_name)
+ types = self.resolver.res_arg(self.namespace, self.types_in.types,
+ self.scope.function_name, name, type_name)
if types is not None:
self.new_symbols[name] = types
- return types
elif isinstance(node.ctx, gast.Store):
- if self.rvalue is not None:
- self.new_symbols[name] = self.rvalue
- else:
- # No type information, assume Any.
- self.new_symbols[name] = {Any}
- return self.rvalue
+ if self.rtype is not None:
+ self.new_symbols[name] = self.rtype
+ types = self.rtype
- assert False, 'unknown ctx'
+ else:
+ assert False, 'unknown ctx'
+
+ if __debug__:
+ self._check_set(types)
+
+ return types
+
+ def visit_Attribute(self, node):
+ parent_types = self.visit(node.value)
+
+ # Attempt to use the static value if known.
+ parent_value = anno.Static.VALUE.of(node.value, None)
+ if parent_value is not None:
+ static_value = getattr(parent_value, node.attr, None)
+
+ else:
+ # Fall back to the type if that is known.
+ if parent_types is None:
+ return None
+
+ inferred_values = [getattr(t, node.attr, None) for t in parent_types]
+ if not inferred_values:
+ return None
+
+ static_value = inferred_values[0]
+ if static_value is None:
+ return None
+
+ if any(v is not static_value for v in inferred_values[1:]):
+ # Static value not stable, assume it's dynamic.
+ return None
+
+ types = self.resolver.res_value(self.namespace, static_value)
+ anno.setanno(node, anno.Static.VALUE, static_value)
+
+ if __debug__:
+ self._check_set(types)
+
+ return types
def visit_Call(self, node):
+ self.visit(node.func)
+
f_name = anno.getanno(node.func, anno.Basic.QN)
-
- kwargs = [self.visit(kw.value) for kw in node.keywords if kw.arg is None]
- keywords = {
- kw.arg: self.visit(kw.value)
- for kw in node.keywords
- if kw.arg is not None
- }
- is_starred = [isinstance(a, gast.Starred) for a in node.args]
- args = [
- self.visit(a)
- for a, starred in zip(node.args, is_starred)
- if not starred
- ]
- starargs = [
- self.visit(a.value)
- for a, starred in zip(node.args, is_starred)
- if starred
- ]
-
if f_name in self.scope.bound:
# Don't attempt external resolution of local functions.
# TODO(mdan): Use type annotations of the local definition.
return None
- return self.resolver.res_call(
- self.namespace, f_name, None, args, keywords, starargs, kwargs)
+ arg_types = [self.visit(a) for a in node.args]
+ keyword_types = [self.visit(kw.value) for kw in node.keywords]
+
+ ret_type, side_effects = self.resolver.res_call(self.namespace,
+ self.types_in.types, node,
+ arg_types, keyword_types)
+ if __debug__:
+ self._check_set(ret_type)
+ if side_effects:
+ if not isinstance(side_effects, dict):
+ raise ValueError(
+ 'side effects must be dict, got {}'.format(side_effects))
+ for k, v in side_effects.items():
+ if not isinstance(k, qual_names.QN):
+ raise ValueError('side effect keys must be QNs, got {}'.format(k))
+ self._check_set(v)
+
+ if side_effects:
+ self.new_symbols.update(side_effects)
+ return ret_type
def visit_Index(self, node):
return self.visit(node.value)
def visit_Assign(self, node):
- self.rvalue = self.visit(node.value)
+ self.rtype = self.visit(node.value)
for t in node.targets:
self.visit(t)
- self.rvalue = None
+ self.rtype = None
def visit_Subscript(self, node):
- val_type = self.visit(node.value)
- slice_type = self.visit(node.slice)
+ val_types = self.visit(node.value)
+ slice_types = self.visit(node.slice)
- if val_type is None or slice_type is None:
+ if val_types is None or slice_types is None:
return None
- return self.resolver.res_call(self.namespace, _GETITEM, val_type,
- (slice_type,), {}, (), ())
+ types = self.resolver.res_subscript(
+ self.namespace, self.types_in.types, node, val_types, slice_types)
+
+ if __debug__:
+ self._check_set(types)
+
+ return types
def visit_Compare(self, node):
+ left_types = self.visit(node.left)
right_types = [self.visit(c) for c in node.comparators]
- op_types = [type(o) for o in node.ops]
- if len(op_types) > 1:
- raise NotImplementedError('chained comparisons')
- assert len(right_types) == 1
- left_type = self.visit(node.left)
- right_type, = right_types
- op_type, = op_types
-
- if left_type is None or right_type is None:
+ if left_types is None or any(t is None for t in right_types):
return None
- f_name = _HANDLERS.get(op_type, None)
- if f_name is None:
- # Python doesn't allow overriding these operators. Their return types are
- # fixed.
- return {_FIXED_RETTYPES[op_type]}
- return self.resolver.res_call(self.namespace, _HANDLERS[op_type],
- left_type, (right_type,), {}, (), ())
+ types = self.resolver.res_compare(
+ self.namespace, self.types_in.types, node, left_types, right_types)
+
+ if __debug__:
+ self._check_set(types)
+
+ return types
def visit_BinOp(self, node):
- left_type = self.visit(node.left)
- right_type = self.visit(node.right)
+ left_types = self.visit(node.left)
+ right_types = self.visit(node.right)
- if left_type is None or right_type is None:
+ if left_types is None or right_types is None:
return None
- # TODO(mdan): This does not fully follow Python operator semantics.
- # For example, in `a + b` Python will try `a.__add__`, but also `b.__radd__`
- return self.resolver.res_call(self.namespace, _HANDLERS[type(node.op)],
- left_type, (right_type,), {}, (), ())
+ types = self.resolver.res_binop(
+ self.namespace, self.types_in.types, node, left_types, right_types)
+
+ if __debug__:
+ self._check_set(types)
+
+ return types
class Analyzer(cfg.GraphVisitor):
@@ -355,7 +414,7 @@
existing_types = {}
anno.setanno(ast_node, anno.Static.CLOSURE_TYPES, existing_types)
- for k, v in types.value.items():
+ for k, v in types.types.items():
if k in existing_types:
existing_types[k].update(v)
else:
@@ -371,10 +430,10 @@
types_out = _SymbolTable(types_in)
ast_node = node.ast_node
- inferrer = StmtInferrer(
- self.resolver, self.scope, self.namespace, self.closure_types, types_in)
+ inferrer = StmtInferrer(self.resolver, self.scope, self.namespace,
+ self.closure_types, types_in)
inferrer.visit(ast_node)
- types_out.value.update(inferrer.new_symbols)
+ types_out.types.update(inferrer.new_symbols)
reaching_fndefs = anno.getanno(ast_node, anno.Static.DEFINED_FNS_IN)
node_scope = anno.getanno(ast_node, anno.Static.SCOPE, None)
@@ -404,8 +463,8 @@
scope = anno.getanno(node, annos.NodeAnno.ARGS_AND_BODY_SCOPE)
closure_types = anno.getanno(node, anno.Static.CLOSURE_TYPES, {})
- analyzer = Analyzer(
- subgraph, self.resolver, self.ctx.info.namespace, scope, closure_types)
+ analyzer = Analyzer(subgraph, self.resolver, self.ctx.info.namespace, scope,
+ closure_types)
analyzer.visit_forward()
# Recursively process any remaining subfunctions.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
index fb7324a..e3cb7e0 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/type_inference_test.py
@@ -29,37 +29,24 @@
from tensorflow.python.platform import test
-class TestResolver(type_inference.Resolver):
+class BasicTestResolver(type_inference.Resolver):
"""A very basic resolver for testing."""
- def res_name(self, ns, name):
- return {type(ns[str(name)])}
+ def res_name(self, ns, types_ns, name):
+ return {type(ns[str(name)])}, ns[str(name)]
def res_value(self, ns, value):
- del ns
return {type(value)}
- def res_call(self, ns, name, target, args, keywords, starargs, kwargs):
- name_str = str(name)
- if name_str in ns:
- return {ns[name_str].__annotations__['return']}
- if target is None:
- return {'unk_{}'.format(name_str)}
- return {'{}_{}'.format(list(target)[0], name_str)}
-
- def res_arg(self, ns, f_name, arg_name, type_anno):
- if f_name == 'magic_no_types':
- return None
- if type_anno is not None:
- return {{'int': int, 'float': float}[str(type_anno)]}
- return {'{}_{}'.format(f_name, arg_name)}
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return {str(type_anno)}
class TestTranspiler(transpiler.GenericTranspiler):
- def __init__(self):
+ def __init__(self, resolver_type):
super().__init__()
- self.resolver = TestResolver()
+ self.resolver = resolver_type()
def get_transformed_name(self, _):
return 'test_item'
@@ -87,16 +74,58 @@
actual = {str(k): v for k, v in actual.items()}
self.assertDictEqual(actual, expected)
+ def test_no_inference_on_unknown_operand_types(self):
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return None
+
+ def test_fn(a, b):
+ return a < b, a - b
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ # With no information on operand types, the operators will infer nothing.
+ self.assertFalse(
+ anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES))
+ self.assertFalse(
+ anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
+
+ def test_resolver_output_checked(self):
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return 1
+
+ def test_fn(a):
+ del a
+ pass
+
+ with self.assertRaisesRegex(ValueError, 'expected to return set'):
+ TestTranspiler(Resolver).transform(test_fn, None)
+
def test_argument(self):
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ if name == qual_names.QN('a'):
+ test_self.assertEqual(type_anno, qual_names.QN('int'))
+ return {str(name) + '_type'}
+
def test_fn(a: int, b):
return a, b
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].value.elts[0], int)
- self.assertTypes(fn_body[0].value.elts[1], 'test_fn_b')
+ self.assertTypes(fn_body[0].value.elts[0], 'a_type')
+ self.assertTypes(fn_body[0].value.elts[1], 'b_type')
def test_argument_of_local_function(self):
@@ -107,42 +136,238 @@
return foo(a)
- tr = TestTranspiler()
+ tr = TestTranspiler(BasicTestResolver)
node, _ = tr.transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].body[0].value, float)
- self.assertClosureTypes(fn_body[0], {'a': {int}})
+ self.assertTypes(fn_body[0].body[0].value, 'float')
+ self.assertClosureTypes(fn_body[0], {'a': {'int'}})
- def test_straightline_assignment(self):
+ def test_assign_straightline(self):
- def test_fn(a: int, c):
+ def test_fn(a: int, c: float):
b = a
return a, b, c
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].targets[0], int)
- self.assertTypes(fn_body[0].value, int)
- self.assertTypes(fn_body[1].value.elts[0], int)
- self.assertTypes(fn_body[1].value.elts[1], int)
- self.assertTypes(fn_body[1].value.elts[2], 'test_fn_c')
+ self.assertTypes(fn_body[0].targets[0], 'int')
+ self.assertTypes(fn_body[0].value, 'int')
+ self.assertTypes(fn_body[1].value.elts[0], 'int')
+ self.assertTypes(fn_body[1].value.elts[1], 'int')
+ self.assertTypes(fn_body[1].value.elts[2], 'float')
- def test_assignment_overwrite(self):
+ def test_expr(self):
+
+ self_test = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_value(self, ns, value):
+ self_test.assertEqual(value, tc.a)
+ return {str}
+
+ def res_name(self, ns, types_ns, name):
+ self_test.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass}, tc
+
+ def res_call(self, ns, types_ns, node, args, keywords):
+ return {int}, None
+
+ class TestClass:
+
+ def a(self):
+ pass
+
+ tc = TestClass()
+
+ def test_fn():
+ tc.a()
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value, int)
+ self.assertTypes(fn_body[0].value.func, str)
+ self.assertEqual(
+ anno.getanno(fn_body[0].value.func, anno.Static.VALUE), tc.a)
+
+ def test_assign_overwriting(self):
def test_fn(a: int, b: float):
c = a
c = b
return c
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].targets[0], int)
+ self.assertTypes(fn_body[0].targets[0], 'int')
+ self.assertTypes(fn_body[0].value, 'int')
+ self.assertTypes(fn_body[1].targets[0], 'float')
+ self.assertTypes(fn_body[1].value, 'float')
+
+ def test_dynamic_attribute_of_static_value(self):
+
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_value(self, ns, value):
+ test_self.assertEqual(value, tc.a)
+ return {int}
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass}, tc
+
+ class TestClass:
+
+ def __init__(self):
+ self.a = 1
+
+ tc = TestClass()
+
+ def test_fn():
+ return tc.a
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value.value, TestClass)
self.assertTypes(fn_body[0].value, int)
- self.assertTypes(fn_body[1].targets[0], float)
- self.assertTypes(fn_body[1].value, float)
+ self.assertIs(anno.getanno(fn_body[0].value.value, anno.Static.VALUE), tc)
+ self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE), tc.a)
+
+ def test_static_attribute_of_typed_value(self):
+
+ test_self = self
+
+ class TestClass:
+
+ a = 1
+
+ tc = TestClass()
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass}, None
+
+ def res_value(self, ns, value):
+ test_self.assertIs(value, tc.a)
+ return {str}
+
+ def test_fn():
+ return tc.a
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value.value, TestClass)
+ self.assertTypes(fn_body[0].value, str) # Resolver is SOT
+ self.assertFalse(anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
+ self.assertEqual(anno.getanno(fn_body[0].value, anno.Static.VALUE), 1)
+
+ def test_static_attribute_of_ambiguous_type(self):
+
+ test_self = self
+
+ class TestClass1:
+
+ a = 1
+
+ class TestClass2:
+
+ a = 2
+
+ tc = TestClass1()
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass1, TestClass2}, None
+
+ def res_value(self, ns, value):
+ test_self.assertIn(value, (1, 2))
+ return {str}
+
+ def test_fn():
+ return tc.a
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value.value, (TestClass1, TestClass2))
+ self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES))
+ self.assertFalse(anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
+ self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
+
+ def test_property_of_typed_value(self):
+
+ test_self = self
+
+ class TestClass:
+
+ @property
+ def a(self):
+ return 1
+
+ tc = TestClass()
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass}, None
+
+ def res_value(self, ns, value):
+ test_self.assertIs(value, TestClass.a)
+ test_self.assertNotEqual(value, 1) # Can't evaluate property of class.
+ return {property}
+
+ def test_fn():
+ return tc.a
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value.value, TestClass)
+ self.assertTypes(fn_body[0].value, property)
+ self.assertFalse(anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
+ self.assertEqual(
+ anno.getanno(fn_body[0].value, anno.Static.VALUE), TestClass.a)
+
+ def test_dynamic_attribute_of_typed_value(self):
+
+ test_self = self
+
+ class TestClass:
+
+ def __init__(self):
+ self.a = 1
+
+ tc = TestClass()
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('tc'))
+ return {TestClass}, None
+
+ def test_fn():
+ return tc.a
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].value.value, TestClass)
+ self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.TYPES))
+ self.assertFalse(anno.hasanno(fn_body[0].value.value, anno.Static.VALUE))
+ self.assertFalse(anno.hasanno(fn_body[0].value, anno.Static.VALUE))
def test_external_value(self):
@@ -152,7 +377,7 @@
b = a
return b
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].targets[0], str)
@@ -160,6 +385,19 @@
def test_external_function(self):
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('g'))
+ return {str}, g
+
+ def res_call(self, ns, types_ns, node, args, keywords):
+ test_self.assertEqual(
+ anno.getanno(node.func, anno.Basic.QN), qual_names.QN('g'))
+ return {float}, None
+
def g() -> float:
return 1.0
@@ -167,12 +405,49 @@
a = g()
return a
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
+ self.assertTypes(fn_body[0].value.func, str)
self.assertTypes(fn_body[0].targets[0], float)
self.assertTypes(fn_body[1].value, float)
+ def test_external_function_side_effects(self):
+
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_name(self, ns, types_ns, name):
+ test_self.assertEqual(name, qual_names.QN('g'))
+ return None, g
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return {str(type_anno)}
+
+ def res_call(self, ns, types_ns, node, args, keywords):
+ return None, {qual_names.QN('x'): {str}}
+
+ def g():
+ # The resolver will pretend that this function has the following body:
+ #
+ # nonlocal x
+ # x = 'a'
+ pass
+
+ def test_fn(x: int):
+ y = x
+ g()
+ return x, y
+
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
+ fn_body = node.body
+
+ self.assertTypes(fn_body[0].targets[0], 'int')
+ self.assertTypes(fn_body[0].value, 'int')
+ self.assertTypes(fn_body[2].value.elts[0], str)
+ self.assertTypes(fn_body[2].value.elts[1], 'int')
+
def test_local_function_closure(self):
def test_fn(x: int):
@@ -182,27 +457,27 @@
foo()
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].body[0].value, int)
- self.assertClosureTypes(fn_body[0], {'x': {int}})
+ self.assertTypes(fn_body[0].body[0].value, 'int')
+ self.assertClosureTypes(fn_body[0], {'x': {'int'}})
def test_local_function_closure_ignored_for_bound_symbols(self):
- def test_fn(x: int): # pylint:disable=unused-argument
+ def test_fn(x: float): # pylint:disable=unused-argument
def foo():
x = x + 1 # pylint:disable=used-before-assignment
foo()
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
self.assertFalse(
anno.hasanno(fn_body[0].body[0].value.left, anno.Static.TYPES))
- self.assertClosureTypes(fn_body[0], {'x': {int}})
+ self.assertClosureTypes(fn_body[0], {'x': {'float'}})
def test_local_function_closure_uses_call_site_types(self):
@@ -214,7 +489,7 @@
x = 1.0
foo()
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(BasicTestResolver).transform(test_fn, None)
fn_body = node.body
self.assertTypes(fn_body[0].body[0].value, float)
@@ -223,54 +498,78 @@
def test_subscript(self):
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return {list}
+
+ def res_value(self, ns, value):
+ return {int}
+
+ def res_subscript(self, ns, types_ns, node, value, slice_):
+ test_self.assertSetEqual(value, {list})
+ test_self.assertSetEqual(slice_, {int})
+ return {str}
+
def test_fn(a):
return a[1]
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].value, 'test_fn_a___getitem__')
- self.assertTypes(fn_body[0].value.value, 'test_fn_a')
+ self.assertTypes(fn_body[0].value, str)
+ self.assertTypes(fn_body[0].value.value, list)
self.assertTypes(fn_body[0].value.slice.value, int)
def test_compare(self):
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return {int}
+
+ def res_compare(self, ns, types_ns, node, left, right):
+ test_self.assertSetEqual(left, {int})
+ test_self.assertListEqual(right, [{int}])
+ return {bool}
+
def test_fn(a, b):
return a < b
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].value, 'test_fn_a___lt__')
- self.assertTypes(fn_body[0].value.left, 'test_fn_a')
- self.assertTypes(fn_body[0].value.comparators[0], 'test_fn_b')
+ self.assertTypes(fn_body[0].value, bool)
+ self.assertTypes(fn_body[0].value.left, int)
+ self.assertTypes(fn_body[0].value.comparators[0], int)
def test_binop(self):
+ test_self = self
+
+ class Resolver(type_inference.Resolver):
+
+ def res_arg(self, ns, types_ns, f_name, name, type_anno):
+ return {list}
+
+ def res_binop(self, ns, types_ns, node, left, right):
+ test_self.assertSetEqual(left, {list})
+ test_self.assertSetEqual(right, {list})
+ return {float}
+
def test_fn(a, b):
return a @ b
- node, _ = TestTranspiler().transform(test_fn, None)
+ node, _ = TestTranspiler(Resolver).transform(test_fn, None)
fn_body = node.body
- self.assertTypes(fn_body[0].value, 'test_fn_a___matmul__')
- self.assertTypes(fn_body[0].value.left, 'test_fn_a')
- self.assertTypes(fn_body[0].value.right, 'test_fn_b')
-
- def test_no_inference_on_unknown_operand_types(self):
-
- # No information on types of a and b, see TestResolver.
- def magic_no_types(a, b):
- return a < b, a - b
-
- node, _ = TestTranspiler().transform(magic_no_types, None)
- fn_body = node.body
-
- # With no information on operand types, the operators will assert nothing.
- self.assertFalse(
- anno.hasanno(fn_body[0].value.elts[0], anno.Static.TYPES))
- self.assertFalse(
- anno.hasanno(fn_body[0].value.elts[1], anno.Static.TYPES))
+ self.assertTypes(fn_body[0].value, float)
+ self.assertTypes(fn_body[0].value.left, list)
+ self.assertTypes(fn_body[0].value.right, list)
if __name__ == '__main__':