[tf.data] Add a new `RebatchDatasetV2` op that does rebatching (instead of
rebatching via graph rewrites). Also adds a `prefetch` after `RebatchDataset` for performance.

PiperOrigin-RevId: 268544896
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index 8aef8ee..b1ec124 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -322,7 +322,7 @@
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core/grappler/optimizers/data:rebatch",
-        "//tensorflow/core/kernels/data:rewrite_utils",
+        "//tensorflow/core/kernels/data:name_utils",
     ],
 )
 
diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
index 6158821..030166a 100644
--- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
@@ -13,24 +13,26 @@
 limitations under the License.
 ==============================================================================*/
 #include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/kernels/data/rewrite_utils.h"
-#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/kernels/data/name_utils.h"
 
 namespace tensorflow {
 namespace data {
 namespace experimental {
 namespace {
 
-constexpr char kOptimizerName[] = "tf_data_rebatcher";
-constexpr char kUseFallbackAttr[] = "use_fallback";
+inline int64 CeilDiv(int64 dividend, int64 divisor) {
+  return (dividend - 1 + divisor) / divisor;
+}
+
+constexpr const char* const kDatasetType = "Rebatch";
 
 class RebatchDatasetOp : public UnaryDatasetOpKernel {
  public:
   explicit RebatchDatasetOp(OpKernelConstruction* ctx)
       : UnaryDatasetOpKernel(ctx) {
-    if (ctx->HasAttr(kUseFallbackAttr)) {
-      OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseFallbackAttr, &use_fallback_));
-    }
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+    OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
   }
 
  protected:
@@ -42,40 +44,213 @@
     OP_REQUIRES(
         ctx, num_replicas > 0,
         errors::InvalidArgument("num_replicas must be greater than zero."));
-
-    auto config_factory = [num_replicas, this]() {
-      return CreateConfig(num_replicas, this->use_fallback_);
-    };
-
-    // We only want to optimize functions for some particular datasets like
-    // FlatMapDataset, InterleaveDataset etc. So we disable generalized
-    // function optimization and explicitly handle function modifications
-    // for those datasets in the rewrite.
-    OP_REQUIRES_OK(ctx,
-                   RewriteDataset(ctx, input, std::move(config_factory),
-                                  /*optimize_function_library=*/false, output));
+    *output =
+        new Dataset(ctx, input, num_replicas, output_types_, output_shapes_);
   }
 
  private:
-  static RewriterConfig CreateConfig(int64 num_replicas, bool use_fallback) {
-    RewriterConfig rewriter_config;
-    rewriter_config.set_fail_on_optimizer_errors(true);
-    rewriter_config.add_optimizers(kOptimizerName);
-    rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
-    auto custom_optimizer = rewriter_config.add_custom_optimizers();
-    custom_optimizer->set_name(kOptimizerName);
-    AttrValue num_replicas_attr;
-    num_replicas_attr.set_i(num_replicas);
-    (*custom_optimizer->mutable_parameter_map())["num_replicas"] =
-        num_replicas_attr;
-    AttrValue use_fallback_attr;
-    use_fallback_attr.set_b(use_fallback);
-    (*custom_optimizer->mutable_parameter_map())["use_fallback"] =
-        use_fallback_attr;
-    return rewriter_config;
-  }
+  class Dataset : public DatasetBase {
+   public:
+    Dataset(OpKernelContext* ctx, const DatasetBase* input,
+            const int64 num_replicas, const DataTypeVector& output_types,
+            const std::vector<PartialTensorShape>& output_shapes)
+        : DatasetBase(DatasetContext(ctx)),
+          input_(input),
+          num_replicas_(num_replicas),
+          output_types_(output_types),
+          output_shapes_(output_shapes) {
+      input_->Ref();
+    }
 
-  bool use_fallback_ = true;
+    ~Dataset() override { input_->Unref(); }
+
+    std::unique_ptr<IteratorBase> MakeIteratorInternal(
+        const string& prefix) const override {
+      name_utils::IteratorPrefixParams params;
+      return absl::make_unique<Iterator>(Iterator::Params{
+          this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
+    }
+
+    const DataTypeVector& output_dtypes() const override {
+      return output_types_;
+    }
+
+    const std::vector<PartialTensorShape>& output_shapes() const override {
+      return output_shapes_;
+    }
+
+    string DebugString() const override {
+      name_utils::DatasetDebugStringParams params;
+      params.set_args(num_replicas_);
+      return name_utils::DatasetDebugString(kDatasetType, params);
+    }
+
+    Status CheckExternalState() const override {
+      return input_->CheckExternalState();
+    }
+
+   protected:
+    Status AsGraphDefInternal(SerializationContext* ctx,
+                              DatasetGraphDefBuilder* b,
+                              Node** output) const override {
+      Node* input_graph_node = nullptr;
+      TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+      Node* num_replicas = nullptr;
+      TF_RETURN_IF_ERROR(b->AddScalar(num_replicas_, &num_replicas));
+      TF_RETURN_IF_ERROR(
+          b->AddDataset(this, {input_graph_node, num_replicas}, output));
+      return Status::OK();
+    }
+
+   private:
+    class Iterator : public DatasetIterator<Dataset> {
+     public:
+      explicit Iterator(const Params& params)
+          : DatasetIterator<Dataset>(params) {}
+
+      ~Iterator() override {}
+
+      Status Initialize(IteratorContext* ctx) override {
+        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+      }
+
+      Status GetNextInternal(IteratorContext* ctx,
+                             std::vector<Tensor>* out_tensors,
+                             bool* end_of_sequence) override {
+        mutex_lock l(mu_);
+        if (slice_number_ % dataset()->num_replicas_ == 0) {
+          input_descriptors_.clear();
+          std::vector<Tensor> input_tensors;
+          TF_RETURN_IF_ERROR(
+              input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
+          if (*end_of_sequence) {
+            return Status::OK();
+          }
+
+          input_descriptors_.reserve(input_tensors.size());
+          for (int i = 0; i < input_tensors.size(); ++i) {
+            if (input_tensors[i].dims() == 0) {
+              return errors::InvalidArgument(
+                  "Cannot rebatch dataset: All components must have at least "
+                  "one dimension. Perhaps your input dataset is not batched? "
+                  "Component ",
+                  i, " is scalar.");
+            }
+
+            int64 original_batch_dim = input_tensors[i].dim_size(0);
+            int64 interval =
+                CeilDiv(original_batch_dim, dataset()->num_replicas_);
+            input_descriptors_.push_back(
+                {std::move(input_tensors[i]), original_batch_dim, interval});
+          }
+        }
+
+        out_tensors->reserve(input_descriptors_.size());
+
+        // We slice each component independently because they may have
+        // different batch dimensions.
+        for (const auto& input_desc : input_descriptors_) {
+          int64 start = input_desc.interval * slice_number_;
+          int64 end = std::min(start + input_desc.interval,
+                               input_desc.original_batch_dim);
+          if (start >= end) {
+            // We can get here if ceil(original_batch_dim_ / new batch dim) <
+            // num_replicas_, i.e. the batch isn't big enough to distribute over
+            // num replicas. In this case, we return empty tensors for the
+            // remaining iterations that correspond to this batch.
+            start = end;
+          }
+          Tensor slice = input_desc.whole_tensor.Slice(start, end);
+          if (slice.IsAligned()) {
+            out_tensors->push_back(std::move(slice));
+          } else {
+            out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
+          }
+        }
+        slice_number_ = (slice_number_ + 1) % dataset()->num_replicas_;
+        return Status::OK();
+      }
+
+     protected:
+      Status SaveInternal(IteratorStateWriter* writer) override {
+        mutex_lock l(mu_);
+        if (!input_impl_) {
+          TF_RETURN_IF_ERROR(
+              writer->WriteScalar(full_name("input_impl_empty"), ""));
+        } else {
+          TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+        }
+        TF_RETURN_IF_ERROR(
+            writer->WriteScalar(full_name("slice_number"), slice_number_));
+
+        if (slice_number_ % dataset()->num_replicas_ != 0) {
+          // Save state of input tensors.
+          for (int i = 0; i < input_descriptors_.size(); ++i) {
+            TF_RETURN_IF_ERROR(writer->WriteTensor(
+                full_name(strings::StrCat("tensors[", i, "]")),
+                input_descriptors_[i].whole_tensor));
+          }
+        }
+        return Status::OK();
+      }
+
+      Status RestoreInternal(IteratorContext* ctx,
+                             IteratorStateReader* reader) override {
+        mutex_lock l(mu_);
+        if (!reader->Contains(full_name("input_impl_empty"))) {
+          TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+        } else {
+          input_impl_.reset();
+        }
+        TF_RETURN_IF_ERROR(
+            reader->ReadScalar(full_name("slice_number"), &slice_number_));
+
+        input_descriptors_.clear();
+        input_descriptors_.resize(dataset()->output_dtypes().size());
+        if (slice_number_ % dataset()->num_replicas_ != 0) {
+          for (int i = 0; i < input_descriptors_.size(); ++i) {
+            TF_RETURN_IF_ERROR(reader->ReadTensor(
+                full_name(strings::StrCat("tensors[", i, "]")),
+                &input_descriptors_[i].whole_tensor));
+            input_descriptors_[i].original_batch_dim =
+                input_descriptors_[i].whole_tensor.dim_size(0);
+            input_descriptors_[i].interval =
+                CeilDiv(input_descriptors_[i].original_batch_dim,
+                        dataset()->num_replicas_);
+          }
+        }
+        return Status::OK();
+      }
+
+     private:
+      // Describes one component of the input.
+      struct InputDescriptor {
+        InputDescriptor() {}
+        InputDescriptor(Tensor&& whole_tensor, int64 original_batch_dim,
+                        int64 interval)
+            : whole_tensor(std::move(whole_tensor)),
+              original_batch_dim(original_batch_dim),
+              interval(interval) {}
+
+        Tensor whole_tensor;
+        int64 original_batch_dim;
+        int64 interval;
+      };
+
+      mutex mu_;
+      std::unique_ptr<IteratorBase> input_impl_;
+      std::vector<InputDescriptor> input_descriptors_ GUARDED_BY(mu_);
+      int64 slice_number_ GUARDED_BY(mu_) = 0;
+    };
+
+    const DatasetBase* const input_;
+    const int64 num_replicas_;
+    const DataTypeVector output_types_;
+    const std::vector<PartialTensorShape> output_shapes_;
+  };
+
+  DataTypeVector output_types_;
+  std::vector<PartialTensorShape> output_shapes_;
 };
 
 REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
index c12d991..ee24067 100644
--- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
@@ -29,12 +29,10 @@
 from tensorflow.python.data.experimental.ops import grouping
 from tensorflow.python.data.experimental.ops import readers
 from tensorflow.python.data.experimental.ops import scan_ops
-from tensorflow.python.data.experimental.ops import sleep
 from tensorflow.python.data.kernel_tests import test_base
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
 from tensorflow.python.framework import test_util
 from tensorflow.python.lib.io import python_io
 from tensorflow.python.ops import array_ops
@@ -90,6 +88,19 @@
       i += 4
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
+  def testBatchSizeNotDivisibleByNumReplicas2(self):
+    dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
+    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
+    # This will rebatch into sub-batches of size 4, since
+    # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get
+    # data.
+    expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)]
+    expected_output.extend([[]])  # Last replica gets an empty batch
+    expected_output.extend(
+        [[k for k in range(i, i + 4)] for i in range(16, 32, 4)])
+    expected_output.extend([[]])  # Last replica gets an empty batch
+    self.assertDatasetProduces(rebatched_dataset, expected_output)
+
   def testTupleOutput(self):
     dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
@@ -119,7 +130,9 @@
     # makes up a complete minibatch.
     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
     if not drop_remainder:
-      expected_output.append([k for k in range(1024, 1032)])
+      # The last partial batch of size 8 is split over 4 replicas
+      expected_output.extend(
+          [[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
   @parameterized.named_parameters(drop_remainder_cases)
@@ -132,7 +145,8 @@
 
     expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]  # pylint: disable=g-complex-comprehension
     if not drop_remainder:
-      expected_output += [[32, 33]]
+      # The last partial batch of size 2 is split over 4 replicas
+      expected_output += [[32], [33], [], []]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
   def testMultipleBatches(self):
@@ -214,9 +228,8 @@
     dataset2 = dataset_ops.Dataset.range(32).batch(8)
     dataset = dataset1.concatenate(dataset2)
     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-    self.assertEqual(
-        [[None]],
-        [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+    self.assertEqual([[None]],
+                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
     expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
                        [[i, i + 1] for i in range(0, 32, 2)])
     self.assertDatasetProduces(rebatched_dataset, expected_output)
@@ -242,24 +255,6 @@
                        for i in range(0, 32, 2)]
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
-  def testUnsupportedTransformError(self):
-    dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
-    with self.assertRaises(errors.InvalidArgumentError):
-      rebatched_dataset = distribute._RebatchDataset(
-          dataset, num_replicas=4, use_fallback=False)
-      next_element = self.getNext(rebatched_dataset)
-      self.evaluate(next_element())
-
-  def testUnsupportedTransformInFlatMapError(self):
-    dataset = dataset_ops.Dataset.range(2).flat_map(
-        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32).apply(sleep.sleep(10)))
-    with self.assertRaises(errors.InvalidArgumentError):
-      rebatched_dataset = distribute._RebatchDataset(
-          dataset, num_replicas=4, use_fallback=False)
-      next_element = self.getNext(rebatched_dataset)
-      self.evaluate(next_element())
-
   def testFlatMapBatching(self):
     dataset = dataset_ops.Dataset.range(2).flat_map(
         lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
@@ -290,11 +285,8 @@
     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-    # List of 4 elements where each element is a list of 8 numbering from 0 to
-    # 31 repeated twice.
-    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
-                       for i in range(0, 32, 8)  # generates 4 elements
-                       for _ in range(2)]
+    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
+    expected_output += expected_output
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
   def testParallelInterleaveBatching(self):
@@ -310,11 +302,8 @@
     rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
     self.assertEqual([[None]],
                      [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-    # List of 4 elements where each element is a list of 8 numbering from 0 to
-    # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
-    expected_output = [[k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
-                       for i in range(0, 32, 8)  # generates 4 elements
-                       for _ in range(2)]
+    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
+    expected_output += expected_output
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
   def testGroupByWindowStaticBatch(self):
@@ -350,8 +339,7 @@
             key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
     dataset = distribute._RebatchDataset(dataset, num_replicas=2)
 
-    self.assertEqual([[None]],
-                     [ts.as_list() for ts in _flat_shapes(dataset)])
+    self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
 
     # The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
     # the batches of 10 (value == 1) split into minibatches of (5, 5)
@@ -377,8 +365,8 @@
 
     self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
 
-    pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
-             (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
+    pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1),
+             (1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
     expected_output = [[value] * batch_size for batch_size, value in pairs]
     self.assertDatasetProduces(dataset, expected_output)
 
@@ -450,92 +438,5 @@
     self.assertDatasetProduces(rebatched_dataset, expected_output)
 
 
-@test_util.run_all_in_graph_and_eager_modes
-class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
-
-  def testWithNoBatchDataset(self):
-    dataset = dataset_ops.Dataset.from_tensor_slices(
-        [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)])  # pylint: disable=g-complex-comprehension
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-    self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual([[8]],
-                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
-    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testWithUnhandledTransformation(self):
-    dataset = dataset_ops.Dataset.range(1024).batch(
-        32, drop_remainder=True).apply(sleep.sleep(10))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-    self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
-    self.assertEqual([[8]],
-                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
-    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testWithUnhandledTransformationInFlatMap(self):
-    dataset = dataset_ops.Dataset.range(2).flat_map(
-        lambda _: dataset_ops.Dataset.range(32).batch(  # pylint: disable=g-long-lambda
-            32, drop_remainder=True).apply(sleep.sleep(10)))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
-    self.assertEqual([[8]],
-                     [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
-    # Two elements where each element is a list of 4 elements where each element
-    # is a list of 8.
-    expected_output = [
-        [k for k in range(i, i + 8)]  # pylint: disable=g-complex-comprehension
-        for _ in range(2) for i in range(0, 32, 8)]  # generates 4 elements
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testWithUnknownBatchDim(self):
-    dataset = dataset_ops.Dataset.range(1024).batch(
-        32, drop_remainder=False).apply(sleep.sleep(10))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
-    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testWithUnknownBatchDimInSecondComponent(self):
-    dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
-    dataset1 = dataset_ops.Dataset.range(1024).batch(
-        32, drop_remainder=False).apply(sleep.sleep(10))
-    dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
-    expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)]  # pylint: disable=g-complex-comprehension
-    expected_output = [(x, x) for x in expected_output]
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testBatchSizeNotDivisibleByNumReplicas(self):
-    dataset = dataset_ops.Dataset.range(64).batch(
-        32, drop_remainder=True).apply(sleep.sleep(10))
-
-    rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
-
-    expected_output = []
-    i = 0
-    for _ in range(2):  # number of steps
-      # first four minibatches have seven elements
-      for _ in range(4):
-        expected_output.append([k for k in range(i, i + 7)])
-        i += 7
-      # last minibatch has four elements
-      expected_output.append([k for k in range(i, i + 4)])
-      i += 4
-    self.assertDatasetProduces(rebatched_dataset, expected_output)
-
-  def testBatchSizesDontMatch(self):
-    dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
-    with self.assertRaisesRegexp(errors.InvalidArgumentError,
-                                 "Cannot use rebatching fallback"):
-      rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
-      next_element = self.getNext(rebatched_dataset)
-      self.evaluate(next_element())
-
-
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py
index 7245a3d..dcf67fb 100644
--- a/tensorflow/python/data/experimental/ops/distribute.py
+++ b/tensorflow/python/data/experimental/ops/distribute.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-from tensorflow.python.compat import compat
 from tensorflow.python.data.ops import dataset_ops
 from tensorflow.python.data.util import nest
 from tensorflow.python.data.util import structure
@@ -99,17 +98,10 @@
 
     self._element_spec = structure.convert_legacy_structure(
         input_types, output_shapes, input_classes)
-    if compat.forward_compatible(2019, 8, 13) or not use_fallback:
-      variant_tensor = ged_ops.rebatch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_replicas=num_replicas,
-          use_fallback=use_fallback,
-          **self._flat_structure)
-    else:
-      variant_tensor = ged_ops.rebatch_dataset(
-          self._input_dataset._variant_tensor,  # pylint: disable=protected-access
-          num_replicas=num_replicas,
-          **self._flat_structure)
+    variant_tensor = ged_ops.rebatch_dataset(
+        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
+        num_replicas=num_replicas,
+        **self._flat_structure)
     super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
 
   @property
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index b55f933..34e6c39 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -484,6 +484,11 @@
         # pylint: disable=protected-access
         with ops.colocate_with(dataset._variant_tensor):
           dataset = distribute._RebatchDataset(dataset, split_batch_by)
+          # Add a prefetch to pipeline rebatching for performance.
+          # TODO(rachelim): Instead of inserting an extra prefetch stage here,
+          # leverage static graph rewrites to insert _RebatchDataset before
+          # the final `prefetch` if it exists.
+          dataset = dataset.prefetch(split_batch_by)
       except errors.InvalidArgumentError as e:
         if "without encountering a batch" in str(e):
           six.reraise(