Place all py_func op on the local host's address space if eager execution is enabled.
PiperOrigin-RevId: 290993424
Change-Id: I0c33cdf781fa4b3c401ea5e8649f606137e42862
diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc
index ac160d4..5b3bd6a 100644
--- a/tensorflow/c/eager/c_api_experimental.cc
+++ b/tensorflow/c/eager/c_api_experimental.cc
@@ -18,6 +18,7 @@
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
@@ -618,3 +619,16 @@
TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
return new TFE_Executor(&ctx->context->Executor());
}
+
+void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
+ auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
+ ctx->context->HostCPU()->parsed_name());
+ auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
+ void* data = tensorflow::port::Malloc(str.length());
+ str.copy(static_cast<char*>(data), str.length(), 0);
+ buf->data = data;
+ buf->length = str.length();
+ buf->data_deallocator = [](void* data, size_t length) {
+ tensorflow::port::Free(data);
+ };
+}
diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h
index 0937258..92132b0 100644
--- a/tensorflow/c/eager/c_api_experimental.h
+++ b/tensorflow/c/eager/c_api_experimental.h
@@ -458,6 +458,11 @@
void (*deallocator)(void* data, size_t len, void* arg),
void* deallocator_arg, TF_Status* status);
+// Retrieves the address space (i.e. job, replia, task) of the local host and
+// saves it in the buffer.
+TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
+ TF_Buffer* buf);
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index e0fb805..b580a55 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -785,6 +785,13 @@
"""List of the names of devices available to execute operations."""
return self._devices
+ def host_address_space(self):
+ self.ensure_initialized()
+ with c_api_util.tf_buffer() as buffer_:
+ pywrap_tfe.TFE_HostAddressSpace(self._context_handle, buffer_)
+ address_space = pywrap_tfe.TF_GetBuffer(buffer_).decode("utf-8")
+ return address_space
+
# TODO(fishx): remove this property.
@property
def execution_mode(self):
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index cd2ae83..07b07ab 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -449,6 +449,11 @@
A list of `Tensor` or a single `Tensor` which `func` computes; an empty list
if `func` returns None.
"""
+ if ops.executing_eagerly_outside_functions():
+ with ops.device(context.context().host_address_space()):
+ return _internal_py_func(
+ func=func, inp=inp, Tout=Tout, eager=True, name=name)
+
return _internal_py_func(func=func, inp=inp, Tout=Tout, eager=True, name=name)
@@ -518,6 +523,16 @@
result, = result
return result
+ if ops.executing_eagerly_outside_functions():
+ with ops.device(context.context().host_address_space()):
+ return _internal_py_func(
+ func=func,
+ inp=inp,
+ Tout=Tout,
+ stateful=stateful,
+ eager=False,
+ name=name)
+
return _internal_py_func(
func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name)
diff --git a/tensorflow/python/tfe_wrapper.cc b/tensorflow/python/tfe_wrapper.cc
index 2841597..9de5a19 100644
--- a/tensorflow/python/tfe_wrapper.cc
+++ b/tensorflow/python/tfe_wrapper.cc
@@ -364,6 +364,9 @@
return output;
},
py::return_value_policy::reference);
+ m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) {
+ TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf);
+ });
m.def("TFE_ContextAddFunction", [](py::handle& ctx, py::handle& func) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());