Add a nested tf.function with control flow test.
PiperOrigin-RevId: 310589571
Change-Id: Icb71cd7f50d77fe4b67ba21bedf415cdc8ff24bd
diff --git a/tensorflow/python/distribute/custom_training_loop_models_test.py b/tensorflow/python/distribute/custom_training_loop_models_test.py
index 3c748bd..48f2af0 100644
--- a/tensorflow/python/distribute/custom_training_loop_models_test.py
+++ b/tensorflow/python/distribute/custom_training_loop_models_test.py
@@ -380,6 +380,46 @@
@combinations.generate(
combinations.combine(
+ distribution=strategy_combinations.all_strategies, mode=["eager"]))
+ def test_nested_tf_functions_with_control_flow(self, distribution):
+ inputs = np.random.random((10, 3)).astype(np.float32)
+ targets = np.ones((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets)).repeat()
+ dataset = dataset.batch(10, drop_remainder=True)
+ input_iterator = iter(distribution.experimental_distribute_dataset(dataset))
+
+ def get_model():
+ x = keras.layers.Input(shape=(3,), name="input")
+ y = keras.layers.Dense(4, name="dense")(x)
+ model = keras.Model(x, y)
+ return model
+
+ with distribution.scope():
+ model = get_model()
+ optimizer = keras.optimizer_v2.gradient_descent.SGD(0.1, momentum=0.01)
+
+ @def_function.function
+ def train_step(iterator):
+
+ def step_fn(inputs):
+ images, targets = inputs
+ with backprop.GradientTape() as tape:
+ outputs = model(images)
+ loss = math_ops.reduce_sum(outputs - targets)
+ grads = tape.gradient(loss, model.variables)
+ optimizer.apply_gradients(zip(grads, model.variables))
+
+ distribution.run(step_fn, args=(next(iterator),))
+
+ @def_function.function
+ def train_steps(iterator):
+ for _ in math_ops.range(10):
+ train_step(iterator)
+
+ train_steps(input_iterator)
+
+ @combinations.generate(
+ combinations.combine(
distribution=strategy_combinations.all_strategies,
mode=["eager"]
))