blob: 718c5bd3ca5b7a4015fad97226abb867292447b6 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lowers break statements to conditionals."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import templates
from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
class _Break(object):
def __init__(self):
self.used = False
self.control_var_name = None
def __repr__(self):
return 'used: %s, var: %s' % (self.used, self.control_var_name)
class BreakTransformer(converter.Base):
"""Canonicalizes break statements into additional conditionals."""
def visit_Break(self, node):
self.state[_Break].used = True
var_name = self.state[_Break].control_var_name
# TODO(mdan): This will fail when expanded inside a top-level else block.
template = """
var_name = True
continue
"""
return templates.replace(template, var_name=var_name)
def _guard_if_present(self, block, var_name):
"""Prevents the block from executing if var_name is set."""
if not block:
return block
template = """
if ag__.not_(var_name):
block
"""
node = templates.replace(
template,
var_name=var_name,
block=block)
return node
def _process_body(self, nodes, break_var):
self.state[_Break].enter()
self.state[_Break].control_var_name = break_var
nodes = self.visit_block(nodes)
break_used = self.state[_Break].used
self.state[_Break].exit()
return nodes, break_used
def visit_While(self, node):
original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
node.test = self.visit(node.test)
node.body, break_used = self._process_body(node.body, break_var)
# A break in the else clause applies to the containing scope.
node.orelse = self.visit_block(node.orelse)
if break_used:
# Python's else clause only triggers if the loop exited cleanly (e.g.
# break did not trigger).
guarded_orelse = self._guard_if_present(node.orelse, break_var)
template = """
var_name = False
while ag__.and_(lambda: test, lambda: ag__.not_(var_name)):
body
else:
orelse
"""
node = templates.replace(
template,
var_name=break_var,
test=node.test,
body=node.body,
orelse=guarded_orelse)
new_while_node = node[1]
anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
return node
def visit_For(self, node):
original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
node.target = self.visit(node.target)
node.iter = self.visit(node.iter)
node.body, break_used = self._process_body(node.body, break_var)
# A break in the else clause applies to the containing scope.
node.orelse = self.visit_block(node.orelse)
if break_used:
# Python's else clause only triggers if the loop exited cleanly (e.g.
# break did not trigger).
guarded_orelse = self._guard_if_present(node.orelse, break_var)
extra_test = templates.replace_as_expression(
'ag__.not_(var_name)', var_name=break_var)
# The extra test is hidden in the AST, which will confuse the static
# analysis. To mitigate that, we insert a no-op statement that ensures
# the control variable is marked as used.
# TODO(mdan): Use a marker instead, e.g. ag__.condition_loop_on(var_name)
template = """
var_name = False
for target in iter_:
(var_name,)
body
else:
orelse
"""
node = templates.replace(
template,
var_name=break_var,
iter_=node.iter,
target=node.target,
body=node.body,
orelse=guarded_orelse)
new_for_node = node[1]
anno.setanno(new_for_node, 'extra_test', extra_test)
anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
return node
def transform(node, ctx):
transformer = BreakTransformer(ctx)
node = transformer.visit(node)
return node