blob: 87575ee67a11e24d696b3ea689a1b9b5ff406bba [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
#include <cstddef>
#include <memory>
#include "absl/memory/memory.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/time/clock.h"
#include "absl/time/time.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/rpc/client/save_profile.h"
#include "tensorflow/core/profiler/utils/time_utils.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
namespace tensorflow {
namespace profiler {
namespace {
constexpr uint64 kMaxEvents = 1000000;
// TODO(yisitu) merge with the implementation in capture_profile.
void PopulateProfileRequest(const RemoteProfilerSessionManagerOptions& options,
absl::string_view session_id,
absl::string_view host_name,
ProfileRequest* request) {
request->set_max_events(kMaxEvents);
request->set_repository_root(options.profiler_options().repository_path());
request->set_session_id(session_id.data(), session_id.size());
request->add_tools("trace_viewer");
request->add_tools("op_profile");
request->add_tools("input_pipeline");
request->add_tools("kernel_stats");
request->add_tools("memory_viewer");
request->add_tools("memory_profile");
request->add_tools("overview_page");
request->add_tools("pod_viewer");
request->add_tools("tensorflow_stats");
request->set_host_name(host_name.data(), host_name.size());
*request->mutable_opts() = options.profiler_options();
request->set_duration_ms(options.profiler_options().duration_ms());
}
} // namespace
/*static*/ std::unique_ptr<RemoteProfilerSessionManager>
RemoteProfilerSessionManager::Create(
const RemoteProfilerSessionManagerOptions& options,
tensorflow::Status& out_status, AddressResolver resolver) {
VLOG(1) << "Creating a RemoteProfilerSessionManager.";
auto session_manager =
absl::WrapUnique(new RemoteProfilerSessionManager(options, resolver));
out_status = session_manager->Init();
if (!out_status.ok()) {
return nullptr;
}
return session_manager;
}
RemoteProfilerSessionManager::RemoteProfilerSessionManager(
RemoteProfilerSessionManagerOptions options, AddressResolver resolver)
: options_(std::move(options)) {
if (resolver) {
resolver_ = std::move(resolver);
} else {
resolver_ = [](absl::string_view addr) { return std::string(addr); };
}
}
RemoteProfilerSessionManager::~RemoteProfilerSessionManager() {
VLOG(2) << "Destroying RemoteProfilerSessionManager.";
}
Status RemoteProfilerSessionManager::Init() {
mutex_lock lock(mutex_);
VLOG(1) << "SessionManager initializing.";
// TODO(b/169482824) Move validation to call site.
Status status = ValidateOptionsLocked();
if (!status.ok()) {
LOG(ERROR) << status;
return status;
}
std::string session_id = GetCurrentTimeStampAsString();
const absl::Time session_created_ts =
absl::FromUnixNanos(options_.session_creation_timestamp_ns());
const absl::Time deadline =
session_created_ts +
absl::Milliseconds(options_.max_session_duration_ms());
LOG(INFO) << "Deadline set to " << deadline
<< " because max_session_duration_ms was "
<< options_.max_session_duration_ms()
<< " and session_creation_timestamp_ns was "
<< options_.session_creation_timestamp_ns() << " ["
<< session_created_ts << "]";
// Prepare a list of clients.
clients_.reserve(options_.service_addresses_size());
for (auto& service_addr : options_.service_addresses()) {
std::string resolved_service_addr = resolver_(service_addr);
ProfileRequest profile_request;
PopulateProfileRequest(options_, session_id, resolved_service_addr,
&profile_request);
// Creation also issues Profile RPC asynchronously.
auto client = RemoteProfilerSession::Create(
std::move(resolved_service_addr), deadline, std::move(profile_request));
clients_.push_back(std::move(client));
}
LOG(INFO) << absl::StrFormat("Issued Profile gRPC to %u clients",
clients_.size());
return Status::OK();
}
Status RemoteProfilerSessionManager::ValidateOptionsLocked() {
if (!options_.service_addresses_size()) {
return errors::InvalidArgument("No service addresses specified.");
}
if (options_.profiler_options().duration_ms() == 0) {
if (options_.max_session_duration_ms() != 0) {
return errors::InvalidArgument(
"If local profiler duration is unbounded, profiling session duration "
"must be unbounded.");
}
}
if (options_.max_session_duration_ms() <
options_.profiler_options().duration_ms()) {
return errors::InvalidArgument(
"The maximum profiling session duration must be greater than or equal "
"to the local profiler duration.");
}
return Status::OK();
}
std::vector<RemoteProfilerSessionManager::Response>
RemoteProfilerSessionManager::WaitForCompletion() {
mutex_lock lock(mutex_);
std::vector<RemoteProfilerSessionManager::Response> remote_responses;
remote_responses.reserve(clients_.size());
for (auto& client : clients_) {
remote_responses.emplace_back();
auto* profile_response = &remote_responses.back().profile_response;
Status& status = remote_responses.back().status;
std::string* service_addr = &remote_responses.back().service_addr;
*profile_response = client->WaitForCompletion(status);
*service_addr = client->GetServiceAddress();
}
return remote_responses;
}
} // namespace profiler
} // namespace tensorflow