Enable profiler for tf.data service.

PiperOrigin-RevId: 326331067
Change-Id: I704968182fb23d5241a071978c6c9afca265f8ab
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 13034eb..c64748c 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -3,6 +3,7 @@
     "//tensorflow/core/platform:build_config.bzl",
     "tf_additional_all_protos",
     "tf_proto_library",
+    "tf_protos_profiler_service",
 )
 load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
 load(
@@ -332,6 +333,7 @@
         "//tensorflow/core:lib",
         "//tensorflow/core:protos_all_cc",
         "//tensorflow/core:tensorflow",
+        "//tensorflow/core/profiler/rpc:profiler_service_impl",
         tf_grpc_cc_dependency(),
     ],
     alwayslink = 1,
@@ -375,14 +377,14 @@
         ":test_util",
         ":worker_cc_grpc_proto",
         ":worker_proto_cc",
+        "@com_google_absl//absl/strings",
         "//tensorflow/core:lib",
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core/data:compression_utils",
         "//tensorflow/core/kernels/data:dataset_test_base",
-        "@com_google_absl//absl/strings",
         tf_grpc_cc_dependency(),
-    ],
+    ] + tf_protos_profiler_service(),
 )
 
 cc_grpc_library(
diff --git a/tensorflow/core/data/service/credentials_factory.cc b/tensorflow/core/data/service/credentials_factory.cc
index 88b0073..43b56d5 100644
--- a/tensorflow/core/data/service/credentials_factory.cc
+++ b/tensorflow/core/data/service/credentials_factory.cc
@@ -65,7 +65,8 @@
 }
 
 Status CredentialsFactory::CreateServerCredentials(
-    absl::string_view protocol, std::shared_ptr<grpc::ServerCredentials>* out) {
+    absl::string_view protocol,
+    std::shared_ptr<::grpc::ServerCredentials>* out) {
   CredentialsFactory* factory;
   TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
   TF_RETURN_IF_ERROR(factory->CreateServerCredentials(out));
@@ -74,7 +75,7 @@
 
 Status CredentialsFactory::CreateClientCredentials(
     absl::string_view protocol,
-    std::shared_ptr<grpc::ChannelCredentials>* out) {
+    std::shared_ptr<::grpc::ChannelCredentials>* out) {
   CredentialsFactory* factory;
   TF_RETURN_IF_ERROR(CredentialsFactory::Get(protocol, &factory));
   TF_RETURN_IF_ERROR(factory->CreateClientCredentials(out));
@@ -86,14 +87,14 @@
   std::string Protocol() override { return "grpc"; }
 
   Status CreateServerCredentials(
-      std::shared_ptr<grpc::ServerCredentials>* out) override {
-    *out = grpc::InsecureServerCredentials();
+      std::shared_ptr<::grpc::ServerCredentials>* out) override {
+    *out = ::grpc::InsecureServerCredentials();
     return Status::OK();
   }
 
   Status CreateClientCredentials(
-      std::shared_ptr<grpc::ChannelCredentials>* out) override {
-    *out = grpc::InsecureChannelCredentials();
+      std::shared_ptr<::grpc::ChannelCredentials>* out) override {
+    *out = ::grpc::InsecureChannelCredentials();
     return Status::OK();
   }
 };
diff --git a/tensorflow/core/data/service/credentials_factory.h b/tensorflow/core/data/service/credentials_factory.h
index a93b941..2407f64 100644
--- a/tensorflow/core/data/service/credentials_factory.h
+++ b/tensorflow/core/data/service/credentials_factory.h
@@ -36,11 +36,11 @@
 
   // Stores server credentials to `*out`.
   virtual Status CreateServerCredentials(
-      std::shared_ptr<grpc::ServerCredentials>* out) = 0;
+      std::shared_ptr<::grpc::ServerCredentials>* out) = 0;
 
   // Stores client credentials to `*out`.
   virtual Status CreateClientCredentials(
-      std::shared_ptr<grpc::ChannelCredentials>* out) = 0;
+      std::shared_ptr<::grpc::ChannelCredentials>* out) = 0;
 
   // Registers a credentials factory.
   static void Register(CredentialsFactory* factory);
@@ -49,13 +49,13 @@
   // `protocol`, and stores them to `*out`.
   static Status CreateServerCredentials(
       absl::string_view protocol,
-      std::shared_ptr<grpc::ServerCredentials>* out);
+      std::shared_ptr<::grpc::ServerCredentials>* out);
 
   // Creates client credentials using the credentials factory registered as
   // `protocol`, and stores them to `*out`.
   static Status CreateClientCredentials(
       absl::string_view protocol,
-      std::shared_ptr<grpc::ChannelCredentials>* out);
+      std::shared_ptr<::grpc::ChannelCredentials>* out);
 
  private:
   // Gets the credentials factory registered via `Register` for the specified
diff --git a/tensorflow/core/data/service/grpc_dispatcher_impl.h b/tensorflow/core/data/service/grpc_dispatcher_impl.h
index 1810c3f..7e8910b 100644
--- a/tensorflow/core/data/service/grpc_dispatcher_impl.h
+++ b/tensorflow/core/data/service/grpc_dispatcher_impl.h
@@ -35,16 +35,16 @@
 //
 class GrpcDispatcherImpl : public DispatcherService::Service {
  public:
-  explicit GrpcDispatcherImpl(grpc::ServerBuilder* server_builder,
+  explicit GrpcDispatcherImpl(::grpc::ServerBuilder* server_builder,
                               const experimental::DispatcherConfig& config);
   ~GrpcDispatcherImpl() override {}
 
   Status Start();
 
-#define HANDLER(method)                               \
-  grpc::Status method(grpc::ServerContext* context,   \
-                      const method##Request* request, \
-                      method##Response* response) override;
+#define HANDLER(method)                                 \
+  ::grpc::Status method(::grpc::ServerContext* context, \
+                        const method##Request* request, \
+                        method##Response* response) override;
   HANDLER(RegisterWorker);
   HANDLER(WorkerUpdate);
   HANDLER(GetOrRegisterDataset);
diff --git a/tensorflow/core/data/service/grpc_util.cc b/tensorflow/core/data/service/grpc_util.cc
index 7f9d2ac..c86496c 100644
--- a/tensorflow/core/data/service/grpc_util.cc
+++ b/tensorflow/core/data/service/grpc_util.cc
@@ -26,7 +26,7 @@
 namespace data {
 namespace grpc_util {
 
-Status WrapError(const std::string& message, const grpc::Status& status) {
+Status WrapError(const std::string& message, const ::grpc::Status& status) {
   if (status.ok()) {
     return errors::Internal("Expected a non-ok grpc status. Wrapping message: ",
                             message);
diff --git a/tensorflow/core/data/service/grpc_util.h b/tensorflow/core/data/service/grpc_util.h
index b0e39df..0ae2a86 100644
--- a/tensorflow/core/data/service/grpc_util.h
+++ b/tensorflow/core/data/service/grpc_util.h
@@ -24,7 +24,7 @@
 namespace grpc_util {
 
 // Wraps a grpc::Status in a tensorflow::Status with the given message.
-Status WrapError(const std::string& message, const grpc::Status& status);
+Status WrapError(const std::string& message, const ::grpc::Status& status);
 
 // Retries the given function if the function produces UNAVAILABLE, ABORTED, or
 // CANCELLED status codes. We retry these codes because they can all indicate
diff --git a/tensorflow/core/data/service/grpc_worker_impl.cc b/tensorflow/core/data/service/grpc_worker_impl.cc
index 5e3183d..b3a37fe 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.cc
+++ b/tensorflow/core/data/service/grpc_worker_impl.cc
@@ -35,11 +35,11 @@
   return impl_.Start(worker_address);
 }
 
-#define HANDLER(method)                                               \
-  grpc::Status GrpcWorkerImpl::method(ServerContext* context,         \
-                                      const method##Request* request, \
-                                      method##Response* response) {   \
-    return ToGrpcStatus(impl_.method(request, response));             \
+#define HANDLER(method)                                                 \
+  ::grpc::Status GrpcWorkerImpl::method(ServerContext* context,         \
+                                        const method##Request* request, \
+                                        method##Response* response) {   \
+    return ToGrpcStatus(impl_.method(request, response));               \
   }
 HANDLER(ProcessTask);
 HANDLER(GetElement);
diff --git a/tensorflow/core/data/service/grpc_worker_impl.h b/tensorflow/core/data/service/grpc_worker_impl.h
index 49caab2..c42e563 100644
--- a/tensorflow/core/data/service/grpc_worker_impl.h
+++ b/tensorflow/core/data/service/grpc_worker_impl.h
@@ -35,16 +35,16 @@
 //
 class GrpcWorkerImpl : public WorkerService::Service {
  public:
-  explicit GrpcWorkerImpl(grpc::ServerBuilder* server_builder,
+  explicit GrpcWorkerImpl(::grpc::ServerBuilder* server_builder,
                           const experimental::WorkerConfig& config);
   ~GrpcWorkerImpl() override {}
 
   Status Start(const std::string& worker_address);
 
-#define HANDLER(method)                               \
-  grpc::Status method(grpc::ServerContext* context,   \
-                      const method##Request* request, \
-                      method##Response* response) override;
+#define HANDLER(method)                                 \
+  ::grpc::Status method(::grpc::ServerContext* context, \
+                        const method##Request* request, \
+                        method##Response* response) override;
   HANDLER(ProcessTask);
   HANDLER(GetElement);
 #undef HANDLER
diff --git a/tensorflow/core/data/service/server_lib.cc b/tensorflow/core/data/service/server_lib.cc
index fb33319..477f785 100644
--- a/tensorflow/core/data/service/server_lib.cc
+++ b/tensorflow/core/data/service/server_lib.cc
@@ -51,7 +51,8 @@
                            credentials, &bound_port_);
   builder.SetMaxReceiveMessageSize(-1);
 
-  AddServiceToBuilder(&builder);
+  AddDataServiceToBuilder(&builder);
+  AddProfilerServiceToBuilder(&builder);
   server_ = builder.BuildAndStart();
   if (!server_) {
     return errors::Internal("Could not start gRPC server");
@@ -77,6 +78,12 @@
 
 int GrpcDataServerBase::BoundPort() { return bound_port(); }
 
+void GrpcDataServerBase::AddProfilerServiceToBuilder(
+    ::grpc::ServerBuilder* builder) {
+  profiler_service_ = CreateProfilerService();
+  builder->RegisterService(profiler_service_.get());
+}
+
 DispatchGrpcDataServer::DispatchGrpcDataServer(
     const experimental::DispatcherConfig& config)
     : GrpcDataServerBase(config.port(), config.protocol(), "DispatchServer"),
@@ -84,7 +91,8 @@
 
 DispatchGrpcDataServer::~DispatchGrpcDataServer() { delete service_; }
 
-void DispatchGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
+void DispatchGrpcDataServer::AddDataServiceToBuilder(
+    ::grpc::ServerBuilder* builder) {
   service_ = absl::make_unique<GrpcDispatcherImpl>(builder, config_).release();
 }
 
@@ -95,8 +103,8 @@
 Status DispatchGrpcDataServer::NumWorkers(int* num_workers) {
   GetWorkersRequest req;
   GetWorkersResponse resp;
-  grpc::ServerContext ctx;
-  grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
+  ::grpc::ServerContext ctx;
+  ::grpc::Status s = service_->GetWorkers(&ctx, &req, &resp);
   if (!s.ok()) {
     return grpc_util::WrapError("Failed to get workers", s);
   }
@@ -111,7 +119,8 @@
 
 WorkerGrpcDataServer::~WorkerGrpcDataServer() { delete service_; }
 
-void WorkerGrpcDataServer::AddServiceToBuilder(grpc::ServerBuilder* builder) {
+void WorkerGrpcDataServer::AddDataServiceToBuilder(
+    ::grpc::ServerBuilder* builder) {
   service_ = absl::make_unique<GrpcWorkerImpl>(builder, config_).release();
 }
 
diff --git a/tensorflow/core/data/service/server_lib.h b/tensorflow/core/data/service/server_lib.h
index 62662e6..0ddc806 100644
--- a/tensorflow/core/data/service/server_lib.h
+++ b/tensorflow/core/data/service/server_lib.h
@@ -19,6 +19,7 @@
 #include "grpcpp/server.h"
 #include "grpcpp/server_builder.h"
 #include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/profiler/rpc/profiler_service_impl.h"
 #include "tensorflow/core/protobuf/data/experimental/service_config.pb.h"
 
 namespace tensorflow {
@@ -52,7 +53,8 @@
   int BoundPort();
 
  protected:
-  virtual void AddServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
+  virtual void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) = 0;
+  void AddProfilerServiceToBuilder(::grpc::ServerBuilder* builder);
   // Starts the service. This will be called after building the service, so
   // bound_port() will return the actual bound port.
   virtual Status StartServiceInternal() = 0;
@@ -68,7 +70,9 @@
   bool started_ = false;
   bool stopped_ = false;
 
-  std::unique_ptr<grpc::Server> server_;
+  std::unique_ptr<::grpc::Server> server_;
+  // TensorFlow profiler service implementation.
+  std::unique_ptr<grpc::ProfilerService::Service> profiler_service_ = nullptr;
 };
 
 class DispatchGrpcDataServer : public GrpcDataServerBase {
@@ -80,7 +84,7 @@
   Status NumWorkers(int* num_workers);
 
  protected:
-  void AddServiceToBuilder(grpc::ServerBuilder* builder) override;
+  void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
   Status StartServiceInternal() override;
 
  private:
@@ -95,7 +99,7 @@
   ~WorkerGrpcDataServer() override;
 
  protected:
-  void AddServiceToBuilder(grpc::ServerBuilder* builder) override;
+  void AddDataServiceToBuilder(::grpc::ServerBuilder* builder) override;
   Status StartServiceInternal() override;
 
  private:
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index 06e5d2e..496e0c7 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -13,6 +13,7 @@
     features = ["-layering_check"],
     visibility = tf_external_workspace_visible(
         [
+            "//tensorflow/core/data/service:__pkg__",
             "//tensorflow/core/distributed_runtime/rpc:__pkg__",
             "//tensorflow_serving/model_servers:__pkg__",
         ],