[NFC] Move some infeed/outfeed logic into {In|Out}feedManager

This will simplify a subsequent change.

PiperOrigin-RevId: 380936688
Change-Id: Ibd28afb9c290229dc03d7a81554d822fa3036060
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 6582303..c93255b 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -47,99 +47,14 @@
 
 Status GpuTransferManager::TransferLiteralToInfeed(
     se::StreamExecutor* executor, const LiteralSlice& literal) {
-  const Shape& literal_shape = literal.shape();
-  VLOG(2) << "Transferring literal to infeed with shape: "
-          << ShapeUtil::HumanString(literal_shape);
-
-  // For a tuple, we transfer each of its elements to the device and
-  // enqueue the resulting destination device addresses with the
-  // infeed manager.
-  ShapeTree<InfeedBuffer> buffer_tree(literal_shape);
-  for (auto& leaf : buffer_tree.leaves()) {
-    const Shape& sub_shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
-    CHECK(sub_shape.IsArray()) << ShapeUtil::HumanStringWithLayout(sub_shape);
-    int64 tuple_element_size = GetByteSizeRequirement(sub_shape);
-    TF_ASSIGN_OR_RETURN(leaf.second, TransferBufferToInfeedInternal(
-                                         executor, tuple_element_size,
-                                         literal.untyped_data(leaf.first)));
-  }
-  return EnqueueBuffersToInfeed(executor, std::move(buffer_tree));
-}
-
-Status GpuTransferManager::EnqueueBuffersToInfeed(
-    se::StreamExecutor* executor, ShapeTree<InfeedBuffer> buffers) {
-  gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(executor);
-  se::Stream* stream = infeed_manager->GetStream();
-
-  // TODO(b/30467474): Since this stream is shared across different
-  // infeed requests, blocking on the stream might be
-  // heavy-handed. Figure out if finer-grained acknowledgement is
-  // possible.
-  Status block_status = stream->BlockHostUntilDone();
-  if (!block_status.ok()) {
-    return InternalError("Failed to complete data transfer on stream %p: %s",
-                         stream, block_status.error_message());
-  }
-
-  infeed_manager->EnqueueDestination(std::move(buffers));
-
-  VLOG(2) << "Infeed data transferred";
-
-  return Status::OK();
-}
-
-StatusOr<InfeedBuffer> GpuTransferManager::TransferBufferToInfeedInternal(
-    se::StreamExecutor* executor, int64 size, const void* source) {
-  if (size > std::numeric_limits<int32>::max()) {
-    return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes",
-                           size, std::numeric_limits<int32>::max());
-  }
-
-  if (size == 0) {
-    return InvalidArgument("Infeed shape needs 0 bytes");
-  }
-
-  gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(executor);
-  se::Stream* stream = infeed_manager->GetStream();
-  if (stream == nullptr) {
-    return InternalError("Failed to obtain a stream");
-  }
-
-  InfeedBuffer buffer(executor, size);
-  stream->ThenMemcpy(buffer.device_memory(), source, size);
-
-  VLOG(2) << "Queued infeed data on stream " << stream;
-
-  return std::move(buffer);
+  return gpu::GetOrCreateInfeedManager(executor)->TransferLiteralToInfeed(
+      executor, literal);
 }
 
 Status GpuTransferManager::TransferLiteralFromOutfeed(
     se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
-  ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
-      &literal.shape());
-
-  for (auto& leaf : outfeed_buffers.leaves()) {
-    const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
-    CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
-    leaf.second =
-        absl::make_unique<gpu::OutfeedBuffer>(GetByteSizeRequirement(shape));
-    leaf.second->set_destination(
-        absl::make_unique<MutableBorrowingLiteral>(literal, leaf.first));
-  }
-
-  // Give the tree of buffers to the outfeed manager. The device will fill it
-  // while we're waiting for it below.
-  gpu::OutfeedManager* outfeed_manager =
-      gpu::GetOrCreateOutfeedManager(executor);
-  outfeed_manager->EnqueueDestination(&outfeed_buffers);
-
-  // Now wait till all the buffers are written.
-  for (auto& leaf : outfeed_buffers.leaves()) {
-    const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
-    CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
-    leaf.second->WaitUntilAvailable();
-  }
-  return Status::OK();
+  return gpu::GetOrCreateOutfeedManager(executor)->TransferLiteralFromOutfeed(
+      executor, literal);
 }
 
 }  // namespace gpu
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index acc301f..97d4477 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -44,16 +44,6 @@
                                     MutableBorrowingLiteral literal) override;
 
  private:
-  // Initiates the infeed data transfers. InfeedBuffer->Done() must be
-  // called to clean up the memory allocated for InfeedBuffer.
-  StatusOr<InfeedBuffer> TransferBufferToInfeedInternal(
-      se::StreamExecutor* executor, int64 size, const void* source);
-
-  // Enqueues infeed data buffers with the infeed manager after their
-  // transfer completes.
-  Status EnqueueBuffersToInfeed(se::StreamExecutor* executor,
-                                ShapeTree<InfeedBuffer> buffers);
-
   TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager);
 };
 
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
index 06b877b..2a2a6ee 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc
@@ -16,6 +16,7 @@
 #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h"
 
 #include "absl/memory/memory.h"
+#include "tensorflow/compiler/xla/shape_util.h"
 
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 #include "tensorflow/compiler/xla/service/gpu/xla_executor_state.h"
@@ -30,6 +31,56 @@
   stream_->Init();
 }
 
+static StatusOr<se::ScopedDeviceMemory<uint8>> CopyBufferToDevice(
+    se::Stream* stream, int64 size, const void* source) {
+  if (size > std::numeric_limits<int32>::max()) {
+    return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes",
+                           size, std::numeric_limits<int32>::max());
+  }
+
+  if (size == 0) {
+    return InvalidArgument("Infeed shape needs 0 bytes");
+  }
+
+  se::StreamExecutor* executor = stream->parent();
+  se::ScopedDeviceMemory<uint8> buffer(executor,
+                                       executor->AllocateArray<uint8>(size));
+  stream->ThenMemcpy(buffer.ptr(), source, size);
+
+  return std::move(buffer);
+}
+
+Status InfeedManager::TransferLiteralToInfeed(se::StreamExecutor* executor,
+                                              const LiteralSlice& literal) {
+  const Shape& literal_shape = literal.shape();
+  VLOG(2) << "Transferring literal to infeed with shape: "
+          << ShapeUtil::HumanString(literal_shape);
+
+  // For a tuple, we transfer each of its elements to the device and enqueue the
+  // resulting destination device addresses with the infeed manager.
+  ShapeTree<se::ScopedDeviceMemory<uint8>> buffer_tree(literal_shape);
+  for (auto& leaf : buffer_tree.leaves()) {
+    const Shape& sub_shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
+    CHECK(sub_shape.IsArray()) << ShapeUtil::HumanStringWithLayout(sub_shape);
+    TF_ASSIGN_OR_RETURN(
+        leaf.second,
+        CopyBufferToDevice(stream(), ShapeUtil::ByteSizeOf(sub_shape),
+                           literal.untyped_data(leaf.first)));
+  }
+
+  // TODO(b/30467474): Since this stream is shared across different infeed
+  // requests, blocking on the stream might be heavy-handed. Figure out if
+  // finer-grained acknowledgement is possible.
+  Status block_status = stream()->BlockHostUntilDone();
+  if (!block_status.ok()) {
+    return InternalError("Failed to complete data transfer on stream %p: %s",
+                         stream(), block_status.error_message());
+  }
+
+  EnqueueDestination(std::move(buffer_tree));
+  return Status::OK();
+}
+
 InfeedManager *GetOrCreateInfeedManager(se::StreamExecutor *executor) {
 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
   stream_executor::gpu::GpuExecutor *gpu_executor =
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
index 519e9a7..597021b 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h
@@ -21,6 +21,7 @@
 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_INFEED_MANAGER_H_
 
 #include "absl/base/thread_annotations.h"
+#include "tensorflow/compiler/xla/literal.h"
 #include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
 #include "tensorflow/compiler/xla/shape_tree.h"
 #include "tensorflow/compiler/xla/types.h"
@@ -41,35 +42,18 @@
 // memory. Potential solution is to pre-allocate a fixed amount of
 // memory and block when that memory is full.
 
-// Defines an infeed buffer that is passed to the runtime by
-// the client. The client manages the memory of the buffer.
-class InfeedBuffer {
- public:
-  InfeedBuffer() = default;
-  InfeedBuffer(se::StreamExecutor* executor, int64 length)
-      : device_memory_(executor, executor->AllocateArray<uint8>(length)),
-        length_(length) {
-    CHECK(!device_memory_->is_null());
-  }
-
-  int64 length() const { return length_; }
-
-  se::DeviceMemoryBase* device_memory() { return device_memory_.ptr(); }
-
- private:
-  se::ScopedDeviceMemory<uint8> device_memory_;
-  int64 length_;
-};
-
 // Client-side class used to enqueue infeed buffers.
-class InfeedManager : public XfeedQueue<ShapeTree<InfeedBuffer>> {
+class InfeedManager
+    : public XfeedQueue<ShapeTree<se::ScopedDeviceMemory<uint8>>> {
  public:
   explicit InfeedManager(se::StreamExecutor* executor);
 
-  // Returns a stream for this infeed manager.
-  se::Stream* GetStream() const { return stream_.get(); }
+  Status TransferLiteralToInfeed(se::StreamExecutor* executor,
+                                 const LiteralSlice& literal);
 
  private:
+  se::Stream* stream() const { return stream_.get(); }
+
   // Stream used to enqueue infeed device copies.
   std::unique_ptr<se::Stream> stream_;
 };
diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
index 789eb52..4961b12 100644
--- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc
@@ -40,14 +40,14 @@
   auto op_profiler =
       params.profiler->MakeScopedInstructionProfiler(profile_index());
 
-  ShapeTree<InfeedBuffer> source_buffers =
+  ShapeTree<se::ScopedDeviceMemory<uint8>> source_buffers =
       GetOrCreateInfeedManager(stream.parent())->BlockingGetNextDestination();
 
   size_t index = 0;
   for (auto& source : source_buffers.leaves()) {
     // Assert that the shapes are compatible.
     const ShapeIndex& shape_index = source.first;
-    InfeedBuffer& buffer = source.second;
+    se::ScopedDeviceMemory<uint8>& buffer = source.second;
     const Shape& source_shape =
         ShapeUtil::GetSubshape(source_buffers.shape(), shape_index);
     TF_RET_CHECK(ShapeUtil::Equal(dest_slices_[index].shape, source_shape))
@@ -57,7 +57,7 @@
         << ShapeUtil::HumanStringWithLayout(dest_slices_[index].shape);
     se::DeviceMemoryBase dest_address =
         buffer_allocations.GetDeviceAddress(dest_slices_[index++].slice);
-    stream.ThenMemcpy(&dest_address, *buffer.device_memory(), buffer.length());
+    stream.ThenMemcpy(&dest_address, *buffer.ptr(), buffer.ptr()->size());
   }
 
   // Make sure that all dest slices have been copied into.
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
index d0feb5a..e83e2f4 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.cc
@@ -40,5 +40,35 @@
 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
 }
 
+Status OutfeedManager::TransferLiteralFromOutfeed(
+    se::StreamExecutor* executor, MutableBorrowingLiteral literal) {
+  ShapeTree<std::unique_ptr<gpu::OutfeedBuffer>> outfeed_buffers(
+      &literal.shape());
+
+  for (auto& leaf : outfeed_buffers.leaves()) {
+    const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
+    CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
+    leaf.second =
+        absl::make_unique<gpu::OutfeedBuffer>(ShapeUtil::ByteSizeOf(shape));
+    leaf.second->set_destination(
+        absl::make_unique<MutableBorrowingLiteral>(literal, leaf.first));
+  }
+
+  // Give the tree of buffers to the outfeed manager. The device will fill it
+  // while we're waiting for it below.
+  gpu::OutfeedManager* outfeed_manager =
+      gpu::GetOrCreateOutfeedManager(executor);
+  outfeed_manager->EnqueueDestination(&outfeed_buffers);
+
+  // Now wait till all the buffers are written.
+  for (auto& leaf : outfeed_buffers.leaves()) {
+    const Shape& shape = ShapeUtil::GetSubshape(literal.shape(), leaf.first);
+    CHECK(shape.IsArray()) << ShapeUtil::HumanStringWithLayout(shape);
+    leaf.second->WaitUntilAvailable();
+  }
+
+  return Status::OK();
+}
+
 }  // namespace gpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
index 9e00464..29a94ef 100644
--- a/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/outfeed_manager.h
@@ -56,7 +56,12 @@
 
 // Manages a thread-safe queue of buffers. The buffers are supposed to be
 // produced by the transfer manager and consumed by the device.
-using OutfeedManager = XfeedQueue<ShapeTree<std::unique_ptr<OutfeedBuffer>>*>;
+class OutfeedManager
+    : public XfeedQueue<ShapeTree<std::unique_ptr<OutfeedBuffer>>*> {
+ public:
+  Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
+                                    MutableBorrowingLiteral literal);
+};
 
 // Returns the GPU outfeed manager for the given stream executor.
 OutfeedManager* GetOrCreateOutfeedManager(se::StreamExecutor* executor);