Place all py_func op on the local host's address space if eager execution is enabled.

PiperOrigin-RevId: 290993424
Change-Id: I0c33cdf781fa4b3c401ea5e8649f606137e42862
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index ac160d4..5b3bd6a 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -18,6 +18,7 @@
 #include "tensorflow/c/c_api.h"
 #include "tensorflow/c/eager/c_api_internal.h"
 #include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/common_runtime/device.h"
 #include "tensorflow/core/lib/monitoring/counter.h"
 #include "tensorflow/core/lib/monitoring/gauge.h"
 #include "tensorflow/core/lib/monitoring/sampler.h"
@@ -618,3 +619,16 @@
 TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
   return new TFE_Executor(&ctx->context->Executor());
 }
+
+void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
+  auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
+      ctx->context->HostCPU()->parsed_name());
+  auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
+  void* data = tensorflow::port::Malloc(str.length());
+  str.copy(static_cast<char*>(data), str.length(), 0);
+  buf->data = data;
+  buf->length = str.length();
+  buf->data_deallocator = [](void* data, size_t length) {
+    tensorflow::port::Free(data);
+  };
+}
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 0937258..92132b0 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -458,6 +458,11 @@
     void (*deallocator)(void* data, size_t len, void* arg),
     void* deallocator_arg, TF_Status* status);
 
+// Retrieves the address space (i.e. job, replia, task) of the local host and
+// saves it in the buffer.
+TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
+                                                TF_Buffer* buf);
+
 #ifdef __cplusplus
 } /* end extern "C" */
 #endif
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index e0fb805..b580a55 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -785,6 +785,13 @@
     """List of the names of devices available to execute operations."""
     return self._devices
 
+  def host_address_space(self):
+    self.ensure_initialized()
+    with c_api_util.tf_buffer() as buffer_:
+      pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
+      address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
+    return address_space
+
   # TODO(fishx): remove this property.
   @property
   def execution_mode(self):
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index cd2ae83..07b07ab 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -449,6 +449,11 @@
     A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
     if `func` returns None.
   """
+  if ops.executing_eagerly_outside_functions():
+    with ops.device(context.context().host_address_space()):
+      return _internal_py_func(
+          func=func, inp=inp, Tout=Tout, eager=True, name=name)
+
   return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
 
 
@@ -518,6 +523,16 @@
       result, = result
     return result
 
+  if ops.executing_eagerly_outside_functions():
+    with ops.device(context.context().host_address_space()):
+      return _internal_py_func(
+          func=func,
+          inp=inp,
+          Tout=Tout,
+          stateful=stateful,
+          eager=False,
+          name=name)
+
   return _internal_py_func(
       func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
 
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 2841597..9de5a19 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -364,6 +364,9 @@
         return output;
       },
       py::return_value_policy::reference);
+  m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
+    TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
+  });
   m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
     tensorflow::Safe_TF_StatusPtr status =
         tensorflow::make_safe(TF_NewStatus());