Enable profiler for tf.data service.
PiperOrigin-RevId: 326331067
Change-Id: I704968182fb23d5241a071978c6c9afca265f8ab
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 13034eb..c64748c 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -3,6 +3,7 @@
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_all_protos",
"tf_proto_library",
+ "tf_protos_profiler_service",
)
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
load(
@@ -332,6 +333,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
+ "//tensorflow/core/profiler/rpc:profiler_service_impl",
tf_grpc_cc_dependency(),
],
alwayslink = 1,
@@ -375,14 +377,14 @@
":test_util",
":worker_cc_grpc_proto",
":worker_proto_cc",
+ "@com_google_absl//absl/strings",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/kernels/data:dataset_test_base",
- "@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
- ],
+ ] + tf_protos_profiler_service(),
)
cc_grpc_library(
diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc
index 88b0073..43b56d5 100644
--- a/tensorflow/core/data/service/credentials_factory.cc
+++ b/tensorflow/core/data/service/credentials_factory.cc
@@ -65,7 +65,8 @@
}
Status CredentialsFactory::CreateServerCredentials(
- absl::string_view protocol, std::shared_ptr<grpc::ServerCredentials>* out) {
+ absl::string_view protocol,
+ std::shared_ptr<::grpc::ServerCredentials>* out) {
CredentialsFactory* factory;
TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out));
@@ -74,7 +75,7 @@
Status CredentialsFactory::CreateClientCredentials(
absl::string_view protocol,
- std::shared_ptr<grpc::ChannelCredentials>* out) {
+ std::shared_ptr<::grpc::ChannelCredentials>* out) {
CredentialsFactory* factory;
TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out));
@@ -86,14 +87,14 @@
std::string Protocol() override { return "grpc"; }
Status CreateServerCredentials(
- std::shared_ptr<grpc::ServerCredentials>* out) override {
- *out = grpc::InsecureServerCredentials();
+ std::shared_ptr<::grpc::ServerCredentials>* out) override {
+ *out = ::grpc::InsecureServerCredentials();
return Status::OK();
}
Status CreateClientCredentials(
- std::shared_ptr<grpc::ChannelCredentials>* out) override {
- *out = grpc::InsecureChannelCredentials();
+ std::shared_ptr<::grpc::ChannelCredentials>* out) override {
+ *out = ::grpc::InsecureChannelCredentials();
return Status::OK();
}
};
diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h
index a93b941..2407f64 100644
--- a/tensorflow/core/data/service/credentials_factory.h
+++ b/tensorflow/core/data/service/credentials_factory.h
@@ -36,11 +36,11 @@
// Stores server credentials to `*out`.
virtual Status CreateServerCredentials(
- std::shared_ptr<grpc::ServerCredentials>* out) = 0;
+ std::shared_ptr<::grpc::ServerCredentials>* out) = 0;
// Stores client credentials to `*out`.
virtual Status CreateClientCredentials(
- std::shared_ptr<grpc::ChannelCredentials>* out) = 0;
+ std::shared_ptr<::grpc::ChannelCredentials>* out) = 0;
// Registers a credentials factory.
static void Register(CredentialsFactory* factory);
@@ -49,13 +49,13 @@
// `protocol`, and stores them to `*out`.
static Status CreateServerCredentials(
absl::string_view protocol,
- std::shared_ptr<grpc::ServerCredentials>* out);
+ std::shared_ptr<::grpc::ServerCredentials>* out);
// Creates client credentials using the credentials factory registered as
// `protocol`, and stores them to `*out`.
static Status CreateClientCredentials(
absl::string_view protocol,
- std::shared_ptr<grpc::ChannelCredentials>* out);
+ std::shared_ptr<::grpc::ChannelCredentials>* out);
private:
// Gets the credentials factory registered via `Register` for the specified
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h
index 1810c3f..7e8910b 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.h
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h
@@ -35,16 +35,16 @@
//
class GrpcDispatcherImpl : public DispatcherService::Service {
public:
- explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
+ explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
const experimental::DispatcherConfig& config);
~GrpcDispatcherImpl() override {}
Status Start();
-#define HANDLER(method) \
- grpc::Status method(grpc::ServerContext* context, \
- const method##Request* request, \
- method##Response* response) override;
+#define HANDLER(method) \
+ ::grpc::Status method(::grpc::ServerContext* context, \
+ const method##Request* request, \
+ method##Response* response) override;
HANDLER(RegisterWorker);
HANDLER(WorkerUpdate);
HANDLER(GetOrRegisterDataset);
diff --git a/tensorflow/core/data/service/grpc_util.cc b/tensorflow/core/data/service/grpc_util.cc
index 7f9d2ac..c86496c 100644
--- a/tensorflow/core/data/service/grpc_util.cc
+++ b/tensorflow/core/data/service/grpc_util.cc
@@ -26,7 +26,7 @@
namespace data {
namespace grpc_util {
-Status WrapError(const std::string& message, const grpc::Status& status) {
+Status WrapError(const std::string& message, const ::grpc::Status& status) {
if (status.ok()) {
return errors::Internal("Expected a non-ok grpc status. Wrapping message: ",
message);
diff --git a/tensorflow/core/data/service/grpc_util.h b/tensorflow/core/data/service/grpc_util.h
index b0e39df..0ae2a86 100644
--- a/tensorflow/core/data/service/grpc_util.h
+++ b/tensorflow/core/data/service/grpc_util.h
@@ -24,7 +24,7 @@
namespace grpc_util {
// Wraps a grpc::Status in a tensorflow::Status with the given message.
-Status WrapError(const std::string& message, const grpc::Status& status);
+Status WrapError(const std::string& message, const ::grpc::Status& status);
// Retries the given function if the function produces UNAVAILABLE, ABORTED, or
// CANCELLED status codes. We retry these codes because they can all indicate
diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc
index 5e3183d..b3a37fe 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.cc
+++ b/tensorflow/core/data/service/grpc_worker_impl.cc
@@ -35,11 +35,11 @@
return impl_.Start(worker_address);
}
-#define HANDLER(method) \
- grpc::Status GrpcWorkerImpl::method(ServerContext* context, \
- const method##Request* request, \
- method##Response* response) { \
- return ToGrpcStatus(impl_.method(request, response)); \
+#define HANDLER(method) \
+ ::grpc::Status GrpcWorkerImpl::method(ServerContext* context, \
+ const method##Request* request, \
+ method##Response* response) { \
+ return ToGrpcStatus(impl_.method(request, response)); \
}
HANDLER(ProcessTask);
HANDLER(GetElement);
diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h
index 49caab2..c42e563 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.h
+++ b/tensorflow/core/data/service/grpc_worker_impl.h
@@ -35,16 +35,16 @@
//
class GrpcWorkerImpl : public WorkerService::Service {
public:
- explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
+ explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
const experimental::WorkerConfig& config);
~GrpcWorkerImpl() override {}
Status Start(const std::string& worker_address);
-#define HANDLER(method) \
- grpc::Status method(grpc::ServerContext* context, \
- const method##Request* request, \
- method##Response* response) override;
+#define HANDLER(method) \
+ ::grpc::Status method(::grpc::ServerContext* context, \
+ const method##Request* request, \
+ method##Response* response) override;
HANDLER(ProcessTask);
HANDLER(GetElement);
#undef HANDLER
diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc
index fb33319..477f785 100644
--- a/tensorflow/core/data/service/server_lib.cc
+++ b/tensorflow/core/data/service/server_lib.cc
@@ -51,7 +51,8 @@
credentials, &bound_port_);
builder.SetMaxReceiveMessageSize(-1);
- AddServiceToBuilder(&builder);
+ AddDataServiceToBuilder(&builder);
+ AddProfilerServiceToBuilder(&builder);
server_ = builder.BuildAndStart();
if (!server_) {
return errors::Internal("Could not start gRPC server");
@@ -77,6 +78,12 @@
int GrpcDataServerBase::BoundPort() { return bound_port(); }
+void GrpcDataServerBase::AddProfilerServiceToBuilder(
+ ::grpc::ServerBuilder* builder) {
+ profiler_service_ = CreateProfilerService();
+ builder->RegisterService(profiler_service_.get());
+}
+
DispatchGrpcDataServer::DispatchGrpcDataServer(
const experimental::DispatcherConfig& config)
: GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
@@ -84,7 +91,8 @@
DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
-void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
+void DispatchGrpcDataServer::AddDataServiceToBuilder(
+ ::grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
}
@@ -95,8 +103,8 @@
Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
GetWorkersRequest req;
GetWorkersResponse resp;
- grpc::ServerContext ctx;
- grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
+ ::grpc::ServerContext ctx;
+ ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
if (!s.ok()) {
return grpc_util::WrapError("Failed to get workers", s);
}
@@ -111,7 +119,8 @@
WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
-void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
+void WorkerGrpcDataServer::AddDataServiceToBuilder(
+ ::grpc::ServerBuilder* builder) {
service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
}
diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h
index 62662e6..0ddc806 100644
--- a/tensorflow/core/data/service/server_lib.h
+++ b/tensorflow/core/data/service/server_lib.h
@@ -19,6 +19,7 @@
#include "grpcpp/server.h"
#include "grpcpp/server_builder.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
#include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
namespace tensorflow {
@@ -52,7 +53,8 @@
int BoundPort();
protected:
- virtual void AddServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
+ virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
+ void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
// Starts the service. This will be called after building the service, so
// bound_port() will return the actual bound port.
virtual Status StartServiceInternal() = 0;
@@ -68,7 +70,9 @@
bool started_ = false;
bool stopped_ = false;
- std::unique_ptr<grpc::Server> server_;
+ std::unique_ptr<::grpc::Server> server_;
+ // TensorFlow profiler service implementation.
+ std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
};
class DispatchGrpcDataServer : public GrpcDataServerBase {
@@ -80,7 +84,7 @@
Status NumWorkers(int* num_workers);
protected:
- void AddServiceToBuilder(grpc::ServerBuilder* builder) override;
+ void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
Status StartServiceInternal() override;
private:
@@ -95,7 +99,7 @@
~WorkerGrpcDataServer() override;
protected:
- void AddServiceToBuilder(grpc::ServerBuilder* builder) override;
+ void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
Status StartServiceInternal() override;
private:
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index 06e5d2e..496e0c7 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -13,6 +13,7 @@
features = ["-layering_check"],
visibility = tf_external_workspace_visible(
[
+ "//tensorflow/core/data/service:__pkg__",
"//tensorflow/core/distributed_runtime/rpc:__pkg__",
"//tensorflow_serving/model_servers:__pkg__",
],