* Added a RemoteProfilerSession that encapsulates connection to a gRPC remote profiler session.
* Added a RemoteProfilerSessionManager class that encapsulates a collection of RemoteProfilerSessions.
* Added tests for each case.
PiperOrigin-RevId: 334508892
Change-Id: I81988a6e0fa60ec1f6588661c35358f7ff09b8dd
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
index d895818..dc8f1ce 100644
--- a/tensorflow/core/profiler/lib/profiler_session.cc
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -147,6 +147,7 @@
}
ProfilerSession::~ProfilerSession() {
+ VLOG(1) << "Profiler session stopping.";
for (auto& profiler : profilers_) {
profiler->Stop().IgnoreError();
}
diff --git a/tensorflow/core/profiler/profiler_options.proto b/tensorflow/core/profiler/profiler_options.proto
index 8b4fc3d..7858f08 100644
--- a/tensorflow/core/profiler/profiler_options.proto
+++ b/tensorflow/core/profiler/profiler_options.proto
@@ -2,6 +2,7 @@
package tensorflow;
+// Next ID: 11
message ProfileOptions {
// Some default value of option are not proto3 default value. Use this version
// to determine if we should use default option value instead of proto3
@@ -50,5 +51,32 @@
// Whether serialize hlo_proto when XLA is used. (version >= 1)
bool enable_hlo_proto = 7;
- // next-field: 8
+ // The local profiler starts profiling at this Unix timestamp in nanoseconds.
+ uint64 start_timestamp_ns = 8;
+
+ // The local profiler collects `duration_ms` milliseconds of data. If the
+ // value is 0, profiling continues until interrupted.
+ uint64 duration_ms = 9;
+
+ // Directory to save profile data to. No-op when empty.
+ string repository_path = 10;
+}
+
+// Options for remote profiler session manager.
+// Next ID: 5
+message RemoteProfilerSessionManagerOptions {
+ // Options for each local profiler.
+ ProfileOptions profiler_options = 1;
+
+ // List of servers to profile. Supported formats: host:port.
+ repeated string service_addresses = 2;
+
+ // Unix timestamp of when the session was started.
+ uint64 session_creation_timestamp_ns = 3;
+
+ // Maximum time (in milliseconds) a profiling session manager waits for all
+ // profilers to finish after issuing gRPC request. If value is 0, session
+ // continues until interrupted. Otherwise, value must be greater than
+ // profiler_options.duration_ms.
+ uint64 max_session_duration_ms = 4;
}
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index 3531e0d..bb4c1a5 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -55,6 +55,7 @@
hdrs = ["profiler_server.h"],
visibility = [
"//tensorflow/compiler/xla/python:__pkg__",
+ "//tensorflow/core/profiler:internal",
"//tensorflow/python:__pkg__",
"//tensorflow/python/profiler/internal:__pkg__",
],
diff --git a/tensorflow/core/profiler/rpc/client/BUILD b/tensorflow/core/profiler/rpc/client/BUILD
index 5702e83..9695fba 100644
--- a/tensorflow/core/profiler/rpc/client/BUILD
+++ b/tensorflow/core/profiler/rpc/client/BUILD
@@ -1,6 +1,7 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper") # buildifier: disable=same-origin-load
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
# For platform specific build config
load(
@@ -58,6 +59,9 @@
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler:profiler_analysis_proto_cc",
"//tensorflow/core/profiler:profiler_service_proto_cc",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/time",
tf_grpc_cc_dependency(),
],
alwayslink = True,
@@ -68,3 +72,67 @@
visibility = ["//tensorflow/python/profiler/internal:__pkg__"],
deps = [":profiler_client_impl"],
)
+
+tf_cc_test(
+ name = "profiler_client_test",
+ srcs = [
+ "profiler_client_test.cc",
+ "profiler_client_test_util.h",
+ ],
+ tags = ["external"], # So that test suite reruns unconditionally.
+ deps = [
+ ":profiler_client_impl",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/lib/core:errors",
+ "//tensorflow/core/platform",
+ "//tensorflow/core/platform:env",
+ "//tensorflow/core/profiler:profiler_service_proto_cc",
+ "//tensorflow/core/profiler/lib:profiler_session",
+ "//tensorflow/core/profiler/rpc:profiler_server_impl",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+cc_library(
+ name = "remote_profiler_session_manager",
+ srcs = ["remote_profiler_session_manager.cc"],
+ hdrs = ["remote_profiler_session_manager.h"],
+ visibility = ["//tensorflow/core/profiler:internal"],
+ deps = [
+ ":profiler_client_impl",
+ ":save_profile",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/profiler:profiler_options_proto_cc",
+ "//tensorflow/core/profiler/lib:profiler_session",
+ "//tensorflow/core/profiler/utils:time_utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ ],
+)
+
+tf_cc_test(
+ name = "remote_profiler_session_manager_test",
+ srcs = [
+ "profiler_client_test_util.h",
+ "remote_profiler_session_manager_test.cc",
+ ],
+ tags = ["external"], # So that test suite reruns unconditionally.
+ deps = [
+ ":remote_profiler_session_manager",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core/platform",
+ "//tensorflow/core/profiler:profiler_service_proto_cc",
+ "//tensorflow/core/profiler/lib:profiler_session",
+ "//tensorflow/core/profiler/rpc:profiler_server_impl",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
+ ],
+)
diff --git a/tensorflow/core/profiler/rpc/client/profiler_client.cc b/tensorflow/core/profiler/rpc/client/profiler_client.cc
index 5eec02b..8a178a2 100644
--- a/tensorflow/core/profiler/rpc/client/profiler_client.cc
+++ b/tensorflow/core/profiler/rpc/client/profiler_client.cc
@@ -17,6 +17,9 @@
#include <limits>
#include "grpcpp/grpcpp.h"
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
@@ -34,50 +37,122 @@
}
template <typename T>
-std::unique_ptr<typename T::Stub> CreateStub(const std::string& service_addr) {
+std::unique_ptr<typename T::Stub> CreateStub(
+ const std::string& service_address) {
::grpc::ChannelArguments channel_args;
channel_args.SetMaxReceiveMessageSize(std::numeric_limits<int32>::max());
// Default URI prefix is "dns:///" if not provided.
auto channel = ::grpc::CreateCustomChannel(
- service_addr, ::grpc::InsecureChannelCredentials(), channel_args);
+ service_address, ::grpc::InsecureChannelCredentials(), channel_args);
if (!channel) {
- LOG(ERROR) << "Unable to create channel" << service_addr;
+ LOG(ERROR) << "Unable to create channel" << service_address;
}
return T::NewStub(channel);
}
} // namespace
-Status ProfileGrpc(const std::string& service_addr,
+Status ProfileGrpc(const std::string& service_address,
const ProfileRequest& request, ProfileResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfilerService::Stub> stub =
- CreateStub<grpc::ProfilerService>(service_addr);
+ CreateStub<grpc::ProfilerService>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Profile(&context, request, response)));
return Status::OK();
}
-Status NewSessionGrpc(const std::string& service_addr,
+Status NewSessionGrpc(const std::string& service_address,
const NewProfileSessionRequest& request,
NewProfileSessionResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfileAnalysis::Stub> stub =
- CreateStub<grpc::ProfileAnalysis>(service_addr);
+ CreateStub<grpc::ProfileAnalysis>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->NewSession(&context, request, response)));
return Status::OK();
}
-Status MonitorGrpc(const std::string& service_addr,
+Status MonitorGrpc(const std::string& service_address,
const MonitorRequest& request, MonitorResponse* response) {
::grpc::ClientContext context;
std::unique_ptr<grpc::ProfilerService::Stub> stub =
- CreateStub<grpc::ProfilerService>(service_addr);
+ CreateStub<grpc::ProfilerService>(service_address);
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Monitor(&context, request, response)));
return Status::OK();
}
+/*static*/ std::unique_ptr<RemoteProfilerSession> RemoteProfilerSession::Create(
+ std::string service_address, absl::Time deadline,
+ ProfileRequest profile_request) {
+ auto instance = absl::WrapUnique(new RemoteProfilerSession(
+ std::move(service_address), deadline, std::move(profile_request)));
+ instance->ProfileAsync();
+ return instance;
+}
+
+RemoteProfilerSession::RemoteProfilerSession(std::string service_address,
+ absl::Time deadline,
+ ProfileRequest profile_request)
+ : response_(absl::make_unique<ProfileResponse>()),
+ service_address_(std::move(service_address)),
+ stub_(CreateStub<grpc::ProfilerService>(service_address_)),
+ deadline_(deadline),
+ profile_request_(std::move(profile_request)) {}
+
+RemoteProfilerSession::~RemoteProfilerSession() {
+ LOG(INFO) << "Waiting for completion.";
+ Status dummy;
+ WaitForCompletion(dummy);
+ grpc_context_.TryCancel();
+}
+
+void RemoteProfilerSession::ProfileAsync() {
+ LOG(INFO) << "Asynchronous gRPC Profile() to " << service_address_;
+ grpc_context_.set_deadline(absl::ToChronoTime(deadline_));
+ VLOG(1) << "Deadline set to " << deadline_;
+ rpc_ = stub_->AsyncProfile(&grpc_context_, profile_request_, &cq_);
+ rpc_->Finish(response_.get(), &grpc_status_,
+ static_cast<void*>(&status_on_completion_));
+ VLOG(2) << "Asynchronous gRPC Profile() issued." << absl::Now();
+}
+
+std::unique_ptr<ProfileResponse> RemoteProfilerSession::WaitForCompletion(
+ Status& out_status) {
+ if (!response_) {
+ out_status = errors::FailedPrecondition(
+ "WaitForCompletion must only be called once.");
+ return nullptr;
+ }
+
+ void* got_tag = nullptr;
+ bool ok = false;
+ // Next blocks until there is a response in the completion queue. Expect the
+ // completion queue to have exactly a single response because deadline is set
+ // and completion queue is only drained once at destruction time.
+ bool success = cq_.Next(&got_tag, &ok);
+ if (!success || !ok || got_tag == nullptr) {
+ out_status =
+ errors::Internal("Missing or invalid event from completion queue.");
+ return nullptr;
+ }
+
+ VLOG(1) << "Writing out status.";
+ // For the event read from the completion queue, expect that got_tag points to
+ // the memory location of status_on_completion.
+ DCHECK_EQ(got_tag, &status_on_completion_);
+ // tagged status points to pre-allocated memory which is okay to overwrite.
+ status_on_completion_.Update(FromGrpcStatus(grpc_status_));
+ if (status_on_completion_.code() == error::DEADLINE_EXCEEDED) {
+ LOG(WARNING) << status_on_completion_;
+ } else if (!status_on_completion_.ok()) {
+ LOG(ERROR) << status_on_completion_;
+ }
+
+ out_status = status_on_completion_;
+ return std::move(response_);
+}
+
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/client/profiler_client.h b/tensorflow/core/profiler/rpc/client/profiler_client.h
index d946d60..b171c67 100644
--- a/tensorflow/core/profiler/rpc/client/profiler_client.h
+++ b/tensorflow/core/profiler/rpc/client/profiler_client.h
@@ -17,6 +17,11 @@
#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_
#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_H_
+#include <memory>
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "absl/time/time.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/profiler/profiler_analysis.grpc.pb.h"
#include "tensorflow/core/profiler/profiler_service.grpc.pb.h"
@@ -24,16 +29,67 @@
namespace tensorflow {
namespace profiler {
-Status ProfileGrpc(const std::string& service_addr,
+// Note that tensorflow/tools/def_file_filter/symbols_pybind.txt is incompatible
+// with absl::string_view.
+Status ProfileGrpc(const std::string& service_address,
const ProfileRequest& request, ProfileResponse* response);
-Status NewSessionGrpc(const std::string& service_addr,
+Status NewSessionGrpc(const std::string& service_address,
const NewProfileSessionRequest& request,
NewProfileSessionResponse* response);
-Status MonitorGrpc(const std::string& service_addr,
+Status MonitorGrpc(const std::string& service_address,
const MonitorRequest& request, MonitorResponse* response);
+class RemoteProfilerSession {
+ public:
+ // Creates an instance and starts a remote profiling session immediately.
+ // This is a non-blocking call and does not wait for a response.
+ // Response must outlive the instantiation.
+ static std::unique_ptr<RemoteProfilerSession> Create(
+ std::string service_address, absl::Time deadline,
+ ProfileRequest profile_request);
+
+ // Not copyable or movable.
+ RemoteProfilerSession(const RemoteProfilerSession&) = delete;
+ RemoteProfilerSession operator=(const RemoteProfilerSession&) = delete;
+
+ ~RemoteProfilerSession();
+
+ absl::string_view GetServiceAddress() const { return service_address_; }
+
+ // Blocks until a response has been received or until deadline expiry,
+ // whichever is first. Subsequent calls after the first will yield nullptr and
+ // an error status.
+ std::unique_ptr<ProfileResponse> WaitForCompletion(Status& out_status);
+
+ private:
+ explicit RemoteProfilerSession(std::string service_addr, absl::Time deadline,
+ ProfileRequest profile_request);
+
+ // Starts a remote profiling session. This is a non-blocking call.
+ // Will be called exactly once during instantiation.
+ // RPC will write to response.profile_response eagerly. However, since
+ // response.status requires a conversion from grpc::Status, it can only be
+ // evaluated lazily at WaitForCompletion() time.
+ void ProfileAsync();
+
+ Status status_on_completion_;
+ std::unique_ptr<ProfileResponse> response_;
+ // Client address and connection attributes.
+ std::string service_address_;
+ std::unique_ptr<grpc::ProfilerService::Stub> stub_;
+ absl::Time deadline_;
+ ::grpc::ClientContext grpc_context_;
+ std::unique_ptr<::grpc::ClientAsyncResponseReader<ProfileResponse>> rpc_;
+ ::grpc::Status grpc_status_;
+
+ // Asynchronous completion queue states.
+ ::grpc::CompletionQueue cq_;
+
+ ProfileRequest profile_request_;
+};
+
} // namespace profiler
} // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/client/profiler_client_test.cc b/tensorflow/core/profiler/rpc/client/profiler_client_test.cc
new file mode 100644
index 0000000..119d4fd
--- /dev/null
+++ b/tensorflow/core/profiler/rpc/client/profiler_client_test.cc
@@ -0,0 +1,147 @@
+/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
+
+#include <memory>
+#include <string>
+
+#include "absl/strings/str_format.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
+#include "tensorflow/core/profiler/rpc/client/profiler_client_test_util.h"
+#include "tensorflow/core/profiler/rpc/profiler_server.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+using ::tensorflow::profiler::test::DurationApproxLess;
+using ::tensorflow::profiler::test::DurationNear;
+using ::tensorflow::profiler::test::StartServer;
+
+TEST(RemoteProfilerSession, Simple) {
+ absl::Duration duration = absl::Milliseconds(10);
+ ProfileRequest request;
+ std::string service_addr;
+ auto server = StartServer(duration, &service_addr, &request);
+ absl::Duration grace = absl::Seconds(1);
+ absl::Duration max_duration = duration + grace;
+ absl::Time approx_start = absl::Now();
+ absl::Time deadline = approx_start + max_duration;
+
+ auto remote_session =
+ RemoteProfilerSession::Create(service_addr, deadline, request);
+
+ Status status;
+ auto response = remote_session->WaitForCompletion(status);
+ absl::Duration elapsed = absl::Now() - approx_start;
+ // At end of session this evaluates to true still.
+ EXPECT_TRUE(status.ok());
+ EXPECT_FALSE(response->empty_trace());
+ EXPECT_GT(response->tool_data_size(), 0);
+ EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
+}
+
+TEST(RemoteProfilerSession, WaitNotCalled) {
+ absl::Duration duration = absl::Milliseconds(10);
+ ProfileRequest request;
+ std::string service_addr;
+ auto server = StartServer(duration, &service_addr, &request);
+ absl::Duration grace = absl::Seconds(1);
+ absl::Duration max_duration = duration + grace;
+ absl::Time approx_start = absl::Now();
+ absl::Time deadline = approx_start + max_duration;
+
+ auto remote_session =
+ RemoteProfilerSession::Create(service_addr, deadline, request);
+ absl::Duration elapsed = absl::Now() - approx_start;
+
+ EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
+}
+
+TEST(RemoteProfilerSession, Timeout) {
+ absl::Duration duration = absl::Milliseconds(10);
+ ProfileRequest request;
+ std::string service_addr;
+ auto server = StartServer(duration, &service_addr, &request);
+ // Expect this to fail immediately since deadline was set to the past,
+ auto remote_session =
+ RemoteProfilerSession::Create(service_addr, absl::Now(), request);
+ Status status;
+ auto response = remote_session->WaitForCompletion(status);
+ // At end of session we will have a timeout error.
+ EXPECT_EQ(status.code(), error::DEADLINE_EXCEEDED);
+
+ EXPECT_FALSE(response->empty_trace()); // This defaults to false.
+ EXPECT_EQ(response->tool_data_size(), 0);
+}
+
+TEST(RemoteProfilerSession, LongDeadline) {
+ absl::Duration duration = absl::Milliseconds(10);
+ ProfileRequest request;
+ std::string service_addr;
+ auto server = StartServer(duration, &service_addr, &request);
+
+ absl::Time approx_start = absl::Now();
+ absl::Duration grace = absl::Seconds(1000);
+ absl::Duration max_duration = duration + grace;
+ const absl::Time deadline = approx_start + max_duration;
+
+ auto remote_session =
+ RemoteProfilerSession::Create(service_addr, deadline, request);
+ Status status;
+ auto response = remote_session->WaitForCompletion(status);
+ absl::Duration elapsed = absl::Now() - approx_start;
+ // At end of session this evaluates to true still.
+ EXPECT_TRUE(status.ok());
+ EXPECT_FALSE(response->empty_trace());
+ EXPECT_GT(response->tool_data_size(), 0);
+ // Elapsed time is near profiling duration despite long grace period.
+ EXPECT_THAT(elapsed, DurationNear(duration));
+}
+
+TEST(RemoteProfilerSession, LongDuration) {
+ absl::Duration duration = absl::Seconds(3);
+ ProfileRequest request;
+ std::string service_addr;
+ auto server = StartServer(duration, &service_addr, &request);
+
+ absl::Time approx_start = absl::Now();
+ // Empirically determined value.
+ absl::Duration grace = absl::Seconds(2);
+ absl::Duration max_duration = duration + grace;
+ const absl::Time deadline = approx_start + max_duration;
+
+ auto remote_session =
+ RemoteProfilerSession::Create(service_addr, deadline, request);
+ Status status;
+ auto response = remote_session->WaitForCompletion(status);
+ absl::Duration elapsed = absl::Now() - approx_start;
+ // At end of session this evaluates to true still.
+ EXPECT_TRUE(status.ok());
+ EXPECT_FALSE(response->empty_trace());
+ EXPECT_GT(response->tool_data_size(), 0);
+ // Elapsed time takes longer to complete for larger traces.
+ EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
+}
+
+} // namespace
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h b/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h
new file mode 100644
index 0000000..925f174
--- /dev/null
+++ b/tensorflow/core/profiler/rpc/client/profiler_client_test_util.h
@@ -0,0 +1,77 @@
+/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GRPC client to perform on-demand profiling
+
+#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_
+#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_
+
+#include <memory>
+#include <string>
+
+#include "absl/memory/memory.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
+#include "tensorflow/core/profiler/rpc/profiler_server.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace test {
+
+inline std::unique_ptr<ProfilerServer> StartServer(
+ absl::Duration duration, std::string* service_addresses,
+ ProfileRequest* request = nullptr) {
+ auto profiler_server = absl::make_unique<ProfilerServer>();
+ int port = testing::PickUnusedPortOrDie();
+ profiler_server->StartProfilerServer(port);
+
+ DCHECK(service_addresses);
+ *service_addresses = absl::StrFormat("localhost:%d", port);
+
+ if (request) {
+ request->set_duration_ms(absl::ToInt64Milliseconds(duration));
+ request->set_max_events(10000);
+ *request->mutable_opts() = ProfilerSession::DefaultOptions();
+ request->mutable_opts()->set_duration_ms(
+ absl::ToInt64Milliseconds(duration));
+ request->set_session_id("test_session");
+ request->set_host_name(*service_addresses);
+ }
+
+ LOG(INFO) << "Started " << *service_addresses << " at " << absl::Now();
+ LOG(INFO) << "Duration: " << duration;
+
+ return profiler_server;
+}
+
+inline ::testing::Matcher<absl::Duration> DurationNear(
+ const absl::Duration duration, absl::Duration epsilon = absl::Seconds(1)) {
+ return ::testing::AllOf(::testing::Ge(duration - epsilon),
+ ::testing::Le(duration + epsilon));
+}
+
+inline ::testing::Matcher<absl::Duration> DurationApproxLess(
+ const absl::Duration duration, absl::Duration epsilon = absl::Seconds(1)) {
+ return ::testing::Le(duration + epsilon);
+}
+
+} // namespace test
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_PROFILER_CLIENT_TEST_H_
diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc
new file mode 100644
index 0000000..87575ee
--- /dev/null
+++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.cc
@@ -0,0 +1,177 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
+
+#include <cstddef>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/rpc/client/save_profile.h"
+#include "tensorflow/core/profiler/utils/time_utils.h"
+#include "tensorflow/core/protobuf/error_codes.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+constexpr uint64 kMaxEvents = 1000000;
+
+// TODO(yisitu) merge with the implementation in capture_profile.
+void PopulateProfileRequest(const RemoteProfilerSessionManagerOptions& options,
+ absl::string_view session_id,
+ absl::string_view host_name,
+ ProfileRequest* request) {
+ request->set_max_events(kMaxEvents);
+ request->set_repository_root(options.profiler_options().repository_path());
+ request->set_session_id(session_id.data(), session_id.size());
+ request->add_tools("trace_viewer");
+ request->add_tools("op_profile");
+ request->add_tools("input_pipeline");
+ request->add_tools("kernel_stats");
+ request->add_tools("memory_viewer");
+ request->add_tools("memory_profile");
+ request->add_tools("overview_page");
+ request->add_tools("pod_viewer");
+ request->add_tools("tensorflow_stats");
+ request->set_host_name(host_name.data(), host_name.size());
+ *request->mutable_opts() = options.profiler_options();
+ request->set_duration_ms(options.profiler_options().duration_ms());
+}
+
+} // namespace
+
+/*static*/ std::unique_ptr<RemoteProfilerSessionManager>
+RemoteProfilerSessionManager::Create(
+ const RemoteProfilerSessionManagerOptions& options,
+ tensorflow::Status& out_status, AddressResolver resolver) {
+ VLOG(1) << "Creating a RemoteProfilerSessionManager.";
+ auto session_manager =
+ absl::WrapUnique(new RemoteProfilerSessionManager(options, resolver));
+ out_status = session_manager->Init();
+ if (!out_status.ok()) {
+ return nullptr;
+ }
+ return session_manager;
+}
+
+RemoteProfilerSessionManager::RemoteProfilerSessionManager(
+ RemoteProfilerSessionManagerOptions options, AddressResolver resolver)
+ : options_(std::move(options)) {
+ if (resolver) {
+ resolver_ = std::move(resolver);
+ } else {
+ resolver_ = [](absl::string_view addr) { return std::string(addr); };
+ }
+}
+
+RemoteProfilerSessionManager::~RemoteProfilerSessionManager() {
+ VLOG(2) << "Destroying RemoteProfilerSessionManager.";
+}
+
+Status RemoteProfilerSessionManager::Init() {
+ mutex_lock lock(mutex_);
+ VLOG(1) << "SessionManager initializing.";
+ // TODO(b/169482824) Move validation to call site.
+ Status status = ValidateOptionsLocked();
+ if (!status.ok()) {
+ LOG(ERROR) << status;
+ return status;
+ }
+
+ std::string session_id = GetCurrentTimeStampAsString();
+ const absl::Time session_created_ts =
+ absl::FromUnixNanos(options_.session_creation_timestamp_ns());
+ const absl::Time deadline =
+ session_created_ts +
+ absl::Milliseconds(options_.max_session_duration_ms());
+
+ LOG(INFO) << "Deadline set to " << deadline
+ << " because max_session_duration_ms was "
+ << options_.max_session_duration_ms()
+ << " and session_creation_timestamp_ns was "
+ << options_.session_creation_timestamp_ns() << " ["
+ << session_created_ts << "]";
+
+ // Prepare a list of clients.
+ clients_.reserve(options_.service_addresses_size());
+
+ for (auto& service_addr : options_.service_addresses()) {
+ std::string resolved_service_addr = resolver_(service_addr);
+ ProfileRequest profile_request;
+ PopulateProfileRequest(options_, session_id, resolved_service_addr,
+ &profile_request);
+
+ // Creation also issues Profile RPC asynchronously.
+ auto client = RemoteProfilerSession::Create(
+ std::move(resolved_service_addr), deadline, std::move(profile_request));
+
+ clients_.push_back(std::move(client));
+ }
+
+ LOG(INFO) << absl::StrFormat("Issued Profile gRPC to %u clients",
+ clients_.size());
+ return Status::OK();
+}
+
+Status RemoteProfilerSessionManager::ValidateOptionsLocked() {
+ if (!options_.service_addresses_size()) {
+ return errors::InvalidArgument("No service addresses specified.");
+ }
+
+ if (options_.profiler_options().duration_ms() == 0) {
+ if (options_.max_session_duration_ms() != 0) {
+ return errors::InvalidArgument(
+ "If local profiler duration is unbounded, profiling session duration "
+ "must be unbounded.");
+ }
+ }
+
+ if (options_.max_session_duration_ms() <
+ options_.profiler_options().duration_ms()) {
+ return errors::InvalidArgument(
+ "The maximum profiling session duration must be greater than or equal "
+ "to the local profiler duration.");
+ }
+ return Status::OK();
+}
+
+std::vector<RemoteProfilerSessionManager::Response>
+RemoteProfilerSessionManager::WaitForCompletion() {
+ mutex_lock lock(mutex_);
+ std::vector<RemoteProfilerSessionManager::Response> remote_responses;
+ remote_responses.reserve(clients_.size());
+
+ for (auto& client : clients_) {
+ remote_responses.emplace_back();
+ auto* profile_response = &remote_responses.back().profile_response;
+ Status& status = remote_responses.back().status;
+ std::string* service_addr = &remote_responses.back().service_addr;
+ *profile_response = client->WaitForCompletion(status);
+ *service_addr = client->GetServiceAddress();
+ }
+ return remote_responses;
+}
+
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h
new file mode 100644
index 0000000..59dc490
--- /dev/null
+++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h
@@ -0,0 +1,92 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_
+#define TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/status.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include "tensorflow/core/profiler/profiler_options.pb.h"
+#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
+#include "tensorflow/core/profiler/rpc/client/profiler_client.h"
+
+namespace tensorflow {
+namespace profiler {
+
+using AddressResolver = std::function<std::string(absl::string_view)>;
+
+// Manages one or more remote profiling sessions.
+class RemoteProfilerSessionManager {
+ public:
+ struct Response {
+ std::string service_addr;
+ std::unique_ptr<ProfileResponse> profile_response;
+ Status status;
+ };
+ // Instantiates a collection of RemoteProfilerSessions starts profiling on
+ // each of them immediately.
+ static std::unique_ptr<RemoteProfilerSessionManager> Create(
+ const RemoteProfilerSessionManagerOptions& options,
+ tensorflow::Status& out_status, AddressResolver resolver = nullptr);
+
+ static RemoteProfilerSessionManagerOptions DefaultOptions() {
+ RemoteProfilerSessionManagerOptions options;
+ *options.mutable_profiler_options() = ProfilerSession::DefaultOptions();
+ return options;
+ }
+
+ // Awaits for responses from remote profiler sessions and returns them as a
+ // list. Subsequent calls will yield an empty list.
+ std::vector<Response> WaitForCompletion();
+
+ // Not copyable or movable.
+ RemoteProfilerSessionManager(const RemoteProfilerSessionManager&) = delete;
+ RemoteProfilerSessionManager operator=(const RemoteProfilerSessionManager&) =
+ delete;
+
+ ~RemoteProfilerSessionManager();
+
+ private:
+ explicit RemoteProfilerSessionManager(
+ RemoteProfilerSessionManagerOptions options, AddressResolver resolver);
+
+ // Initialization of all client contexts.
+ Status Init();
+
+ Status ValidateOptionsLocked() TF_EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+
+ mutex mutex_;
+ // Remote profiler session options.
+ RemoteProfilerSessionManagerOptions options_ TF_GUARDED_BY(mutex_);
+ // List of clients, each connects to a profiling service.
+ std::vector<std::unique_ptr<RemoteProfilerSession>> clients_
+ TF_GUARDED_BY(mutex_);
+ // Resolves an address into a format that gRPC understands.
+ AddressResolver resolver_ TF_GUARDED_BY(mutex_);
+};
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_RPC_CLIENT_REMOTE_PROFILER_SESSION_MANAGER_H_
diff --git a/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc
new file mode 100644
index 0000000..3e693f1
--- /dev/null
+++ b/tensorflow/core/profiler/rpc/client/remote_profiler_session_manager_test.cc
@@ -0,0 +1,122 @@
+/* Copyright 2020 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/rpc/client/remote_profiler_session_manager.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/time/clock.h"
+#include "absl/time/time.h"
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/profiler/profiler_service.pb.h"
+#include "tensorflow/core/profiler/rpc/client/profiler_client_test_util.h"
+#include "tensorflow/core/profiler/rpc/profiler_server.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+using ::tensorflow::profiler::test::DurationApproxLess;
+using ::tensorflow::profiler::test::DurationNear;
+using ::tensorflow::profiler::test::StartServer;
+using Response = tensorflow::profiler::RemoteProfilerSessionManager::Response;
+
+TEST(RemoteProfilerSessionManagerTest, Simple) {
+ absl::Duration duration = absl::Milliseconds(30);
+ RemoteProfilerSessionManagerOptions options =
+ RemoteProfilerSessionManager::DefaultOptions();
+ options.mutable_profiler_options()->set_duration_ms(
+ absl::ToInt64Milliseconds(duration));
+
+ std::string service_addresses;
+ auto server = StartServer(duration, &service_addresses);
+ options.add_service_addresses(service_addresses);
+ absl::Time approx_start = absl::Now();
+ absl::Duration grace = absl::Seconds(1);
+ absl::Duration max_duration = duration + grace;
+ options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
+ options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
+
+ Status status;
+ auto sessions = RemoteProfilerSessionManager::Create(options, status);
+ EXPECT_TRUE(status.ok());
+ std::vector<Response> responses = sessions->WaitForCompletion();
+ absl::Duration elapsed = absl::Now() - approx_start;
+ ASSERT_EQ(responses.size(), 1);
+ EXPECT_TRUE(responses.back().status.ok());
+ EXPECT_FALSE(responses.back().profile_response->empty_trace());
+ EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
+ EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
+}
+
+TEST(RemoteProfilerSessionManagerTest, ExpiredDeadline) {
+ absl::Duration duration = absl::Milliseconds(30);
+ RemoteProfilerSessionManagerOptions options =
+ RemoteProfilerSessionManager::DefaultOptions();
+ options.mutable_profiler_options()->set_duration_ms(
+ absl::ToInt64Milliseconds(duration));
+
+ std::string service_addresses;
+ auto server = StartServer(duration, &service_addresses);
+ options.add_service_addresses(service_addresses);
+ absl::Duration grace = absl::Seconds(1);
+ absl::Duration max_duration = duration + grace;
+ options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
+ // This will create a deadline in the past.
+ options.set_session_creation_timestamp_ns(0);
+
+ absl::Time approx_start = absl::Now();
+ Status status;
+ auto sessions = RemoteProfilerSessionManager::Create(options, status);
+ EXPECT_TRUE(status.ok());
+ std::vector<Response> responses = sessions->WaitForCompletion();
+ absl::Duration elapsed = absl::Now() - approx_start;
+ EXPECT_THAT(elapsed, DurationNear(absl::Seconds(0)));
+ ASSERT_EQ(responses.size(), 1);
+ EXPECT_EQ(responses.back().status.code(), error::DEADLINE_EXCEEDED);
+ EXPECT_FALSE(responses.back().profile_response->empty_trace());
+ EXPECT_EQ(responses.back().profile_response->tool_data_size(), 0);
+}
+
+TEST(RemoteProfilerSessionManagerTest, LongSession) {
+ absl::Duration duration = absl::Seconds(3);
+ RemoteProfilerSessionManagerOptions options =
+ RemoteProfilerSessionManager::DefaultOptions();
+ options.mutable_profiler_options()->set_duration_ms(
+ absl::ToInt64Milliseconds(duration));
+
+ std::string service_addresses;
+ auto server = StartServer(duration, &service_addresses);
+ options.add_service_addresses(service_addresses);
+ absl::Time approx_start = absl::Now();
+ absl::Duration grace = absl::Seconds(2);
+ absl::Duration max_duration = duration + grace;
+ options.set_max_session_duration_ms(absl::ToInt64Milliseconds(max_duration));
+ options.set_session_creation_timestamp_ns(absl::ToUnixNanos(approx_start));
+
+ Status status;
+ auto sessions = RemoteProfilerSessionManager::Create(options, status);
+ EXPECT_TRUE(status.ok());
+ std::vector<Response> responses = sessions->WaitForCompletion();
+ absl::Duration elapsed = absl::Now() - approx_start;
+ ASSERT_EQ(responses.size(), 1);
+ EXPECT_TRUE(responses.back().status.ok());
+ EXPECT_FALSE(responses.back().profile_response->empty_trace());
+ EXPECT_GT(responses.back().profile_response->tool_data_size(), 0);
+ EXPECT_THAT(elapsed, DurationApproxLess(max_duration));
+}
+
+} // namespace
+} // namespace profiler
+} // namespace tensorflow