Add a nested tf.function with control flow test.

PiperOrigin-RevId: 310589571
Change-Id: Icb71cd7f50d77fe4b67ba21bedf415cdc8ff24bd
diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py
index 3c748bd..48f2af0 100644
--- a/tensorflow/python/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_models_test.py
@@ -380,6 +380,46 @@
 
   @combinations.generate(
       combinations.combine(
+          distribution=strategy_combinations.all_strategies, mode=["eager"]))
+  def test_nested_tf_functions_with_control_flow(self, distribution):
+    inputs = np.random.random((10, 3)).astype(np.float32)
+    targets = np.ones((10, 4), dtype=np.float32)
+    dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
+    dataset = dataset.batch(10, drop_remainder=True)
+    input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
+
+    def get_model():
+      x = keras.layers.Input(shape=(3,), name="input")
+      y = keras.layers.Dense(4, name="dense")(x)
+      model = keras.Model(x, y)
+      return model
+
+    with distribution.scope():
+      model = get_model()
+      optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
+
+    @def_function.function
+    def train_step(iterator):
+
+      def step_fn(inputs):
+        images, targets = inputs
+        with backprop.GradientTape() as tape:
+          outputs = model(images)
+          loss = math_ops.reduce_sum(outputs - targets)
+        grads = tape.gradient(loss, model.variables)
+        optimizer.apply_gradients(zip(grads, model.variables))
+
+      distribution.run(step_fn, args=(next(iterator),))
+
+    @def_function.function
+    def train_steps(iterator):
+      for _ in math_ops.range(10):
+        train_step(iterator)
+
+    train_steps(input_iterator)
+
+  @combinations.generate(
+      combinations.combine(
           distribution=strategy_combinations.all_strategies,
           mode=["eager"]
       ))