Add support in collective_rma_distributed for worker interfaces that directly populate the pointer provided in the RecvBuf request with the remote tensor, rather than populating transport_options with a serialized proto containing the remote tensor. This allows for more efficient implementation that avoids extra serialization / de-serialization costs from transport format into protos and finally to tensor buffer.

PiperOrigin-RevId: 372976144
Change-Id: I2f46d9c6b7dd30ea7003de7587703dc8666e7c27
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
index 29fcd82..05032da 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed.cc
@@ -14,6 +14,8 @@
 ==============================================================================*/
 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
 
+#include <memory>
+
 #include "tensorflow/core/common_runtime/base_collective_executor.h"
 #include "tensorflow/core/common_runtime/copy_tensor.h"
 #include "tensorflow/core/common_runtime/device_mgr.h"
@@ -24,6 +26,7 @@
 #include "tensorflow/core/distributed_runtime/request_id.h"
 #include "tensorflow/core/distributed_runtime/worker_cache.h"
 #include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/platform/protobuf_internal.h"
 #include "tensorflow/core/protobuf/transport_options.pb.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
@@ -73,6 +76,32 @@
     head += tensor_content_chunk.size();
   }
 }
+
+Status PopulateTensorFromResponse(const RecvBufResponse& response,
+                                  Tensor* cpu_tensor) {
+  const bool has_transport_options = response.has_transport_options();
+
+  // If there are no transport options, then the tensor has already been
+  // copied into request.buf_ptr.
+  if (!has_transport_options) return Status::OK();
+
+  const int64 total_bytes = cpu_tensor->TotalBytes();
+  int64 num_bytes = 0;
+  RecvBufRespExtra extra;
+  response.transport_options().UnpackTo(&extra);
+  for (const auto& chunk : extra.tensor_content()) {
+    num_bytes += chunk.size();
+  }
+
+  if (num_bytes != total_bytes) {
+    return errors::Internal("Tensor Size Mismatch: RecvBufResponse returned ",
+                            num_bytes,
+                            " bytes, expected: ", cpu_tensor->TotalBytes());
+  }
+  PopulateTensorFromExtra(extra, cpu_tensor);
+  return Status::OK();
+}
+
 }  // namespace
 
 void CollectiveRemoteAccessDistributed::RecvFromPeer(
@@ -94,83 +123,95 @@
   struct State {
     DeviceAttributes server_attributes;
     std::unique_ptr<RecvBufCall> call;
+    std::unique_ptr<Tensor> cpu_tensor;
   };
   State* state = new State;
 
-  // Logic to be executed on the RecvBufAsync callback.
-  auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
-                            to_device_ctx, to_tensor, dev_to_dev_stream_index,
-                            done](const Status& s) {
-    if (s.ok()) {
-      // In this generic implementation the bytes come back in the
-      // RPC response protobuf rather than via RDMA so we need to copy
-      // them into the destination tensor here.
-      RecvBufRespExtra extra;
-      state->call->resp_.transport_options().UnpackTo(&extra);
-      int64 num_bytes = 0;
-      for (const auto& chunk : extra.tensor_content()) {
-        num_bytes += chunk.size();
-      }
-      const int64 total_bytes = to_tensor->TotalBytes();
-      if (num_bytes != total_bytes) {
-        done(errors::Internal("RecvBufResponse returned ", num_bytes,
-                              " bytes where to_tensor expected ",
-                              to_tensor->TotalBytes()));
-        delete state;
-        return;
-      }
-      if (to_device->tensorflow_gpu_device_info()) {
-        // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
-        // Use GPU-registered memory for the CPU tensor so the transfer
-        // goes faster.
-        Device* cpu_dev = nullptr;
-        Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
-        if (!status.ok()) {
-          done(status);
-          delete state;
-          return;
-        }
-        AllocatorAttributes cpu_attr;
-        cpu_attr.set_gpu_compatible(true);
-        ScopedMemoryDebugAnnotation op_annotation(
-            "CollectiveRemoteAccessDistributed::RecvFromPeer"
-            "::recv_buf_callback",
-            step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape());
-        Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
-                                        to_tensor->dtype(), to_tensor->shape());
-        PopulateTensorFromExtra(extra, cpu_tensor);
-        // Then copy it to the GPU.
-        CopyTensor::ViaDMA("",  // edge name (non-existent)
-                           nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
-                           to_device, cpu_attr, to_alloc_attr, cpu_tensor,
-                           to_tensor, dev_to_dev_stream_index,
-                           [this, cpu_tensor, done](const Status& s) {
-                             delete cpu_tensor;
-                             // This callback must not block, so execute
-                             // done in another thread.
-                             work_queue_->Schedule([s, done] { done(s); });
-                           });
-        delete state;
-        return;
-      } else {
-        // CPU device
-        PopulateTensorFromExtra(extra, to_tensor);
-      }
-    }
-
-    delete state;
-    done(s);
-  };
-
+  DeviceAttributes server_attributes;
   Status s = dev_resolver_->GetDeviceAttributes(peer_device,
                                                 &state->server_attributes);
   if (!s.ok()) {
-    recv_buf_callback(s);
+    delete state;
+    done(s);
     return;
   }
+
+  Tensor* dst_tensor = nullptr;
+  Device* cpu_dev = nullptr;
+  if (to_device->tensorflow_gpu_device_info()) {
+    // Move the bytes into a CPU tensor then use tensor-to-tensor copy.
+    // Use GPU-registered memory for the CPU tensor so the transfer
+    // goes faster.
+
+    Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
+    if (!status.ok()) {
+      delete state;
+      done(s);
+      return;
+    }
+    AllocatorAttributes cpu_attr;
+    cpu_attr.set_gpu_compatible(true);
+    ScopedMemoryDebugAnnotation op_annotation(
+        "CollectiveRemoteAccessDistributed::RecvFromPeer"
+        "::recv_buf_callback",
+        step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape());
+
+    state->cpu_tensor =
+        std::make_unique<Tensor>(cpu_dev->GetAllocator(cpu_attr),
+                                 to_tensor->dtype(), to_tensor->shape());
+    dst_tensor = state->cpu_tensor.get();
+  } else {
+    dst_tensor = to_tensor;
+  }
+
+  // Logic to be executed on the RecvBufAsync callback.
+  auto recv_buf_callback =
+      [this, state, to_device, to_alloc_attr, to_device_ctx, to_tensor, cpu_dev,
+       dev_to_dev_stream_index, dst_tensor, done](const Status& s) {
+        if (s.ok()) {
+          // In this generic implementation the bytes come back in one of 2
+          // ways:
+          // 1. In the response protobuf transport_options field (OR)
+          // 2. It has already been copied over into RecvBufCall::req_.buf_ptr()
+          // provided in request. buf_ptr is set to dst_tensor and points to
+          // either the temporary cpu_tensor in case to_device is a GPU device
+          // OR directly to to_tensor if to_device is not a GPU device.
+          //
+          // PopulateTensorFromResponse handles both cases.
+          // (NOP in 2nd case) In case the final to_tensor is on GPU, buf_ptr
+          // points to a tmp CPU buffer and needs to be copied over to
+          // to_tensor.
+          Status status =
+              PopulateTensorFromResponse(state->call->resp_, dst_tensor);
+          if (!status.ok()) {
+            done(status);
+            delete state;
+            return;
+          }
+
+          if (to_device->tensorflow_gpu_device_info()) {
+            AllocatorAttributes cpu_attr;
+            cpu_attr.set_gpu_compatible(true);
+            CopyTensor::ViaDMA("",  // edge name (non-existent)
+                               nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
+                               to_device, cpu_attr, to_alloc_attr, dst_tensor,
+                               to_tensor, dev_to_dev_stream_index,
+                               [this, state, done](const Status& s) {
+                                 delete state;
+                                 // This callback must not block, so execute
+                                 // done in another thread.
+                                 work_queue_->Schedule([s, done] { done(s); });
+                               });
+            return;
+          }
+        }
+        delete state;
+        done(s);
+      };
+
   state->call.reset(new RecvBufCall(
       step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
-      to_alloc_attr, to_tensor, client_locality, state->server_attributes,
+      to_alloc_attr, dst_tensor, client_locality, state->server_attributes,
       cancellation_manager, worker_cache_));
   CancellationToken abortion_token =
       abortion_cancel_mgr_.get_cancellation_token();
@@ -194,10 +235,10 @@
     done(Status::OK());
     return;
   }
-  // We send a GetStatus RPC with fail_fast=false to check the health of a peer
-  // task. If the RPC succeeds, we verify if the peer_device incarnation matches
-  // the local record if we have it. Note that DeviceResolverInterface always
-  // caches the device attributes.
+  // We send a GetStatus RPC to check the health of a peer task. If the RPC
+  // succeeds, we verify if the peer_device incarnation matches the local record
+  // if we have it. Note that DeviceResolverInterface always caches the device
+  // attributes.
   WorkerInterface* wi = worker_cache_->GetOrCreateWorker(peer_task);
   if (wi == nullptr) {
     done(errors::InvalidArgument(peer_task,
diff --git a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
index 74282be..35ae27f 100644
--- a/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
+++ b/tensorflow/core/distributed_runtime/collective_rma_distributed_test.cc
@@ -21,13 +21,16 @@
 #include "tensorflow/core/common_runtime/process_util.h"
 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
 #include "tensorflow/core/distributed_runtime/test_utils.h"
+#include "tensorflow/core/framework/allocator.h"
 #include "tensorflow/core/framework/cancellation.h"
 #include "tensorflow/core/framework/device_attributes.pb.h"
 #include "tensorflow/core/lib/core/notification.h"
 #include "tensorflow/core/lib/core/status_test_util.h"
 #include "tensorflow/core/lib/random/random.h"
 #include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/errors.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/test.h"
 #include "tensorflow/core/protobuf/transport_options.pb.h"
 #include "tensorflow/core/protobuf/worker.pb.h"
@@ -42,20 +45,33 @@
 namespace tensorflow {
 namespace {
 
-static std::unique_ptr<Device> NewDevice(const string& type,
-                                         const string& name) {
+class FakeAllocator : public Allocator {
+ public:
+  string Name() override { return "fake"; }
+  void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+    return port::AlignedMalloc(num_bytes, alignment);
+  }
+  void DeallocateRaw(void* ptr) override { return port::AlignedFree(ptr); }
+};
+
+static std::unique_ptr<Device> NewDevice(const string& type, const string& name,
+                                         Allocator* allocator) {
   class FakeDevice : public Device {
    public:
-    explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
+    explicit FakeDevice(const DeviceAttributes& attr, Allocator* allocator)
+        : Device(nullptr, attr), allocator_(allocator) {}
     Status Sync() override { return Status::OK(); }
-    Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
+    Allocator* GetAllocator(AllocatorAttributes) override { return allocator_; }
+
+   private:
+    Allocator* const allocator_;
   };
   DeviceAttributes attr;
   attr.set_name(name);
   attr.set_device_type(type);
   attr.mutable_locality()->set_numa_node(3);  // a non-default value
   attr.set_incarnation(random::New64());
-  return absl::make_unique<FakeDevice>(attr);
+  return absl::make_unique<FakeDevice>(attr, allocator);
 }
 
 static int64 kStepId = 123;
@@ -63,12 +79,14 @@
 class FakeWorker : public TestWorkerInterface {
  public:
   FakeWorker(const string& name, DeviceMgr* dev_mgr,
-             DeviceResolverDistributed* dres, bool is_failed)
+             DeviceResolverDistributed* dres, bool is_failed,
+             bool set_tensor_in_extra)
       : name_(name),
         device_mgr_(dev_mgr),
         device_resolver_(dres),
         buf_rendezvous_(kStepId, dev_mgr),
-        is_failed_(is_failed) {}
+        is_failed_(is_failed),
+        set_tensor_in_extra_(set_tensor_in_extra) {}
 
   // Direct access to a BufRendezvous that holds whatever the remote
   // worker is supposed to have.
@@ -112,17 +130,29 @@
     buf_rendezvous_.ConsumeBuf(
         request->buf_rendezvous_key(), request->src_device(),
         request->src_incarnation(),
-        [opts, response, done](const Status& s, BufRendezvous::Hook* h) {
+        [this, opts, request, response, done](const Status& status,
+                                              BufRendezvous::Hook* h) {
+          Status s = status;
           if (s.ok()) {
             opts->ClearCancelCallback();
-            // Since this is not really RDMA into pre-allocated memory send the
-            // bytes in the response.
-            RecvBufRespExtra extra;
             int64 num_bytes = h->prod_value->TotalBytes();
-            extra.add_tensor_content(string(
-                reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
-                num_bytes));
-            response->mutable_transport_options()->PackFrom(extra);
+
+            if (set_tensor_in_extra_) {
+              // Since this is not really RDMA into pre-allocated memory send
+              // the bytes in the response.
+              RecvBufRespExtra extra;
+              extra.add_tensor_content(string(
+                  reinterpret_cast<const char*>(DMAHelper::base(h->prod_value)),
+                  num_bytes));
+              response->mutable_transport_options()->PackFrom(extra);
+            } else {
+              if (request->num_bytes() != num_bytes) {
+                s = errors::Internal("Tensor Size Mismatch.");
+              } else {
+                memcpy(reinterpret_cast<void*>(request->buf_ptr()),
+                       DMAHelper::base(h->prod_value), num_bytes);
+              }
+            }
           }
           done(s);
           if (h) BufRendezvous::DoneWithHook(h);
@@ -136,6 +166,7 @@
   DeviceResolverDistributed* device_resolver_;
   BufRendezvous buf_rendezvous_;
   bool is_failed_;
+  const bool set_tensor_in_extra_;
 };
 
 class FakeCache : public TestWorkerCache {
@@ -179,7 +210,19 @@
   }
 };
 
-class CollRMADistTest : public ::testing::Test {
+enum TEST_PARAM_DEVICE_TYPE {
+  TEST_PARAM_DEVICE_TYPE_CPU = 0,
+  TEST_PARAM_DEVICE_TYPE_GPU,
+};
+
+enum TEST_PARAM_TENSOR_LOC {
+  TEST_PARAM_TENSOR_LOC_AT_BUF_PTR = 0,
+  TEST_PARAM_TENSOR_LOC_IN_EXTRA,
+};
+
+class CollRMADistTest
+    : public ::testing::TestWithParam<
+          std::tuple<TEST_PARAM_DEVICE_TYPE, TEST_PARAM_TENSOR_LOC>> {
  protected:
   CollRMADistTest()
       : work_queue_(
@@ -217,12 +260,17 @@
     const int kNumElts = 8;
     expected_value_ = Tensor(DT_FLOAT, {kNumElts});
     to_tensor_ = Tensor(DT_FLOAT, {kNumElts});
+    large_response_ = Tensor(DT_FLOAT, {2 * kNumElts});
     auto exp_alias = expected_value_.flat<float>();
     auto to_alias = to_tensor_.flat<float>();
+    auto large_response_alias = large_response_.flat<float>();
     for (int i = 0; i < kNumElts; ++i) {
       exp_alias(i) = i;
       to_alias(i) = -1;
     }
+    for (int i = 0; i < 2 * kNumElts; ++i) {
+      large_response_alias(i) = -2;
+    }
   }
 
   // Populates all device resolvers with device attributes of the cluster. This
@@ -243,7 +291,8 @@
     for (int i = 0; i < num_devices; ++i) {
       devices.push_back(NewDevice(
           device_type,
-          strings::StrCat(worker_name, "/device:", device_type, ":", i)));
+          strings::StrCat(worker_name, "/device:", device_type, ":", i),
+          &fake_allocator_));
     }
     DeviceMgr* dev_mgr = new StaticDeviceMgr(std::move(devices));
     device_mgrs_.push_back(dev_mgr);
@@ -254,7 +303,12 @@
     }
     DeviceResolverDistributed* dev_res = new DeviceResolverDistributed(dev_mgr);
     dev_resolvers_[worker_name] = dev_res;
-    FakeWorker* fw = new FakeWorker(worker_name, dev_mgr, dev_res, is_failed);
+    FakeWorker* fw =
+        new FakeWorker(worker_name, dev_mgr, dev_res, is_failed,
+                       /*set_tensor_in_extra=*/
+                       std::get<TEST_PARAM_TENSOR_LOC>(GetParam()) ==
+                           TEST_PARAM_TENSOR_LOC_IN_EXTRA);
+
     workers_.push_back(fw);
     wc_.AddWorker(worker_name, fw);
   }
@@ -280,6 +334,19 @@
     }
   }
 
+  void ValidateResultTensorUnchanged() {
+    for (int i = 0; i < to_tensor_.NumElements(); ++i) {
+      EXPECT_FLOAT_EQ(-1, to_tensor_.flat<float>()(i));
+    }
+  }
+
+  void MaybeSetGPUDevice(Device* dst_device) {
+    if (std::get<TEST_PARAM_DEVICE_TYPE>(GetParam()) ==
+        TEST_PARAM_DEVICE_TYPE_GPU) {
+      dst_device->set_tensorflow_gpu_device_info(&gpu_device_info_);
+    }
+  }
+
   FakeCache wc_;
   CancellationManager cm_;
   std::vector<DeviceMgr*> device_mgrs_;
@@ -292,13 +359,16 @@
   int num_done_ TF_GUARDED_BY(mu_);
   condition_variable done_;
   Tensor expected_value_;
+  Tensor large_response_;
   Tensor to_tensor_;
   CallOptions opts_;
   DeviceLocality device_locality_;
   AllocatorAttributes alloc_attr_;
+  FakeAllocator fake_allocator_;
+  DeviceBase::GpuDeviceInfo gpu_device_info_;
 };
 
-TEST_F(CollRMADistTest, ProdFirstOK) {
+TEST_P(CollRMADistTest, ProdFirstOK) {
   ResolveDeviceAttributes();
   Notification consumer_note;
   Notification producer_note;
@@ -318,6 +388,7 @@
   string dev_name = "CPU:0";
   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
   DeviceContext* to_device_ctx = nullptr;
+  MaybeSetGPUDevice(dst_device);
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
       "/job:worker/replica:0/task:1",                     // peer_task
@@ -336,7 +407,7 @@
   ValidateResultTensor();
 }
 
-TEST_F(CollRMADistTest, ConsFirstOK) {
+TEST_P(CollRMADistTest, ConsFirstOK) {
   ResolveDeviceAttributes();
   Notification consumer_note;
   Notification producer_note;
@@ -347,6 +418,7 @@
   Device* dst_device = nullptr;
   string dev_name = "CPU:0";
   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  MaybeSetGPUDevice(dst_device);
   DeviceContext* to_device_ctx = nullptr;
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
@@ -374,7 +446,7 @@
   ValidateResultTensor();
 }
 
-TEST_F(CollRMADistTest, ConsFirstAbort) {
+TEST_P(CollRMADistTest, ConsFirstAbort) {
   ResolveDeviceAttributes();
   Notification consumer_note;
   Status consumer_status;
@@ -382,6 +454,7 @@
   Device* dst_device = nullptr;
   string dev_name = "CPU:0";
   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  MaybeSetGPUDevice(dst_device);
   DeviceContext* to_device_ctx = nullptr;
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
@@ -399,7 +472,47 @@
   EXPECT_EQ(consumer_status.error_message(), "Cancelled");
 }
 
-TEST_F(CollRMADistTest, WorkerRestart) {
+TEST_P(CollRMADistTest, ResponseTooLarge) {
+  ResolveDeviceAttributes();
+  Notification consumer_note;
+  Notification producer_note;
+  Status consumer_status;
+  Status producer_status;
+  FakeWorker* wi = workers_[1];
+  const string kBufKey = "fake_buf_key";
+  wi->buf_rendezvous()->ProvideBuf(
+      kBufKey, nullptr /*device*/, nullptr /*dev_ctx*/, &large_response_,
+      AllocatorAttributes(),
+      [&producer_note, &producer_status](const Status& s) {
+        producer_status.Update(s);
+        producer_note.Notify();
+      },
+      nullptr /*cancellation_manager*/);
+  Device* dst_device = nullptr;
+  string dev_name = "CPU:0";
+  TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  DeviceContext* to_device_ctx = nullptr;
+  MaybeSetGPUDevice(dst_device);
+  rma_->RecvFromPeer(
+      "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
+      "/job:worker/replica:0/task:1",                     // peer_task
+      false,                                              // peer_is_local
+      kBufKey, dst_device, to_device_ctx, alloc_attr_, &to_tensor_,
+      device_locality_, 0 /*dev_to_dev_stream_index*/,
+      nullptr /*cancellation_manager*/,
+      [&consumer_status, &consumer_note](const Status& s) {
+        consumer_status = s;
+        consumer_note.Notify();
+      });
+  consumer_note.WaitForNotification();
+  EXPECT_THAT(consumer_status.error_message(),
+              ::testing::HasSubstr("Tensor Size Mismatch"));
+  producer_note.WaitForNotification();
+  TF_EXPECT_OK(producer_status);
+  ValidateResultTensorUnchanged();
+}
+
+TEST_P(CollRMADistTest, WorkerRestart) {
   ResolveDeviceAttributes();
   Notification consumer_note;
   Notification producer_note;
@@ -410,6 +523,7 @@
   Device* dst_device = nullptr;
   string dev_name = "CPU:0";
   TF_EXPECT_OK(device_mgrs_[0]->LookupDevice(dev_name, &dst_device));
+  MaybeSetGPUDevice(dst_device);
   DeviceContext* to_device_ctx = nullptr;
   rma_->RecvFromPeer(
       "/job:worker/replica:0/task:1/device:" + dev_name,  // peer_dev
@@ -454,7 +568,7 @@
   EXPECT_TRUE(errors::IsFailedPrecondition(consumer_status));
 }
 
-TEST_F(CollRMADistTest, CheckHealthOKWithCachedAttr) {
+TEST_P(CollRMADistTest, CheckHealthOKWithCachedAttr) {
   ResolveDeviceAttributes();
   Status check_health_status;
   Notification check_health_done;
@@ -468,7 +582,7 @@
   TF_EXPECT_OK(check_health_status);
 }
 
-TEST_F(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
+TEST_P(CollRMADistTest, CheckHealthOKWithoutCachedAttr) {
   Status check_health_status;
   Notification check_health_done;
   rma_->CheckPeerHealth(
@@ -481,7 +595,7 @@
   EXPECT_TRUE(check_health_status.ok());
 }
 
-TEST_F(CollRMADistTest, CheckHealthRestarted) {
+TEST_P(CollRMADistTest, CheckHealthRestarted) {
   ResolveDeviceAttributes();
   RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1);
 
@@ -497,7 +611,7 @@
   EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
 }
 
-TEST_F(CollRMADistTest, CheckHealthFailedPeer) {
+TEST_P(CollRMADistTest, CheckHealthFailedPeer) {
   ResolveDeviceAttributes();
   RestartWorker("/job:worker/replica:0/task:1", "CPU", /*num_devices*/ 1,
                 /*is_failed*/ true);
@@ -514,7 +628,7 @@
   EXPECT_TRUE(errors::IsUnavailable(check_health_status));
 }
 
-TEST_F(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
+TEST_P(CollRMADistTest, CheckHealthRestartedWithDifferentDevices) {
   ResolveDeviceAttributes();
   RestartWorker("/job:worker/replica:0/task:1", "GPU", /*num_devices*/ 1);
   Status check_health_status;
@@ -529,5 +643,12 @@
   EXPECT_TRUE(errors::IsFailedPrecondition(check_health_status));
 }
 
+INSTANTIATE_TEST_SUITE_P(
+    TensorInBufPtrOrExtra, CollRMADistTest,
+    ::testing::Combine(::testing::Values(TEST_PARAM_TENSOR_LOC_AT_BUF_PTR,
+                                         TEST_PARAM_TENSOR_LOC_IN_EXTRA),
+                       ::testing::Values(TEST_PARAM_DEVICE_TYPE_CPU,
+                                         TEST_PARAM_DEVICE_TYPE_GPU)));
+
 }  // namespace
 }  // namespace tensorflow