blob: 086eda7574a4c0846ba35a8fcbe897be6bdccc1a [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Activity analysis.
Requires qualified name annotations (see qual_names.py).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import weakref
import gast
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import qual_names
from tensorflow.python.autograph.pyct import transformer
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(mdan): Add support for PY3 (e.g. Param vs arg).
# TODO(alexbw): Ignore named literals (e.g. None)
class Scope(object):
"""Encloses local symbol definition and usage information.
This can track for instance whether a symbol is modified in the current scope.
Note that scopes do not necessarily align with Python's scopes. For example,
the body of an if statement may be considered a separate scope.
Attributes:
modified: identifiers modified in this scope
created: identifiers created in this scope
used: identifiers referenced in this scope
"""
def __init__(self, parent, isolated=True, add_unknown_symbols=False):
"""Create a new scope.
Args:
parent: A Scope or None.
isolated: Whether the scope is isolated, that is, whether variables
created in this scope should be visible to the parent scope.
add_unknown_symbols: Whether to handle attributed and subscripts
without having first seen the base name.
E.g., analyzing the statement 'x.y = z' without first having seen 'x'.
"""
self.isolated = isolated
self.parent = parent
self.add_unknown_symbols = add_unknown_symbols
self.modified = set()
# TODO(mdan): Completely remove this.
self.created = set()
self.used = set()
self.params = {}
self.returned = set()
# TODO(mdan): Rename to `locals`
@property
def referenced(self):
if not self.isolated and self.parent is not None:
return self.used | self.parent.referenced
return self.used
def __repr__(self):
return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
tuple(self.modified))
def copy_from(self, other):
"""Recursively copies the contents of this scope from another scope."""
if (self.parent is None) != (other.parent is None):
raise ValueError('cannot copy scopes of different structures')
if other.parent is not None:
self.parent.copy_from(other.parent)
self.isolated = other.isolated
self.modified = copy.copy(other.modified)
self.created = copy.copy(other.created)
self.used = copy.copy(other.used)
self.params = copy.copy(other.params)
self.returned = copy.copy(other.returned)
@classmethod
def copy_of(cls, other):
if other.parent is not None:
parent = cls.copy_of(other.parent)
else:
parent = None
new_copy = cls(parent)
new_copy.copy_from(other)
return new_copy
def merge_from(self, other):
if (self.parent is None) != (other.parent is None):
raise ValueError('cannot merge scopes of different structures')
if other.parent is not None:
self.parent.merge_from(other.parent)
self.modified |= other.modified
self.created |= other.created
self.used |= other.used
self.params.update(other.params)
self.returned |= other.returned
def has(self, name):
if name in self.modified:
return True
elif self.parent is not None:
return self.parent.has(name)
return False
def mark_read(self, name):
self.used.add(name)
if self.parent is not None and name not in self.created:
self.parent.mark_read(name)
def mark_param(self, name, owner):
# Assumption: all AST nodes have the same life span. This lets us use
# a weak reference to mark the connection between a symbol node and the
# function node whose argument that symbol is.
self.params[name] = weakref.ref(owner)
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
if name.is_composite():
parent = name.parent
if not writes_create_symbol:
return
else:
if not self.has(parent):
if self.add_unknown_symbols:
self.mark_read(parent)
else:
raise ValueError('Unknown symbol "%s".' % parent)
self.created.add(name)
def mark_write(self, name):
"""Marks the given symbol as modified in the current scope."""
self.modified.add(name)
if self.isolated:
self.mark_creation(name)
else:
if self.parent is None:
self.mark_creation(name)
else:
if not self.parent.has(name):
self.mark_creation(name)
self.parent.mark_write(name)
def mark_returned(self, name):
self.returned.add(name)
if not self.isolated and self.parent is not None:
self.parent.mark_returned(name)
class ActivityAnalyzer(transformer.Base):
"""Annotates nodes with local scope information.
See Scope.
The use of this class requires that qual_names.resolve() has been called on
the node. This class will ignore nodes have not been
annotated with their qualified names.
"""
def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
super(ActivityAnalyzer, self).__init__(context)
self.scope = Scope(parent_scope, None, add_unknown_symbols)
self._in_return_statement = False
self._in_aug_assign = False
@property
def _in_constructor(self):
if len(self.enclosing_entities) > 1:
innermost = self.enclosing_entities[-1]
parent = self.enclosing_entities[-2]
return isinstance(parent, gast.ClassDef) and innermost.name == '__init__'
return False
def _node_sets_self_attribute(self, node):
if anno.hasanno(node, anno.Basic.QN):
qn = anno.getanno(node, anno.Basic.QN)
# TODO(mdan): The 'self' argument is not guaranteed to be called 'self'.
if qn.has_attr and qn.parent.qn == ('self',):
return True
return False
def _track_symbol(self,
node,
composite_writes_alter_parent=False,
writes_create_symbol=False):
# A QN may be missing when we have an attribute (or subscript) on a function
# call. Example: a().b
if not anno.hasanno(node, anno.Basic.QN):
return
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Store):
self.scope.mark_write(qn)
if qn.is_composite and composite_writes_alter_parent:
self.scope.mark_write(qn.parent)
if writes_create_symbol:
self.scope.mark_creation(qn, writes_create_symbol=True)
if self._in_aug_assign:
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
# Param contexts appear in function defs, so they have the meaning of
# defining a variable.
self.scope.mark_write(qn)
self.scope.mark_param(qn, self.enclosing_entities[-1])
else:
raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
if self._in_return_statement:
self.scope.mark_returned(qn)
def _enter_scope(self, isolated):
self.scope = Scope(self.scope, isolated=isolated)
def _exit_scope(self):
self.scope = self.scope.parent
def _process_statement(self, node):
self._enter_scope(False)
node = self.generic_visit(node)
anno.setanno(node, anno.Static.SCOPE, self.scope)
self._exit_scope()
return node
def visit_Expr(self, node):
return self._process_statement(node)
def visit_Return(self, node):
self._in_return_statement = True
node = self._process_statement(node)
self._in_return_statement = False
return node
def visit_Assign(self, node):
return self._process_statement(node)
def visit_AugAssign(self, node):
# Special rules for AugAssign. In Assign, the target is only written,
# but in AugAssig (e.g. a += b), the target is both read and written.
self._in_aug_assign = True
node = self._process_statement(node)
self._in_aug_assign = False
return node
def visit_Name(self, node):
node = self.generic_visit(node)
self._track_symbol(node)
return node
def visit_Attribute(self, node):
node = self.generic_visit(node)
if self._in_constructor and self._node_sets_self_attribute(node):
self._track_symbol(
node, composite_writes_alter_parent=True, writes_create_symbol=True)
else:
self._track_symbol(node)
return node
def visit_Subscript(self, node):
node = self.generic_visit(node)
# Subscript writes (e.g. a[b] = "value") are considered to modify
# both the element itself (a[b]) and its parent (a).
self._track_symbol(node)
return node
def visit_Print(self, node):
self._enter_scope(False)
node.values = self.visit_block(node.values)
anno.setanno(node, anno.Static.SCOPE, self.scope)
anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
self._exit_scope()
return node
def visit_Assert(self, node):
return self._process_statement(node)
def visit_Call(self, node):
self._enter_scope(False)
node.args = self.visit_block(node.args)
node.keywords = self.visit_block(node.keywords)
# TODO(mdan): Account starargs, kwargs
anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
self._exit_scope()
node.func = self.visit(node.func)
return node
def _process_block_node(self, node, block, scope_name):
self._enter_scope(False)
block = self.visit_block(block)
anno.setanno(node, scope_name, self.scope)
self._exit_scope()
return node
def _process_parallel_blocks(self, parent, children):
# Because the scopes are not isolated, processing any child block
# modifies the parent state causing the other child blocks to be
# processed incorrectly. So we need to checkpoint the parent scope so that
# each child sees the same context.
before_parent = Scope.copy_of(self.scope)
after_children = []
for child, scope_name in children:
self.scope.copy_from(before_parent)
parent = self._process_block_node(parent, child, scope_name)
after_child = Scope.copy_of(self.scope)
after_children.append(after_child)
for after_child in after_children:
self.scope.merge_from(after_child)
return parent
def visit_arguments(self, node):
return self._process_statement(node)
def visit_FunctionDef(self, node):
# The FunctionDef node itself has a Scope object that tracks the creation
# of its name, along with the usage of any decorator accompany it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
self.scope.mark_write(qual_names.QN(node.name))
anno.setanno(node, anno.Static.SCOPE, self.scope)
self._exit_scope()
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
node.args = self.visit(node.args)
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.
self._enter_scope(False)
node.body = self.visit_block(node.body)
anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
self._exit_scope()
self._exit_scope()
return node
def visit_With(self, node):
self._enter_scope(False)
node = self.generic_visit(node)
anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
self._exit_scope()
return node
def visit_withitem(self, node):
return self._process_statement(node)
def visit_If(self, node):
self._enter_scope(False)
node.test = self.visit(node.test)
anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
anno.setanno(node.test, anno.Static.SCOPE, self.scope)
self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_For(self, node):
self._enter_scope(False)
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_While(self, node):
self._enter_scope(False)
node.test = self.visit(node.test)
anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
anno.setanno(node.test, anno.Static.SCOPE, self.scope)
self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def resolve(node, context, parent_scope=None):
return ActivityAnalyzer(context, parent_scope).visit(node)