[tf.data] Add a new `RebatchDatasetV2` op that does rebatching (instead of
rebatching via graph rewrites). Also adds a `prefetch` after `RebatchDataset` for performance.
PiperOrigin-RevId: 268544896
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index 8aef8ee..b1ec124 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -322,7 +322,7 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/optimizers/data:rebatch",
- "//tensorflow/core/kernels/data:rewrite_utils",
+ "//tensorflow/core/kernels/data:name_utils",
],
)
diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
index 6158821..030166a 100644
--- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
@@ -13,24 +13,26 @@
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/kernels/data/rewrite_utils.h"
-#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/kernels/data/name_utils.h"
namespace tensorflow {
namespace data {
namespace experimental {
namespace {
-constexpr char kOptimizerName[] = "tf_data_rebatcher";
-constexpr char kUseFallbackAttr[] = "use_fallback";
+inline int64 CeilDiv(int64 dividend, int64 divisor) {
+ return (dividend - 1 + divisor) / divisor;
+}
+
+constexpr const char* const kDatasetType = "Rebatch";
class RebatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit RebatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx) {
- if (ctx->HasAttr(kUseFallbackAttr)) {
- OP_REQUIRES_OK(ctx, ctx->GetAttr(kUseFallbackAttr, &use_fallback_));
- }
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
}
protected:
@@ -42,40 +44,213 @@
OP_REQUIRES(
ctx, num_replicas > 0,
errors::InvalidArgument("num_replicas must be greater than zero."));
-
- auto config_factory = [num_replicas, this]() {
- return CreateConfig(num_replicas, this->use_fallback_);
- };
-
- // We only want to optimize functions for some particular datasets like
- // FlatMapDataset, InterleaveDataset etc. So we disable generalized
- // function optimization and explicitly handle function modifications
- // for those datasets in the rewrite.
- OP_REQUIRES_OK(ctx,
- RewriteDataset(ctx, input, std::move(config_factory),
- /*optimize_function_library=*/false, output));
+ *output =
+ new Dataset(ctx, input, num_replicas, output_types_, output_shapes_);
}
private:
- static RewriterConfig CreateConfig(int64 num_replicas, bool use_fallback) {
- RewriterConfig rewriter_config;
- rewriter_config.set_fail_on_optimizer_errors(true);
- rewriter_config.add_optimizers(kOptimizerName);
- rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
- auto custom_optimizer = rewriter_config.add_custom_optimizers();
- custom_optimizer->set_name(kOptimizerName);
- AttrValue num_replicas_attr;
- num_replicas_attr.set_i(num_replicas);
- (*custom_optimizer->mutable_parameter_map())["num_replicas"] =
- num_replicas_attr;
- AttrValue use_fallback_attr;
- use_fallback_attr.set_b(use_fallback);
- (*custom_optimizer->mutable_parameter_map())["use_fallback"] =
- use_fallback_attr;
- return rewriter_config;
- }
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const int64 num_replicas, const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ num_replicas_(num_replicas),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
- bool use_fallback_ = true;
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ name_utils::IteratorPrefixParams params;
+ return absl::make_unique<Iterator>(Iterator::Params{
+ this, name_utils::IteratorPrefix(kDatasetType, prefix, params)});
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ name_utils::DatasetDebugStringParams params;
+ params.set_args(num_replicas_);
+ return name_utils::DatasetDebugString(kDatasetType, params);
+ }
+
+ Status CheckExternalState() const override {
+ return input_->CheckExternalState();
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* num_replicas = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(num_replicas_, &num_replicas));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, num_replicas}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ ~Iterator() override {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (slice_number_ % dataset()->num_replicas_ == 0) {
+ input_descriptors_.clear();
+ std::vector<Tensor> input_tensors;
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &input_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ return Status::OK();
+ }
+
+ input_descriptors_.reserve(input_tensors.size());
+ for (int i = 0; i < input_tensors.size(); ++i) {
+ if (input_tensors[i].dims() == 0) {
+ return errors::InvalidArgument(
+ "Cannot rebatch dataset: All components must have at least "
+ "one dimension. Perhaps your input dataset is not batched? "
+ "Component ",
+ i, " is scalar.");
+ }
+
+ int64 original_batch_dim = input_tensors[i].dim_size(0);
+ int64 interval =
+ CeilDiv(original_batch_dim, dataset()->num_replicas_);
+ input_descriptors_.push_back(
+ {std::move(input_tensors[i]), original_batch_dim, interval});
+ }
+ }
+
+ out_tensors->reserve(input_descriptors_.size());
+
+ // We slice each component independently because they may have
+ // different batch dimensions.
+ for (const auto& input_desc : input_descriptors_) {
+ int64 start = input_desc.interval * slice_number_;
+ int64 end = std::min(start + input_desc.interval,
+ input_desc.original_batch_dim);
+ if (start >= end) {
+ // We can get here if ceil(original_batch_dim_ / new batch dim) <
+ // num_replicas_, i.e. the batch isn't big enough to distribute over
+ // num replicas. In this case, we return empty tensors for the
+ // remaining iterations that correspond to this batch.
+ start = end;
+ }
+ Tensor slice = input_desc.whole_tensor.Slice(start, end);
+ if (slice.IsAligned()) {
+ out_tensors->push_back(std::move(slice));
+ } else {
+ out_tensors->push_back(tensor::DeepCopy(std::move(slice)));
+ }
+ }
+ slice_number_ = (slice_number_ + 1) % dataset()->num_replicas_;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (!input_impl_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ } else {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("slice_number"), slice_number_));
+
+ if (slice_number_ % dataset()->num_replicas_ != 0) {
+ // Save state of input tensors.
+ for (int i = 0; i < input_descriptors_.size(); ++i) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("tensors[", i, "]")),
+ input_descriptors_[i].whole_tensor));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("slice_number"), &slice_number_));
+
+ input_descriptors_.clear();
+ input_descriptors_.resize(dataset()->output_dtypes().size());
+ if (slice_number_ % dataset()->num_replicas_ != 0) {
+ for (int i = 0; i < input_descriptors_.size(); ++i) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("tensors[", i, "]")),
+ &input_descriptors_[i].whole_tensor));
+ input_descriptors_[i].original_batch_dim =
+ input_descriptors_[i].whole_tensor.dim_size(0);
+ input_descriptors_[i].interval =
+ CeilDiv(input_descriptors_[i].original_batch_dim,
+ dataset()->num_replicas_);
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Describes one component of the input.
+ struct InputDescriptor {
+ InputDescriptor() {}
+ InputDescriptor(Tensor&& whole_tensor, int64 original_batch_dim,
+ int64 interval)
+ : whole_tensor(std::move(whole_tensor)),
+ original_batch_dim(original_batch_dim),
+ interval(interval) {}
+
+ Tensor whole_tensor;
+ int64 original_batch_dim;
+ int64 interval;
+ };
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_;
+ std::vector<InputDescriptor> input_descriptors_ GUARDED_BY(mu_);
+ int64 slice_number_ GUARDED_BY(mu_) = 0;
+ };
+
+ const DatasetBase* const input_;
+ const int64 num_replicas_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
};
REGISTER_KERNEL_BUILDER(Name("RebatchDataset").Device(DEVICE_CPU),
diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
index c12d991..ee24067 100644
--- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
@@ -29,12 +29,10 @@
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.experimental.ops import scan_ops
-from tensorflow.python.data.experimental.ops import sleep
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
@@ -90,6 +88,19 @@
i += 4
self.assertDatasetProduces(rebatched_dataset, expected_output)
+ def testBatchSizeNotDivisibleByNumReplicas2(self):
+ dataset = dataset_ops.Dataset.range(32).batch(16, drop_remainder=True)
+ rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
+ # This will rebatch into sub-batches of size 4, since
+ # ceil(16 / 5) = 4. However, that means only the first 4 replicas will get
+ # data.
+ expected_output = [[k for k in range(i, i + 4)] for i in range(0, 16, 4)]
+ expected_output.extend([[]]) # Last replica gets an empty batch
+ expected_output.extend(
+ [[k for k in range(i, i + 4)] for i in range(16, 32, 4)])
+ expected_output.extend([[]]) # Last replica gets an empty batch
+ self.assertDatasetProduces(rebatched_dataset, expected_output)
+
def testTupleOutput(self):
dataset = dataset_ops.Dataset.range(1024).map(lambda x: (x, x)).batch(32)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
@@ -119,7 +130,9 @@
# makes up a complete minibatch.
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
if not drop_remainder:
- expected_output.append([k for k in range(1024, 1032)])
+ # The last partial batch of size 8 is split over 4 replicas
+ expected_output.extend(
+ [[k for k in range(i, i + 2)] for i in range(1024, 1032, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@parameterized.named_parameters(drop_remainder_cases)
@@ -132,7 +145,8 @@
expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)] # pylint: disable=g-complex-comprehension
if not drop_remainder:
- expected_output += [[32, 33]]
+ # The last partial batch of size 2 is split over 4 replicas
+ expected_output += [[32], [33], [], []]
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testMultipleBatches(self):
@@ -214,9 +228,8 @@
dataset2 = dataset_ops.Dataset.range(32).batch(8)
dataset = dataset1.concatenate(dataset2)
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
- self.assertEqual(
- [[None]],
- [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
+ self.assertEqual([[None]],
+ [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
expected_output = ([[i, i + 1, i + 2, i + 3] for i in range(0, 64, 4)] +
[[i, i + 1] for i in range(0, 32, 2)])
self.assertDatasetProduces(rebatched_dataset, expected_output)
@@ -242,24 +255,6 @@
for i in range(0, 32, 2)]
self.assertDatasetProduces(rebatched_dataset, expected_output)
- def testUnsupportedTransformError(self):
- dataset = dataset_ops.Dataset.range(1024).batch(32).apply(sleep.sleep(10))
- with self.assertRaises(errors.InvalidArgumentError):
- rebatched_dataset = distribute._RebatchDataset(
- dataset, num_replicas=4, use_fallback=False)
- next_element = self.getNext(rebatched_dataset)
- self.evaluate(next_element())
-
- def testUnsupportedTransformInFlatMapError(self):
- dataset = dataset_ops.Dataset.range(2).flat_map(
- lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
- 32).apply(sleep.sleep(10)))
- with self.assertRaises(errors.InvalidArgumentError):
- rebatched_dataset = distribute._RebatchDataset(
- dataset, num_replicas=4, use_fallback=False)
- next_element = self.getNext(rebatched_dataset)
- self.evaluate(next_element())
-
def testFlatMapBatching(self):
dataset = dataset_ops.Dataset.range(2).flat_map(
lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
@@ -290,11 +285,8 @@
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
- # List of 4 elements where each element is a list of 8 numbering from 0 to
- # 31 repeated twice.
- expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
- for i in range(0, 32, 8) # generates 4 elements
- for _ in range(2)]
+ expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
+ expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testParallelInterleaveBatching(self):
@@ -310,11 +302,8 @@
rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
self.assertEqual([[None]],
[ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
- # List of 4 elements where each element is a list of 8 numbering from 0 to
- # 31 repeated twice in collated fashion i.e [0...8], [0...8] etc.
- expected_output = [[k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
- for i in range(0, 32, 8) # generates 4 elements
- for _ in range(2)]
+ expected_output = [[k for k in range(i, i + 8)] for i in range(0, 32, 8)]
+ expected_output += expected_output
self.assertDatasetProduces(rebatched_dataset, expected_output)
def testGroupByWindowStaticBatch(self):
@@ -350,8 +339,7 @@
key_func=lambda x: x, reduce_func=reduce_fn, window_size=10))
dataset = distribute._RebatchDataset(dataset, num_replicas=2)
- self.assertEqual([[None]],
- [ts.as_list() for ts in _flat_shapes(dataset)])
+ self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
# The batches of 5 (value == 0) will be split into minibatches of (3, 2) and
# the batches of 10 (value == 1) split into minibatches of (5, 5)
@@ -377,8 +365,8 @@
self.assertEqual([[None]], [ts.as_list() for ts in _flat_shapes(dataset)])
- pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (5, 1), (5, 1), (1, 1),
- (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
+ pairs = [(3, 0), (2, 0), (3, 0), (2, 0), (1, 0), (0, 0), (5, 1), (5, 1),
+ (1, 1), (0, 1), (3, 0), (2, 0), (2, 0), (2, 0), (5, 1), (4, 1)]
expected_output = [[value] * batch_size for batch_size, value in pairs]
self.assertDatasetProduces(dataset, expected_output)
@@ -450,92 +438,5 @@
self.assertDatasetProduces(rebatched_dataset, expected_output)
-@test_util.run_all_in_graph_and_eager_modes
-class RebatchDatasetFallbackTest(test_base.DatasetTestBase):
-
- def testWithNoBatchDataset(self):
- dataset = dataset_ops.Dataset.from_tensor_slices(
- [[k for k in range(i, i + 32)] for i in range(0, 1024, 32)]) # pylint: disable=g-complex-comprehension
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
- self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
- self.assertEqual([[8]],
- [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
- expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testWithUnhandledTransformation(self):
- dataset = dataset_ops.Dataset.range(1024).batch(
- 32, drop_remainder=True).apply(sleep.sleep(10))
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
- self.assertEqual([[32]], [ts.as_list() for ts in _flat_shapes(dataset)])
- self.assertEqual([[8]],
- [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
- expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testWithUnhandledTransformationInFlatMap(self):
- dataset = dataset_ops.Dataset.range(2).flat_map(
- lambda _: dataset_ops.Dataset.range(32).batch( # pylint: disable=g-long-lambda
- 32, drop_remainder=True).apply(sleep.sleep(10)))
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
- self.assertEqual([[8]],
- [ts.as_list() for ts in _flat_shapes(rebatched_dataset)])
-
- # Two elements where each element is a list of 4 elements where each element
- # is a list of 8.
- expected_output = [
- [k for k in range(i, i + 8)] # pylint: disable=g-complex-comprehension
- for _ in range(2) for i in range(0, 32, 8)] # generates 4 elements
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testWithUnknownBatchDim(self):
- dataset = dataset_ops.Dataset.range(1024).batch(
- 32, drop_remainder=False).apply(sleep.sleep(10))
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
- expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testWithUnknownBatchDimInSecondComponent(self):
- dataset0 = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
- dataset1 = dataset_ops.Dataset.range(1024).batch(
- 32, drop_remainder=False).apply(sleep.sleep(10))
- dataset = dataset_ops.Dataset.zip((dataset0, dataset1))
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=4)
-
- expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
- expected_output = [(x, x) for x in expected_output]
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testBatchSizeNotDivisibleByNumReplicas(self):
- dataset = dataset_ops.Dataset.range(64).batch(
- 32, drop_remainder=True).apply(sleep.sleep(10))
-
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
-
- expected_output = []
- i = 0
- for _ in range(2): # number of steps
- # first four minibatches have seven elements
- for _ in range(4):
- expected_output.append([k for k in range(i, i + 7)])
- i += 7
- # last minibatch has four elements
- expected_output.append([k for k in range(i, i + 4)])
- i += 4
- self.assertDatasetProduces(rebatched_dataset, expected_output)
-
- def testBatchSizesDontMatch(self):
- dataset = dataset_ops.Dataset.from_tensors((np.arange(10), np.arange(5)))
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "Cannot use rebatching fallback"):
- rebatched_dataset = distribute._RebatchDataset(dataset, num_replicas=5)
- next_element = self.getNext(rebatched_dataset)
- self.evaluate(next_element())
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/ops/distribute.py b/tensorflow/python/data/experimental/ops/distribute.py
index 7245a3d..dcf67fb 100644
--- a/tensorflow/python/data/experimental/ops/distribute.py
+++ b/tensorflow/python/data/experimental/ops/distribute.py
@@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
@@ -99,17 +98,10 @@
self._element_spec = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
- if compat.forward_compatible(2019, 8, 13) or not use_fallback:
- variant_tensor = ged_ops.rebatch_dataset(
- self._input_dataset._variant_tensor, # pylint: disable=protected-access
- num_replicas=num_replicas,
- use_fallback=use_fallback,
- **self._flat_structure)
- else:
- variant_tensor = ged_ops.rebatch_dataset(
- self._input_dataset._variant_tensor, # pylint: disable=protected-access
- num_replicas=num_replicas,
- **self._flat_structure)
+ variant_tensor = ged_ops.rebatch_dataset(
+ self._input_dataset._variant_tensor, # pylint: disable=protected-access
+ num_replicas=num_replicas,
+ **self._flat_structure)
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
@property
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index b55f933..34e6c39 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -484,6 +484,11 @@
# pylint: disable=protected-access
with ops.colocate_with(dataset._variant_tensor):
dataset = distribute._RebatchDataset(dataset, split_batch_by)
+ # Add a prefetch to pipeline rebatching for performance.
+ # TODO(rachelim): Instead of inserting an extra prefetch stage here,
+ # leverage static graph rewrites to insert _RebatchDataset before
+ # the final `prefetch` if it exists.
+ dataset = dataset.prefetch(split_batch_by)
except errors.InvalidArgumentError as e:
if "without encountering a batch" in str(e):
six.reraise(