blob: 46889e737e7c36182a35ef4ed3d5d69794a063e3 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/cancellable_call.h"
#include "tensorflow/core/distributed_runtime/request_id.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
namespace {
class RecvBufCall : public CancellableCall {
public:
RecvBufCall(int64 step_id, const string& peer_device, const string& peer_task,
const string& key, Device* to_device,
DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality,
const DeviceAttributes& server_attributes,
CancellationManager* cancel_mgr, WorkerCacheInterface* wc)
: CancellableCall(cancel_mgr, peer_task, wc) {
req_.set_step_id(step_id);
req_.set_buf_rendezvous_key(key);
*req_.mutable_client_locality() = client_locality;
*req_.mutable_server_locality() = server_attributes.locality();
req_.set_num_bytes(to_tensor->TotalBytes());
req_.set_buf_ptr(reinterpret_cast<int64>(DMAHelper::base(to_tensor)));
req_.set_src_device(peer_device);
req_.set_src_incarnation(server_attributes.incarnation());
req_.set_dst_device(to_device->name());
req_.set_request_id(GetUniqueRequestId());
}
~RecvBufCall() override {}
void IssueCall(const StatusCallback& done) override {
wi_->RecvBufAsync(&opts_, &req_, &resp_, done);
}
RecvBufRequest req_;
RecvBufResponse resp_;
};
void PopulateTensorFromExtra(const RecvBufRespExtra& extra,
Tensor* cpu_tensor) {
char* head = reinterpret_cast<char*>(DMAHelper::base(cpu_tensor));
for (const auto& tensor_content_chunk : extra.tensor_content()) {
memcpy(head, std::string(tensor_content_chunk).data(),
tensor_content_chunk.size());
head += tensor_content_chunk.size();
}
}
} // namespace
void CollectiveRemoteAccessDistributed::RecvFromPeer(
const string& peer_device, const string& peer_task, bool peer_is_local,
const string& key, Device* to_device, DeviceContext* to_device_ctx,
const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
const DeviceLocality& client_locality, int dev_to_dev_stream_index,
const StatusCallback& done) {
if (peer_is_local) {
CollectiveRemoteAccessLocal::RecvFromPeer(
peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
done);
return;
}
// State that needs to be threaded through a couple of async calls
// in order to make this function completely non-blocking.
struct State {
DeviceAttributes server_attributes;
std::unique_ptr<RecvBufCall> call;
};
State* state = new State;
// Logic to be executed on the RecvBufAsync callback.
auto recv_buf_callback = [this, state, peer_task, to_device, to_alloc_attr,
to_device_ctx, to_tensor, dev_to_dev_stream_index,
done](const Status& s) {
if (s.ok()) {
// In this generic implementation the bytes come back in the
// RPC response protobuf rather than via RDMA so we need to copy
// them into the destination tensor here.
RecvBufRespExtra extra;
state->call->resp_.transport_options().UnpackTo(&extra);
int64 num_bytes = 0;
for (const auto& chunk : extra.tensor_content()) {
num_bytes += chunk.size();
}
const int64 total_bytes = to_tensor->TotalBytes();
if (num_bytes != total_bytes) {
done(errors::Internal("RecvBufResponse returned ", num_bytes,
" bytes where to_tensor expected ",
to_tensor->TotalBytes()));
delete state;
return;
}
if (to_device->tensorflow_gpu_device_info()) {
// Move the bytes into a CPU tensor then use tensor-to-tensor copy.
// Use GPU-registered memory for the CPU tensor so the transfer
// goes faster.
Device* cpu_dev = nullptr;
Status status = dev_mgr_->LookupDevice("CPU:0", &cpu_dev);
if (!status.ok()) {
done(status);
delete state;
return;
}
AllocatorAttributes cpu_attr;
cpu_attr.set_gpu_compatible(true);
ScopedMemoryDebugAnnotation op_annotation(
"CollectiveRemoteAccessDistributed::RecvFromPeer"
"::recv_buf_callback",
step_id_, "dynamic", to_tensor->dtype(), &to_tensor->shape());
Tensor* cpu_tensor = new Tensor(cpu_dev->GetAllocator(cpu_attr),
to_tensor->dtype(), to_tensor->shape());
PopulateTensorFromExtra(extra, cpu_tensor);
// Then copy it to the GPU.
CopyTensor::ViaDMA("", // edge name (non-existent)
nullptr /*send_dev_ctx*/, to_device_ctx, cpu_dev,
to_device, cpu_attr, to_alloc_attr, cpu_tensor,
to_tensor, dev_to_dev_stream_index,
[this, cpu_tensor, done](const Status& s) {
delete cpu_tensor;
// This callback must not block, so execute
// done in another thread.
RunClosure([s, done] { done(s); });
});
delete state;
return;
} else {
// CPU device
PopulateTensorFromExtra(extra, to_tensor);
}
}
if (!s.ok() && errors::IsFailedPrecondition(s)) {
dev_resolver_->ClearTask(peer_task);
}
delete state;
done(s);
};
// Logic to execute once we have the device attributes for the server-side
// device.
auto dev_attributes_callback = [this, state, peer_device, peer_task, key,
to_device, to_device_ctx, to_alloc_attr,
to_tensor, client_locality,
recv_buf_callback](const Status& s) {
if (!s.ok()) {
recv_buf_callback(s);
} else {
state->call.reset(new RecvBufCall(
step_id_, peer_device, peer_task, key, to_device, to_device_ctx,
to_alloc_attr, to_tensor, client_locality, state->server_attributes,
&cancel_mgr_, worker_cache_));
state->call->Start(recv_buf_callback);
}
};
dev_resolver_->GetDeviceAttributesAsync(peer_device, peer_task,
&state->server_attributes,
dev_attributes_callback);
}
void CollectiveRemoteAccessDistributed::StartAbort(const Status& s) {
CollectiveRemoteAccessLocal::StartAbort(s);
cancel_mgr_.StartCancel();
}
} // namespace tensorflow