Preserve the directives annotation while lowering break statements.
PiperOrigin-RevId: 295780462
Change-Id: I48fa59628c110aafe250ba20b7b6cdf2cae73e26
diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
index c540907..718c5bd 100644
--- a/tensorflow/python/autograph/converters/break_statements.py
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -71,6 +71,7 @@
return nodes, break_used
def visit_While(self, node):
+ original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
@@ -98,9 +99,13 @@
body=node.body,
orelse=guarded_orelse)
+ new_while_node = node[1]
+ anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
+
return node
def visit_For(self, node):
+ original_node = node
scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
@@ -137,7 +142,9 @@
body=node.body,
orelse=guarded_orelse)
- anno.setanno(node[1], 'extra_test', extra_test)
+ new_for_node = node[1]
+ anno.setanno(new_for_node, 'extra_test', extra_test)
+ anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
return node
diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
index c789ced..37accdc 100644
--- a/tensorflow/python/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -20,6 +20,7 @@
from tensorflow.python.autograph.converters import break_statements
from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import anno
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
@@ -46,6 +47,21 @@
self.assertTransformedEquivalent(test_fn, 1)
self.assertTransformedEquivalent(test_fn, 4)
+ def test_while_loop_preserves_directives(self):
+
+ def test_fn(x):
+ while x > 0:
+ x -= 1
+ if x % 2 == 0:
+ break
+
+ node, ctx = self.prepare(test_fn, {})
+ fake_annotation = object()
+ anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
+ node = break_statements.transform(node, ctx)
+ self.assertIs(
+ anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
+
def test_for_loop(self):
def test_fn(a):
@@ -63,6 +79,20 @@
# but the section following the break will be skipped.
self.assertEqual([3], result.test_fn([5, 4]))
+ def test_for_loop_preserves_directives(self):
+
+ def test_fn(a):
+ for x in a:
+ if x % 2 == 0:
+ break
+
+ node, ctx = self.prepare(test_fn, {})
+ fake_annotation = object()
+ anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
+ node = break_statements.transform(node, ctx)
+ self.assertIs(
+ anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
+
def test_nested(self):
def test_fn(x):