Preserve the directives annotation while lowering break statements.

PiperOrigin-RevId: 295780462
Change-Id: I48fa59628c110aafe250ba20b7b6cdf2cae73e26
diff --git a/tensorflow/python/autograph/converters/break_statements.py b/tensorflow/python/autograph/converters/break_statements.py
index c540907..718c5bd 100644
--- a/tensorflow/python/autograph/converters/break_statements.py
+++ b/tensorflow/python/autograph/converters/break_statements.py
@@ -71,6 +71,7 @@
     return nodes, break_used
 
   def visit_While(self, node):
+    original_node = node
     scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
     break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
 
@@ -98,9 +99,13 @@
           body=node.body,
           orelse=guarded_orelse)
 
+      new_while_node = node[1]
+      anno.copyanno(original_node, new_while_node, anno.Basic.DIRECTIVES)
+
     return node
 
   def visit_For(self, node):
+    original_node = node
     scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
     break_var = self.ctx.namer.new_symbol('break_', scope.referenced)
 
@@ -137,7 +142,9 @@
           body=node.body,
           orelse=guarded_orelse)
 
-      anno.setanno(node[1], 'extra_test', extra_test)
+      new_for_node = node[1]
+      anno.setanno(new_for_node, 'extra_test', extra_test)
+      anno.copyanno(original_node, new_for_node, anno.Basic.DIRECTIVES)
 
     return node
 
diff --git a/tensorflow/python/autograph/converters/break_statements_test.py b/tensorflow/python/autograph/converters/break_statements_test.py
index c789ced..37accdc 100644
--- a/tensorflow/python/autograph/converters/break_statements_test.py
+++ b/tensorflow/python/autograph/converters/break_statements_test.py
@@ -20,6 +20,7 @@
 
 from tensorflow.python.autograph.converters import break_statements
 from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import anno
 from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
 
@@ -46,6 +47,21 @@
     self.assertTransformedEquivalent(test_fn, 1)
     self.assertTransformedEquivalent(test_fn, 4)
 
+  def test_while_loop_preserves_directives(self):
+
+    def test_fn(x):
+      while x > 0:
+        x -= 1
+        if x % 2 == 0:
+          break
+
+    node, ctx = self.prepare(test_fn, {})
+    fake_annotation = object()
+    anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
+    node = break_statements.transform(node, ctx)
+    self.assertIs(
+        anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
+
   def test_for_loop(self):
 
     def test_fn(a):
@@ -63,6 +79,20 @@
       # but the section following the break will be skipped.
       self.assertEqual([3], result.test_fn([5, 4]))
 
+  def test_for_loop_preserves_directives(self):
+
+    def test_fn(a):
+      for x in a:
+        if x % 2 == 0:
+          break
+
+    node, ctx = self.prepare(test_fn, {})
+    fake_annotation = object()
+    anno.setanno(node.body[0], anno.Basic.DIRECTIVES, fake_annotation)
+    node = break_statements.transform(node, ctx)
+    self.assertIs(
+        anno.getanno(node.body[1], anno.Basic.DIRECTIVES), fake_annotation)
+
   def test_nested(self):
 
     def test_fn(x):