| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License");; |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| |
| #include "Python.h" |
| #include "absl/strings/str_format.h" |
| #include "pybind11/chrono.h" |
| #include "pybind11/complex.h" |
| #include "pybind11/functional.h" |
| #include "pybind11/pybind11.h" |
| #include "pybind11/stl.h" |
| #include "tensorflow/c/c_api.h" |
| #include "tensorflow/c/c_api_experimental.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_experimental.h" |
| #include "tensorflow/c/eager/c_api_internal.h" |
| #include "tensorflow/c/eager/dlpack.h" |
| #include "tensorflow/c/tf_status.h" |
| #include "tensorflow/c/tf_status_helper.h" |
| #include "tensorflow/compiler/jit/flags.h" |
| #include "tensorflow/python/eager/pywrap_tensor_conversion.h" |
| #include "tensorflow/python/eager/pywrap_tfe.h" |
| #include "tensorflow/python/lib/core/py_exception_registry.h" |
| #include "tensorflow/python/lib/core/pybind11_lib.h" |
| #include "tensorflow/python/lib/core/pybind11_status.h" |
| #include "tensorflow/python/lib/core/safe_ptr.h" |
| #include "tensorflow/python/util/util.h" |
| |
| namespace py = pybind11; |
| |
| PYBIND11_MAKE_OPAQUE(TFE_Executor); |
| PYBIND11_MAKE_OPAQUE(TFE_ContextOptions); |
| PYBIND11_MAKE_OPAQUE(TFE_CancellationManager); |
| |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter2); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge0); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge1); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGauge2); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge0); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge1); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGauge2); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge0); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge1); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGauge2); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler0); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler1); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringSampler2); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounterCell); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringIntGaugeCell); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringStringGaugeCell); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringBoolGaugeCell); |
| PYBIND11_MAKE_OPAQUE(TFE_MonitoringSamplerCell); |
| |
| PYBIND11_MAKE_OPAQUE(TF_DeviceList); |
| PYBIND11_MAKE_OPAQUE(TF_Function); |
| PYBIND11_MAKE_OPAQUE(TF_Buffer); |
| |
| // Eager helper functions migrated from pywrap_tfe.i. |
| |
| namespace tensorflow { |
| |
| // We cannot use Context as an opaque type. SWIG also had |
| // difficult directly passing the pointer around. These |
| // typemaps are migrated over from pywrap_tfe.i. I tried |
| // using a custom type caster, but we get segfaults periodically. |
| |
| // TODO(amitpatankar): Move input and output logic of Context into a |
| // pybind11 custom type caster. |
| |
| TFE_Context* InputTFE_Context(const py::handle& ctx) { |
| return static_cast<TFE_Context*>(PyCapsule_GetPointer(ctx.ptr(), nullptr)); |
| } |
| |
| PyObject* OutputTFE_Context(TFE_Context* context) { |
| return PyCapsule_New(context, nullptr, TFE_DeleteContextCapsule); |
| } |
| |
| TF_Buffer* ProtoStringToTFBuffer(PyObject* input) { |
| // Convert a Python string object to TF_Buffer. |
| char* c_string; |
| Py_ssize_t py_size; |
| // PyBytes_AsStringAndSize() does not copy but simply interprets the input |
| if (PyBytes_AsStringAndSize(input, &c_string, &py_size) == -1) { |
| // Python has raised an error (likely TypeError or UnicodeEncodeError). |
| throw py::error_already_set(); |
| } |
| return TF_NewBufferFromString(static_cast<void*>(c_string), |
| static_cast<size_t>(py_size)); |
| } |
| |
| // These functions are typemaps from the Python side. I did not use |
| // a custom type caster since the logic is slightly harder to follow. This |
| // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. |
| TFE_InputTensorHandles InputTFE_InputTensorHandles( |
| const py::handle& input_tensors) { |
| TFE_InputTensorHandles input_tensor_handles; |
| if (input_tensors.ptr() != Py_None) { |
| if (!PyList_Check(input_tensors.ptr())) { |
| tensorflow::ThrowTypeError("must provide a list of Tensors as inputs"); |
| } |
| Py_ssize_t len = PyList_Size(input_tensors.ptr()); |
| input_tensor_handles.resize(len); |
| for (Py_ssize_t i = 0; i < len; ++i) { |
| PyObject* elem = PyList_GetItem(input_tensors.ptr(), i); |
| if (!elem) { |
| tensorflow::ThrowTypeError("Input Tensor does not exist."); |
| } |
| if (EagerTensor_CheckExact(elem)) { |
| (input_tensor_handles)[i] = EagerTensor_Handle(elem); |
| } else if (tensorflow::swig::IsEagerTensorSlow(elem)) { |
| // Use equivalent of object.__getattribute__ to get the underlying |
| // tf wrapped EagerTensor (if there is one). |
| tensorflow::Safe_PyObjectPtr tf_should_use_attr( |
| #if PY_MAJOR_VERSION < 3 |
| PyString_InternFromString("_tf_should_use_wrapped_value") |
| #else |
| PyUnicode_InternFromString("_tf_should_use_wrapped_value") |
| #endif |
| ); |
| tensorflow::Safe_PyObjectPtr value_attr( |
| PyObject_GenericGetAttr(elem, tf_should_use_attr.get())); |
| if (value_attr) { |
| // This is an EagerTensor wrapped inside a TFShouldUse wrapped object. |
| (input_tensor_handles)[i] = EagerTensor_Handle(value_attr.get()); |
| } else { |
| // This is a subclass of EagerTensor that we don't support. |
| PyErr_Clear(); |
| tensorflow::ThrowTypeError( |
| tensorflow::strings::StrCat( |
| "Saw an object that is an instance of a strict subclass of " |
| "EagerTensor, which is not supported. Item ", |
| i, " is type: ", elem->ob_type->tp_name) |
| .c_str()); |
| } |
| } else if (tensorflow::swig::IsTensor(elem)) { |
| // If it isnt an EagerTensor, but is still a Tensor, it must be a graph |
| // tensor. |
| tensorflow::Safe_PyObjectPtr name_attr( |
| PyObject_GetAttrString(elem, "name")); |
| tensorflow::ThrowTypeError( |
| tensorflow::strings::StrCat( |
| "An op outside of the function building code is being passed\n" |
| "a \"Graph\" tensor. It is possible to have Graph tensors\n" |
| "leak out of the function building context by including a\n" |
| "tf.init_scope in your function building code.\n" |
| "For example, the following function will fail:\n", |
| " @tf.function\n", " def has_init_scope():\n", |
| " my_constant = tf.constant(1.)\n", |
| " with tf.init_scope():\n", |
| " added = my_constant * 2\n", |
| "The graph tensor has name: ", |
| name_attr ? TFE_GetPythonString(name_attr.get()) : "<unknown>") |
| .c_str()); |
| } else { |
| tensorflow::ThrowTypeError( |
| tensorflow::strings::StrCat( |
| "provided list of inputs contains objects other " |
| "than 'EagerTensor'. Item ", |
| i, " is type: ", elem->ob_type->tp_name) |
| .c_str()); |
| } |
| } |
| } |
| return input_tensor_handles; |
| } |
| |
| // These functions are typemaps from the Python side. I did not use |
| // a custom type caster since the logic is slightly harder to follow. This |
| // converter is also only used once in `TFE_Py_ExecuteCancelable_wrapper`. |
| // This function actually takes a number rather than an output Tensor holder. |
| TFE_OutputTensorHandles InputTFE_OutputTensorHandles( |
| const py::handle& num_outputs) { |
| TFE_OutputTensorHandles output_tensor_handles; |
| #if PY_MAJOR_VERSION < 3 |
| if (!PyInt_Check(num_outputs.ptr())) { |
| #else |
| if (!PyLong_Check(num_outputs.ptr())) { |
| #endif |
| PyErr_SetString(PyExc_TypeError, |
| "expected an integer value (size of the number of " |
| "outputs of the operation)"); |
| throw py::error_already_set(); |
| } |
| #if PY_MAJOR_VERSION < 3 |
| long sz = PyInt_AsLong(num_outputs.ptr()); // NOLINT |
| #else |
| long sz = PyLong_AsLong(num_outputs.ptr()); // NOLINT |
| #endif |
| if (sz > 0) { |
| #if PY_MAJOR_VERSION < 3 |
| output_tensor_handles.resize(PyInt_AsLong(num_outputs.ptr()), nullptr); |
| #else |
| output_tensor_handles.resize(PyLong_AsLong(num_outputs.ptr()), nullptr); |
| #endif |
| } |
| return output_tensor_handles; |
| } |
| |
| // Packs multiple `EagerTensor`s of the same dtype and shape into one |
| // `EagerTensor`. |
| py::object TFE_Py_PackEagerTensors_wrapper(const py::handle& context, |
| const py::handle& tensors) { |
| TFE_Context* ctx = tensorflow::InputTFE_Context(context); |
| TFE_InputTensorHandles handles = InputTFE_InputTensorHandles(tensors); |
| tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
| int size = handles.size(); |
| TFE_TensorHandle* packed_handle = |
| TFE_CreatePackedTensorHandle(ctx, handles.data(), &size, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| PyObject* packed_tensor = |
| EagerTensorFromHandle(packed_handle, /*is_packed=*/true); |
| return tensorflow::PyoOrThrow(packed_tensor); |
| } |
| |
| // This function was created from fusing the typemap logic in platform/base.i. |
| py::object TFE_Py_ExecuteCancelable_wrapper( |
| const py::handle& context, const char* device_name, const char* op_name, |
| const py::handle& inputs, const py::handle& attrs, |
| TFE_CancellationManager* cancellation_manager, |
| const py::handle& num_outputs) { |
| TFE_Context* ctx = tensorflow::InputTFE_Context(context); |
| TFE_InputTensorHandles input_tensor_handles = |
| InputTFE_InputTensorHandles(inputs); |
| TFE_OutputTensorHandles output_tensor_handles = |
| InputTFE_OutputTensorHandles(num_outputs); |
| tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
| TFE_Py_ExecuteCancelable(ctx, device_name, op_name, &input_tensor_handles, |
| attrs.ptr(), cancellation_manager, |
| &output_tensor_handles, status.get()); |
| |
| int output_len = output_tensor_handles.size(); |
| PyObject* output_list = PyList_New(output_len); |
| for (int i = 0; i < output_len; ++i) { |
| PyObject* output; |
| output = EagerTensorFromHandle(output_tensor_handles.at(i)); |
| PyList_SetItem(output_list, i, output); |
| } |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return tensorflow::PyoOrThrow(output_list); |
| } |
| |
| static py::object TF_ListPhysicalDevices() { |
| std::vector<string> devices; |
| tensorflow::Status s = |
| tensorflow::DeviceFactory::ListAllPhysicalDevices(&devices); |
| MaybeRaiseRegisteredFromStatus(s); |
| PyObject* result = PyList_New(devices.size()); |
| int i = 0; |
| for (auto& dev : devices) { |
| PyObject* dev_obj = PyBytes_FromStringAndSize(dev.data(), dev.size()); |
| PyList_SetItem(result, i, dev_obj); |
| ++i; |
| } |
| return tensorflow::PyoOrThrow(result); |
| } |
| |
| static std::unordered_map<string, string> TF_GetDeviceDetails(int index) { |
| tensorflow::Safe_TF_StatusPtr status = tensorflow::make_safe(TF_NewStatus()); |
| std::unordered_map<string, string> device_details; |
| tensorflow::Status s = |
| tensorflow::DeviceFactory::GetAnyDeviceDetails(index, &device_details); |
| tensorflow::Set_TF_Status_from_Status(status.get(), s); |
| MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return device_details; |
| } |
| |
| static py::object TFE_ClearScalarCache() { |
| tensorflow::TFE_TensorHandleCache::Get()->Clear(); |
| return py::none(); |
| } |
| |
| } // namespace tensorflow |
| |
| // py::return_value_policy::reference is defined as specified by the |
| // pybind11 documents listed here. |
| // https://pybind11.readthedocs.io/en/stable/advanced/functions.html#return-value-policies |
| // This means that C++ maintains ownership of the object. We |
| // are only assigning this to functions that return opaque types. |
| |
| PYBIND11_MODULE(_pywrap_tfe, m) { |
| py::class_<TFE_Executor> TFE_Executor_class(m, "TFE_Executor"); |
| py::class_<TFE_ContextOptions> TFE_ContextOptions_class(m, |
| "TFE_ContextOptions"); |
| py::class_<TFE_MonitoringCounter0> TFE_MonitoringCounter0_class( |
| m, "TFE_MonitoringCounter0"); |
| py::class_<TFE_MonitoringCounter1> TFE_MonitoringCounter1_class( |
| m, "TFE_MonitoringCounter1"); |
| py::class_<TFE_MonitoringCounter2> TFE_MonitoringCounter2_class( |
| m, "TFE_MonitoringCounter2"); |
| py::class_<TFE_MonitoringStringGauge0> TFE_MonitoringStringGauge0_class( |
| m, "TFE_MonitoringStringGauge0"); |
| py::class_<TFE_MonitoringStringGauge1> TFE_MonitoringStringGauge1_class( |
| m, "TFE_MonitoringStringGauge1"); |
| py::class_<TFE_MonitoringStringGauge2> TFE_MonitoringStringGauge2_class( |
| m, "TFE_MonitoringStringGauge2"); |
| py::class_<TFE_MonitoringIntGauge0> TFE_MonitoringIntGauge0_class( |
| m, "TFE_MonitoringIntGauge0"); |
| py::class_<TFE_MonitoringIntGauge1> TFE_MonitoringIntGauge1_class( |
| m, "TFE_MonitoringIntGauge1"); |
| py::class_<TFE_MonitoringIntGauge2> TFE_MonitoringIntGauge2_class( |
| m, "TFE_MonitoringIntGauge2"); |
| py::class_<TFE_MonitoringBoolGauge0> TFE_MonitoringBoolGauge0_class( |
| m, "TFE_MonitoringBoolGauge0"); |
| py::class_<TFE_MonitoringBoolGauge1> TFE_MonitoringBoolGauge1_class( |
| m, "TFE_MonitoringBoolGauge1"); |
| py::class_<TFE_MonitoringBoolGauge2> TFE_MonitoringBoolGauge2_class( |
| m, "TFE_MonitoringBoolGauge2"); |
| py::class_<TFE_MonitoringCounterCell> TFE_MonitoringCounterCell_class( |
| m, "TFE_MonitoringCounterCell"); |
| py::class_<TFE_MonitoringIntGaugeCell> TFE_MonitoringIntGaugeCell_class( |
| m, "TFE_MonitoringIntGaugeCell"); |
| py::class_<TFE_MonitoringStringGaugeCell> TFE_MonitoringStringGaugeCell_class( |
| m, "TFE_MonitoringStringGaugeCell"); |
| py::class_<TFE_MonitoringBoolGaugeCell> TFE_MonitoringBoolGaugeCell_class( |
| m, "TFE_MonitoringBoolGaugeCell"); |
| py::class_<TFE_MonitoringSamplerCell> TFE_MonitoringSamplerCell_class( |
| m, "TFE_MonitoringSamplerCell"); |
| py::class_<TFE_MonitoringBuckets> TFE_MonitoringBuckets_class( |
| m, "TFE_MonitoringBuckets"); |
| py::class_<TFE_MonitoringSampler0> TFE_MonitoringSampler0_class( |
| m, "TFE_MonitoringSampler0"); |
| py::class_<TFE_MonitoringSampler1> TFE_MonitoringSampler1_class( |
| m, "TFE_MonitoringSampler1"); |
| py::class_<TFE_MonitoringSampler2> TFE_MonitoringSampler2_class( |
| m, "TFE_MonitoringSampler2"); |
| py::class_<TFE_CancellationManager> TFE_CancellationManager_class( |
| m, "TFE_CancellationManager"); |
| |
| py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList"); |
| py::class_<TF_Function> TF_Function_class(m, "TF_Function"); |
| |
| m.def("TFE_Py_RegisterExceptionClass", [](const py::handle& e) { |
| return tensorflow::PyoOrThrow(TFE_Py_RegisterExceptionClass(e.ptr())); |
| }); |
| m.def("TFE_Py_RegisterFallbackExceptionClass", [](const py::handle& e) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_RegisterFallbackExceptionClass(e.ptr())); |
| }); |
| |
| m.def( |
| "TFE_GetTotalMemoryUsage", [](py::handle& ctx, const char* device_name) { |
| tensorflow::EagerContext* context = tensorflow::ContextFromInterface( |
| reinterpret_cast<tensorflow::ImmediateExecutionContext*>( |
| tensorflow::InputTFE_Context(ctx))); |
| |
| tensorflow::DeviceNameUtils::ParsedName input_device_name; |
| if (!tensorflow::DeviceNameUtils::ParseFullName(device_name, |
| &input_device_name) && |
| !tensorflow::DeviceNameUtils::ParseLocalName(device_name, |
| &input_device_name)) { |
| tensorflow::ThrowValueError( |
| absl::StrFormat("Failed parsing device name: '%s'", device_name) |
| .c_str()); |
| } |
| |
| std::vector<tensorflow::Device*> devices = |
| context->local_device_mgr()->ListDevices(); |
| |
| tensorflow::Device* matched_device = nullptr; |
| for (int device_idx = 0; device_idx < devices.size(); device_idx++) { |
| tensorflow::Device* device = devices[device_idx]; |
| |
| if (tensorflow::DeviceNameUtils::AreCompatibleDevNames( |
| input_device_name, device->parsed_name())) { |
| if (device->device_type() == tensorflow::DEVICE_CPU) { |
| tensorflow::ThrowValueError( |
| "CPU does not support getting allocator information"); |
| } |
| |
| if (matched_device != nullptr) { |
| tensorflow::ThrowValueError( |
| absl::StrFormat( |
| "Multiple devices matching the provided string " |
| "'%s': '%s' and " |
| "'%s' ", |
| device_name, matched_device->name(), device->name()) |
| .c_str()); |
| } |
| matched_device = device; |
| } |
| } |
| |
| if (matched_device == nullptr) { |
| tensorflow::ThrowValueError( |
| absl::StrFormat("No matching devices found for '%s'", device_name) |
| .c_str()); |
| } |
| |
| tensorflow::AllocatorAttributes attrs; |
| tensorflow::Allocator* allocator = matched_device->GetAllocator(attrs); |
| |
| if (absl::optional<tensorflow::AllocatorStats> stats = |
| allocator->GetStats()) { |
| return stats->bytes_in_use; |
| } |
| |
| tensorflow::ThrowTypeError( |
| absl::StrFormat("Allocator stats not available for device '%s'", |
| matched_device->name()) |
| .c_str()); |
| }); |
| |
| // XLA Eager Logic |
| m.def("TF_SetXlaEnableLazyCompilation", &TF_SetXlaEnableLazyCompilation); |
| m.def("TF_SetTfXlaCpuGlobalJit", &TF_SetTfXlaCpuGlobalJit); |
| m.def("TF_SetXlaAutoJitMode", &TF_SetXlaAutoJitMode); |
| m.def("TF_SetXlaConstantFoldingDisabled", &TF_SetXlaConstantFoldingDisabled); |
| m.def("TF_GetXlaConstantFoldingDisabled", &TF_GetXlaConstantFoldingDisabled); |
| m.def("TF_SetXlaMinClusterSize", &TF_SetXlaMinClusterSize); |
| |
| // MLIR Logic |
| m.def("TF_IsMlirBridgeEnabled", [] { |
| return tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge; |
| }); |
| m.def("TF_EnableMlirBridge", [](bool enabled) { |
| tensorflow::GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge = enabled; |
| }); |
| m.def("TF_EnableXlaDevices", [] { |
| tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true; |
| }); |
| |
| // // TFE_Context Logic |
| m.def( |
| "TFE_NewContext", |
| [](const TFE_ContextOptions* opts) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_Context* context = TFE_NewContext(opts, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return tensorflow::PyoOrThrow(tensorflow::OutputTFE_Context(context)); |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_DeleteContext", [](py::handle& o) { |
| TFE_DeleteContext(tensorflow::InputTFE_Context(o)); |
| }); |
| m.def( |
| "TFE_ContextListDevices", |
| [](py::handle& o) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_ContextListDevices(tensorflow::InputTFE_Context(o), |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_HostAddressSpace", [](py::handle& o, TF_Buffer& buf) { |
| TFE_HostAddressSpace(tensorflow::InputTFE_Context(o), &buf); |
| }); |
| m.def("TFE_ContextAddFunction", [](py::handle& ctx, TF_Function* func) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextAddFunction(tensorflow::InputTFE_Context(ctx), func, |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextAddFunctionDef", |
| [](py::handle& ctx, const char* serialized_function_def, size_t size) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextAddFunctionDef(tensorflow::InputTFE_Context(ctx), |
| serialized_function_def, size, |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextGetFunctionDef", |
| [](py::handle& ctx, const char* function_name, TF_Buffer& buf) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextGetFunctionDef(tensorflow::InputTFE_Context(ctx), |
| function_name, &buf, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextRemoveFunction", [](py::handle& ctx, const char* name) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextRemoveFunction(tensorflow::InputTFE_Context(ctx), name, |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextHasFunction", [](py::handle& ctx, const char* name) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_ContextHasFunction(tensorflow::InputTFE_Context(ctx), name); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }); |
| m.def("TFE_ContextEnableRunMetadata", [](py::handle& ctx) { |
| TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextDisableRunMetadata", [](py::handle& ctx) { |
| TFE_ContextEnableRunMetadata(tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextEnableGraphCollection", [](py::handle& ctx) { |
| TFE_ContextEnableGraphCollection(tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextDisableGraphCollection", [](py::handle& ctx) { |
| TFE_ContextDisableGraphCollection(tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextExportRunMetadata", [](py::handle& ctx, TF_Buffer& buf) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextExportRunMetadata(tensorflow::InputTFE_Context(ctx), &buf, |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextClearCaches", [](py::handle& o) { |
| TFE_ContextClearCaches(tensorflow::InputTFE_Context(o)); |
| }); |
| m.def("TFE_GetContextId", [](py::handle& ctx) { |
| return TFE_GetContextId(tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextGetDevicePlacementPolicy", [](py::handle& ctx) { |
| return TFE_ContextGetDevicePlacementPolicy( |
| tensorflow::InputTFE_Context(ctx)); |
| }); |
| m.def("TFE_ContextSetThreadLocalDevicePlacementPolicy", |
| [](py::handle& ctx, TFE_ContextDevicePlacementPolicy policy) { |
| TFE_ContextSetThreadLocalDevicePlacementPolicy( |
| tensorflow::InputTFE_Context(ctx), policy); |
| }); |
| m.def("TFE_ContextSetServerDef", [](py::handle& ctx, int keep_alive_secs, |
| py::bytes proto) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| tensorflow::Safe_TF_BufferPtr buf = |
| tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
| TFE_ContextSetServerDef(tensorflow::InputTFE_Context(ctx), keep_alive_secs, |
| buf.get()->data, buf.get()->length, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextUpdateServerDef", [](py::handle& ctx, int keep_alive_secs, |
| py::bytes proto) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| tensorflow::Safe_TF_BufferPtr buf = |
| tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
| Py_BEGIN_ALLOW_THREADS; |
| TFE_ContextUpdateServerDef(tensorflow::InputTFE_Context(ctx), |
| keep_alive_secs, buf.get()->data, |
| buf.get()->length, status.get()); |
| Py_END_ALLOW_THREADS; |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextCheckAlive", [](py::handle& ctx, const char* worker_name) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| bool output = TFE_ContextCheckAlive(tensorflow::InputTFE_Context(ctx), |
| worker_name, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }); |
| m.def("TFE_ContextSyncExecutors", [](py::handle& ctx) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextClearExecutors", [](py::handle& ctx) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextAsyncWait(tensorflow::InputTFE_Context(ctx), status.get()); |
| // NOTE: different from TFE_ContextSyncExecutors that raises potential |
| // errors, deliberately ignore executor statuses in cleanup. |
| }); |
| m.def("TFE_ContextSetSoftDevicePlacement", [](py::handle& ctx, bool enable) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, |
| status.get()); |
| }); |
| m.def("TFE_ContextSetLogDevicePlacement", [](py::handle& ctx, bool enable) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_ContextSetSoftDevicePlacement(tensorflow::InputTFE_Context(ctx), enable, |
| status.get()); |
| }); |
| |
| // TFE_Executor logic |
| m.def( |
| "TFE_NewExecutor", |
| [](const bool is_async) { |
| TFE_Executor* exc = TFE_NewExecutor(is_async); |
| return exc; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_DeleteExecutor", &TFE_DeleteExecutor); |
| m.def("TFE_ExecutorIsAsync", &TFE_ExecutorIsAsync); |
| m.def("TFE_ExecutorWaitForAllPendingNodes", [](TFE_Executor& exc) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| // NOTE: release Python GIL for pending PyFunc ops to be executed properly. |
| Py_BEGIN_ALLOW_THREADS; |
| TFE_ExecutorWaitForAllPendingNodes(&exc, status.get()); |
| Py_END_ALLOW_THREADS; |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ExecutorClearError", &TFE_ExecutorClearError); |
| m.def("TFE_ContextSetExecutorForThread", [](py::handle& ctx, |
| TFE_Executor& exc) { |
| TFE_ContextSetExecutorForThread(tensorflow::InputTFE_Context(ctx), &exc); |
| }); |
| m.def( |
| "TFE_ContextGetExecutorForThread", |
| [](py::handle& o) { |
| return TFE_ContextGetExecutorForThread(tensorflow::InputTFE_Context(o)); |
| }, |
| py::return_value_policy::reference); |
| |
| m.def("TFE_OpNameGetAttrType", |
| [](py::handle& ctx, const char* op_or_function_name, |
| const char* attr_name) { |
| int temp = 0; |
| unsigned char* is_list = reinterpret_cast<unsigned char*>(&temp); |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_OpNameGetAttrType(tensorflow::InputTFE_Context(ctx), |
| op_or_function_name, attr_name, |
| is_list, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| #if PY_MAJOR_VERSION < 3 |
| PyObject* output_pyo = PyInt_FromLong(output); |
| #else |
| PyObject* output_pyo = PyLong_FromLong(output); |
| #endif |
| if (*is_list == 1) { |
| PyObject* list = PyList_New(1); |
| PyList_SetItem(list, 0, output_pyo); |
| return tensorflow::PyoOrThrow(list); |
| } |
| return tensorflow::PyoOrThrow(output_pyo); |
| }); |
| m.def("TFE_Py_InitEagerTensor", [](const py::handle& o) { |
| return tensorflow::PyoOrThrow(TFE_Py_InitEagerTensor(o.ptr())); |
| }); |
| m.def("TFE_Py_PackEagerTensors", |
| [](const py::handle& context, const py::handle& handles) { |
| return tensorflow::TFE_Py_PackEagerTensors_wrapper(context, handles); |
| }); |
| m.def("TFE_Py_SetEagerTensorProfiler", &TFE_Py_SetEagerTensorProfiler); |
| m.def("TFE_Py_RegisterJVPFunction", [](const py::handle& o) { |
| return tensorflow::PyoOrThrow(TFE_Py_RegisterJVPFunction(o.ptr())); |
| }); |
| m.def("TFE_Py_RegisterGradientFunction", [](const py::handle& o) { |
| return tensorflow::PyoOrThrow(TFE_Py_RegisterGradientFunction(o.ptr())); |
| }); |
| m.def("TFE_Py_Execute", |
| [](const py::handle& context, const char* device_name, |
| const char* op_name, const py::handle& inputs, |
| const py::handle& attrs, const py::handle& num_outputs) { |
| return tensorflow::TFE_Py_ExecuteCancelable_wrapper( |
| context, device_name, op_name, inputs, attrs.ptr(), nullptr, |
| num_outputs); |
| }); |
| m.def( |
| "TFE_Py_ExecuteCancelable", |
| [](const py::handle& context, const char* device_name, |
| const char* op_name, const py::handle& inputs, const py::handle& attrs, |
| TFE_CancellationManager& cancellation_manager, |
| const py::handle& num_outputs) { |
| return tensorflow::TFE_Py_ExecuteCancelable_wrapper( |
| context, device_name, op_name, inputs, attrs.ptr(), |
| &cancellation_manager, num_outputs); |
| }); |
| m.def("TFE_Py_FastPathExecute", [](const py::args args) { |
| // TFE_Py_FastPathExecute requires error checking prior to returning. |
| return tensorflow::PyoOrThrow(TFE_Py_FastPathExecute_C(args.ptr())); |
| }); |
| m.def("TFE_Py_RecordGradient", |
| [](const py::handle& op_name, const py::handle& inputs, |
| const py::handle& attrs, const py::handle& results, |
| const py::handle& forward_pass_name_scope) { |
| return tensorflow::PyoOrThrow(TFE_Py_RecordGradient( |
| op_name.ptr(), inputs.ptr(), attrs.ptr(), results.ptr(), |
| forward_pass_name_scope.ptr())); |
| }); |
| m.def("TFE_Py_UID", []() { return tensorflow::PyoOrThrow(TFE_Py_UID()); }); |
| |
| // TFE_Py_Tape Logic |
| m.def("TFE_Py_TapeSetNew", [](const py::handle& persistent, |
| const py::handle& watch_accessed_variables) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_TapeSetNew(persistent.ptr(), watch_accessed_variables.ptr())); |
| }); |
| m.def("TFE_Py_TapeSetAdd", |
| [](const py::handle& tape) { TFE_Py_TapeSetAdd(tape.ptr()); }); |
| m.def("TFE_Py_TapeSetRemove", |
| [](const py::handle& tape) { TFE_Py_TapeSetRemove(tape.ptr()); }); |
| m.def("TFE_Py_TapeSetStopOnThread", &TFE_Py_TapeSetStopOnThread); |
| m.def("TFE_Py_TapeSetRestartOnThread", &TFE_Py_TapeSetRestartOnThread); |
| m.def("TFE_Py_TapeSetIsStopped", |
| []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsStopped()); }); |
| m.def("TFE_Py_TapeSetIsEmpty", |
| []() { return tensorflow::PyoOrThrow(TFE_Py_TapeSetIsEmpty()); }); |
| m.def("TFE_Py_TapeSetShouldRecordBackprop", [](const py::handle& tensors) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_TapeSetShouldRecordBackprop(tensors.ptr())); |
| }); |
| m.def("TFE_Py_TapeSetPossibleGradientTypes", [](const py::handle& tensors) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_TapeSetPossibleGradientTypes(tensors.ptr())); |
| }); |
| m.def("TFE_Py_TapeSetDeleteTrace", &TFE_Py_TapeSetDeleteTrace); |
| m.def("TFE_Py_TapeSetRecordOperation", |
| [](const py::handle& op_type, const py::handle& output_tensors, |
| const py::handle& input_tensors, const py::handle& backward_function, |
| const py::handle& forward_function) { |
| return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperation( |
| op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
| backward_function.ptr(), forward_function.ptr())); |
| }); |
| m.def( |
| "TFE_Py_TapeSetRecordOperationBackprop", |
| [](const py::handle& op_type, const py::handle& output_tensors, |
| const py::handle& input_tensors, const py::handle& backward_function) { |
| return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationBackprop( |
| op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
| backward_function.ptr())); |
| }); |
| m.def( |
| "TFE_Py_TapeSetRecordOperationForwardprop", |
| [](const py::handle& op_type, const py::handle& output_tensors, |
| const py::handle& input_tensors, const py::handle& backward_function, |
| const py::handle& forwardprop_output_indices) { |
| return tensorflow::PyoOrThrow(TFE_Py_TapeSetRecordOperationForwardprop( |
| op_type.ptr(), output_tensors.ptr(), input_tensors.ptr(), |
| backward_function.ptr(), forwardprop_output_indices.ptr())); |
| }); |
| m.def("TFE_Py_TapeGradient", |
| [](const py::handle& tape, const py::handle& target, |
| const py::handle& sources, const py::handle& output_gradients, |
| const py::handle& sources_raw, |
| const py::handle& unconnected_gradients) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| PyObject* output = TFE_Py_TapeGradient( |
| tape.ptr(), target.ptr(), sources.ptr(), output_gradients.ptr(), |
| sources_raw.ptr(), unconnected_gradients.ptr(), status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return tensorflow::PyoOrThrow(output); |
| }); |
| |
| m.def("TFE_Py_TapeVariableAccessed", [](const py::handle& variable) { |
| TFE_Py_TapeVariableAccessed(variable.ptr()); |
| }); |
| m.def("TFE_Py_TapeWatch", |
| [](const py::handle& tape, const py::handle& tensor) { |
| TFE_Py_TapeWatch(tape.ptr(), tensor.ptr()); |
| }); |
| m.def("TFE_Py_TapeWatchVariable", |
| [](const py::handle& tape, const py::handle& variable) { |
| TFE_Py_TapeWatchVariable(tape.ptr(), variable.ptr()); |
| }); |
| m.def("TFE_Py_TapeWatchedVariables", [](const py::handle& tape) { |
| return tensorflow::PyoOrThrow(TFE_Py_TapeWatchedVariables(tape.ptr())); |
| }); |
| |
| // TFE_Py_VariableWatcher logic. |
| m.def("TFE_Py_VariableWatcherNew", |
| []() { return tensorflow::PyoOrThrow(TFE_Py_VariableWatcherNew()); }); |
| m.def("TFE_Py_VariableWatcherRemove", [](const py::handle& variable_watcher) { |
| TFE_Py_VariableWatcherRemove(variable_watcher.ptr()); |
| }); |
| m.def("TFE_Py_VariableWatcherVariableAccessed", |
| [](const py::handle& variable) { |
| TFE_Py_VariableWatcherVariableAccessed(variable.ptr()); |
| }); |
| m.def("TFE_Py_VariableWatcherWatchedVariables", |
| [](const py::handle& variable_watcher) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_VariableWatcherWatchedVariables(variable_watcher.ptr())); |
| }); |
| |
| // TFE_Py_ForwardAccumulator logic. |
| m.def("TFE_Py_ForwardAccumulatorNew", [](bool use_batch) { |
| return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorNew(use_batch)); |
| }); |
| |
| m.def("TFE_Py_ForwardAccumulatorSetAdd", [](const py::handle& accumulator) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_ForwardAccumulatorSetAdd(accumulator.ptr())); |
| }); |
| m.def("TFE_Py_ForwardAccumulatorSetRemove", |
| [](const py::handle& accumulator) { |
| TFE_Py_ForwardAccumulatorSetRemove(accumulator.ptr()); |
| }); |
| |
| m.def("TFE_Py_ForwardAccumulatorWatch", |
| [](const py::handle& accumulator, const py::handle& tensor, |
| const py::handle& tangent) { |
| TFE_Py_ForwardAccumulatorWatch(accumulator.ptr(), tensor.ptr(), |
| tangent.ptr()); |
| }); |
| m.def("TFE_Py_ForwardAccumulatorJVP", |
| [](const py::handle& accumulator, const py::handle& tensor) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_ForwardAccumulatorJVP(accumulator.ptr(), tensor.ptr())); |
| }); |
| m.def("TFE_Py_ForwardAccumulatorPushState", []() { |
| return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPushState()); |
| }); |
| m.def("TFE_Py_ForwardAccumulatorPopState", []() { |
| return tensorflow::PyoOrThrow(TFE_Py_ForwardAccumulatorPopState()); |
| }); |
| m.def("TFE_Py_PackJVPs", [](const py::handle& tensors) { |
| return tensorflow::PyoOrThrow(TFE_Py_PackJVPs(tensors.ptr())); |
| }); |
| |
| // TFE_ContextOptions Logic |
| m.def("TFE_NewContextOptions", &TFE_NewContextOptions, |
| py::return_value_policy::reference); |
| m.def("TFE_ContextOptionsSetConfig", [](TFE_ContextOptions* options, |
| py::bytes proto) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| tensorflow::Safe_TF_BufferPtr buf = |
| tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
| TFE_ContextOptionsSetConfig(options, buf.get()->data, buf.get()->length, |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_ContextOptionsSetDevicePlacementPolicy", |
| &TFE_ContextOptionsSetDevicePlacementPolicy); |
| m.def("TFE_ContextOptionsSetLazyRemoteInputsCopy", |
| &TFE_ContextOptionsSetLazyRemoteInputsCopy); |
| m.def("TFE_ContextOptionsSetTfrt", &TFE_ContextOptionsSetTfrt); |
| m.def("TFE_ContextOptionsSetAsync", &TFE_ContextOptionsSetAsync); |
| m.def("TFE_DeleteContextOptions", &TFE_DeleteContextOptions, |
| py::return_value_policy::reference); |
| |
| // TFE_Py_TensorShape Logic |
| m.def("TFE_Py_TensorShapeSlice", |
| [](const py::handle& tensors, int slice_dim) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_TensorShapeSlice(tensors.ptr(), slice_dim)); |
| }); |
| m.def("TFE_Py_TensorShapeOnDevice", [](const py::handle& tensors, |
| int slice_dim) { |
| return tensorflow::PyoOrThrow(TFE_Py_TensorShapeOnDevice(tensors.ptr())); |
| }); |
| m.def("TFE_Py_EnableInteractivePythonLogging", |
| &TFE_Py_EnableInteractivePythonLogging); |
| |
| // Additional Context Logic |
| m.def("TFE_Py_SetEagerContext", [](const py::handle& o) { |
| return tensorflow::PyoOrThrow(TFE_Py_SetEagerContext(o.ptr())); |
| }); |
| m.def("TFE_ContextStartStep", [](py::handle& o) { |
| TFE_ContextStartStep(tensorflow::InputTFE_Context(o.ptr())); |
| }); |
| m.def("TFE_ContextEndStep", [](py::handle& o) { |
| TFE_ContextEndStep(tensorflow::InputTFE_Context(o.ptr())); |
| }); |
| m.def("TFE_Py_RegisterVSpace", [](const py::handle& o) { |
| return tensorflow::PyoOrThrow(TFE_Py_RegisterVSpace(o.ptr())); |
| }); |
| m.def("TFE_Py_EncodeArg", |
| [](const py::handle& o, bool include_tensor_ranks_only) { |
| return tensorflow::PyoOrThrow( |
| TFE_Py_EncodeArg(o.ptr(), include_tensor_ranks_only)); |
| }); |
| m.def("TFE_EnableCollectiveOps", [](const py::handle& ctx, py::bytes proto) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| tensorflow::Safe_TF_BufferPtr buf = |
| tensorflow::make_safe(tensorflow::ProtoStringToTFBuffer(proto.ptr())); |
| TFE_EnableCollectiveOps(tensorflow::InputTFE_Context(ctx), buf.get()->data, |
| buf.get()->length, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TFE_AbortCollectiveOps", [](const py::handle& ctx, int code, |
| const char* message) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TF_SetStatus(status.get(), static_cast<TF_Code>(code), message); |
| TFE_AbortCollectiveOps(tensorflow::InputTFE_Context(ctx), status.get()); |
| }); |
| m.def("TFE_CollectiveOpsCheckPeerHealth", |
| [](const py::handle& ctx, const char* task) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| TFE_CollectiveOpsCheckPeerHealth(tensorflow::InputTFE_Context(ctx), |
| task, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| m.def("TF_ListPhysicalDevices", &tensorflow::TF_ListPhysicalDevices); |
| m.def("TF_GetDeviceDetails", &tensorflow::TF_GetDeviceDetails); |
| m.def("TF_DeleteDeviceList", &TF_DeleteDeviceList, |
| py::return_value_policy::reference); |
| m.def("TF_DeviceListCount", &TF_DeviceListCount); |
| m.def("TF_DeviceListName", [](const TF_DeviceList* list, int index) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TF_DeviceListName(list, index, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }); |
| m.def("TF_DeviceListType", [](const TF_DeviceList* list, int index) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TF_DeviceListType(list, index, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }); |
| |
| m.def("TF_PickUnusedPortOrDie", &TF_PickUnusedPortOrDie); |
| |
| // TFE_MonitoringCounter Logic |
| m.def("TFE_MonitoringCounterCellIncrementBy", |
| &TFE_MonitoringCounterCellIncrementBy); |
| m.def("TFE_MonitoringCounterCellValue", &TFE_MonitoringCounterCellValue); |
| m.def( |
| "TFE_MonitoringNewCounter0", |
| [](const char* name, const char* description) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewCounter0(name, status.get(), description); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteCounter0", &TFE_MonitoringDeleteCounter0, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellCounter0", &TFE_MonitoringGetCellCounter0, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewCounter1", |
| [](const char* name, const char* description, const char* label1) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewCounter1(name, status.get(), description, label1); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteCounter1", &TFE_MonitoringDeleteCounter1, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellCounter1", &TFE_MonitoringGetCellCounter1, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewCounter2", |
| [](const char* name, const char* description, const char* label1, |
| const char* label2) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewCounter2(name, status.get(), description, |
| label1, label2); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteCounter2", &TFE_MonitoringDeleteCounter2, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellCounter2", &TFE_MonitoringGetCellCounter2, |
| py::return_value_policy::reference); |
| |
| // TFE_MonitoringIntGauge Logic |
| m.def("TFE_MonitoringIntGaugeCellSet", &TFE_MonitoringIntGaugeCellSet); |
| m.def("TFE_MonitoringIntGaugeCellValue", &TFE_MonitoringIntGaugeCellValue); |
| m.def( |
| "TFE_MonitoringNewIntGauge0", |
| [](const char* name, const char* description) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewIntGauge0(name, status.get(), description); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteIntGauge0", &TFE_MonitoringDeleteIntGauge0, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellIntGauge0", &TFE_MonitoringGetCellIntGauge0, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewIntGauge1", |
| [](const char* name, const char* description, const char* label1) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewIntGauge1(name, status.get(), description, label1); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteIntGauge1", &TFE_MonitoringDeleteIntGauge1, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellIntGauge1", &TFE_MonitoringGetCellIntGauge1, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewIntGauge2", |
| [](const char* name, const char* description, const char* label1, |
| const char* label2) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewIntGauge2(name, status.get(), |
| description, label1, label2); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteIntGauge2", &TFE_MonitoringDeleteIntGauge2, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellIntGauge2", &TFE_MonitoringGetCellIntGauge2, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringStringGaugeCellSet", &TFE_MonitoringStringGaugeCellSet); |
| m.def("TFE_MonitoringStringGaugeCellValue", |
| &TFE_MonitoringStringGaugeCellValue); |
| m.def( |
| "TFE_MonitoringNewStringGauge0", |
| [](const char* name, const char* description) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewStringGauge0(name, status.get(), description); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| |
| // TFE_MonitoringStringGauge Logic |
| m.def("TFE_MonitoringDeleteStringGauge0", &TFE_MonitoringDeleteStringGauge0); |
| m.def("TFE_MonitoringGetCellStringGauge0", &TFE_MonitoringGetCellStringGauge0, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewStringGauge1", |
| [](const char* name, const char* description, const char* label1) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewStringGauge1(name, status.get(), |
| description, label1); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteStringGauge1", &TFE_MonitoringDeleteStringGauge1); |
| m.def("TFE_MonitoringGetCellStringGauge1", &TFE_MonitoringGetCellStringGauge1, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewStringGauge2", |
| [](const char* name, const char* description, const char* label1, |
| const char* label2) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewStringGauge2( |
| name, status.get(), description, label1, label2); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteStringGauge2", &TFE_MonitoringDeleteStringGauge2); |
| m.def("TFE_MonitoringGetCellStringGauge2", &TFE_MonitoringGetCellStringGauge2, |
| py::return_value_policy::reference); |
| |
| // TFE_MonitoringBoolGauge Logic |
| m.def("TFE_MonitoringBoolGaugeCellSet", &TFE_MonitoringBoolGaugeCellSet); |
| m.def("TFE_MonitoringBoolGaugeCellValue", &TFE_MonitoringBoolGaugeCellValue); |
| m.def( |
| "TFE_MonitoringNewBoolGauge0", |
| [](const char* name, const char* description) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewBoolGauge0(name, status.get(), description); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteBoolGauge0", &TFE_MonitoringDeleteBoolGauge0, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellBoolGauge0", &TFE_MonitoringGetCellBoolGauge0, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewBoolGauge1", |
| [](const char* name, const char* description, const char* label1) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewBoolGauge1(name, status.get(), |
| description, label1); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteBoolGauge1", &TFE_MonitoringDeleteBoolGauge1, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellBoolGauge1", &TFE_MonitoringGetCellBoolGauge1, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewBoolGauge2", |
| [](const char* name, const char* description, const char* label1, |
| const char* label2) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewBoolGauge2(name, status.get(), |
| description, label1, label2); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteBoolGauge2", &TFE_MonitoringDeleteBoolGauge2, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellBoolGauge2", &TFE_MonitoringGetCellBoolGauge2, |
| py::return_value_policy::reference); |
| |
| // TFE_MonitoringSampler Logic |
| m.def("TFE_MonitoringSamplerCellAdd", &TFE_MonitoringSamplerCellAdd); |
| m.def("TFE_MonitoringSamplerCellValue", &TFE_MonitoringSamplerCellValue); |
| m.def("TFE_MonitoringNewExponentialBuckets", |
| &TFE_MonitoringNewExponentialBuckets, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteBuckets", &TFE_MonitoringDeleteBuckets, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewSampler0", |
| [](const char* name, TFE_MonitoringBuckets* buckets, |
| const char* description) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = |
| TFE_MonitoringNewSampler0(name, buckets, status.get(), description); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteSampler0", &TFE_MonitoringDeleteSampler0, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellSampler0", &TFE_MonitoringGetCellSampler0, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewSampler1", |
| [](const char* name, TFE_MonitoringBuckets* buckets, |
| const char* description, const char* label1) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewSampler1(name, buckets, status.get(), |
| description, label1); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteSampler1", &TFE_MonitoringDeleteSampler1, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellSampler1", &TFE_MonitoringGetCellSampler1, |
| py::return_value_policy::reference); |
| m.def( |
| "TFE_MonitoringNewSampler2", |
| [](const char* name, TFE_MonitoringBuckets* buckets, |
| const char* description, const char* label1, const char* label2) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| auto output = TFE_MonitoringNewSampler2(name, buckets, status.get(), |
| description, label1, label2); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| return output; |
| }, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringDeleteSampler2", &TFE_MonitoringDeleteSampler2, |
| py::return_value_policy::reference); |
| m.def("TFE_MonitoringGetCellSampler2", &TFE_MonitoringGetCellSampler2, |
| py::return_value_policy::reference); |
| |
| // TFE_CancellationManager Logic |
| m.def("TFE_NewCancellationManager", &TFE_NewCancellationManager, |
| py::return_value_policy::reference); |
| m.def("TFE_CancellationManagerIsCancelled", |
| &TFE_CancellationManagerIsCancelled); |
| m.def("TFE_CancellationManagerStartCancel", |
| &TFE_CancellationManagerStartCancel); |
| m.def("TFE_DeleteCancellationManager", &TFE_DeleteCancellationManager, |
| py::return_value_policy::reference); |
| |
| m.def("TFE_ClearScalarCache", &tensorflow::TFE_ClearScalarCache); |
| |
| // Util buffer helper functions |
| m.def("TF_NewBufferFromString", &TF_NewBufferFromString, |
| py::return_value_policy::reference); |
| |
| // DLPack functions |
| m.def("TFE_ToDlpackCapsule", [](py::handle& o) { |
| PyObject* eager_tensor_pyobject_ptr = o.ptr(); |
| TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr); |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| |
| py::capsule capsule( |
| dlm_ptr, tensorflow::kDlTensorCapsuleName, [](PyObject* capsule) { |
| if (PyCapsule_IsValid(capsule, tensorflow::kDlTensorCapsuleName)) { |
| void* dlm_rptr = |
| PyCapsule_GetPointer(capsule, tensorflow::kDlTensorCapsuleName); |
| if (dlm_rptr) { |
| tensorflow::TFE_CallDLManagedTensorDeleter(dlm_rptr); |
| PyCapsule_SetDestructor(capsule, nullptr); |
| } |
| } |
| }); |
| return capsule; |
| }); |
| |
| m.def("TFE_FromDlpackCapsule", [](const py::capsule& pycapsule, |
| const py::handle& context) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| if (absl::string_view(pycapsule.name()) != |
| tensorflow::kDlTensorCapsuleName) { |
| status->status = tensorflow::errors::InvalidArgument( |
| "DLPack tensor must be a capsule with name \"dltensor\", got \"%s\". " |
| "Note that a DLPack tensor may be consumed at most once.", |
| absl::string_view(pycapsule.name())); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| } |
| |
| TFE_TensorHandle* thandle = tensorflow::TFE_HandleFromDLPack( |
| pycapsule, status.get(), tensorflow::InputTFE_Context(context)); |
| |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| |
| PyCapsule_SetName(pycapsule.ptr(), "used_dltensor"); |
| PyCapsule_SetDestructor(pycapsule.ptr(), nullptr); |
| |
| PyObject* pyhandle = EagerTensorFromHandle(thandle); |
| return tensorflow::PyoOrThrow(pyhandle); |
| }); |
| |
| m.def("TFE_Py_RegisterCustomDevice", [](const py::handle& context, |
| const py::capsule& device, |
| const char* device_name, |
| const py::capsule& device_info) { |
| tensorflow::Safe_TF_StatusPtr status = |
| tensorflow::make_safe(TF_NewStatus()); |
| if (absl::string_view(device.name()) != "TFE_CustomDevice") { |
| status->status = tensorflow::errors::InvalidArgument( |
| "Expected a capsule named 'TFE_CustomDevice' for the `device` " |
| "argument, got ", |
| absl::string_view(device.name())); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| } |
| if (absl::string_view(device_info.name()) != |
| "TFE_CustomDevice_DeviceInfo") { |
| status->status = tensorflow::errors::InvalidArgument( |
| "Expected a capsule named 'TFE_CustomDevice_DeviceInfo' for " |
| "the `device_info` argument, got ", |
| absl::string_view(device_info.name())); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| } |
| // TFE_RegisterCustomDevice takes ownership |
| PyCapsule_SetDestructor(device_info.ptr(), nullptr); |
| TFE_RegisterCustomDevice( |
| tensorflow::InputTFE_Context(context), |
| *reinterpret_cast<TFE_CustomDevice*>( |
| PyCapsule_GetPointer(device.ptr(), "TFE_CustomDevice")), |
| device_name, |
| PyCapsule_GetPointer(device_info.ptr(), "TFE_CustomDevice_DeviceInfo"), |
| status.get()); |
| tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
| }); |
| |
| // C API Enum |
| |
| py::enum_<TFE_ContextDevicePlacementPolicy>( |
| m, "TFE_ContextDevicePlacementPolicy") |
| .value("TFE_DEVICE_PLACEMENT_EXPLICIT", TFE_DEVICE_PLACEMENT_EXPLICIT) |
| .value("TFE_DEVICE_PLACEMENT_WARN", TFE_DEVICE_PLACEMENT_WARN) |
| .value("TFE_DEVICE_PLACEMENT_SILENT", TFE_DEVICE_PLACEMENT_SILENT) |
| .value("TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32", |
| TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32) |
| .export_values(); |
| |
| py::enum_<TF_AttrType>(m, "TF_AttrType") |
| .value("TF_ATTR_STRING", TF_ATTR_STRING) |
| .value("TF_ATTR_INT", TF_ATTR_INT) |
| .value("TF_ATTR_FLOAT", TF_ATTR_FLOAT) |
| .value("TF_ATTR_BOOL", TF_ATTR_BOOL) |
| .value("TF_ATTR_TYPE", TF_ATTR_TYPE) |
| .value("TF_ATTR_SHAPE", TF_ATTR_SHAPE) |
| .value("TF_ATTR_TENSOR", TF_ATTR_TENSOR) |
| .value("TF_ATTR_PLACEHOLDER", TF_ATTR_PLACEHOLDER) |
| .value("TF_ATTR_FUNC", TF_ATTR_FUNC) |
| .export_values(); |
| }; |