Remove constraint of specifying steps_per_epoch for MWMS+Keras model.fit.
PiperOrigin-RevId: 323885033
Change-Id: I6db8e98b401b09083a89edf2d779cbc14eb8a10d
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index bdd4cbc..50a7bee 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -176,11 +176,13 @@
"//tensorflow/python:training_lib",
"//tensorflow/python:training_util",
"//tensorflow/python:variables",
+ "//tensorflow/python/compat:v2_compat",
"//tensorflow/python/distribute:collective_all_reduce_strategy",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:cross_device_utils",
"//tensorflow/python/distribute:multi_worker_test_base",
"//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/distribute:strategy_combinations",
"//tensorflow/python/distribute:strategy_test_lib",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
"//tensorflow/python/eager:context",
diff --git a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
index f2869e4..60b7d46 100644
--- a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
@@ -22,22 +22,29 @@
import numpy as np
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import cross_device_utils
from tensorflow.python.distribute import multi_worker_test_base
from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.engine import training
from tensorflow.python.keras.mixed_precision.experimental import policy
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
+from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@@ -316,5 +323,51 @@
self._test_mixed_precision(None, None, required_gpus)
+@combinations.generate(
+ combinations.combine(
+ strategy=[
+ strategy_combinations.multi_worker_mirrored_2x1_cpu,
+ strategy_combinations.multi_worker_mirrored_2x1_gpu,
+ ],
+ mode=['eager']))
+class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
+ parameterized.TestCase):
+
+ def testFitWithoutStepsPerEpochPartialBatch(self, strategy):
+
+ def _model_fn():
+ x = layers.Input(shape=(1,), name='input')
+ y = layers.Dense(1, name='dense')(x)
+ model = training.Model(x, y)
+ return model
+
+ def _get_dataset():
+ inputs = array_ops.expand_dims_v2(constant_op.constant(range(10)), axis=1)
+ targets = array_ops.expand_dims_v2(
+ constant_op.constant(range(10)), axis=1)
+ # Make global batch size 12 for 2 replicas and a non-repeated dataset with
+ # 10 elements so that we have partial batch
+ dataset = dataset_ops.Dataset.from_tensor_slices(
+ (inputs, targets)).batch(12, drop_remainder=False)
+ return dataset
+
+ with strategy.scope():
+ optimizer_fn = gradient_descent_keras.SGD
+ optimizer = optimizer_fn(0.001)
+ model = _model_fn()
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(
+ optimizer,
+ loss,
+ metrics=metrics)
+ dataset = _get_dataset()
+ kernel_before = model.get_weights()[0][0]
+ model.fit(dataset, epochs=10)
+ kernel_after = model.get_weights()[0][0]
+ self.assertNotEqual(kernel_before, kernel_after)
+ self.assertGreater(abs(kernel_before-1), abs(kernel_after-1))
+
if __name__ == '__main__':
- test.main()
+ v2_compat.enable_v2_behavior()
+ combinations.main()
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 7403605..0e4886f 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -1238,18 +1238,6 @@
if adapter_steps is not None:
return adapter_steps
- if (ds_context.get_strategy().extended._in_multi_worker_mode() and # pylint: disable=protected-access
- (dataset.options().experimental_distribute.auto_shard_policy !=
- distribute_options.AutoShardPolicy.OFF)):
- # If the dataset would be auto-sharded, we should not infer a local
- # steps_per_epoch due to the possible inbalanced sharding between workers.
- raise ValueError("When dataset is sharded across workers, please "
- "specify a reasonable `steps_per_epoch` such that all "
- "workers will train the same number of steps and each "
- "step can get data from dataset without EOF. This is "
- "required for allreduce to succeed. We will handle the "
- "last partial batch in the future.")
-
size = cardinality.cardinality(dataset)
if size == cardinality.INFINITE and steps is None:
raise ValueError("When passing an infinitely repeating dataset, you "