blob: 8a178a281fe1443d9d85a3cfab328ea93ebf45e0 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
#include <limits>
#include "grpcpp/grpcpp.h"
#include "absl/memory/memory.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace profiler {
namespace {
inline Status FromGrpcStatus(const ::grpc::Status& s) {
return s.ok() ? Status::OK()
: Status(static_cast<error::Code>(s.error_code()),
s.error_message());
}
template <typename T>
std::unique_ptr<typename T::Stub> CreateStub(
const std::string& service_address) {
::grpc::ChannelArguments channel_args;
channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
// Default URI prefix is "dns:///" if not provided.
auto channel = ::grpc::CreateCustomChannel(
service_address, ::grpc::InsecureChannelCredentials(), channel_args);
if (!channel) {
LOG(ERROR) << "Unable to create channel" << service_address;
}
return T::NewStub(channel);
}
} // namespace
Status ProfileGrpc(const std::string& service_address,
const ProfileRequest& request, ProfileResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfilerService::Stub> stub =
CreateStub<grpc::ProfilerService>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Profile(&context, request, response)));
return Status::OK();
}
Status NewSessionGrpc(const std::string& service_address,
const NewProfileSessionRequest& request,
NewProfileSessionResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
CreateStub<grpc::ProfileAnalysis>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->NewSession(&context, request, response)));
return Status::OK();
}
Status MonitorGrpc(const std::string& service_address,
const MonitorRequest& request, MonitorResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfilerService::Stub> stub =
CreateStub<grpc::ProfilerService>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Monitor(&context, request, response)));
return Status::OK();
}
/*static*/ std::unique_ptr<RemoteProfilerSession> RemoteProfilerSession::Create(
std::string service_address, absl::Time deadline,
ProfileRequest profile_request) {
auto instance = absl::WrapUnique(new RemoteProfilerSession(
std::move(service_address), deadline, std::move(profile_request)));
instance->ProfileAsync();
return instance;
}
RemoteProfilerSession::RemoteProfilerSession(std::string service_address,
absl::Time deadline,
ProfileRequest profile_request)
: response_(absl::make_unique<ProfileResponse>()),
service_address_(std::move(service_address)),
stub_(CreateStub<grpc::ProfilerService>(service_address_)),
deadline_(deadline),
profile_request_(std::move(profile_request)) {}
RemoteProfilerSession::~RemoteProfilerSession() {
LOG(INFO) << "Waiting for completion.";
Status dummy;
WaitForCompletion(dummy);
grpc_context_.TryCancel();
}
void RemoteProfilerSession::ProfileAsync() {
LOG(INFO) << "Asynchronous gRPC Profile() to " << service_address_;
grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
VLOG(1) << "Deadline set to " << deadline_;
rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
rpc_->Finish(response_.get(), &grpc_status_,
static_cast<void*>(&status_on_completion_));
VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
}
std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
Status& out_status) {
if (!response_) {
out_status = errors::FailedPrecondition(
"WaitForCompletion must only be called once.");
return nullptr;
}
void* got_tag = nullptr;
bool ok = false;
// Next blocks until there is a response in the completion queue. Expect the
// completion queue to have exactly a single response because deadline is set
// and completion queue is only drained once at destruction time.
bool success = cq_.Next(&got_tag, &ok);
if (!success || !ok || got_tag == nullptr) {
out_status =
errors::Internal("Missing or invalid event from completion queue.");
return nullptr;
}
VLOG(1) << "Writing out status.";
// For the event read from the completion queue, expect that got_tag points to
// the memory location of status_on_completion.
DCHECK_EQ(got_tag, &status_on_completion_);
// tagged status points to pre-allocated memory which is okay to overwrite.
status_on_completion_.Update(FromGrpcStatus(grpc_status_));
if (status_on_completion_.code() == error::DEADLINE_EXCEEDED) {
LOG(WARNING) << status_on_completion_;
} else if (!status_on_completion_.ok()) {
LOG(ERROR) << status_on_completion_;
}
out_status = status_on_completion_;
return std::move(response_);
}
} // namespace profiler
} // namespace tensorflow