Remove constraint of specifying steps_per_epoch for MWMS+Keras model.fit.

PiperOrigin-RevId: 323885033
Change-Id: I6db8e98b401b09083a89edf2d779cbc14eb8a10d
diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD
index bdd4cbc..50a7bee 100644
--- a/tensorflow/python/keras/distribute/BUILD
+++ b/tensorflow/python/keras/distribute/BUILD
@@ -176,11 +176,13 @@
         "//tensorflow/python:training_lib",
         "//tensorflow/python:training_util",
         "//tensorflow/python:variables",
+        "//tensorflow/python/compat:v2_compat",
         "//tensorflow/python/distribute:collective_all_reduce_strategy",
         "//tensorflow/python/distribute:combinations",
         "//tensorflow/python/distribute:cross_device_utils",
         "//tensorflow/python/distribute:multi_worker_test_base",
         "//tensorflow/python/distribute:multi_worker_util",
+        "//tensorflow/python/distribute:strategy_combinations",
         "//tensorflow/python/distribute:strategy_test_lib",
         "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
         "//tensorflow/python/eager:context",
diff --git a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
index f2869e4..60b7d46 100644
--- a/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
+++ b/tensorflow/python/keras/distribute/collective_all_reduce_strategy_test.py
@@ -22,22 +22,29 @@
 import numpy as np
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import v2_compat
+from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.distribute import collective_all_reduce_strategy
 from tensorflow.python.distribute import combinations
 from tensorflow.python.distribute import cross_device_utils
 from tensorflow.python.distribute import multi_worker_test_base
 from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.distribute import strategy_combinations
 from tensorflow.python.distribute import strategy_test_lib
 from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
 from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import test_util
 from tensorflow.python.keras import layers
 from tensorflow.python.keras import testing_utils
 from tensorflow.python.keras.engine import sequential
+from tensorflow.python.keras.engine import training
 from tensorflow.python.keras.mixed_precision.experimental import policy
 from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
+from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_keras
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import nn
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import variables
@@ -316,5 +323,51 @@
       self._test_mixed_precision(None, None, required_gpus)
 
 
+@combinations.generate(
+    combinations.combine(
+        strategy=[
+            strategy_combinations.multi_worker_mirrored_2x1_cpu,
+            strategy_combinations.multi_worker_mirrored_2x1_gpu,
+        ],
+        mode=['eager']))
+class DistributedCollectiveAllReduceStrategyEagerTest(test.TestCase,
+                                                      parameterized.TestCase):
+
+  def testFitWithoutStepsPerEpochPartialBatch(self, strategy):
+
+    def _model_fn():
+      x = layers.Input(shape=(1,), name='input')
+      y = layers.Dense(1, name='dense')(x)
+      model = training.Model(x, y)
+      return model
+
+    def _get_dataset():
+      inputs = array_ops.expand_dims_v2(constant_op.constant(range(10)), axis=1)
+      targets = array_ops.expand_dims_v2(
+          constant_op.constant(range(10)), axis=1)
+      # Make global batch size 12 for 2 replicas and a non-repeated dataset with
+      # 10 elements so that we have partial batch
+      dataset = dataset_ops.Dataset.from_tensor_slices(
+          (inputs, targets)).batch(12, drop_remainder=False)
+      return dataset
+
+    with strategy.scope():
+      optimizer_fn = gradient_descent_keras.SGD
+      optimizer = optimizer_fn(0.001)
+      model = _model_fn()
+      loss = 'mse'
+      metrics = ['mae']
+      model.compile(
+          optimizer,
+          loss,
+          metrics=metrics)
+    dataset = _get_dataset()
+    kernel_before = model.get_weights()[0][0]
+    model.fit(dataset, epochs=10)
+    kernel_after = model.get_weights()[0][0]
+    self.assertNotEqual(kernel_before, kernel_after)
+    self.assertGreater(abs(kernel_before-1), abs(kernel_after-1))
+
 if __name__ == '__main__':
-  test.main()
+  v2_compat.enable_v2_behavior()
+  combinations.main()
diff --git a/tensorflow/python/keras/engine/data_adapter.py b/tensorflow/python/keras/engine/data_adapter.py
index 7403605..0e4886f 100644
--- a/tensorflow/python/keras/engine/data_adapter.py
+++ b/tensorflow/python/keras/engine/data_adapter.py
@@ -1238,18 +1238,6 @@
     if adapter_steps is not None:
       return adapter_steps
 
-    if (ds_context.get_strategy().extended._in_multi_worker_mode() and  # pylint: disable=protected-access
-        (dataset.options().experimental_distribute.auto_shard_policy !=
-         distribute_options.AutoShardPolicy.OFF)):
-      # If the dataset would be auto-sharded, we should not infer a local
-      # steps_per_epoch due to the possible inbalanced sharding between workers.
-      raise ValueError("When dataset is sharded across workers, please "
-                       "specify a reasonable `steps_per_epoch` such that all "
-                       "workers will train the same number of steps and each "
-                       "step can get data from dataset without EOF. This is "
-                       "required for allreduce to succeed. We will handle the "
-                       "last partial batch in the future.")
-
     size = cardinality.cardinality(dataset)
     if size == cardinality.INFINITE and steps is None:
       raise ValueError("When passing an infinitely repeating dataset, you "