blob: f573685c0a599c287695327f14d0abaa5ccf30de [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/client.h"
#include <string>
#include <utility>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
namespace xla {
Client::Client(ServiceInterface* stub) : stub_(stub) {}
Client::~Client() = default;
StatusOr<Literal> Client::Transfer(const GlobalData& data,
const Shape* shape_with_layout) {
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferToClientResponse response;
VLOG(1) << "making transfer request";
VLOG(3) << "TransferToClientRequest: {" << request.DebugString() << "}";
Status s = stub_->TransferToClient(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToClientResponse: {" << response.DebugString() << "}";
if (!response.has_literal()) {
return FailedPrecondition(
"server provided response without a literal in "
"TransferToClient request");
}
return Literal::CreateFromProto(*response.mutable_literal());
}
StatusOr<std::unique_ptr<GlobalData>> Client::TransferToServer(
const LiteralSlice& literal, const DeviceHandle* device_handle) {
TransferToServerRequest request;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
TransferToServerResponse response;
VLOG(1) << "making transfer to server request";
VLOG(3) << "TransferToServerRequest: {" << request.DebugString() << "}";
Status s = stub_->TransferToServer(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToServerResponse: {" << response.DebugString() << "}";
if (!response.has_data()) {
return FailedPrecondition(
"server provided response without a data handle in "
"TransferToServer request");
}
return absl::make_unique<GlobalData>(stub_, response.data());
}
Status Client::TransferToInfeed(const LiteralSlice& literal, int64_t replica_id,
const DeviceHandle* device_handle) {
TransferToInfeedRequest request;
*request.mutable_literal() = literal.ToProto();
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
request.set_replica_id(replica_id);
TransferToInfeedResponse response;
VLOG(1) << "making transfer to infeed request";
VLOG(3) << "TransferToInfeedRequest: {" << request.DebugString() << "}";
Status s = stub_->TransferToInfeed(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "TransferToInfeedResponse: {" << response.DebugString() << "}";
return ::tensorflow::OkStatus();
}
StatusOr<Literal> Client::TransferFromOutfeed(
const Shape* shape_with_layout, int64_t replica_id,
const DeviceHandle* device_handle) {
TransferFromOutfeedRequest request;
if (device_handle) {
*request.mutable_device_handle() = *device_handle;
}
request.set_replica_id(replica_id);
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferFromOutfeedResponse response;
VLOG(1) << "making transfer from outfeed request";
VLOG(3) << "TransferFromOutfeedRequest: {" << request.DebugString() << "}";
Status s = stub_->TransferFromOutfeed(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "TransferFromOutfeedResponse: {" << response.DebugString() << "}";
if (!response.has_literal()) {
return FailedPrecondition(
"server provided response without a literal in "
"TransferToClient request");
}
return Literal::CreateFromProto(response.literal());
}
Status Client::ResetDevice() {
ResetDeviceRequest request;
ResetDeviceResponse response;
VLOG(1) << "making reset device request";
VLOG(3) << "ResetDeviceRequest: {" << request.DebugString() << "}";
Status s = stub_->ResetDevice(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "ResetDeviceResponse: {" << response.DebugString() << "}";
return ::tensorflow::OkStatus();
}
StatusOr<Literal> Client::ExecuteAndTransfer(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<GlobalData> data,
Execute(computation, arguments, execution_options, execution_profile));
absl::optional<Shape> shape_with_output_layout;
if (execution_options && execution_options->has_shape_with_output_layout()) {
shape_with_output_layout =
Shape(execution_options->shape_with_output_layout());
}
return Transfer(*data, shape_with_output_layout.has_value()
? &(*shape_with_output_layout)
: nullptr);
}
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
const Layout* output_layout) const {
ComputeConstantGraphRequest request;
*request.mutable_computation() = computation.proto();
if (output_layout != nullptr) {
*request.mutable_output_layout() = output_layout->ToProto();
}
ComputeConstantResponse response;
VLOG(2) << "making compute-constant-graph request";
Status s = stub_->ComputeConstantGraph(&request, &response);
VLOG(2) << "done with request";
if (!s.ok()) {
return s;
}
VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
if (!response.has_literal()) {
return InternalError(
"no computed literal in the provided response in ComputeConstantGraph "
"request");
}
return Literal::CreateFromProto(response.literal());
}
StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
return XlaComputation(module.hlo().hlo_module());
}
StatusOr<ExecutionHandle> Client::Compile(
const XlaComputation& computation, absl::Span<const Shape> argument_shapes,
const ExecutionOptions* execution_options) {
CompileRequest request;
*request.mutable_computation() = computation.proto();
if (execution_options == nullptr) {
*request.mutable_execution_options() = CreateDefaultExecutionOptions();
} else {
*request.mutable_execution_options() = *execution_options;
}
if (request.execution_options().device_handles_size() > 1) {
return InvalidArgument(
"Compiling with multiple device handles is not supported. Use "
"'Execute' instead.");
}
// The argument shapes affect how the computation is compiled.
for (const auto& arg_shape : argument_shapes) {
*request.add_input_shape_with_layout() = arg_shape.ToProto();
}
CompileResponse response;
VLOG(1) << "making compile request: " << request.ShortDebugString();
Status s = stub_->Compile(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
TF_RET_CHECK(response.has_handle());
return response.handle();
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const ExecutionHandle& handle, absl::Span<GlobalData* const> arguments,
ExecutionProfile* execution_profile) {
ExecuteRequest request;
*request.mutable_handle() = handle;
for (GlobalData* argument : arguments) {
CHECK(argument != nullptr) << "Argument pointers must not be null.";
*request.add_arguments() = argument->handle();
}
ExecuteResponse response;
VLOG(1) << "making execute request: " << request.ShortDebugString();
Status s = stub_->Execute(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
if (execution_profile != nullptr) {
*execution_profile = response.profile();
}
return absl::make_unique<GlobalData>(stub_, response.output());
}
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const XlaComputation& computation, absl::Span<GlobalData* const> arguments,
const ExecutionOptions* execution_options,
ExecutionProfile* execution_profile) {
// Create an ExecutionOptions if necessary, or set its DeviceHandles.
absl::optional<ExecutionOptions> options_storage;
if (!execution_options || execution_options->device_handles().empty()) {
if (execution_options) {
options_storage.emplace(*execution_options);
} else {
options_storage.emplace(CreateDefaultExecutionOptions());
}
execution_options = &*options_storage;
TF_ASSIGN_OR_RETURN(auto device_handles,
GetDeviceHandles(/*device_count=*/1));
TF_RET_CHECK(!device_handles.empty());
*options_storage->add_device_handles() = std::move(device_handles[0]);
}
std::vector<XlaComputationInstance> computation_instances = {
XlaComputationInstance{
computation,
std::vector<GlobalData*>(arguments.begin(), arguments.end()),
*execution_options, execution_profile}};
// Instead of invoking Compile() and Execute(), invoke
// Service::ExecuteParallel() to execute our one computation. Compile()
// caches the executable forever, which isn't what we want.
VLOG(1) << "Making ExecuteParallel request: "
<< execution_options->DebugString();
TF_ASSIGN_OR_RETURN(auto results, ExecuteParallel(computation_instances));
VLOG(1) << "ExecuteParallel request done.";
// The result selection is a bit hacky, but better than assuming it is
// device 0.
//
// TODO(b/118493728): Allow Execute to return one result per computation.
for (int64_t i = 0, end = results.size(); i < end; i++) {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(*results[i]));
if (!ShapeUtil::IsEmptyTuple(shape)) {
VLOG(3) << "Fetching result from device " << i << ": "
<< ShapeUtil::HumanString(shape);
return std::move(results[i]);
}
}
TF_RET_CHECK(!results.empty());
VLOG(1) << "Defaulting to device 0 result";
return std::move(results[0]);
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
absl::Span<const XlaComputationInstance> computations) {
ExecuteGraphParallelRequest request;
for (const XlaComputationInstance& computation : computations) {
ExecuteGraphRequest single_request;
*single_request.mutable_computation() = computation.computation.proto();
for (GlobalData* argument : computation.arguments) {
*single_request.add_arguments() = argument->handle();
}
*single_request.mutable_execution_options() = computation.execution_options;
*request.add_requests() = single_request;
}
ExecuteParallelResponse response;
VLOG(1) << "making execute-graph-parallel request: "
<< request.ShortDebugString();
Status s = stub_->ExecuteGraphParallel(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
std::vector<std::unique_ptr<GlobalData>> outputs;
for (size_t i = 0, end = response.responses_size(); i < end; ++i) {
outputs.push_back(
absl::make_unique<GlobalData>(stub_, response.responses(i).output()));
if (i < computations.size() &&
computations[i].execution_profile != nullptr) {
*computations[i].execution_profile = response.responses(i).profile();
}
}
return std::move(outputs);
}
StatusOr<std::vector<DeviceHandle>> Client::GetDeviceHandles(
int64_t device_count) {
if (device_count < 1) {
return InvalidArgument("device_count must be greater than 0");
}
GetDeviceHandlesRequest request;
request.set_device_count(device_count);
GetDeviceHandlesResponse response;
VLOG(1) << "making get device request: " << request.ShortDebugString();
Status s = stub_->GetDeviceHandles(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
std::vector<DeviceHandle> device_handles;
const auto& response_device_handles = response.device_handles();
device_handles.reserve(response_device_handles.size());
for (const DeviceHandle& device_handle : response_device_handles) {
device_handles.push_back(device_handle);
}
return device_handles;
}
Status Client::Unregister(const GlobalData& data) {
UnregisterRequest request;
*request.add_data() = data.handle();
UnregisterResponse response;
VLOG(1) << "making unregister request";
Status s = stub_->Unregister(&request, &response);
VLOG(1) << "done with request";
return s;
}
StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::DeconstructTuple(
const GlobalData& data) {
DeconstructTupleRequest request;
*request.mutable_tuple_handle() = data.handle();
DeconstructTupleResponse response;
VLOG(1) << "making DestructTuple request";
Status s = stub_->DeconstructTuple(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
std::vector<std::unique_ptr<GlobalData>> handles;
for (auto& handle : response.element_handles()) {
handles.push_back(absl::make_unique<GlobalData>(stub_, handle));
}
return std::move(handles);
}
StatusOr<ComputationStats> Client::GetComputationStats(
const XlaComputation& computation,
const DebugOptions& debug_options) const {
ComputationGraphStatsRequest request;
// TODO(b/74197823): Find a way to avoid the copy of the hlo proto.
*request.mutable_computation() = computation.proto();
*request.mutable_debug_options() = debug_options;
ComputationStatsResponse response;
VLOG(1) << "making computation graph stats request";
Status s = stub_->GetComputationGraphStats(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
CHECK(response.has_stats());
return response.stats();
}
StatusOr<std::unique_ptr<ProgramShape>> Client::GetComputationShape(
const XlaComputation& computation) {
TF_ASSIGN_OR_RETURN(const auto& result, computation.GetProgramShape());
return absl::make_unique<ProgramShape>(result);
}
StatusOr<Shape> Client::GetShape(const GlobalData& data) {
GetShapeRequest request;
*request.mutable_data() = data.handle();
GetShapeResponse response;
VLOG(1) << "making get shape request";
Status s = stub_->GetShape(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
return Shape(response.shape());
}
StatusOr<std::string> Client::ExecutionStatsAsString(
const XlaComputation& computation, const ExecutionProfile& profile) {
TF_ASSIGN_OR_RETURN(
auto computation_stats,
GetComputationStats(computation, GetDebugOptionsFromFlags()));
int64_t total_flops =
computation_stats.flop_count() + computation_stats.transcendental_count();
if (profile.compute_time_ns() > 0) {
int64_t nanoseconds = profile.compute_time_ns();
int64_t cycle_count = profile.compute_cycle_count();
double gflops = total_flops / nanoseconds;
return absl::StrCat(
"[Execution Statistics] flop count: ", computation_stats.flop_count(),
", transcendental count: ", computation_stats.transcendental_count(),
", compute execution time: ", nanoseconds, " nsec",
", compute cycles: ", cycle_count, ", performance: ", gflops,
"gflop/s");
}
return std::string("[Execution Statistics] not available.");
}
StatusOr<ChannelHandle> Client::CreateChannelHandleByType(
ChannelHandle::ChannelType type) {
CreateChannelHandleRequest request;
request.set_channel_type(type);
CreateChannelHandleResponse response;
VLOG(1) << "making create channel handle request";
Status s = stub_->CreateChannelHandle(&request, &response);
VLOG(1) << "done with request";
if (!s.ok()) {
return s;
}
return response.channel();
}
StatusOr<ChannelHandle> Client::CreateChannelHandle() {
return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_DEVICE);
}
StatusOr<ChannelHandle> Client::CreateHostToDeviceChannelHandle() {
return CreateChannelHandleByType(ChannelHandle::HOST_TO_DEVICE);
}
StatusOr<ChannelHandle> Client::CreateDeviceToHostChannelHandle() {
return CreateChannelHandleByType(ChannelHandle::DEVICE_TO_HOST);
}
} // namespace xla