| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <atomic> |
| #include <cstring> |
| #include <unordered_map> |
| |
| #include "absl/strings/str_cat.h" |
| #include "absl/types/variant.h" |
| #include "tensorflow/c/c_api.h" |
| #include "tensorflow/c/c_api_internal.h" |
| #include "tensorflow/c/eager/c_api.h" |
| #include "tensorflow/c/eager/c_api_internal.h" |
| #include "tensorflow/c/eager/tape.h" |
| #include "tensorflow/c/eager/tfe_context_internal.h" |
| #include "tensorflow/c/eager/tfe_op_internal.h" |
| #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
| #include "tensorflow/c/tf_status.h" |
| #include "tensorflow/core/framework/types.pb.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/gtl/cleanup.h" |
| #include "tensorflow/core/lib/gtl/compactptrset.h" |
| #include "tensorflow/core/lib/gtl/flatmap.h" |
| #include "tensorflow/core/lib/gtl/flatset.h" |
| #include "tensorflow/core/lib/strings/strcat.h" |
| #include "tensorflow/core/lib/strings/stringprintf.h" |
| #include "tensorflow/core/platform/casts.h" |
| #include "tensorflow/core/platform/mutex.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/platform/types.h" |
| #include "tensorflow/core/profiler/lib/traceme.h" |
| #include "tensorflow/core/util/abstract_stack_trace.h" |
| #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" |
| #include "tensorflow/python/eager/pywrap_tensor.h" |
| #include "tensorflow/python/eager/pywrap_tfe.h" |
| #include "tensorflow/python/lib/core/py_util.h" |
| #include "tensorflow/python/lib/core/safe_ptr.h" |
| #include "tensorflow/python/util/stack_trace.h" |
| #include "tensorflow/python/util/util.h" |
| |
| using tensorflow::string; |
| using tensorflow::strings::Printf; |
| |
| namespace { |
| // NOTE: Items are retrieved from and returned to these unique_ptrs, and they |
| // act as arenas. This is important if the same thread requests 2 items without |
| // releasing one. |
| // The following sequence of events on the same thread will still succeed: |
| // - GetOp <- Returns existing. |
| // - GetOp <- Allocates and returns a new pointer. |
| // - ReleaseOp <- Sets the item in the unique_ptr. |
| // - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one. |
| // This occurs when a PyFunc kernel is run. This behavior makes it safe in that |
| // case, as well as the case where python decides to reuse the underlying |
| // C++ thread in 2 python threads case. |
| struct OpDeleter { |
| void operator()(TFE_Op* op) const { TFE_DeleteOp(op); } |
| }; |
| thread_local std::unordered_map<TFE_Context*, |
| std::unique_ptr<TFE_Op, OpDeleter>> |
| thread_local_eager_operation_map; // NOLINT |
| thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT |
| nullptr; |
| |
| std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) { |
| auto it = thread_local_eager_operation_map.find(ctx); |
| if (it == thread_local_eager_operation_map.end()) { |
| return nullptr; |
| } |
| return std::move(it->second); |
| } |
| |
| TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name, |
| const char* raw_device_name, TF_Status* status) { |
| auto op = ReleaseThreadLocalOp(ctx); |
| if (!op) { |
| op.reset(tensorflow::wrap(tensorflow::unwrap(ctx)->CreateOperation())); |
| } |
| status->status = |
| tensorflow::unwrap(op.get())->Reset(op_or_function_name, raw_device_name); |
| if (!status->status.ok()) { |
| op.reset(); |
| } |
| return op.release(); |
| } |
| |
| void ReturnOp(TFE_Context* ctx, TFE_Op* op) { |
| if (op) { |
| tensorflow::unwrap(op)->Clear(); |
| thread_local_eager_operation_map[ctx].reset(op); |
| } |
| } |
| |
| TF_Status* ReleaseThreadLocalStatus() { |
| if (thread_local_tf_status == nullptr) { |
| return nullptr; |
| } |
| return thread_local_tf_status.release(); |
| } |
| |
| struct InputInfo { |
| InputInfo(int i, bool is_list) : i(i), is_list(is_list) {} |
| |
| int i; |
| bool is_list = false; |
| }; |
| |
| // Takes in output gradients, returns input gradients. |
| typedef std::function<PyObject*(PyObject*, |
| const std::vector<tensorflow::int64>&)> |
| PyBackwardFunction; |
| |
| using AttrToInputsMap = |
| tensorflow::gtl::FlatMap<string, |
| tensorflow::gtl::InlinedVector<InputInfo, 4>>; |
| |
| tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() { |
| static auto* all_attr_to_input_maps = |
| new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>; |
| return all_attr_to_input_maps; |
| } |
| |
| // This function doesn't use a lock, since we depend on the GIL directly. |
| AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) { |
| #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4 |
| DCHECK(PyGILState_Check()) |
| << "This function needs to hold the GIL when called."; |
| #endif |
| auto* all_attr_to_input_maps = GetAllAttrToInputsMaps(); |
| auto* output = |
| tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name()); |
| if (output != nullptr) { |
| return output; |
| } |
| |
| std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap); |
| |
| // Store a list of InputIndex -> List of corresponding inputs. |
| for (int i = 0; i < op_def.input_arg_size(); i++) { |
| if (!op_def.input_arg(i).type_attr().empty()) { |
| auto it = m->find(op_def.input_arg(i).type_attr()); |
| if (it == m->end()) { |
| it = m->insert({op_def.input_arg(i).type_attr(), {}}).first; |
| } |
| it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty()); |
| } |
| } |
| |
| auto* retval = m.get(); |
| (*all_attr_to_input_maps)[op_def.name()] = m.release(); |
| |
| return retval; |
| } |
| |
| // This function doesn't use a lock, since we depend on the GIL directly. |
| tensorflow::gtl::FlatMap< |
| string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>* |
| GetAllAttrToDefaultsMaps() { |
| static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap< |
| string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>; |
| return all_attr_to_defaults_maps; |
| } |
| |
| tensorflow::gtl::FlatMap<string, tensorflow::DataType>* |
| GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) { |
| #if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4 |
| DCHECK(PyGILState_Check()) |
| << "This function needs to hold the GIL when called."; |
| #endif |
| auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps(); |
| auto* output = |
| tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name()); |
| if (output != nullptr) { |
| return output; |
| } |
| |
| auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>; |
| |
| for (const auto& attr : op_def.attr()) { |
| if (attr.type() == "type" && attr.has_default_value()) { |
| new_map->insert({attr.name(), attr.default_value().type()}); |
| } |
| } |
| |
| (*all_attr_to_defaults_maps)[op_def.name()] = new_map; |
| |
| return new_map; |
| } |
| |
| struct FastPathOpExecInfo { |
| TFE_Context* ctx; |
| const char* device_name; |
| |
| bool run_callbacks; |
| bool run_post_exec_callbacks; |
| bool run_gradient_callback; |
| |
| // The op name of the main op being executed. |
| PyObject* name; |
| // The op type name of the main op being executed. |
| PyObject* op_name; |
| PyObject* callbacks; |
| |
| // All the args passed into the FastPathOpExecInfo. |
| PyObject* args; |
| |
| // DTypes can come from another input that has the same attr. So build that |
| // map. |
| const AttrToInputsMap* attr_to_inputs_map; |
| const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes; |
| tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes; |
| }; |
| |
| #define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \ |
| bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \ |
| type* value) { \ |
| if (check_fn(py_value)) { \ |
| *value = static_cast<type>(parse_fn(py_value)); \ |
| return true; \ |
| } else { \ |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, \ |
| tensorflow::strings::StrCat( \ |
| "Expecting " #type " value for attr ", key, ", got ", \ |
| py_value->ob_type->tp_name) \ |
| .c_str()); \ |
| return false; \ |
| } \ |
| } |
| |
| #if PY_MAJOR_VERSION >= 3 |
| PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong) |
| PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong) |
| #else |
| PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong) |
| #endif |
| PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble) |
| #undef PARSE_VALUE |
| |
| #if PY_MAJOR_VERSION < 3 |
| bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status, |
| int64_t* value) { |
| if (PyInt_Check(py_value)) { |
| *value = static_cast<int64_t>(PyInt_AsLong(py_value)); |
| return true; |
| } else if (PyLong_Check(py_value)) { |
| *value = static_cast<int64_t>(PyLong_AsLong(py_value)); |
| return true; |
| } |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat("Expecting int or long value for attr ", key, |
| ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| #endif |
| |
| Py_ssize_t TensorShapeNumDims(PyObject* value) { |
| const auto size = PySequence_Size(value); |
| if (size == -1) { |
| // TensorShape.__len__ raises an error in the scenario where the shape is an |
| // unknown, which needs to be cleared. |
| // TODO(nareshmodi): ensure that this is actually a TensorShape. |
| PyErr_Clear(); |
| } |
| return size; |
| } |
| |
| bool IsInteger(PyObject* py_value) { |
| #if PY_MAJOR_VERSION >= 3 |
| return PyLong_Check(py_value); |
| #else |
| return PyInt_Check(py_value) || PyLong_Check(py_value); |
| #endif |
| } |
| |
| // This function considers a Dimension._value of None to be valid, and sets the |
| // value to be -1 in that case. |
| bool ParseDimensionValue(const string& key, PyObject* py_value, |
| TF_Status* status, int64_t* value) { |
| if (IsInteger(py_value)) { |
| return ParseInt64Value(key, py_value, status, value); |
| } |
| |
| tensorflow::Safe_PyObjectPtr dimension_value( |
| PyObject_GetAttrString(py_value, "_value")); |
| if (dimension_value == nullptr) { |
| PyErr_Clear(); |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat("Expecting a Dimension for attr ", key, |
| ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| |
| if (dimension_value.get() == Py_None) { |
| *value = -1; |
| return true; |
| } |
| |
| return ParseInt64Value(key, dimension_value.get(), status, value); |
| } |
| |
| bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status, |
| tensorflow::StringPiece* value) { |
| if (PyBytes_Check(py_value)) { |
| Py_ssize_t size = 0; |
| char* buf = nullptr; |
| if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false; |
| *value = tensorflow::StringPiece(buf, size); |
| return true; |
| } |
| #if PY_MAJOR_VERSION >= 3 |
| if (PyUnicode_Check(py_value)) { |
| Py_ssize_t size = 0; |
| const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size); |
| if (buf == nullptr) return false; |
| *value = tensorflow::StringPiece(buf, size); |
| return true; |
| } |
| #endif |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat("Expecting a string value for attr ", key, |
| ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| |
| bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status, |
| unsigned char* value) { |
| *value = PyObject_IsTrue(py_value); |
| return true; |
| } |
| |
| // The passed in py_value is expected to be an object of the python type |
| // dtypes.DType or an int. |
| bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status, |
| int* value) { |
| if (IsInteger(py_value)) { |
| return ParseIntValue(key, py_value, status, value); |
| } |
| |
| tensorflow::Safe_PyObjectPtr py_type_enum( |
| PyObject_GetAttrString(py_value, "_type_enum")); |
| if (py_type_enum == nullptr) { |
| PyErr_Clear(); |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key, |
| ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| |
| return ParseIntValue(key, py_type_enum.get(), status, value); |
| } |
| |
| bool SetOpAttrList( |
| TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list, |
| TF_AttrType type, |
| tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, |
| TF_Status* status) { |
| if (!PySequence_Check(py_list)) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat("Expecting sequence value for attr ", key, |
| ", got ", py_list->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| const int num_values = PySequence_Size(py_list); |
| if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values; |
| |
| #define PARSE_LIST(c_type, parse_fn) \ |
| std::unique_ptr<c_type[]> values(new c_type[num_values]); \ |
| for (int i = 0; i < num_values; ++i) { \ |
| tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \ |
| if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \ |
| } |
| |
| if (type == TF_ATTR_STRING) { |
| std::unique_ptr<const void*[]> values(new const void*[num_values]); |
| std::unique_ptr<size_t[]> lengths(new size_t[num_values]); |
| for (int i = 0; i < num_values; ++i) { |
| tensorflow::StringPiece value; |
| tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
| if (!ParseStringValue(key, py_value.get(), status, &value)) return false; |
| values[i] = value.data(); |
| lengths[i] = value.size(); |
| } |
| TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); |
| } else if (type == TF_ATTR_INT) { |
| PARSE_LIST(int64_t, ParseInt64Value); |
| TFE_OpSetAttrIntList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_FLOAT) { |
| PARSE_LIST(float, ParseFloatValue); |
| TFE_OpSetAttrFloatList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_BOOL) { |
| PARSE_LIST(unsigned char, ParseBoolValue); |
| TFE_OpSetAttrBoolList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_TYPE) { |
| PARSE_LIST(int, ParseTypeValue); |
| TFE_OpSetAttrTypeList(op, key, |
| reinterpret_cast<const TF_DataType*>(values.get()), |
| num_values); |
| } else if (type == TF_ATTR_SHAPE) { |
| // Make one pass through the input counting the total number of |
| // dims across all the input lists. |
| int total_dims = 0; |
| for (int i = 0; i < num_values; ++i) { |
| tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
| if (py_value.get() != Py_None) { |
| if (!PySequence_Check(py_value.get())) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat( |
| "Expecting None or sequence value for element", i, |
| " of attr ", key, ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| const auto size = TensorShapeNumDims(py_value.get()); |
| if (size >= 0) { |
| total_dims += size; |
| } |
| } |
| } |
| // Allocate a buffer that can fit all of the dims together. |
| std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); |
| // Copy the input dims into the buffer and set dims to point to |
| // the start of each list's dims. |
| std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); |
| std::unique_ptr<int[]> num_dims(new int[num_values]); |
| int64_t* offset = buffer.get(); |
| for (int i = 0; i < num_values; ++i) { |
| tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
| if (py_value.get() == Py_None) { |
| dims[i] = nullptr; |
| num_dims[i] = -1; |
| } else { |
| const auto size = TensorShapeNumDims(py_value.get()); |
| if (size == -1) { |
| dims[i] = nullptr; |
| num_dims[i] = -1; |
| continue; |
| } |
| dims[i] = offset; |
| num_dims[i] = size; |
| for (int j = 0; j < size; ++j) { |
| tensorflow::Safe_PyObjectPtr inner_py_value( |
| PySequence_ITEM(py_value.get(), j)); |
| if (inner_py_value.get() == Py_None) { |
| *offset = -1; |
| } else if (!ParseDimensionValue(key, inner_py_value.get(), status, |
| offset)) { |
| return false; |
| } |
| ++offset; |
| } |
| } |
| } |
| TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, |
| status); |
| if (!status->status.ok()) return false; |
| } else if (type == TF_ATTR_FUNC) { |
| std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); |
| for (int i = 0; i < num_values; ++i) { |
| tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); |
| // Allow: |
| // (1) String function name, OR |
| // (2) A Python object with a .name attribute |
| // (A crude test for being a |
| // tensorflow.python.framework.function._DefinedFunction) |
| // (which is what the various "defun" or "Defun" decorators do). |
| // And in the future also allow an object that can encapsulate |
| // the function name and its attribute values. |
| tensorflow::StringPiece func_name; |
| if (!ParseStringValue(key, py_value.get(), status, &func_name)) { |
| PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name"); |
| if (name_attr == nullptr || |
| !ParseStringValue(key, name_attr, status, &func_name)) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat( |
| "unable to set function value attribute from a ", |
| py_value.get()->ob_type->tp_name, |
| " object. If you think this is an error, please file an " |
| "issue at " |
| "https://github.com/tensorflow/tensorflow/issues/new") |
| .c_str()); |
| return false; |
| } |
| } |
| funcs[i] = TFE_NewOp(ctx, func_name.data(), status); |
| if (!status->status.ok()) return false; |
| } |
| TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); |
| if (!status->status.ok()) return false; |
| } else { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| tensorflow::strings::StrCat("Attr ", key, |
| " has unhandled list type ", type) |
| .c_str()); |
| return false; |
| } |
| #undef PARSE_LIST |
| return true; |
| } |
| |
| TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func, |
| TF_Status* status) { |
| TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status); |
| for (const auto& attr : func.attr()) { |
| if (!status->status.ok()) return nullptr; |
| SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status); |
| if (!status->status.ok()) return nullptr; |
| } |
| return func_op; |
| } |
| |
| void SetOpAttrListDefault( |
| TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, |
| const char* key, TF_AttrType type, |
| tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, |
| TF_Status* status) { |
| if (type == TF_ATTR_STRING) { |
| int num_values = attr.default_value().list().s_size(); |
| std::unique_ptr<const void*[]> values(new const void*[num_values]); |
| std::unique_ptr<size_t[]> lengths(new size_t[num_values]); |
| (*attr_list_sizes)[key] = num_values; |
| for (int i = 0; i < num_values; i++) { |
| const string& v = attr.default_value().list().s(i); |
| values[i] = v.data(); |
| lengths[i] = v.size(); |
| } |
| TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values); |
| } else if (type == TF_ATTR_INT) { |
| int num_values = attr.default_value().list().i_size(); |
| std::unique_ptr<int64_t[]> values(new int64_t[num_values]); |
| (*attr_list_sizes)[key] = num_values; |
| for (int i = 0; i < num_values; i++) { |
| values[i] = attr.default_value().list().i(i); |
| } |
| TFE_OpSetAttrIntList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_FLOAT) { |
| int num_values = attr.default_value().list().f_size(); |
| std::unique_ptr<float[]> values(new float[num_values]); |
| (*attr_list_sizes)[key] = num_values; |
| for (int i = 0; i < num_values; i++) { |
| values[i] = attr.default_value().list().f(i); |
| } |
| TFE_OpSetAttrFloatList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_BOOL) { |
| int num_values = attr.default_value().list().b_size(); |
| std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]); |
| (*attr_list_sizes)[key] = num_values; |
| for (int i = 0; i < num_values; i++) { |
| values[i] = attr.default_value().list().b(i); |
| } |
| TFE_OpSetAttrBoolList(op, key, values.get(), num_values); |
| } else if (type == TF_ATTR_TYPE) { |
| int num_values = attr.default_value().list().type_size(); |
| std::unique_ptr<int[]> values(new int[num_values]); |
| (*attr_list_sizes)[key] = num_values; |
| for (int i = 0; i < num_values; i++) { |
| values[i] = attr.default_value().list().type(i); |
| } |
| TFE_OpSetAttrTypeList(op, key, |
| reinterpret_cast<const TF_DataType*>(values.get()), |
| attr.default_value().list().type_size()); |
| } else if (type == TF_ATTR_SHAPE) { |
| int num_values = attr.default_value().list().shape_size(); |
| (*attr_list_sizes)[key] = num_values; |
| int total_dims = 0; |
| for (int i = 0; i < num_values; ++i) { |
| if (!attr.default_value().list().shape(i).unknown_rank()) { |
| total_dims += attr.default_value().list().shape(i).dim_size(); |
| } |
| } |
| // Allocate a buffer that can fit all of the dims together. |
| std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]); |
| // Copy the input dims into the buffer and set dims to point to |
| // the start of each list's dims. |
| std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]); |
| std::unique_ptr<int[]> num_dims(new int[num_values]); |
| int64_t* offset = buffer.get(); |
| for (int i = 0; i < num_values; ++i) { |
| const auto& shape = attr.default_value().list().shape(i); |
| if (shape.unknown_rank()) { |
| dims[i] = nullptr; |
| num_dims[i] = -1; |
| } else { |
| for (int j = 0; j < shape.dim_size(); j++) { |
| *offset = shape.dim(j).size(); |
| ++offset; |
| } |
| } |
| } |
| TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values, |
| status); |
| } else if (type == TF_ATTR_FUNC) { |
| int num_values = attr.default_value().list().func_size(); |
| (*attr_list_sizes)[key] = num_values; |
| std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]); |
| for (int i = 0; i < num_values; i++) { |
| funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status); |
| } |
| TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values); |
| } else { |
| TF_SetStatus(status, TF_UNIMPLEMENTED, |
| "Lists of tensors are not yet implemented for default valued " |
| "attributes for an operation."); |
| } |
| } |
| |
| bool SetOpAttrScalar( |
| TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value, |
| TF_AttrType type, |
| tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, |
| TF_Status* status) { |
| if (type == TF_ATTR_STRING) { |
| tensorflow::StringPiece value; |
| if (!ParseStringValue(key, py_value, status, &value)) return false; |
| TFE_OpSetAttrString(op, key, value.data(), value.size()); |
| } else if (type == TF_ATTR_INT) { |
| int64_t value; |
| if (!ParseInt64Value(key, py_value, status, &value)) return false; |
| TFE_OpSetAttrInt(op, key, value); |
| // attr_list_sizes is set for all int attributes (since at this point we are |
| // not aware if that attribute might be used to calculate the size of an |
| // output list or not). |
| if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value; |
| } else if (type == TF_ATTR_FLOAT) { |
| float value; |
| if (!ParseFloatValue(key, py_value, status, &value)) return false; |
| TFE_OpSetAttrFloat(op, key, value); |
| } else if (type == TF_ATTR_BOOL) { |
| unsigned char value; |
| if (!ParseBoolValue(key, py_value, status, &value)) return false; |
| TFE_OpSetAttrBool(op, key, value); |
| } else if (type == TF_ATTR_TYPE) { |
| int value; |
| if (!ParseTypeValue(key, py_value, status, &value)) return false; |
| TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value)); |
| } else if (type == TF_ATTR_SHAPE) { |
| if (py_value == Py_None) { |
| TFE_OpSetAttrShape(op, key, nullptr, -1, status); |
| } else { |
| if (!PySequence_Check(py_value)) { |
| TF_SetStatus(status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat( |
| "Expecting None or sequence value for attr", key, |
| ", got ", py_value->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| const auto num_dims = TensorShapeNumDims(py_value); |
| if (num_dims == -1) { |
| TFE_OpSetAttrShape(op, key, nullptr, -1, status); |
| return true; |
| } |
| std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]); |
| for (int i = 0; i < num_dims; ++i) { |
| tensorflow::Safe_PyObjectPtr inner_py_value( |
| PySequence_ITEM(py_value, i)); |
| if (inner_py_value.get() == Py_None) { |
| dims[i] = -1; |
| } else if (!ParseDimensionValue(key, inner_py_value.get(), status, |
| &dims[i])) { |
| return false; |
| } |
| } |
| TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status); |
| } |
| if (!status->status.ok()) return false; |
| } else if (type == TF_ATTR_FUNC) { |
| // Allow: |
| // (1) String function name, OR |
| // (2) A Python object with a .name attribute |
| // (A crude test for being a |
| // tensorflow.python.framework.function._DefinedFunction) |
| // (which is what the various "defun" or "Defun" decorators do). |
| // And in the future also allow an object that can encapsulate |
| // the function name and its attribute values. |
| tensorflow::StringPiece func_name; |
| if (!ParseStringValue(key, py_value, status, &func_name)) { |
| PyObject* name_attr = PyObject_GetAttrString(py_value, "name"); |
| if (name_attr == nullptr || |
| !ParseStringValue(key, name_attr, status, &func_name)) { |
| TF_SetStatus( |
| status, TF_INVALID_ARGUMENT, |
| tensorflow::strings::StrCat( |
| "unable to set function value attribute from a ", |
| py_value->ob_type->tp_name, |
| " object. If you think this is an error, please file an issue " |
| "at https://github.com/tensorflow/tensorflow/issues/new") |
| .c_str()); |
| return false; |
| } |
| } |
| TF_SetStatus(status, TF_OK, ""); |
| TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size()); |
| } else { |
| TF_SetStatus( |
| status, TF_UNIMPLEMENTED, |
| tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type) |
| .c_str()); |
| return false; |
| } |
| return true; |
| } |
| |
| void SetOpAttrScalarDefault( |
| TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value, |
| const char* attr_name, |
| tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, |
| TF_Status* status) { |
| SetOpAttrValueScalar(ctx, op, default_value, attr_name, status); |
| if (default_value.value_case() == tensorflow::AttrValue::kI) { |
| (*attr_list_sizes)[attr_name] = default_value.i(); |
| } |
| } |
| |
| // start_index is the index at which the Tuple/List attrs will start getting |
| // processed. |
| void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index, |
| TF_Status* out_status) { |
| if (attrs == Py_None) return; |
| Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index; |
| if ((len & 1) != 0) { |
| TF_SetStatus(out_status, TF_INVALID_ARGUMENT, |
| "Expecting attrs tuple to have even length."); |
| return; |
| } |
| // Parse attrs |
| for (Py_ssize_t i = 0; i < len; i += 2) { |
| PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i); |
| PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1); |
| #if PY_MAJOR_VERSION >= 3 |
| const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key) |
| : PyUnicode_AsUTF8(py_key); |
| #else |
| const char* key = PyBytes_AsString(py_key); |
| #endif |
| unsigned char is_list = 0; |
| const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status); |
| if (!out_status->status.ok()) return; |
| if (is_list != 0) { |
| if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status)) |
| return; |
| } else { |
| if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status)) |
| return; |
| } |
| } |
| } |
| |
| // This function will set the op attrs required. If an attr has the value of |
| // None, then it will read the AttrDef to get the default value and set that |
| // instead. Any failure in this function will simply fall back to the slow |
| // path. |
| void SetOpAttrWithDefaults( |
| TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr, |
| const char* attr_name, PyObject* attr_value, |
| tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes, |
| TF_Status* status) { |
| unsigned char is_list = 0; |
| const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status); |
| if (!status->status.ok()) return; |
| if (attr_value == Py_None) { |
| if (is_list != 0) { |
| SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes, |
| status); |
| } else { |
| SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name, |
| attr_list_sizes, status); |
| } |
| } else { |
| if (is_list != 0) { |
| SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes, |
| status); |
| } else { |
| SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes, |
| status); |
| } |
| } |
| } |
| |
| PyObject* GetPythonObjectFromInt(int num) { |
| #if PY_MAJOR_VERSION >= 3 |
| return PyLong_FromLong(num); |
| #else |
| return PyInt_FromLong(num); |
| #endif |
| } |
| |
| // Python subclass of Exception that is created on not ok Status. |
| tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED); |
| PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr; |
| |
| // Python subclass of Exception that is created to signal fallback. |
| PyObject* fallback_exception_class = nullptr; |
| |
| // Python function that returns input gradients given output gradients. |
| PyObject* gradient_function = nullptr; |
| |
| // Python function that returns output gradients given input gradients. |
| PyObject* forward_gradient_function = nullptr; |
| |
| static std::atomic<int64_t> _uid; |
| |
| } // namespace |
| |
| TF_Status* GetStatus() { |
| TF_Status* maybe_status = ReleaseThreadLocalStatus(); |
| if (maybe_status) { |
| TF_SetStatus(maybe_status, TF_OK, ""); |
| return maybe_status; |
| } else { |
| return TF_NewStatus(); |
| } |
| } |
| |
| void ReturnStatus(TF_Status* status) { |
| TF_SetStatus(status, TF_OK, ""); |
| thread_local_tf_status.reset(status); |
| } |
| |
| void TFE_Py_Execute(TFE_Context* ctx, const char* device_name, |
| const char* op_name, TFE_InputTensorHandles* inputs, |
| PyObject* attrs, TFE_OutputTensorHandles* outputs, |
| TF_Status* out_status) { |
| TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs, |
| /*cancellation_manager=*/nullptr, outputs, |
| out_status); |
| } |
| |
| void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name, |
| const char* op_name, |
| TFE_InputTensorHandles* inputs, PyObject* attrs, |
| TFE_CancellationManager* cancellation_manager, |
| TFE_OutputTensorHandles* outputs, |
| TF_Status* out_status) { |
| tensorflow::profiler::TraceMe activity( |
| "TFE_Py_ExecuteCancelable", tensorflow::profiler::TraceMeLevel::kInfo); |
| |
| TFE_Op* op = GetOp(ctx, op_name, device_name, out_status); |
| |
| auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); }); |
| if (!out_status->status.ok()) return; |
| |
| tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); |
| |
| for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) { |
| TFE_OpAddInput(op, inputs->at(i), out_status); |
| } |
| if (cancellation_manager && out_status->status.ok()) { |
| TFE_OpSetCancellationManager(op, cancellation_manager, out_status); |
| } |
| if (out_status->status.ok()) { |
| SetOpAttrs(ctx, op, attrs, 0, out_status); |
| } |
| Py_BEGIN_ALLOW_THREADS; |
| |
| int num_outputs = outputs->size(); |
| |
| if (out_status->status.ok()) { |
| TFE_Execute(op, outputs->data(), &num_outputs, out_status); |
| } |
| |
| if (out_status->status.ok()) { |
| outputs->resize(num_outputs); |
| } else { |
| TF_SetStatus(out_status, TF_GetCode(out_status), |
| tensorflow::strings::StrCat(TF_Message(out_status), |
| " [Op:", op_name, "]") |
| .c_str()); |
| } |
| |
| Py_END_ALLOW_THREADS; |
| } |
| |
| PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) { |
| tensorflow::mutex_lock l(exception_class_mutex); |
| if (exception_class != nullptr) { |
| Py_DECREF(exception_class); |
| } |
| if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { |
| exception_class = nullptr; |
| PyErr_SetString(PyExc_TypeError, |
| "TFE_Py_RegisterExceptionClass: " |
| "Registered class should be subclass of Exception."); |
| return nullptr; |
| } |
| |
| Py_INCREF(e); |
| exception_class = e; |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { |
| if (fallback_exception_class != nullptr) { |
| Py_DECREF(fallback_exception_class); |
| } |
| if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) { |
| fallback_exception_class = nullptr; |
| PyErr_SetString(PyExc_TypeError, |
| "TFE_Py_RegisterFallbackExceptionClass: " |
| "Registered class should be subclass of Exception."); |
| return nullptr; |
| } else { |
| Py_INCREF(e); |
| fallback_exception_class = e; |
| Py_RETURN_NONE; |
| } |
| } |
| |
| PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) { |
| if (gradient_function != nullptr) { |
| Py_DECREF(gradient_function); |
| } |
| if (!PyCallable_Check(e)) { |
| gradient_function = nullptr; |
| PyErr_SetString(PyExc_TypeError, |
| "TFE_Py_RegisterGradientFunction: " |
| "Registered object should be function."); |
| return nullptr; |
| } else { |
| Py_INCREF(e); |
| gradient_function = e; |
| Py_RETURN_NONE; |
| } |
| } |
| |
| PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) { |
| if (forward_gradient_function != nullptr) { |
| Py_DECREF(forward_gradient_function); |
| } |
| if (!PyCallable_Check(e)) { |
| forward_gradient_function = nullptr; |
| PyErr_SetString(PyExc_TypeError, |
| "TFE_Py_RegisterJVPFunction: " |
| "Registered object should be function."); |
| return nullptr; |
| } else { |
| Py_INCREF(e); |
| forward_gradient_function = e; |
| Py_RETURN_NONE; |
| } |
| } |
| |
| void RaiseFallbackException(const char* message) { |
| if (fallback_exception_class != nullptr) { |
| PyErr_SetString(fallback_exception_class, message); |
| return; |
| } |
| |
| PyErr_SetString( |
| PyExc_RuntimeError, |
| tensorflow::strings::StrCat( |
| "Fallback exception type not set, attempting to fallback due to ", |
| message) |
| .data()); |
| } |
| |
| // Format and return `status`' error message with the attached stack trace if |
| // available. `status` must have an error. |
| std::string FormatErrorStatusStackTrace(const tensorflow::Status& status) { |
| tensorflow::DCheckPyGilState(); |
| DCHECK(!status.ok()); |
| |
| if (status.stack_trace().empty()) return status.error_message(); |
| |
| const std::vector<tensorflow::StackFrame>& stack_trace = status.stack_trace(); |
| |
| PyObject* linecache = PyImport_ImportModule("linecache"); |
| PyObject* getline = |
| PyObject_GetAttr(linecache, PyUnicode_FromString("getline")); |
| DCHECK(getline); |
| |
| std::ostringstream result; |
| result << "Exception originated from\n\n"; |
| |
| for (const tensorflow::StackFrame& stack_frame : stack_trace) { |
| PyObject* line_str_obj = PyObject_CallFunction( |
| getline, const_cast<char*>("si"), stack_frame.file_name.c_str(), |
| stack_frame.line_number); |
| tensorflow::StringPiece line_str = TFE_GetPythonString(line_str_obj); |
| tensorflow::str_util::RemoveWhitespaceContext(&line_str); |
| result << " File \"" << stack_frame.file_name << "\", line " |
| << stack_frame.line_number << ", in " << stack_frame.function_name |
| << '\n'; |
| |
| if (!line_str.empty()) result << " " << line_str << '\n'; |
| Py_XDECREF(line_str_obj); |
| } |
| |
| Py_DecRef(getline); |
| Py_DecRef(linecache); |
| |
| result << '\n' << status.error_message(); |
| return result.str(); |
| } |
| |
| int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) { |
| if (status->status.ok()) return 0; |
| const char* msg = TF_Message(status); |
| if (exception == nullptr) { |
| tensorflow::mutex_lock l(exception_class_mutex); |
| if (exception_class != nullptr) { |
| tensorflow::Safe_PyObjectPtr val(Py_BuildValue( |
| "si", FormatErrorStatusStackTrace(status->status).c_str(), |
| TF_GetCode(status))); |
| if (PyErr_Occurred()) { |
| // NOTE: This hides the actual error (i.e. the reason `status` was not |
| // TF_OK), but there is nothing we can do at this point since we can't |
| // generate a reasonable error from the status. |
| // Consider adding a message explaining this. |
| return -1; |
| } |
| PyErr_SetObject(exception_class, val.get()); |
| return -1; |
| } else { |
| exception = PyExc_RuntimeError; |
| } |
| } |
| // May be update already set exception. |
| PyErr_SetString(exception, msg); |
| return -1; |
| } |
| |
| int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status, |
| PyObject* exception) { |
| if (status.ok()) return 0; |
| const char* msg = status.error_message().c_str(); |
| if (exception == nullptr) { |
| tensorflow::mutex_lock l(exception_class_mutex); |
| if (exception_class != nullptr) { |
| tensorflow::Safe_PyObjectPtr val(Py_BuildValue( |
| "si", FormatErrorStatusStackTrace(status).c_str(), status.code())); |
| PyErr_SetObject(exception_class, val.get()); |
| return -1; |
| } else { |
| exception = PyExc_RuntimeError; |
| } |
| } |
| // May be update already set exception. |
| PyErr_SetString(exception, msg); |
| return -1; |
| } |
| |
| const char* TFE_GetPythonString(PyObject* o) { |
| #if PY_MAJOR_VERSION >= 3 |
| if (PyBytes_Check(o)) { |
| return PyBytes_AsString(o); |
| } else { |
| return PyUnicode_AsUTF8(o); |
| } |
| #else |
| return PyBytes_AsString(o); |
| #endif |
| } |
| |
| int64_t get_uid() { return _uid++; } |
| |
| PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); } |
| |
| void TFE_DeleteContextCapsule(PyObject* context) { |
| TFE_Context* ctx = |
| reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr)); |
| auto op = ReleaseThreadLocalOp(ctx); |
| op.reset(); |
| TFE_DeleteContext(ctx); |
| } |
| |
| static tensorflow::int64 MakeInt(PyObject* integer) { |
| #if PY_MAJOR_VERSION >= 3 |
| return PyLong_AsLong(integer); |
| #else |
| return PyInt_AsLong(integer); |
| #endif |
| } |
| |
| static tensorflow::int64 FastTensorId(PyObject* tensor) { |
| if (EagerTensor_CheckExact(tensor)) { |
| return PyEagerTensor_ID(tensor); |
| } |
| PyObject* id_field = PyObject_GetAttrString(tensor, "_id"); |
| if (id_field == nullptr) { |
| return -1; |
| } |
| tensorflow::int64 id = MakeInt(id_field); |
| Py_DECREF(id_field); |
| return id; |
| } |
| |
| static tensorflow::DataType FastTensorDtype(PyObject* tensor) { |
| if (EagerTensor_CheckExact(tensor)) { |
| return PyEagerTensor_Dtype(tensor); |
| } |
| PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype"); |
| if (dtype_field == nullptr) { |
| return tensorflow::DT_INVALID; |
| } |
| PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum"); |
| Py_DECREF(dtype_field); |
| if (dtype_field == nullptr) { |
| return tensorflow::DT_INVALID; |
| } |
| tensorflow::int64 id = MakeInt(enum_field); |
| Py_DECREF(enum_field); |
| return static_cast<tensorflow::DataType>(id); |
| } |
| |
| class PyTapeTensor { |
| public: |
| PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, |
| const tensorflow::TensorShape& shape) |
| : id_(id), dtype_(dtype), shape_(shape) {} |
| PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype, |
| PyObject* shape) |
| : id_(id), dtype_(dtype), shape_(shape) { |
| Py_INCREF(absl::get<1>(shape_)); |
| } |
| PyTapeTensor(const PyTapeTensor& other) { |
| id_ = other.id_; |
| dtype_ = other.dtype_; |
| shape_ = other.shape_; |
| if (shape_.index() == 1) { |
| Py_INCREF(absl::get<1>(shape_)); |
| } |
| } |
| |
| ~PyTapeTensor() { |
| if (shape_.index() == 1) { |
| Py_DECREF(absl::get<1>(shape_)); |
| } |
| } |
| PyObject* GetShape() const; |
| PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); } |
| tensorflow::int64 GetID() const { return id_; } |
| tensorflow::DataType GetDType() const { return dtype_; } |
| |
| PyObject* OnesLike() const; |
| PyObject* ZerosLike() const; |
| |
| private: |
| tensorflow::int64 id_; |
| tensorflow::DataType dtype_; |
| |
| // Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that |
| // PyObject is the tensor itself. This is used to support tf.shape(tensor) for |
| // partially-defined shapes and tf.zeros_like(tensor) for variant-dtype |
| // tensors. |
| absl::variant<tensorflow::TensorShape, PyObject*> shape_; |
| }; |
| |
| static PyTapeTensor TapeTensorFromTensor(PyObject* tensor); |
| |
| class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction, |
| PyTapeTensor> { |
| public: |
| explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) { |
| Py_INCREF(py_vspace_); |
| } |
| |
| tensorflow::Status Initialize() { |
| num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn"); |
| if (num_elements_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn"); |
| if (aggregate_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn"); |
| if (zeros_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn"); |
| if (zeros_like_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn"); |
| if (ones_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn"); |
| if (ones_like_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn"); |
| if (graph_shape_fn_ == nullptr) { |
| return tensorflow::errors::InvalidArgument("invalid vspace"); |
| } |
| return tensorflow::Status::OK(); |
| } |
| |
| ~PyVSpace() override { |
| Py_XDECREF(num_elements_); |
| Py_XDECREF(aggregate_fn_); |
| Py_XDECREF(zeros_fn_); |
| Py_XDECREF(zeros_like_fn_); |
| Py_XDECREF(ones_fn_); |
| Py_XDECREF(ones_like_fn_); |
| Py_XDECREF(graph_shape_fn_); |
| |
| Py_DECREF(py_vspace_); |
| } |
| |
| tensorflow::int64 NumElements(PyObject* tensor) const final { |
| if (EagerTensor_CheckExact(tensor)) { |
| return PyEagerTensor_NumElements(tensor); |
| } |
| PyObject* arglist = |
| Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor)); |
| PyObject* result = PyEval_CallObject(num_elements_, arglist); |
| Py_DECREF(arglist); |
| if (result == nullptr) { |
| // The caller detects whether a python exception has been raised. |
| return -1; |
| } |
| tensorflow::int64 r = MakeInt(result); |
| Py_DECREF(result); |
| return r; |
| } |
| |
| PyObject* AggregateGradients( |
| tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final { |
| PyObject* list = PyList_New(gradient_tensors.size()); |
| for (int i = 0; i < gradient_tensors.size(); ++i) { |
| // Note: stealing a reference to the gradient tensors. |
| CHECK(gradient_tensors[i] != nullptr); |
| CHECK(gradient_tensors[i] != Py_None); |
| PyList_SET_ITEM(list, i, |
| reinterpret_cast<PyObject*>(gradient_tensors[i])); |
| } |
| PyObject* arglist = Py_BuildValue("(O)", list); |
| CHECK(arglist != nullptr); |
| PyObject* result = PyEval_CallObject(aggregate_fn_, arglist); |
| Py_DECREF(arglist); |
| Py_DECREF(list); |
| return result; |
| } |
| |
| tensorflow::int64 TensorId(PyObject* tensor) const final { |
| return FastTensorId(tensor); |
| } |
| |
| void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); } |
| |
| PyObject* Ones(PyObject* shape, PyObject* dtype) const { |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| PyObject* arg_list = Py_BuildValue("OO", shape, dtype); |
| PyObject* result = PyEval_CallObject(ones_fn_, arg_list); |
| Py_DECREF(arg_list); |
| return result; |
| } |
| |
| PyObject* OnesLike(PyObject* tensor) const { |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL); |
| } |
| |
| PyObject* Zeros(PyObject* shape, PyObject* dtype) const { |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| PyObject* arg_list = Py_BuildValue("OO", shape, dtype); |
| PyObject* result = PyEval_CallObject(zeros_fn_, arg_list); |
| Py_DECREF(arg_list); |
| return result; |
| } |
| |
| PyObject* ZerosLike(PyObject* tensor) const { |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL); |
| } |
| |
| PyObject* GraphShape(PyObject* tensor) const { |
| PyObject* arg_list = Py_BuildValue("(O)", tensor); |
| PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list); |
| Py_DECREF(arg_list); |
| return result; |
| } |
| |
| tensorflow::Status CallBackwardFunction( |
| PyBackwardFunction* backward_function, |
| const std::vector<tensorflow::int64>& unneeded_gradients, |
| tensorflow::gtl::ArraySlice<PyObject*> output_gradients, |
| std::vector<PyObject*>* result) const final { |
| PyObject* grads = PyTuple_New(output_gradients.size()); |
| for (int i = 0; i < output_gradients.size(); ++i) { |
| if (output_gradients[i] == nullptr) { |
| Py_INCREF(Py_None); |
| PyTuple_SET_ITEM(grads, i, Py_None); |
| } else { |
| PyTuple_SET_ITEM(grads, i, |
| reinterpret_cast<PyObject*>(output_gradients[i])); |
| } |
| } |
| PyObject* py_result = (*backward_function)(grads, unneeded_gradients); |
| Py_DECREF(grads); |
| if (py_result == nullptr) { |
| return tensorflow::errors::Internal("gradient function threw exceptions"); |
| } |
| result->clear(); |
| PyObject* seq = |
| PySequence_Fast(py_result, "expected a sequence of gradients"); |
| if (seq == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "gradient function did not return a list"); |
| } |
| int len = PySequence_Fast_GET_SIZE(seq); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
| VLOG(1) << "Gradient length is " << len; |
| result->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = seq_array[i]; |
| if (item == Py_None) { |
| result->push_back(nullptr); |
| } else { |
| Py_INCREF(item); |
| result->push_back(item); |
| } |
| } |
| Py_DECREF(seq); |
| Py_DECREF(py_result); |
| return tensorflow::Status::OK(); |
| } |
| |
| void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); } |
| |
| PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final { |
| return TapeTensorFromTensor(tensor); |
| } |
| |
| private: |
| PyObject* py_vspace_; |
| |
| PyObject* num_elements_; |
| PyObject* aggregate_fn_; |
| PyObject* zeros_fn_; |
| PyObject* zeros_like_fn_; |
| PyObject* ones_fn_; |
| PyObject* ones_like_fn_; |
| PyObject* graph_shape_fn_; |
| }; |
| PyVSpace* py_vspace = nullptr; |
| |
| bool HasAccumulator(); |
| |
| PyObject* TFE_Py_RegisterVSpace(PyObject* e) { |
| if (py_vspace != nullptr) { |
| if (HasAccumulator()) { |
| // Accumulators reference py_vspace, so we can't swap it out while one is |
| // active. This is unlikely to ever happen. |
| MaybeRaiseExceptionFromStatus( |
| tensorflow::errors::Internal( |
| "Can't change the vspace implementation while a " |
| "forward accumulator is active."), |
| nullptr); |
| } |
| delete py_vspace; |
| } |
| |
| py_vspace = new PyVSpace(e); |
| auto status = py_vspace->Initialize(); |
| if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
| delete py_vspace; |
| return nullptr; |
| } |
| |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* PyTapeTensor::GetShape() const { |
| if (shape_.index() == 0) { |
| auto& shape = absl::get<0>(shape_); |
| PyObject* py_shape = PyTuple_New(shape.dims()); |
| for (int i = 0; i < shape.dims(); ++i) { |
| PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i))); |
| } |
| |
| return py_shape; |
| } |
| |
| return py_vspace->GraphShape(absl::get<1>(shape_)); |
| } |
| |
| PyObject* PyTapeTensor::OnesLike() const { |
| if (shape_.index() == 1) { |
| PyObject* tensor = absl::get<1>(shape_); |
| return py_vspace->OnesLike(tensor); |
| } |
| PyObject* py_shape = GetShape(); |
| PyObject* py_dtype = GetPyDType(); |
| PyObject* result = py_vspace->Ones(py_shape, py_dtype); |
| Py_DECREF(py_dtype); |
| Py_DECREF(py_shape); |
| return result; |
| } |
| |
| PyObject* PyTapeTensor::ZerosLike() const { |
| if (shape_.index() == 1) { |
| PyObject* tensor = absl::get<1>(shape_); |
| return py_vspace->ZerosLike(tensor); |
| } |
| PyObject* py_shape = GetShape(); |
| PyObject* py_dtype = GetPyDType(); |
| PyObject* result = py_vspace->Zeros(py_shape, py_dtype); |
| Py_DECREF(py_dtype); |
| Py_DECREF(py_shape); |
| return result; |
| } |
| |
| // Keeps track of all variables that have been accessed during execution. |
| class VariableWatcher { |
| public: |
| VariableWatcher() {} |
| |
| ~VariableWatcher() { |
| for (const IdAndVariable& v : watched_variables_) { |
| Py_DECREF(v.variable); |
| } |
| } |
| |
| tensorflow::int64 WatchVariable(PyObject* v) { |
| tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle")); |
| if (handle == nullptr) { |
| return -1; |
| } |
| tensorflow::int64 id = FastTensorId(handle.get()); |
| |
| tensorflow::mutex_lock l(watched_variables_mu_); |
| auto insert_result = watched_variables_.emplace(id, v); |
| |
| if (insert_result.second) { |
| // Only increment the reference count if we aren't already watching this |
| // variable. |
| Py_INCREF(v); |
| } |
| |
| return id; |
| } |
| |
| PyObject* GetVariablesAsPyTuple() { |
| tensorflow::mutex_lock l(watched_variables_mu_); |
| PyObject* result = PyTuple_New(watched_variables_.size()); |
| Py_ssize_t pos = 0; |
| for (const IdAndVariable& id_and_variable : watched_variables_) { |
| PyTuple_SET_ITEM(result, pos++, id_and_variable.variable); |
| Py_INCREF(id_and_variable.variable); |
| } |
| return result; |
| } |
| |
| private: |
| // We store an IdAndVariable in the map since the map needs to be locked |
| // during insert, but should not call back into python during insert to avoid |
| // deadlocking with the GIL. |
| struct IdAndVariable { |
| tensorflow::int64 id; |
| PyObject* variable; |
| |
| IdAndVariable(tensorflow::int64 id, PyObject* variable) |
| : id(id), variable(variable) {} |
| }; |
| struct CompareById { |
| bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const { |
| return lhs.id < rhs.id; |
| } |
| }; |
| |
| tensorflow::mutex watched_variables_mu_; |
| std::set<IdAndVariable, CompareById> watched_variables_ |
| TF_GUARDED_BY(watched_variables_mu_); |
| }; |
| |
| class GradientTape |
| : public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, |
| PyTapeTensor> { |
| public: |
| explicit GradientTape(bool persistent, bool watch_accessed_variables) |
| : tensorflow::eager::GradientTape<PyObject, PyBackwardFunction, |
| PyTapeTensor>(persistent), |
| watch_accessed_variables_(watch_accessed_variables) {} |
| |
| virtual ~GradientTape() {} |
| |
| void VariableAccessed(PyObject* v) { |
| if (watch_accessed_variables_) { |
| WatchVariable(v); |
| } |
| } |
| |
| void WatchVariable(PyObject* v) { |
| tensorflow::int64 id = variable_watcher_.WatchVariable(v); |
| |
| if (!PyErr_Occurred()) { |
| this->Watch(id); |
| } |
| } |
| |
| PyObject* GetVariablesAsPyTuple() { |
| return variable_watcher_.GetVariablesAsPyTuple(); |
| } |
| |
| private: |
| bool watch_accessed_variables_; |
| VariableWatcher variable_watcher_; |
| }; |
| |
| typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction, |
| PyTapeTensor> |
| ForwardAccumulator; |
| |
| // Incremented when a GradientTape or accumulator is newly added to a set, and |
| // used to enforce an ordering between them. |
| std::atomic_uint_fast64_t tape_nesting_id_counter(0); |
| |
| typedef struct { |
| PyObject_HEAD |
| /* Type-specific fields go here. */ |
| GradientTape* tape; |
| // A nesting order between GradientTapes and ForwardAccumulators, used to |
| // ensure that GradientTapes do not watch the products of outer |
| // ForwardAccumulators. |
| tensorflow::int64 nesting_id; |
| } TFE_Py_Tape; |
| |
| static void TFE_Py_Tape_Delete(PyObject* tape) { |
| delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape; |
| Py_TYPE(tape)->tp_free(tape); |
| } |
| |
| static PyTypeObject TFE_Py_Tape_Type = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */ |
| sizeof(TFE_Py_Tape), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| &TFE_Py_Tape_Delete, /* tp_dealloc */ |
| #if PY_VERSION_HEX < 0x03080000 |
| nullptr, /* tp_print */ |
| #else |
| 0, /* tp_vectorcall_offset */ |
| #endif |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| nullptr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| nullptr, /* tp_hash */ |
| nullptr, /* tp_call */ |
| nullptr, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| Py_TPFLAGS_DEFAULT, /* tp_flags */ |
| "TFE_Py_Tape objects", /* tp_doc */ |
| }; |
| |
| typedef struct { |
| PyObject_HEAD |
| /* Type-specific fields go here. */ |
| ForwardAccumulator* accumulator; |
| // A nesting order between GradientTapes and ForwardAccumulators, used to |
| // ensure that GradientTapes do not watch the products of outer |
| // ForwardAccumulators. |
| tensorflow::int64 nesting_id; |
| } TFE_Py_ForwardAccumulator; |
| |
| static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) { |
| delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator; |
| Py_TYPE(accumulator)->tp_free(accumulator); |
| } |
| |
| static PyTypeObject TFE_Py_ForwardAccumulator_Type = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */ |
| sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| &TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */ |
| #if PY_VERSION_HEX < 0x03080000 |
| nullptr, /* tp_print */ |
| #else |
| 0, /* tp_vectorcall_offset */ |
| #endif |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| nullptr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| nullptr, /* tp_hash */ |
| nullptr, /* tp_call */ |
| nullptr, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| Py_TPFLAGS_DEFAULT, /* tp_flags */ |
| "TFE_Py_ForwardAccumulator objects", /* tp_doc */ |
| }; |
| |
| typedef struct { |
| PyObject_HEAD |
| /* Type-specific fields go here. */ |
| VariableWatcher* variable_watcher; |
| } TFE_Py_VariableWatcher; |
| |
| static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) { |
| delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) |
| ->variable_watcher; |
| Py_TYPE(variable_watcher)->tp_free(variable_watcher); |
| } |
| |
| static PyTypeObject TFE_Py_VariableWatcher_Type = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */ |
| sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| &TFE_Py_VariableWatcher_Delete, /* tp_dealloc */ |
| #if PY_VERSION_HEX < 0x03080000 |
| nullptr, /* tp_print */ |
| #else |
| 0, /* tp_vectorcall_offset */ |
| #endif |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| nullptr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| nullptr, /* tp_hash */ |
| nullptr, /* tp_call */ |
| nullptr, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| Py_TPFLAGS_DEFAULT, /* tp_flags */ |
| "TFE_Py_VariableWatcher objects", /* tp_doc */ |
| }; |
| |
| // Note: in the current design no mutex is needed here because of the python |
| // GIL, which is always held when any TFE_Py_* methods are called. We should |
| // revisit this if/when decide to not hold the GIL while manipulating the tape |
| // stack. |
| tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() { |
| thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> |
| tape_set = nullptr; |
| if (tape_set == nullptr) { |
| tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>); |
| } |
| return tape_set.get(); |
| } |
| |
| tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>* |
| GetVariableWatcherSet() { |
| thread_local std::unique_ptr< |
| tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> |
| variable_watcher_set = nullptr; |
| if (variable_watcher_set == nullptr) { |
| variable_watcher_set.reset( |
| new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>); |
| } |
| return variable_watcher_set.get(); |
| } |
| |
| // A linked hash set, where iteration is in insertion order. |
| // |
| // Nested accumulators rely on op recording happening in insertion order, so an |
| // unordered data structure like CompactPointerSet is not suitable. Outer |
| // accumulators need to observe operations first so they know to watch the inner |
| // accumulator's jvp computation. |
| // |
| // Not thread safe. |
| class AccumulatorSet { |
| public: |
| // Returns true if `element` was newly inserted, false if it already exists. |
| bool insert(TFE_Py_ForwardAccumulator* element) { |
| if (map_.find(element) != map_.end()) { |
| return false; |
| } |
| ListType::iterator it = ordered_.insert(ordered_.end(), element); |
| map_.insert(std::make_pair(element, it)); |
| return true; |
| } |
| |
| void erase(TFE_Py_ForwardAccumulator* element) { |
| MapType::iterator existing = map_.find(element); |
| if (existing == map_.end()) { |
| return; |
| } |
| ListType::iterator list_position = existing->second; |
| map_.erase(existing); |
| ordered_.erase(list_position); |
| } |
| |
| bool empty() const { return ordered_.empty(); } |
| |
| size_t size() const { return ordered_.size(); } |
| |
| private: |
| typedef std::list<TFE_Py_ForwardAccumulator*> ListType; |
| typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*, |
| ListType::iterator> |
| MapType; |
| |
| public: |
| typedef ListType::const_iterator const_iterator; |
| typedef ListType::const_reverse_iterator const_reverse_iterator; |
| |
| const_iterator begin() const { return ordered_.begin(); } |
| const_iterator end() const { return ordered_.end(); } |
| |
| const_reverse_iterator rbegin() const { return ordered_.rbegin(); } |
| const_reverse_iterator rend() const { return ordered_.rend(); } |
| |
| private: |
| MapType map_; |
| ListType ordered_; |
| }; |
| |
| AccumulatorSet* GetAccumulatorSet() { |
| thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr}; |
| if (accumulator_set == nullptr) { |
| accumulator_set.reset(new AccumulatorSet); |
| } |
| return accumulator_set.get(); |
| } |
| |
| inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); } |
| |
| inline bool HasGradientTape() { return !GetTapeSet()->empty(); } |
| |
| inline bool HasAccumulatorOrTape() { |
| return HasGradientTape() || HasAccumulator(); |
| } |
| |
| // A safe copy of a set, used for tapes and accumulators. The copy is not |
| // affected by other python threads changing the set of active tapes. |
| template <typename ContainerType> |
| class SafeSetCopy { |
| public: |
| explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) { |
| for (auto* member : set_copy_) { |
| Py_INCREF(member); |
| } |
| } |
| |
| ~SafeSetCopy() { |
| for (auto* member : set_copy_) { |
| Py_DECREF(member); |
| } |
| } |
| |
| typename ContainerType::const_iterator begin() const { |
| return set_copy_.begin(); |
| } |
| |
| typename ContainerType::const_iterator end() const { return set_copy_.end(); } |
| |
| bool empty() const { return set_copy_.empty(); } |
| size_t size() const { return set_copy_.size(); } |
| |
| protected: |
| ContainerType set_copy_; |
| }; |
| |
| class SafeTapeSet |
| : public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> { |
| public: |
| SafeTapeSet() |
| : SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>( |
| *GetTapeSet()) {} |
| }; |
| |
| class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> { |
| public: |
| SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {} |
| |
| typename AccumulatorSet::const_reverse_iterator rbegin() const { |
| return set_copy_.rbegin(); |
| } |
| |
| typename AccumulatorSet::const_reverse_iterator rend() const { |
| return set_copy_.rend(); |
| } |
| }; |
| |
| class SafeVariableWatcherSet |
| : public SafeSetCopy< |
| tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> { |
| public: |
| SafeVariableWatcherSet() |
| : SafeSetCopy< |
| tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>( |
| *GetVariableWatcherSet()) {} |
| }; |
| |
| bool* ThreadTapeIsStopped() { |
| thread_local bool thread_tape_is_stopped{false}; |
| return &thread_tape_is_stopped; |
| } |
| |
| void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; } |
| |
| void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; } |
| |
| PyObject* TFE_Py_TapeSetIsStopped() { |
| if (*ThreadTapeIsStopped()) { |
| Py_RETURN_TRUE; |
| } |
| Py_RETURN_FALSE; |
| } |
| |
| PyObject* TFE_Py_TapeSetNew(PyObject* persistent, |
| PyObject* watch_accessed_variables) { |
| TFE_Py_Tape_Type.tp_new = PyType_GenericNew; |
| if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr; |
| TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type); |
| tape->tape = new GradientTape(persistent == Py_True, |
| watch_accessed_variables == Py_True); |
| Py_INCREF(tape); |
| tape->nesting_id = tape_nesting_id_counter.fetch_add(1); |
| GetTapeSet()->insert(tape); |
| return reinterpret_cast<PyObject*>(tape); |
| } |
| |
| void TFE_Py_TapeSetAdd(PyObject* tape) { |
| Py_INCREF(tape); |
| TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape); |
| if (!GetTapeSet()->insert(tfe_tape).second) { |
| // Already exists in the tape set. |
| Py_DECREF(tape); |
| } else { |
| tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1); |
| } |
| } |
| |
| PyObject* TFE_Py_TapeSetIsEmpty() { |
| if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { |
| Py_RETURN_TRUE; |
| } |
| Py_RETURN_FALSE; |
| } |
| |
| void TFE_Py_TapeSetRemove(PyObject* tape) { |
| auto* stack = GetTapeSet(); |
| stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape)); |
| // We kept a reference to the tape in the set to ensure it wouldn't get |
| // deleted under us; cleaning it up here. |
| Py_DECREF(tape); |
| } |
| |
| static std::vector<tensorflow::int64> MakeIntList(PyObject* list) { |
| if (list == Py_None) { |
| return {}; |
| } |
| PyObject* seq = PySequence_Fast(list, "expected a sequence"); |
| if (seq == nullptr) { |
| return {}; |
| } |
| int len = PySequence_Size(list); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
| std::vector<tensorflow::int64> tensor_ids; |
| tensor_ids.reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = seq_array[i]; |
| #if PY_MAJOR_VERSION >= 3 |
| if (PyLong_Check(item)) { |
| #else |
| if (PyLong_Check(item) || PyInt_Check(item)) { |
| #endif |
| tensorflow::int64 id = MakeInt(item); |
| tensor_ids.push_back(id); |
| } else { |
| tensor_ids.push_back(-1); |
| } |
| } |
| Py_DECREF(seq); |
| return tensor_ids; |
| } |
| |
| // Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be |
| // null. Returns true on success and false on a Python exception. |
| bool TensorShapesAndDtypes(PyObject* tensors, |
| std::vector<tensorflow::int64>* tensor_ids, |
| std::vector<tensorflow::DataType>* dtypes) { |
| tensorflow::Safe_PyObjectPtr seq( |
| PySequence_Fast(tensors, "expected a sequence")); |
| if (seq == nullptr) { |
| return false; |
| } |
| int len = PySequence_Fast_GET_SIZE(seq.get()); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
| tensor_ids->reserve(len); |
| dtypes->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = seq_array[i]; |
| tensor_ids->push_back(FastTensorId(item)); |
| dtypes->push_back(FastTensorDtype(item)); |
| } |
| return true; |
| } |
| |
| bool TapeCouldPossiblyRecord(PyObject* tensors) { |
| if (tensors == Py_None) { |
| return false; |
| } |
| if (*ThreadTapeIsStopped()) { |
| return false; |
| } |
| if (!HasAccumulatorOrTape()) { |
| return false; |
| } |
| return true; |
| } |
| |
| bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); } |
| |
| bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); } |
| |
| PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) { |
| if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) { |
| Py_RETURN_FALSE; |
| } |
| // TODO(apassos) consider not building a list and changing the API to check |
| // each tensor individually. |
| std::vector<tensorflow::int64> tensor_ids; |
| std::vector<tensorflow::DataType> dtypes; |
| if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { |
| return nullptr; |
| } |
| auto tape_set = *GetTapeSet(); |
| for (TFE_Py_Tape* tape : tape_set) { |
| if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { |
| Py_RETURN_TRUE; |
| } |
| } |
| |
| Py_RETURN_FALSE; |
| } |
| |
| PyObject* TFE_Py_ForwardAccumulatorPushState() { |
| auto forward_accumulators = *GetAccumulatorSet(); |
| for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
| accumulator->accumulator->PushState(); |
| } |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* TFE_Py_ForwardAccumulatorPopState() { |
| auto forward_accumulators = *GetAccumulatorSet(); |
| for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
| accumulator->accumulator->PopState(); |
| } |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) { |
| if (!TapeCouldPossiblyRecord(tensors)) { |
| return GetPythonObjectFromInt(0); |
| } |
| std::vector<tensorflow::int64> tensor_ids; |
| std::vector<tensorflow::DataType> dtypes; |
| if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) { |
| return nullptr; |
| } |
| |
| // If there is a persistent tape watching, or if there are multiple tapes |
| // watching, we'll return immediately indicating that higher-order tape |
| // gradients are possible. |
| bool some_tape_watching = false; |
| if (CouldBackprop()) { |
| auto tape_set = *GetTapeSet(); |
| for (TFE_Py_Tape* tape : tape_set) { |
| if (tape->tape->ShouldRecord(tensor_ids, dtypes)) { |
| if (tape->tape->IsPersistent() || some_tape_watching) { |
| // Either this is the second tape watching, or this tape is |
| // persistent: higher-order gradients are possible. |
| return GetPythonObjectFromInt(2); |
| } |
| some_tape_watching = true; |
| } |
| } |
| } |
| if (CouldForwardprop()) { |
| auto forward_accumulators = *GetAccumulatorSet(); |
| for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) { |
| if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) { |
| if (some_tape_watching) { |
| // This is the second tape watching: higher-order gradients are |
| // possible. Note that there's no equivalent of persistence for |
| // forward-mode. |
| return GetPythonObjectFromInt(2); |
| } |
| some_tape_watching = true; |
| } |
| } |
| } |
| if (some_tape_watching) { |
| // There's exactly one non-persistent tape. The user can request first-order |
| // gradients but won't be able to get higher-order tape gradients. |
| return GetPythonObjectFromInt(1); |
| } else { |
| // There are no tapes. The user can't request tape gradients. |
| return GetPythonObjectFromInt(0); |
| } |
| } |
| |
| void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) { |
| if (!CouldBackprop()) { |
| return; |
| } |
| tensorflow::int64 tensor_id = FastTensorId(tensor); |
| if (PyErr_Occurred()) { |
| return; |
| } |
| reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id); |
| } |
| |
| bool ListContainsNone(PyObject* list) { |
| if (list == Py_None) return true; |
| tensorflow::Safe_PyObjectPtr seq( |
| PySequence_Fast(list, "expected a sequence")); |
| if (seq == nullptr) { |
| return false; |
| } |
| |
| int len = PySequence_Size(list); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = seq_array[i]; |
| if (item == Py_None) return true; |
| } |
| |
| return false; |
| } |
| |
| static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) { |
| if (EagerTensor_CheckExact(tensor)) { |
| tensorflow::ImmediateExecutionTensorHandle* handle = |
| tensorflow::unwrap(EagerTensor_Handle(tensor)); |
| tensorflow::int64 id = PyEagerTensor_ID(tensor); |
| tensorflow::DataType dtype = |
| static_cast<tensorflow::DataType>(handle->DataType()); |
| if (dtype == tensorflow::DT_VARIANT) { |
| return PyTapeTensor(id, dtype, tensor); |
| } |
| |
| tensorflow::TensorShape tensor_shape; |
| int num_dims; |
| tensorflow::Status status = handle->NumDims(&num_dims); |
| if (status.ok()) { |
| for (int i = 0; i < num_dims; ++i) { |
| tensorflow::int64 dim_size; |
| status = handle->Dim(i, &dim_size); |
| if (!status.ok()) break; |
| tensor_shape.AddDim(dim_size); |
| } |
| } |
| |
| if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
| return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
| tensorflow::TensorShape({})); |
| } else { |
| return PyTapeTensor(id, dtype, tensor_shape); |
| } |
| } |
| tensorflow::int64 id = FastTensorId(tensor); |
| if (PyErr_Occurred()) { |
| return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
| tensorflow::TensorShape({})); |
| } |
| PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype"); |
| PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum"); |
| Py_DECREF(dtype_object); |
| tensorflow::DataType dtype = |
| static_cast<tensorflow::DataType>(MakeInt(dtype_enum)); |
| Py_DECREF(dtype_enum); |
| if (PyErr_Occurred()) { |
| return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
| tensorflow::TensorShape({})); |
| } |
| static char _shape_tuple[] = "_shape_tuple"; |
| tensorflow::Safe_PyObjectPtr shape_tuple( |
| PyObject_CallMethod(tensor, _shape_tuple, nullptr)); |
| if (PyErr_Occurred()) { |
| return PyTapeTensor(id, static_cast<tensorflow::DataType>(0), |
| tensorflow::TensorShape({})); |
| } |
| |
| if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) { |
| return PyTapeTensor(id, dtype, tensor); |
| } |
| |
| auto l = MakeIntList(shape_tuple.get()); |
| // Replace -1, which represents accidental Nones which can occur in graph mode |
| // and can cause errors in shape construction with 0s. |
| for (auto& c : l) { |
| if (c < 0) { |
| c = 0; |
| } |
| } |
| tensorflow::TensorShape shape(l); |
| return PyTapeTensor(id, dtype, shape); |
| } |
| |
| // Populates output_info from output_seq, which must come from PySequence_Fast. |
| // |
| // Does not take ownership of output_seq. Returns true on success and false if a |
| // Python exception has been set. |
| bool TapeTensorsFromTensorSequence(PyObject* output_seq, |
| std::vector<PyTapeTensor>* output_info) { |
| Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq); |
| PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq); |
| output_info->reserve(output_len); |
| for (Py_ssize_t i = 0; i < output_len; ++i) { |
| output_info->push_back(TapeTensorFromTensor(output_seq_array[i])); |
| if (PyErr_Occurred() != nullptr) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) { |
| PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); |
| if (seq == nullptr) { |
| return {}; |
| } |
| int len = PySequence_Fast_GET_SIZE(seq); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
| std::vector<tensorflow::int64> list; |
| list.reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* tensor = seq_array[i]; |
| list.push_back(FastTensorId(tensor)); |
| if (PyErr_Occurred()) { |
| Py_DECREF(seq); |
| return list; |
| } |
| } |
| Py_DECREF(seq); |
| return list; |
| } |
| |
| void TFE_Py_TapeVariableAccessed(PyObject* variable) { |
| if (!CouldBackprop()) { |
| return; |
| } |
| for (TFE_Py_Tape* tape : SafeTapeSet()) { |
| tape->tape->VariableAccessed(variable); |
| } |
| } |
| |
| void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) { |
| if (!CouldBackprop()) { |
| return; |
| } |
| reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable); |
| } |
| |
| PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { |
| return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple(); |
| } |
| |
| PyObject* TFE_Py_VariableWatcherNew() { |
| TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew; |
| if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr; |
| TFE_Py_VariableWatcher* variable_watcher = |
| PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type); |
| variable_watcher->variable_watcher = new VariableWatcher(); |
| Py_INCREF(variable_watcher); |
| GetVariableWatcherSet()->insert(variable_watcher); |
| return reinterpret_cast<PyObject*>(variable_watcher); |
| } |
| |
| void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) { |
| auto* stack = GetVariableWatcherSet(); |
| stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)); |
| // We kept a reference to the variable watcher in the set to ensure it |
| // wouldn't get deleted under us; cleaning it up here. |
| Py_DECREF(variable_watcher); |
| } |
| |
| void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) { |
| for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) { |
| variable_watcher->variable_watcher->WatchVariable(variable); |
| } |
| } |
| |
| PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) { |
| return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher) |
| ->variable_watcher->GetVariablesAsPyTuple(); |
| } |
| |
| namespace { |
| std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) { |
| PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); |
| if (seq == nullptr) { |
| return {}; |
| } |
| int len = PySequence_Fast_GET_SIZE(seq); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
| std::vector<tensorflow::DataType> list; |
| list.reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* tensor = seq_array[i]; |
| list.push_back(FastTensorDtype(tensor)); |
| } |
| Py_DECREF(seq); |
| return list; |
| } |
| |
| PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id, |
| PyObject* weak_tensor_ref) { |
| tensorflow::int64 parsed_tensor_id = MakeInt(tensor_id); |
| for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) { |
| accumulator->accumulator->DeleteGradient(parsed_tensor_id); |
| } |
| Py_DECREF(weak_tensor_ref); |
| Py_DECREF(tensor_id); |
| Py_INCREF(Py_None); |
| return Py_None; |
| } |
| |
| static PyMethodDef forward_accumulator_delete_gradient_method_def = { |
| "ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient, |
| METH_O, "ForwardAccumulatorDeleteGradient"}; |
| |
| void RegisterForwardAccumulatorCleanup(PyObject* tensor, |
| tensorflow::int64 tensor_id) { |
| tensorflow::Safe_PyObjectPtr callback( |
| PyCFunction_New(&forward_accumulator_delete_gradient_method_def, |
| PyLong_FromLong(tensor_id))); |
| // We need to keep a reference to the weakref active if we want our callback |
| // called. The callback itself now owns the weakref object and the tensor ID |
| // object. |
| PyWeakref_NewRef(tensor, callback.get()); |
| } |
| |
| void TapeSetRecordBackprop( |
| const string& op_type, const std::vector<PyTapeTensor>& output_info, |
| const std::vector<tensorflow::int64>& input_ids, |
| const std::vector<tensorflow::DataType>& input_dtypes, |
| const std::function<PyBackwardFunction*()>& backward_function_getter, |
| const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
| tensorflow::uint64 max_gradient_tape_id) { |
| if (!CouldBackprop()) { |
| return; |
| } |
| for (TFE_Py_Tape* tape : SafeTapeSet()) { |
| if (tape->nesting_id < max_gradient_tape_id) { |
| tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes, |
| backward_function_getter, |
| backward_function_killer); |
| } |
| } |
| } |
| |
| bool TapeSetRecordForwardprop( |
| const string& op_type, PyObject* output_seq, |
| const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors, |
| const std::vector<tensorflow::int64>& input_ids, |
| const std::vector<tensorflow::DataType>& input_dtypes, |
| const std::function<PyBackwardFunction*()>& backward_function_getter, |
| const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
| const tensorflow::eager::ForwardFunction<PyObject>* forward_function, |
| PyObject* forwardprop_output_indices, |
| tensorflow::uint64* max_gradient_tape_id) { |
| *max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max(); |
| if (!CouldForwardprop()) { |
| return true; |
| } |
| auto accumulator_set = SafeAccumulatorSet(); |
| tensorflow::Safe_PyObjectPtr input_seq( |
| PySequence_Fast(input_tensors, "expected a sequence of tensors")); |
| if (input_seq == nullptr || PyErr_Occurred()) return false; |
| Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get()); |
| PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq); |
| for (int i = 0; i < output_info.size(); ++i) { |
| RegisterForwardAccumulatorCleanup(output_seq_array[i], |
| output_info[i].GetID()); |
| } |
| if (forwardprop_output_indices != nullptr && |
| forwardprop_output_indices != Py_None) { |
| tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast( |
| forwardprop_output_indices, "Expected a sequence of indices")); |
| if (indices_fast == nullptr || PyErr_Occurred()) { |
| return false; |
| } |
| if (PySequence_Fast_GET_SIZE(indices_fast.get()) != |
| accumulator_set.size()) { |
| MaybeRaiseExceptionFromStatus( |
| tensorflow::errors::Internal( |
| "Accumulators were added or removed from the active set " |
| "between packing and unpacking."), |
| nullptr); |
| } |
| PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get()); |
| Py_ssize_t accumulator_index = 0; |
| for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin(); |
| it != accumulator_set.rend(); ++it, ++accumulator_index) { |
| tensorflow::Safe_PyObjectPtr jvp_index_seq( |
| PySequence_Fast(indices_fast_array[accumulator_index], |
| "Expected a sequence of jvp indices.")); |
| if (jvp_index_seq == nullptr || PyErr_Occurred()) { |
| return false; |
| } |
| Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get()); |
| PyObject** jvp_index_seq_array = |
| PySequence_Fast_ITEMS(jvp_index_seq.get()); |
| for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) { |
| PyObject* tuple = jvp_index_seq_array[jvp_index]; |
| tensorflow::int64 primal_tensor_id = |
| output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID(); |
| (*it)->accumulator->Watch( |
| primal_tensor_id, |
| output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]); |
| } |
| } |
| } else { |
| std::vector<PyTapeTensor> input_info; |
| input_info.reserve(input_len); |
| PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get()); |
| for (Py_ssize_t i = 0; i < input_len; ++i) { |
| input_info.push_back(TapeTensorFromTensor(input_seq_array[i])); |
| } |
| for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) { |
| tensorflow::Status status = accumulator->accumulator->Accumulate( |
| op_type, input_info, output_info, input_ids, input_dtypes, |
| forward_function, backward_function_getter, backward_function_killer); |
| if (PyErr_Occurred()) return false; // Don't swallow Python exceptions. |
| if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
| return false; |
| } |
| if (accumulator->accumulator->BusyAccumulating()) { |
| // Ensure inner accumulators don't see outer accumulators' jvps. This |
| // mostly happens on its own, with some potentially surprising |
| // exceptions, so the blanket policy is for consistency. |
| *max_gradient_tape_id = accumulator->nesting_id; |
| break; |
| } |
| } |
| } |
| return true; |
| } |
| |
| PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) { |
| PyObject* py_input_tangents = PyTuple_New(input_tangents.size()); |
| for (int i = 0; i < input_tangents.size(); ++i) { |
| PyObject* element; |
| if (input_tangents[i] == nullptr) { |
| element = Py_None; |
| } else { |
| element = input_tangents[i]; |
| } |
| Py_INCREF(element); |
| PyTuple_SET_ITEM(py_input_tangents, i, element); |
| } |
| return py_input_tangents; |
| } |
| |
| tensorflow::Status ParseTangentOutputs( |
| PyObject* user_output, std::vector<PyObject*>* output_tangents) { |
| if (user_output == Py_None) { |
| // No connected gradients. |
| return tensorflow::Status::OK(); |
| } |
| tensorflow::Safe_PyObjectPtr fast_result( |
| PySequence_Fast(user_output, "expected a sequence of forward gradients")); |
| if (fast_result == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "forward gradient function did not return a sequence."); |
| } |
| int len = PySequence_Fast_GET_SIZE(fast_result.get()); |
| PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get()); |
| output_tangents->reserve(len); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = fast_result_array[i]; |
| if (item == Py_None) { |
| output_tangents->push_back(nullptr); |
| } else { |
| Py_INCREF(item); |
| output_tangents->push_back(item); |
| } |
| } |
| return tensorflow::Status::OK(); |
| } |
| |
| // Calls the registered forward_gradient_function, computing `output_tangents` |
| // from `input_tangents`. `output_tangents` must not be null. |
| // |
| // `op_name`, `attrs`, `inputs`, and `results` describe the operation for which |
| // the forward function is being called. |
| tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs, |
| PyObject* inputs, PyObject* results, |
| const std::vector<PyObject*>& input_tangents, |
| std::vector<PyObject*>* output_tangents, |
| bool use_batch) { |
| if (forward_gradient_function == nullptr) { |
| return tensorflow::errors::Internal( |
| "No forward gradient function registered."); |
| } |
| tensorflow::Safe_PyObjectPtr py_input_tangents( |
| TangentsAsPyTuple(input_tangents)); |
| |
| // Normalize the input sequence to a tuple so it works with function |
| // caching; otherwise it may be an opaque _InputList object. |
| tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs)); |
| PyObject* to_batch = (use_batch) ? Py_True : Py_False; |
| tensorflow::Safe_PyObjectPtr callback_args( |
| Py_BuildValue("OOOOOO", op_name, attrs, input_tuple.get(), results, |
| py_input_tangents.get(), to_batch)); |
| tensorflow::Safe_PyObjectPtr py_result( |
| PyObject_CallObject(forward_gradient_function, callback_args.get())); |
| if (py_result == nullptr || PyErr_Occurred()) { |
| return tensorflow::errors::Internal( |
| "forward gradient function threw exceptions"); |
| } |
| return ParseTangentOutputs(py_result.get(), output_tangents); |
| } |
| |
| // Like CallJVPFunction, but calls a pre-bound forward function. |
| // These are passed in from a record_gradient argument. |
| tensorflow::Status CallOpSpecificJVPFunction( |
| PyObject* op_specific_forward_function, |
| const std::vector<PyObject*>& input_tangents, |
| std::vector<PyObject*>* output_tangents) { |
| tensorflow::Safe_PyObjectPtr py_input_tangents( |
| TangentsAsPyTuple(input_tangents)); |
| |
| tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject( |
| op_specific_forward_function, py_input_tangents.get())); |
| if (py_result == nullptr || PyErr_Occurred()) { |
| return tensorflow::errors::Internal( |
| "forward gradient function threw exceptions"); |
| } |
| return ParseTangentOutputs(py_result.get(), output_tangents); |
| } |
| |
| bool ParseOpTypeString(PyObject* op_type, string* op_type_string) { |
| if (PyBytes_Check(op_type)) { |
| *op_type_string = PyBytes_AsString(op_type); |
| } else if (PyUnicode_Check(op_type)) { |
| #if PY_MAJOR_VERSION >= 3 |
| *op_type_string = PyUnicode_AsUTF8(op_type); |
| #else |
| PyObject* py_str = PyUnicode_AsUTF8String(op_type); |
| if (py_str == nullptr) { |
| return false; |
| } |
| *op_type_string = PyBytes_AS_STRING(py_str); |
| Py_DECREF(py_str); |
| #endif |
| } else { |
| PyErr_SetString(PyExc_RuntimeError, "op_type should be a string."); |
| return false; |
| } |
| return true; |
| } |
| |
| bool TapeSetRecordOperation( |
| PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors, |
| const std::vector<tensorflow::int64>& input_ids, |
| const std::vector<tensorflow::DataType>& input_dtypes, |
| const std::function<PyBackwardFunction*()>& backward_function_getter, |
| const std::function<void(PyBackwardFunction*)>& backward_function_killer, |
| const tensorflow::eager::ForwardFunction<PyObject>* forward_function) { |
| std::vector<PyTapeTensor> output_info; |
| tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
| output_tensors, "expected a sequence of integer tensor ids")); |
| if (PyErr_Occurred() || |
| !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
| return false; |
| } |
| string op_type_str; |
| if (!ParseOpTypeString(op_type, &op_type_str)) { |
| return false; |
| } |
| tensorflow::uint64 max_gradient_tape_id; |
| if (!TapeSetRecordForwardprop( |
| op_type_str, output_seq.get(), output_info, input_tensors, input_ids, |
| input_dtypes, backward_function_getter, backward_function_killer, |
| forward_function, nullptr /* No special-cased jvps. */, |
| &max_gradient_tape_id)) { |
| return false; |
| } |
| TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, |
| backward_function_getter, backward_function_killer, |
| max_gradient_tape_id); |
| return true; |
| } |
| } // namespace |
| |
| PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type, |
| PyObject* output_tensors, |
| PyObject* input_tensors, |
| PyObject* backward_function, |
| PyObject* forward_function) { |
| if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) { |
| Py_RETURN_NONE; |
| } |
| std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::vector<tensorflow::DataType> input_dtypes = |
| MakeTensorDtypeList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::function<PyBackwardFunction*()> backward_function_getter( |
| [backward_function]() { |
| Py_INCREF(backward_function); |
| PyBackwardFunction* function = new PyBackwardFunction( |
| [backward_function](PyObject* out_grads, |
| const std::vector<tensorflow::int64>& unused) { |
| return PyObject_CallObject(backward_function, out_grads); |
| }); |
| return function; |
| }); |
| std::function<void(PyBackwardFunction*)> backward_function_killer( |
| [backward_function](PyBackwardFunction* py_backward_function) { |
| Py_DECREF(backward_function); |
| delete py_backward_function; |
| }); |
| |
| if (forward_function == Py_None) { |
| if (!TapeSetRecordOperation( |
| op_type, input_tensors, output_tensors, input_ids, input_dtypes, |
| backward_function_getter, backward_function_killer, |
| nullptr /* No special-cased forward function */)) { |
| return nullptr; |
| } |
| } else { |
| tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function( |
| [forward_function](const std::vector<PyObject*>& input_tangents, |
| std::vector<PyObject*>* output_tangents, bool use_batch=false) { |
| return CallOpSpecificJVPFunction(forward_function, input_tangents, |
| output_tangents); |
| }); |
| if (!TapeSetRecordOperation( |
| op_type, input_tensors, output_tensors, input_ids, input_dtypes, |
| backward_function_getter, backward_function_killer, |
| &wrapped_forward_function)) { |
| return nullptr; |
| } |
| } |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* TFE_Py_TapeSetRecordOperationForwardprop( |
| PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors, |
| PyObject* backward_function, PyObject* forwardprop_output_indices) { |
| if (!HasAccumulator() || *ThreadTapeIsStopped()) { |
| Py_RETURN_NONE; |
| } |
| std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::vector<tensorflow::DataType> input_dtypes = |
| MakeTensorDtypeList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::function<PyBackwardFunction*()> backward_function_getter( |
| [backward_function]() { |
| Py_INCREF(backward_function); |
| PyBackwardFunction* function = new PyBackwardFunction( |
| [backward_function](PyObject* out_grads, |
| const std::vector<tensorflow::int64>& unused) { |
| return PyObject_CallObject(backward_function, out_grads); |
| }); |
| return function; |
| }); |
| std::function<void(PyBackwardFunction*)> backward_function_killer( |
| [backward_function](PyBackwardFunction* py_backward_function) { |
| Py_DECREF(backward_function); |
| delete py_backward_function; |
| }); |
| std::vector<PyTapeTensor> output_info; |
| tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
| output_tensors, "expected a sequence of integer tensor ids")); |
| if (PyErr_Occurred() || |
| !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
| return nullptr; |
| } |
| string op_type_str; |
| if (!ParseOpTypeString(op_type, &op_type_str)) { |
| return nullptr; |
| } |
| tensorflow::uint64 max_gradient_tape_id; |
| if (!TapeSetRecordForwardprop( |
| op_type_str, output_seq.get(), output_info, input_tensors, input_ids, |
| input_dtypes, backward_function_getter, backward_function_killer, |
| nullptr /* no special-cased forward function */, |
| forwardprop_output_indices, &max_gradient_tape_id)) { |
| return nullptr; |
| } |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type, |
| PyObject* output_tensors, |
| PyObject* input_tensors, |
| PyObject* backward_function) { |
| if (!CouldBackprop()) { |
| Py_RETURN_NONE; |
| } |
| std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::vector<tensorflow::DataType> input_dtypes = |
| MakeTensorDtypeList(input_tensors); |
| if (PyErr_Occurred()) return nullptr; |
| |
| std::function<PyBackwardFunction*()> backward_function_getter( |
| [backward_function]() { |
| Py_INCREF(backward_function); |
| PyBackwardFunction* function = new PyBackwardFunction( |
| [backward_function](PyObject* out_grads, |
| const std::vector<tensorflow::int64>& unused) { |
| return PyObject_CallObject(backward_function, out_grads); |
| }); |
| return function; |
| }); |
| std::function<void(PyBackwardFunction*)> backward_function_killer( |
| [backward_function](PyBackwardFunction* py_backward_function) { |
| Py_DECREF(backward_function); |
| delete py_backward_function; |
| }); |
| std::vector<PyTapeTensor> output_info; |
| tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast( |
| output_tensors, "expected a sequence of integer tensor ids")); |
| if (PyErr_Occurred() || |
| !TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) { |
| return nullptr; |
| } |
| string op_type_str; |
| if (!ParseOpTypeString(op_type, &op_type_str)) { |
| return nullptr; |
| } |
| TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes, |
| backward_function_getter, backward_function_killer, |
| // No filtering based on relative ordering with forward |
| // accumulators. |
| std::numeric_limits<tensorflow::uint64>::max()); |
| Py_RETURN_NONE; |
| } |
| |
| void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { |
| for (TFE_Py_Tape* tape : *GetTapeSet()) { |
| tape->tape->DeleteTrace(tensor_id); |
| } |
| } |
| |
| std::vector<PyObject*> MakeTensorList(PyObject* tensors) { |
| PyObject* seq = PySequence_Fast(tensors, "expected a sequence"); |
| if (seq == nullptr) { |
| return {}; |
| } |
| int len = PySequence_Fast_GET_SIZE(seq); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq); |
| std::vector<PyObject*> list(seq_array, seq_array + len); |
| Py_DECREF(seq); |
| return list; |
| } |
| |
| PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target, |
| PyObject* sources, PyObject* output_gradients, |
| PyObject* sources_raw, |
| PyObject* unconnected_gradients, |
| TF_Status* status) { |
| TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape); |
| if (!tape_obj->tape->IsPersistent()) { |
| auto* tape_set = GetTapeSet(); |
| if (tape_set->find(tape_obj) != tape_set->end()) { |
| PyErr_SetString(PyExc_RuntimeError, |
| "gradient() cannot be invoked within the " |
| "GradientTape context (i.e., while operations are being " |
| "recorded). Either move the call to gradient() to be " |
| "outside the 'with tf.GradientTape' block, or " |
| "use a persistent tape: " |
| "'with tf.GradientTape(persistent=true)'"); |
| return nullptr; |
| } |
| } |
| |
| std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target); |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources); |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(), |
| sources_vec.end()); |
| |
| tensorflow::Safe_PyObjectPtr seq = |
| tensorflow::make_safe(PySequence_Fast(target, "expected a sequence")); |
| int len = PySequence_Fast_GET_SIZE(seq.get()); |
| PyObject** seq_array = PySequence_Fast_ITEMS(seq.get()); |
| std::unordered_map<tensorflow::int64, PyTapeTensor> |
| source_tensors_that_are_targets; |
| for (int i = 0; i < len; ++i) { |
| tensorflow::int64 target_id = target_vec[i]; |
| if (sources_set.find(target_id) != sources_set.end()) { |
| auto tensor = seq_array[i]; |
| source_tensors_that_are_targets.insert( |
| std::make_pair(target_id, TapeTensorFromTensor(tensor))); |
| } |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| } |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| |
| std::vector<PyObject*> outgrad_vec; |
| if (output_gradients != Py_None) { |
| outgrad_vec = MakeTensorList(output_gradients); |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| for (PyObject* tensor : outgrad_vec) { |
| // Calling the backward function will eat a reference to the tensors in |
| // outgrad_vec, so we need to increase their reference count. |
| Py_INCREF(tensor); |
| } |
| } |
| std::vector<PyObject*> result; |
| status->status = tape_obj->tape->ComputeGradient( |
| *py_vspace, target_vec, sources_vec, source_tensors_that_are_targets, |
| outgrad_vec, &result); |
| if (!status->status.ok()) { |
| if (PyErr_Occurred()) { |
| // Do not propagate the erroneous status as that would swallow the |
| // exception which caused the problem. |
| status->status = tensorflow::Status::OK(); |
| } |
| return nullptr; |
| } |
| |
| bool unconnected_gradients_zero = |
| strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0; |
| std::vector<PyObject*> sources_obj; |
| if (unconnected_gradients_zero) { |
| // Uses the "raw" sources here so it can properly make a zeros tensor even |
| // if there are resource variables as sources. |
| sources_obj = MakeTensorList(sources_raw); |
| } |
| |
| if (!result.empty()) { |
| PyObject* py_result = PyList_New(result.size()); |
| tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size()); |
| for (int i = 0; i < result.size(); ++i) { |
| if (result[i] == nullptr) { |
| if (unconnected_gradients_zero) { |
| // generate a zeros tensor in the shape of sources[i] |
| tensorflow::DataType dtype = FastTensorDtype(sources_obj[i]); |
| PyTapeTensor tensor = |
| PyTapeTensor(sources_vec[i], dtype, sources_obj[i]); |
| result[i] = tensor.ZerosLike(); |
| } else { |
| Py_INCREF(Py_None); |
| result[i] = Py_None; |
| } |
| } else if (seen_results.find(result[i]) != seen_results.end()) { |
| Py_INCREF(result[i]); |
| } |
| seen_results.insert(result[i]); |
| PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i])); |
| } |
| return py_result; |
| } |
| return PyList_New(0); |
| } |
| |
| PyObject* TFE_Py_ForwardAccumulatorNew(bool use_batch) { |
| TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew; |
| if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr; |
| TFE_Py_ForwardAccumulator* accumulator = |
| PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type); |
| if (py_vspace == nullptr) { |
| MaybeRaiseExceptionFromStatus( |
| tensorflow::errors::Internal( |
| "ForwardAccumulator requires a PyVSpace to be registered."), |
| nullptr); |
| } |
| accumulator->accumulator = new ForwardAccumulator(*py_vspace, use_batch); |
| return reinterpret_cast<PyObject*>(accumulator); |
| } |
| |
| PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) { |
| TFE_Py_ForwardAccumulator* c_accumulator( |
| reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); |
| c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1); |
| if (GetAccumulatorSet()->insert(c_accumulator)) { |
| Py_INCREF(accumulator); |
| Py_RETURN_NONE; |
| } else { |
| MaybeRaiseExceptionFromStatus( |
| tensorflow::errors::Internal( |
| "A ForwardAccumulator was added to the active set twice."), |
| nullptr); |
| return nullptr; |
| } |
| } |
| |
| void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) { |
| GetAccumulatorSet()->erase( |
| reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)); |
| Py_DECREF(accumulator); |
| } |
| |
| void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor, |
| PyObject* tangent) { |
| tensorflow::int64 tensor_id = FastTensorId(tensor); |
| reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) |
| ->accumulator->Watch(tensor_id, tangent); |
| RegisterForwardAccumulatorCleanup(tensor, tensor_id); |
| } |
| |
| // Returns a new reference to the JVP Tensor. |
| PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator, |
| PyObject* tensor) { |
| PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator) |
| ->accumulator->FetchJVP(FastTensorId(tensor)); |
| if (jvp == nullptr) { |
| jvp = Py_None; |
| } |
| Py_INCREF(jvp); |
| return jvp; |
| } |
| |
| PyObject* TFE_Py_PackJVPs(PyObject* tensors) { |
| if (!TapeCouldPossiblyRecord(tensors)) { |
| tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0)); |
| tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0)); |
| return PyTuple_Pack(2, empty_tuple.get(), empty_list.get()); |
| } |
| auto accumulators = *GetAccumulatorSet(); |
| tensorflow::Safe_PyObjectPtr tensors_fast( |
| PySequence_Fast(tensors, "Expected a sequence of input Tensors.")); |
| if (tensors_fast == nullptr || PyErr_Occurred()) { |
| return nullptr; |
| } |
| std::vector<tensorflow::int64> augmented_input_ids; |
| int len = PySequence_Fast_GET_SIZE(tensors_fast.get()); |
| PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get()); |
| for (Py_ssize_t position = 0; position < len; ++position) { |
| PyObject* input = tensors_fast_array[position]; |
| if (input == Py_None) { |
| continue; |
| } |
| tensorflow::DataType input_dtype(FastTensorDtype(input)); |
| if (input_dtype == tensorflow::DT_INVALID) { |
| return nullptr; |
| } |
| augmented_input_ids.push_back(FastTensorId(input)); |
| } |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| // Find the innermost accumulator such that all outer accumulators are |
| // recording. Any more deeply nested accumulators will not have their JVPs |
| // saved. |
| AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin(); |
| for (; innermost_all_recording != accumulators.end(); |
| ++innermost_all_recording) { |
| if ((*innermost_all_recording)->accumulator->BusyAccumulating()) { |
| break; |
| } |
| } |
| AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording( |
| innermost_all_recording); |
| |
| bool saving_jvps = false; |
| tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size())); |
| std::vector<PyObject*> new_tensors; |
| Py_ssize_t accumulator_index = 0; |
| // Start with the innermost accumulators to give outer accumulators a chance |
| // to find their higher-order JVPs. |
| for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin(); |
| it != accumulators.rend(); ++it, ++accumulator_index) { |
| std::vector<tensorflow::int64> new_input_ids; |
| std::vector<std::pair<tensorflow::int64, tensorflow::int64>> |
| accumulator_indices; |
| if (it == reverse_innermost_all_recording) { |
| saving_jvps = true; |
| } |
| if (saving_jvps) { |
| for (int input_index = 0; input_index < augmented_input_ids.size(); |
| ++input_index) { |
| tensorflow::int64 existing_input = augmented_input_ids[input_index]; |
| PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input); |
| if (jvp != nullptr) { |
| new_tensors.push_back(jvp); |
| new_input_ids.push_back(FastTensorId(jvp)); |
| accumulator_indices.emplace_back( |
| input_index, |
| augmented_input_ids.size() + new_input_ids.size() - 1); |
| } |
| } |
| } |
| tensorflow::Safe_PyObjectPtr accumulator_indices_py( |
| PyTuple_New(accumulator_indices.size())); |
| for (int i = 0; i < accumulator_indices.size(); ++i) { |
| tensorflow::Safe_PyObjectPtr from_index( |
| GetPythonObjectFromInt(accumulator_indices[i].first)); |
| tensorflow::Safe_PyObjectPtr to_index( |
| GetPythonObjectFromInt(accumulator_indices[i].second)); |
| PyTuple_SetItem(accumulator_indices_py.get(), i, |
| PyTuple_Pack(2, from_index.get(), to_index.get())); |
| } |
| PyTuple_SetItem(all_indices.get(), accumulator_index, |
| accumulator_indices_py.release()); |
| augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(), |
| new_input_ids.end()); |
| } |
| |
| tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size())); |
| for (int i = 0; i < new_tensors.size(); ++i) { |
| PyObject* jvp = new_tensors[i]; |
| Py_INCREF(jvp); |
| PyList_SET_ITEM(new_tensors_py.get(), i, jvp); |
| } |
| return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get()); |
| } |
| |
| namespace { |
| static const int kFastPathExecuteInputStartIndex = 5; |
| |
| PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) { |
| #if PY_MAJOR_VERSION >= 3 |
| return PyUnicode_FromStringAndSize(s.data(), s.size()); |
| #else |
| return PyBytes_FromStringAndSize(s.data(), s.size()); |
| #endif |
| } |
| |
| bool CheckResourceVariable(PyObject* item) { |
| if (tensorflow::swig::IsResourceVariable(item)) { |
| tensorflow::Safe_PyObjectPtr handle( |
| PyObject_GetAttrString(item, "_handle")); |
| return EagerTensor_CheckExact(handle.get()); |
| } |
| |
| return false; |
| } |
| |
| bool IsNumberType(PyObject* item) { |
| #if PY_MAJOR_VERSION >= 3 |
| return PyFloat_Check(item) || PyLong_Check(item); |
| #else |
| return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item); |
| #endif |
| } |
| |
| bool CheckOneInput(PyObject* item) { |
| if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) || |
| PyArray_Check(item) || IsNumberType(item)) { |
| return true; |
| } |
| |
| // Sequences are not properly handled. Sequences with purely python numeric |
| // types work, but sequences with mixes of EagerTensors and python numeric |
| // types don't work. |
| // TODO(nareshmodi): fix |
| return false; |
| } |
| |
| bool CheckInputsOk(PyObject* seq, int start_index, |
| const tensorflow::OpDef& op_def) { |
| for (int i = 0; i < op_def.input_arg_size(); i++) { |
| PyObject* item = PyTuple_GET_ITEM(seq, i + start_index); |
| if (!op_def.input_arg(i).number_attr().empty() || |
| !op_def.input_arg(i).type_list_attr().empty()) { |
| // This item should be a seq input. |
| if (!PySequence_Check(item)) { |
| VLOG(1) << "Falling back to slow path for Op \"" << op_def.name() |
| << "\", Input \"" << op_def.input_arg(i).name() |
| << "\" since we expected a sequence, but got " |
| << item->ob_type->tp_name; |
| return false; |
| } |
| tensorflow::Safe_PyObjectPtr fast_item( |
| PySequence_Fast(item, "Could not parse sequence.")); |
| if (fast_item.get() == nullptr) { |
| return false; |
| } |
| int len = PySequence_Fast_GET_SIZE(fast_item.get()); |
| PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get()); |
| for (Py_ssize_t j = 0; j < len; j++) { |
| PyObject* inner_item = fast_item_array[j]; |
| if (!CheckOneInput(inner_item)) { |
| VLOG(1) << "Falling back to slow path for Op \"" << op_def.name() |
| << "\", Input \"" << op_def.input_arg(i).name() |
| << "\", Index " << j |
| << " since we expected an EagerTensor/ResourceVariable, " |
| "but got " |
| << inner_item->ob_type->tp_name; |
| return false; |
| } |
| } |
| } else if (!CheckOneInput(item)) { |
| VLOG(1) |
| << "Falling back to slow path for Op \"" << op_def.name() |
| << "\", Input \"" << op_def.input_arg(i).name() |
| << "\" since we expected an EagerTensor/ResourceVariable, but got " |
| << item->ob_type->tp_name; |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| tensorflow::DataType MaybeGetDType(PyObject* item) { |
| if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) { |
| return FastTensorDtype(item); |
| } |
| |
| return tensorflow::DT_INVALID; |
| } |
| |
| tensorflow::DataType MaybeGetDTypeForAttr(const string& attr, |
| FastPathOpExecInfo* op_exec_info) { |
| auto cached_it = op_exec_info->cached_dtypes.find(attr); |
| if (cached_it != op_exec_info->cached_dtypes.end()) { |
| return cached_it->second; |
| } |
| |
| auto it = op_exec_info->attr_to_inputs_map->find(attr); |
| if (it == op_exec_info->attr_to_inputs_map->end()) { |
| // No other inputs - this should never happen. |
| return tensorflow::DT_INVALID; |
| } |
| |
| for (const auto& input_info : it->second) { |
| PyObject* item = PyTuple_GET_ITEM( |
| op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i); |
| if (input_info.is_list) { |
| tensorflow::Safe_PyObjectPtr fast_item( |
| PySequence_Fast(item, "Unable to allocate")); |
| int len = PySequence_Fast_GET_SIZE(fast_item.get()); |
| PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get()); |
| for (int i = 0; i < len; i++) { |
| auto dtype = MaybeGetDType(fast_item_array[i]); |
| if (dtype != tensorflow::DT_INVALID) return dtype; |
| } |
| } else { |
| auto dtype = MaybeGetDType(item); |
| if (dtype != tensorflow::DT_INVALID) return dtype; |
| } |
| } |
| |
| auto default_it = op_exec_info->default_dtypes->find(attr); |
| if (default_it != op_exec_info->default_dtypes->end()) { |
| return default_it->second; |
| } |
| |
| return tensorflow::DT_INVALID; |
| } |
| |
| PyObject* CopySequenceSettingIndicesToNull( |
| PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) { |
| tensorflow::Safe_PyObjectPtr fast_seq( |
| PySequence_Fast(seq, "unable to allocate")); |
| int len = PySequence_Fast_GET_SIZE(fast_seq.get()); |
| PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get()); |
| PyObject* result = PyTuple_New(len); |
| for (int i = 0; i < len; i++) { |
| PyObject* item; |
| if (indices.find(i) != indices.end()) { |
| item = Py_None; |
| } else { |
| item = fast_seq_array[i]; |
| } |
| Py_INCREF(item); |
| PyTuple_SET_ITEM(result, i, item); |
| } |
| return result; |
| } |
| |
| PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, |
| PyObject* results, |
| PyObject* forward_pass_name_scope = nullptr) { |
| std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs); |
| if (PyErr_Occurred()) return nullptr; |
| std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs); |
| if (PyErr_Occurred()) return nullptr; |
| |
| bool should_record = false; |
| for (TFE_Py_Tape* tape : SafeTapeSet()) { |
| if (tape->tape->ShouldRecord(input_ids, input_dtypes)) { |
| should_record = true; |
| break; |
| } |
| } |
| if (!should_record) { |
| for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) { |
| if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) { |
| should_record = true; |
| break; |
| } |
| } |
| } |
| if (!should_record) Py_RETURN_NONE; |
| |
| string c_op_name = TFE_GetPythonString(op_name); |
| |
| PyObject* op_outputs; |
| bool op_outputs_tuple_created = false; |
| |
| if (const auto unused_output_indices = |
| OpGradientUnusedOutputIndices(c_op_name)) { |
| if (unused_output_indices->empty()) { |
| op_outputs = Py_None; |
| } else { |
| op_outputs_tuple_created = true; |
| op_outputs = |
| CopySequenceSettingIndicesToNull(results, *unused_output_indices); |
| } |
| } else { |
| op_outputs = results; |
| } |
| |
| PyObject* op_inputs; |
| bool op_inputs_tuple_created = false; |
| |
| if (const auto unused_input_indices = |
| OpGradientUnusedInputIndices(c_op_name)) { |
| if (unused_input_indices->empty()) { |
| op_inputs = Py_None; |
| } else { |
| op_inputs_tuple_created = true; |
| op_inputs = |
| CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); |
| } |
| } else { |
| op_inputs = inputs; |
| } |
| |
| tensorflow::eager::ForwardFunction<PyObject> py_forward_function( |
| [op_name, attrs, inputs, results]( |
| const std::vector<PyObject*>& input_tangents, |
| std::vector<PyObject*>* output_tangents, bool use_batch) { |
| return CallJVPFunction(op_name, attrs, inputs, results, input_tangents, |
| output_tangents, use_batch); |
| }); |
| tensorflow::eager::ForwardFunction<PyObject>* forward_function; |
| if (c_op_name == "While" || c_op_name == "StatelessWhile" || |
| c_op_name == "If" || c_op_name == "StatelessIf") { |
| // Control flow contains non-hashable attributes. Handling them in Python is |
| // a headache, so instead we'll stay as close to GradientTape's handling as |
| // possible (a null forward function means the accumulator forwards to a |
| // tape). |
| // |
| // This is safe to do since we'll only see control flow when graph building, |
| // in which case we can rely on pruning. |
| forward_function = nullptr; |
| } else { |
| forward_function = &py_forward_function; |
| } |
| |
| PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); |
| |
| if (!forward_pass_name_scope) forward_pass_name_scope = Py_None; |
| |
| TapeSetRecordOperation( |
| op_name, inputs, results, input_ids, input_dtypes, |
| [op_name, attrs, num_inputs, op_inputs, op_outputs, |
| forward_pass_name_scope]() { |
| Py_INCREF(op_name); |
| Py_INCREF(attrs); |
| Py_INCREF(num_inputs); |
| Py_INCREF(op_inputs); |
| Py_INCREF(op_outputs); |
| Py_INCREF(forward_pass_name_scope); |
| PyBackwardFunction* function = new PyBackwardFunction( |
| [op_name, attrs, num_inputs, op_inputs, op_outputs, |
| forward_pass_name_scope]( |
| PyObject* output_grads, |
| const std::vector<tensorflow::int64>& unneeded_gradients) { |
| if (PyErr_Occurred()) { |
| return static_cast<PyObject*>(nullptr); |
| } |
| tensorflow::Safe_PyObjectPtr skip_input_indices; |
| if (!unneeded_gradients.empty()) { |
| skip_input_indices.reset( |
| PyTuple_New(unneeded_gradients.size())); |
| for (int i = 0; i < unneeded_gradients.size(); i++) { |
| PyTuple_SET_ITEM( |
| skip_input_indices.get(), i, |
| GetPythonObjectFromInt(unneeded_gradients[i])); |
| } |
| } else { |
| Py_INCREF(Py_None); |
| skip_input_indices.reset(Py_None); |
| } |
| tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue( |
| "OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs, |
| output_grads, skip_input_indices.get(), |
| forward_pass_name_scope)); |
| |
| tensorflow::Safe_PyObjectPtr result( |
| PyObject_CallObject(gradient_function, callback_args.get())); |
| |
| if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr); |
| |
| return tensorflow::swig::Flatten(result.get()); |
| }); |
| return function; |
| }, |
| [op_name, attrs, num_inputs, op_inputs, op_outputs, |
| forward_pass_name_scope](PyBackwardFunction* backward_function) { |
| Py_DECREF(op_name); |
| Py_DECREF(attrs); |
| Py_DECREF(num_inputs); |
| Py_DECREF(op_inputs); |
| Py_DECREF(op_outputs); |
| Py_DECREF(forward_pass_name_scope); |
| |
| delete backward_function; |
| }, |
| forward_function); |
| |
| Py_DECREF(num_inputs); |
| if (op_outputs_tuple_created) Py_DECREF(op_outputs); |
| if (op_inputs_tuple_created) Py_DECREF(op_inputs); |
| |
| if (PyErr_Occurred()) { |
| return nullptr; |
| } |
| |
| Py_RETURN_NONE; |
| } |
| |
| void MaybeNotifyVariableAccessed(PyObject* input) { |
| DCHECK(CheckResourceVariable(input)); |
| DCHECK(PyObject_HasAttrString(input, "_trainable")); |
| |
| tensorflow::Safe_PyObjectPtr trainable( |
| PyObject_GetAttrString(input, "_trainable")); |
| if (trainable.get() == Py_False) return; |
| TFE_Py_TapeVariableAccessed(input); |
| TFE_Py_VariableWatcherVariableAccessed(input); |
| } |
| |
| bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info, |
| PyObject* input, tensorflow::Safe_PyObjectPtr* output, |
| TF_Status* status) { |
| MaybeNotifyVariableAccessed(input); |
| |
| TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status); |
| auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); }); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; |
| |
| TFE_OpSetDevice(op, parent_op_exec_info.device_name, status); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; |
| |
| // Set dtype |
| DCHECK(PyObject_HasAttrString(input, "_dtype")); |
| tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype")); |
| int value; |
| if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) { |
| return false; |
| } |
| TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value)); |
| |
| // Get handle |
| tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle")); |
| if (!EagerTensor_CheckExact(handle.get())) return false; |
| TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; |
| |
| int num_retvals = 1; |
| TFE_TensorHandle* output_handle; |
| TFE_Execute(op, &output_handle, &num_retvals, status); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false; |
| |
| // Always create the py object (and correctly DECREF it) from the returned |
| // value, else the data will leak. |
| output->reset(EagerTensorFromHandle(output_handle)); |
| |
| // TODO(nareshmodi): Should we run post exec callbacks here? |
| if (parent_op_exec_info.run_gradient_callback) { |
| tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1)); |
| PyTuple_SET_ITEM(inputs.get(), 0, handle.release()); |
| |
| tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1)); |
| Py_INCREF(output->get()); // stay alive after since tuple steals. |
| PyTuple_SET_ITEM(outputs.get(), 0, output->get()); |
| |
| tensorflow::Safe_PyObjectPtr op_string( |
| GetPythonObjectFromString("ReadVariableOp")); |
| if (!RecordGradient(op_string.get(), inputs.get(), Py_None, |
| outputs.get())) { |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| // Supports 3 cases at the moment: |
| // i) input is an EagerTensor. |
| // ii) input is a ResourceVariable - in this case, the is_variable param is |
| // set to true. |
| // iii) input is an arbitrary python list/tuple (note, this handling doesn't |
| // support packing). |
| // |
| // NOTE: dtype_hint_getter must *always* return a PyObject that can be |
| // decref'd. So if no hint is found, Py_RETURN_NONE (which correctly |
| // increfs Py_None). |
| // |
| // NOTE: This function sets a python error directly, and returns false. |
| // TF_Status is only passed since we don't want to have to reallocate it. |
| bool ConvertToTensor( |
| const FastPathOpExecInfo& op_exec_info, PyObject* input, |
| tensorflow::Safe_PyObjectPtr* output_handle, |
| // This gets a hint for this particular input. |
| const std::function<tensorflow::DataType()>& dtype_hint_getter, |
| // This sets the dtype after conversion is complete. |
| const std::function<void(const tensorflow::DataType dtype)>& dtype_setter, |
| TF_Status* status) { |
| if (EagerTensor_CheckExact(input)) { |
| Py_INCREF(input); |
| output_handle->reset(input); |
| return true; |
| } else if (CheckResourceVariable(input)) { |
| return ReadVariableOp(op_exec_info, input, output_handle, status); |
| } |
| |
| // The hint comes from a supposedly similarly typed tensor. |
| tensorflow::DataType dtype_hint = dtype_hint_getter(); |
| |
| TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor( |
| op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name); |
| if (handle == nullptr) { |
| return MaybeRaiseExceptionFromTFStatus(status, nullptr); |
| } |
| |
| output_handle->reset(EagerTensorFromHandle(handle)); |
| dtype_setter( |
| static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle))); |
| |
| return true; |
| } |
| |
| // Adds input and type attr to the op, and to the list of flattened |
| // inputs/attrs. |
| bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input, |
| const bool add_type_attr, |
| const tensorflow::OpDef::ArgDef& input_arg, |
| std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs, |
| std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs, |
| TFE_Op* op, TF_Status* status) { |
| // py_eager_tensor's ownership is transferred to flattened_inputs if it is |
| // required, else the object is destroyed and DECREF'd when the object goes |
| // out of scope in this function. |
| tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr; |
| |
| if (!ConvertToTensor( |
| *op_exec_info, input, &py_eager_tensor, |
| [&]() { |
| if (input_arg.type() != tensorflow::DataType::DT_INVALID) { |
| return input_arg.type(); |
| } |
| return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info); |
| }, |
| [&](const tensorflow::DataType dtype) { |
| op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype; |
| }, |
| status)) { |
| return false; |
| } |
| |
| TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get()); |
| |
| if (add_type_attr && !input_arg.type_attr().empty()) { |
| auto dtype = TFE_TensorHandleDataType(input_handle); |
| TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype); |
| if (flattened_attrs != nullptr) { |
| flattened_attrs->emplace_back( |
| GetPythonObjectFromString(input_arg.type_attr())); |
| flattened_attrs->emplace_back(PyLong_FromLong(dtype)); |
| } |
| } |
| |
| if (flattened_inputs != nullptr) { |
| flattened_inputs->emplace_back(std::move(py_eager_tensor)); |
| } |
| |
| TFE_OpAddInput(op, input_handle, status); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| const char* GetDeviceName(PyObject* py_device_name) { |
| if (py_device_name != Py_None) { |
| return TFE_GetPythonString(py_device_name); |
| } |
| return nullptr; |
| } |
| |
| bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) { |
| if (!PySequence_Check(seq)) { |
| PyErr_SetString(PyExc_TypeError, |
| Printf("expected a sequence for attr %s, got %s instead", |
| attr_name.data(), seq->ob_type->tp_name) |
| .data()); |
| |
| return false; |
| } |
| if (PyArray_Check(seq) && |
| PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) { |
| PyErr_SetString(PyExc_ValueError, |
| Printf("expected a sequence for attr %s, got an ndarray " |
| "with rank %d instead", |
| attr_name.data(), |
| PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq))) |
| .data()); |
| return false; |
| } |
| return true; |
| } |
| |
| bool RunCallbacks( |
| const FastPathOpExecInfo& op_exec_info, PyObject* args, |
| int num_inferred_attrs, |
| const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs, |
| const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs, |
| PyObject* flattened_result) { |
| DCHECK(op_exec_info.run_callbacks); |
| |
| tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size())); |
| for (int i = 0; i < flattened_inputs.size(); i++) { |
| PyObject* input = flattened_inputs[i].get(); |
| Py_INCREF(input); |
| PyTuple_SET_ITEM(inputs.get(), i, input); |
| } |
| |
| int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs; |
| int num_attrs = flattened_attrs.size() + num_non_inferred_attrs; |
| tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs)); |
| |
| for (int i = 0; i < num_non_inferred_attrs; i++) { |
| auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i); |
| Py_INCREF(attr); |
| PyTuple_SET_ITEM(attrs.get(), i, attr); |
| } |
| |
| for (int i = num_non_inferred_attrs; i < num_attrs; i++) { |
| PyObject* attr_or_name = |
| flattened_attrs.at(i - num_non_inferred_attrs).get(); |
| Py_INCREF(attr_or_name); |
| PyTuple_SET_ITEM(attrs.get(), i, attr_or_name); |
| } |
| |
| if (op_exec_info.run_gradient_callback) { |
| if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(), |
| flattened_result)) { |
| return false; |
| } |
| } |
| |
| if (op_exec_info.run_post_exec_callbacks) { |
| tensorflow::Safe_PyObjectPtr callback_args( |
| Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(), |
| flattened_result, op_exec_info.name)); |
| for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) { |
| PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i); |
| if (!PyCallable_Check(callback_fn)) { |
| PyErr_SetString( |
| PyExc_TypeError, |
| Printf("expected a function for " |
| "post execution callback in index %ld, got %s instead", |
| i, callback_fn->ob_type->tp_name) |
| .c_str()); |
| return false; |
| } |
| PyObject* callback_result = |
| PyObject_CallObject(callback_fn, callback_args.get()); |
| if (!callback_result) { |
| return false; |
| } |
| Py_DECREF(callback_result); |
| } |
| } |
| |
| return true; |
| } |
| |
| } // namespace |
| |
| PyObject* TFE_Py_FastPathExecute_C(PyObject* args) { |
| tensorflow::profiler::TraceMe activity( |
| "TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo); |
| Py_ssize_t args_size = PyTuple_GET_SIZE(args); |
| if (args_size < kFastPathExecuteInputStartIndex) { |
| PyErr_SetString( |
| PyExc_ValueError, |
| Printf("There must be at least %d items in the input tuple.", |
| kFastPathExecuteInputStartIndex) |
| .c_str()); |
| return nullptr; |
| } |
| |
| FastPathOpExecInfo op_exec_info; |
| |
| TFE_Context* ctx = reinterpret_cast<TFE_Context*>( |
| PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr)); |
| op_exec_info.ctx = ctx; |
| op_exec_info.args = args; |
| |
| if (ctx == nullptr) { |
| // The context hasn't been initialized. It will be in the slow path. |
| RaiseFallbackException( |
| "This function does not handle the case of the path where " |
| "all inputs are not already EagerTensors."); |
| return nullptr; |
| } |
| |
| op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1)); |
| op_exec_info.op_name = PyTuple_GET_ITEM(args, 2); |
| op_exec_info.name = PyTuple_GET_ITEM(args, 3); |
| op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4); |
| |
| // TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks |
| // (similar to benchmark_tf_gradient_function_*). Also consider using an |
| // InlinedVector for flattened_attrs and flattened_inputs if the benchmarks |
| // point out problems with heap allocs. |
| op_exec_info.run_gradient_callback = |
| !*ThreadTapeIsStopped() && HasAccumulatorOrTape(); |
| op_exec_info.run_post_exec_callbacks = |
| op_exec_info.callbacks != Py_None && |
| PyList_Size(op_exec_info.callbacks) > 0; |
| op_exec_info.run_callbacks = op_exec_info.run_gradient_callback || |
| op_exec_info.run_post_exec_callbacks; |
| |
| TF_Status* status = GetStatus(); |
| const char* op_name = TFE_GetPythonString(op_exec_info.op_name); |
| if (op_name == nullptr) { |
| PyErr_SetString(PyExc_TypeError, |
| Printf("expected a string for op_name, got %s instead", |
| op_exec_info.op_name->ob_type->tp_name) |
| .c_str()); |
| return nullptr; |
| } |
| |
| TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status); |
| tensorflow::unwrap(op)->SetStackTrace(tensorflow::GetStackTrace()); |
| |
| auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] { |
| ReturnStatus(status); |
| ReturnOp(ctx, op); |
| }); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
| return nullptr; |
| } |
| |
| const tensorflow::OpDef* op_def = tensorflow::unwrap(op)->OpDef(); |
| if (op_def == nullptr) return nullptr; |
| |
| if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) { |
| PyErr_SetString( |
| PyExc_ValueError, |
| Printf("Tuple size smaller than intended. Expected to be at least %d, " |
| "was %ld", |
| kFastPathExecuteInputStartIndex + op_def->input_arg_size(), |
| args_size) |
| .c_str()); |
| return nullptr; |
| } |
| |
| if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) { |
| RaiseFallbackException( |
| "This function does not handle the case of the path where " |
| "all inputs are not already EagerTensors."); |
| return nullptr; |
| } |
| |
| op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def); |
| op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def); |
| |
| // Mapping of attr name to size - used to calculate the number of values |
| // to be expected by the TFE_Execute run. |
| tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes; |
| |
| // Set non-inferred attrs, including setting defaults if the attr is passed in |
| // as None. |
| for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size(); |
| i < args_size; i += 2) { |
| PyObject* py_attr_name = PyTuple_GET_ITEM(args, i); |
| const char* attr_name = TFE_GetPythonString(py_attr_name); |
| PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1); |
| |
| // Not creating an index since most of the time there are not more than a |
| // few attrs. |
| // TODO(nareshmodi): Maybe include the index as part of the |
| // OpRegistrationData. |
| for (const auto& attr : op_def->attr()) { |
| if (tensorflow::StringPiece(attr_name) == attr.name()) { |
| SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value, |
| &attr_list_sizes, status); |
| |
| if (!status->status.ok()) { |
| VLOG(1) << "Falling back to slow path for Op \"" << op_def->name() |
| << "\" since we are unable to set the value for attr \"" |
| << attr.name() << "\" due to: " << TF_Message(status); |
| RaiseFallbackException(TF_Message(status)); |
| return nullptr; |
| } |
| |
| break; |
| } |
| } |
| } |
| |
| // Flat attrs and inputs as required by the record_gradient call. The attrs |
| // here only contain inferred attrs (non-inferred attrs are added directly |
| // from the input args). |
| // All items in flattened_attrs and flattened_inputs contain |
| // Safe_PyObjectPtr - any time something steals a reference to this, it must |
| // INCREF. |
| // TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work |
| // directly. |
| std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs = |
| nullptr; |
| std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs = |
| nullptr; |
| |
| // TODO(nareshmodi): Encapsulate callbacks information into a struct. |
| if (op_exec_info.run_callbacks) { |
| flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); |
| flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>); |
| } |
| |
| // Add inferred attrs and inputs. |
| // The following code might set duplicate type attrs. This will result in |
| // the CacheKey for the generated AttrBuilder possibly differing from |
| // those where the type attrs are correctly set. Inconsistent CacheKeys |
| // for ops means that there might be unnecessarily duplicated kernels. |
| // TODO(nareshmodi): Fix this. |
| for (int i = 0; i < op_def->input_arg_size(); i++) { |
| const auto& input_arg = op_def->input_arg(i); |
| |
| PyObject* input = |
| PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i); |
| if (!input_arg.number_attr().empty()) { |
| // The item is a homogeneous list. |
| if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr; |
| tensorflow::Safe_PyObjectPtr fast_input( |
| PySequence_Fast(input, "Could not parse sequence.")); |
| if (fast_input.get() == nullptr) { |
| return nullptr; |
| } |
| Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get()); |
| PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get()); |
| |
| TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len); |
| if (op_exec_info.run_callbacks) { |
| flattened_attrs->emplace_back( |
| GetPythonObjectFromString(input_arg.number_attr())); |
| flattened_attrs->emplace_back(PyLong_FromLong(len)); |
| } |
| attr_list_sizes[input_arg.number_attr()] = len; |
| |
| if (len > 0) { |
| // First item adds the type attr. |
| if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg, |
| flattened_attrs.get(), flattened_inputs.get(), op, |
| status)) { |
| return nullptr; |
| } |
| |
| for (Py_ssize_t j = 1; j < len; j++) { |
| // Since the list is homogeneous, we don't need to re-add the attr. |
| if (!AddInputToOp(&op_exec_info, fast_input_array[j], false, |
| input_arg, nullptr /* flattened_attrs */, |
| flattened_inputs.get(), op, status)) { |
| return nullptr; |
| } |
| } |
| } |
| } else if (!input_arg.type_list_attr().empty()) { |
| // The item is a heterogeneous list. |
| if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) { |
| return nullptr; |
| } |
| tensorflow::Safe_PyObjectPtr fast_input( |
| PySequence_Fast(input, "Could not parse sequence.")); |
| if (fast_input.get() == nullptr) { |
| return nullptr; |
| } |
| const string& attr_name = input_arg.type_list_attr(); |
| Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get()); |
| PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get()); |
| tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len); |
| PyObject* py_attr_value = nullptr; |
| if (op_exec_info.run_callbacks) { |
| py_attr_value = PyTuple_New(len); |
| } |
| for (Py_ssize_t j = 0; j < len; j++) { |
| PyObject* py_input = fast_input_array[j]; |
| tensorflow::Safe_PyObjectPtr py_eager_tensor; |
| if (!ConvertToTensor( |
| op_exec_info, py_input, &py_eager_tensor, |
| []() { return tensorflow::DT_INVALID; }, |
| [](const tensorflow::DataType dtype) {}, status)) { |
| return nullptr; |
| } |
| |
| TFE_TensorHandle* input_handle = |
| EagerTensor_Handle(py_eager_tensor.get()); |
| |
| attr_value[j] = TFE_TensorHandleDataType(input_handle); |
| |
| TFE_OpAddInput(op, input_handle, status); |
| if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) { |
| return nullptr; |
| } |
| |
| if (op_exec_info.run_callbacks) { |
| flattened_inputs->emplace_back(std::move(py_eager_tensor)); |
| |
| PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j])); |
| } |
| } |
| if (op_exec_info.run_callbacks) { |
| flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name)); |
| flattened_attrs->emplace_back(py_attr_value); |
| } |
| TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(), |
| attr_value.size()); |
| attr_list_sizes[attr_name] = len; |
| } else { |
| // The item is a single item. |
| if (!AddInputToOp(&op_exec_info, input, true, input_arg, |
| flattened_attrs.get(), flattened_inputs.get(), op, |
| status)) { |
| return nullptr; |
| } |
| } |
| } |
| |
| int num_retvals = 0; |
| for (int i = 0; i < op_def->output_arg_size(); i++) { |
| const auto& output_arg = op_def->output_arg(i); |
| int delta = 1; |
| if (!output_arg.number_attr().empty()) { |
| delta = attr_list_sizes[output_arg.number_attr()]; |
| } else if (!output_arg.type_list_attr().empty()) { |
| delta = attr_list_sizes[output_arg.type_list_attr()]; |
| } |
| if (delta < 0) { |
| RaiseFallbackException( |
| "Attributes suggest that the size of an output list is less than 0"); |
| return nullptr; |
| } |
| num_retvals += delta; |
| } |
| |
| tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals); |
| |
| Py_BEGIN_ALLOW_THREADS; |
| TFE_Execute(op, retvals.data(), &num_retvals, status); |
| Py_END_ALLOW_THREADS; |
| |
| if (!status->status.ok()) { |
| // Augment the status with the op_name for easier debugging similar to |
| // TFE_Py_Execute. |
| std::vector<tensorflow::StackFrame> stack_trace = |
| status->status.stack_trace(); |
| status->status = tensorflow::Status( |
| status->status.code(), |
| tensorflow::strings::StrCat( |
| TF_Message(status), |
| " [Op:", TFE_GetPythonString(op_exec_info.op_name), "]"), |
| std::move(stack_trace)); |
| |
| MaybeRaiseExceptionFromTFStatus(status, nullptr); |
| return nullptr; |
| } |
| |
| tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals)); |
| for (int i = 0; i < num_retvals; ++i) { |
| PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i])); |
| } |
| |
| if (op_exec_info.run_callbacks) { |
| if (!RunCallbacks( |
| op_exec_info, args, |
| kFastPathExecuteInputStartIndex + op_def->input_arg_size(), |
| *flattened_inputs, *flattened_attrs, flat_result.get())) { |
| return nullptr; |
| } |
| } |
| |
| // Unflatten results. |
| if (op_def->output_arg_size() == 0) { |
| Py_RETURN_NONE; |
| } |
| |
| if (op_def->output_arg_size() == 1) { |
| if (!op_def->output_arg(0).number_attr().empty() || |
| !op_def->output_arg(0).type_list_attr().empty()) { |
| return flat_result.release(); |
| } else { |
| auto* result = PyList_GET_ITEM(flat_result.get(), 0); |
| Py_INCREF(result); |
| return result; |
| } |
| } |
| |
| // Correctly output the results that are made into a namedtuple. |
| PyObject* result = PyList_New(op_def->output_arg_size()); |
| int flat_result_index = 0; |
| for (int i = 0; i < op_def->output_arg_size(); i++) { |
| if (!op_def->output_arg(i).number_attr().empty()) { |
| int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()]; |
| PyObject* inner_list = PyList_New(list_length); |
| for (int j = 0; j < list_length; j++) { |
| PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
| Py_INCREF(obj); |
| PyList_SET_ITEM(inner_list, j, obj); |
| } |
| PyList_SET_ITEM(result, i, inner_list); |
| } else if (!op_def->output_arg(i).type_list_attr().empty()) { |
| int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()]; |
| PyObject* inner_list = PyList_New(list_length); |
| for (int j = 0; j < list_length; j++) { |
| PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
| Py_INCREF(obj); |
| PyList_SET_ITEM(inner_list, j, obj); |
| } |
| PyList_SET_ITEM(result, i, inner_list); |
| } else { |
| PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++); |
| Py_INCREF(obj); |
| PyList_SET_ITEM(result, i, obj); |
| } |
| } |
| return result; |
| } |
| |
| PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, |
| PyObject* attrs, PyObject* results, |
| PyObject* forward_pass_name_scope) { |
| if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) { |
| Py_RETURN_NONE; |
| } |
| |
| return RecordGradient(op_name, inputs, attrs, results, |
| forward_pass_name_scope); |
| } |
| |
| namespace { |
| const char kTensor[] = "T"; |
| const char kList[] = "L"; |
| const char kListEnd[] = "l"; |
| const char kTuple[] = "U"; |
| const char kTupleEnd[] = "u"; |
| const char kDict[] = "D"; |
| const char kRaw[] = "R"; |
| const char kShape[] = "s"; |
| const char kShapeDelim[] = "-"; |
| const char kDType[] = "d"; |
| const char kNone[] = "n"; |
| const char kCompositeTensor[] = "C"; |
| const char kAttrs[] = "A"; |
| const char kAttrsEnd[] = "a"; |
| |
| struct EncodeResult { |
| string str; |
| std::vector<PyObject*> objects; |
| |
| PyObject* ToPyTuple() { |
| PyObject* result = PyTuple_New(2); |
| |
| PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str)); |
| |
| if (objects.empty()) { |
| Py_INCREF(Py_None); |
| PyTuple_SET_ITEM(result, 1, Py_None); |
| } else { |
| PyObject* objects_tuple = PyTuple_New(objects.size()); |
| |
| for (int i = 0; i < objects.size(); i++) { |
| PyTuple_SET_ITEM(objects_tuple, i, objects[i]); |
| } |
| |
| PyTuple_SET_ITEM(result, 1, objects_tuple); |
| } |
| |
| return result; |
| } |
| }; |
| |
| tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, |
| bool include_tensor_ranks_only, |
| EncodeResult* result) { |
| if (EagerTensor_CheckExact(arg)) { |
| tensorflow::ImmediateExecutionTensorHandle* handle = |
| tensorflow::unwrap(EagerTensor_Handle(arg)); |
| |
| absl::StrAppend(&result->str, kDType, |
| static_cast<tensorflow::DataType>(handle->DataType())); |
| absl::StrAppend(&result->str, kShape); |
| |
| int num_dims; |
| tensorflow::Status status = handle->NumDims(&num_dims); |
| if (!status.ok()) return status; |
| |
| if (include_tensor_ranks_only) { |
| absl::StrAppend(&result->str, num_dims); |
| } else { |
| for (int i = 0; i < num_dims; ++i) { |
| tensorflow::int64 dim_size; |
| status = handle->Dim(i, &dim_size); |
| if (!status.ok()) return status; |
| absl::StrAppend(&result->str, dim_size, kShapeDelim); |
| } |
| } |
| return tensorflow::Status::OK(); |
| } |
| |
| tensorflow::Safe_PyObjectPtr dtype_object( |
| PyObject_GetAttrString(arg, "dtype")); |
| |
| if (dtype_object == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "ops.Tensor object doesn't have dtype() attr."); |
| } |
| |
| tensorflow::Safe_PyObjectPtr dtype_enum( |
| PyObject_GetAttrString(dtype_object.get(), "_type_enum")); |
| |
| if (dtype_enum == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "ops.Tensor's dtype object doesn't have _type_enum() attr."); |
| } |
| |
| tensorflow::DataType dtype = |
| static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get())); |
| |
| absl::StrAppend(&result->str, kDType, dtype); |
| |
| static char _shape_tuple[] = "_shape_tuple"; |
| tensorflow::Safe_PyObjectPtr shape_tuple( |
| PyObject_CallMethod(arg, _shape_tuple, nullptr)); |
| |
| if (shape_tuple == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "ops.Tensor object doesn't have _shape_tuple() method."); |
| } |
| |
| if (shape_tuple.get() == Py_None) { |
| // Unknown shape, encode that directly. |
| absl::StrAppend(&result->str, kNone); |
| return tensorflow::Status::OK(); |
| } |
| |
| absl::StrAppend(&result->str, kShape); |
| tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast( |
| shape_tuple.get(), "shape_tuple didn't return a sequence")); |
| |
| int len = PySequence_Fast_GET_SIZE(shape_seq.get()); |
| PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get()); |
| |
| if (include_tensor_ranks_only) { |
| absl::StrAppend(&result->str, len); |
| } else { |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = shape_seq_array[i]; |
| if (item == Py_None) { |
| absl::StrAppend(&result->str, kNone); |
| } else { |
| absl::StrAppend(&result->str, MakeInt(item)); |
| } |
| } |
| } |
| return tensorflow::Status::OK(); |
| } |
| |
| tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, |
| bool include_tensor_ranks_only, |
| EncodeResult* result); |
| |
| // This function doesn't set the type of sequence before |
| tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type, |
| const char* end_type, |
| bool include_tensor_ranks_only, |
| EncodeResult* result) { |
| tensorflow::Safe_PyObjectPtr arg_seq( |
| PySequence_Fast(arg, "unable to create seq from list/tuple")); |
| |
| absl::StrAppend(&result->str, type); |
| int len = PySequence_Fast_GET_SIZE(arg_seq.get()); |
| PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get()); |
| for (int i = 0; i < len; ++i) { |
| PyObject* item = arg_seq_array[i]; |
| if (item == Py_None) { |
| absl::StrAppend(&result->str, kNone); |
| } else { |
| TF_RETURN_IF_ERROR( |
| TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result)); |
| } |
| } |
| absl::StrAppend(&result->str, end_type); |
| |
| return tensorflow::Status::OK(); |
| } |
| |
| tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, |
| bool include_tensor_ranks_only, |
| EncodeResult* result) { |
| if (tensorflow::swig::IsTensor(arg)) { |
| absl::StrAppend(&result->str, kTensor); |
| TF_RETURN_IF_ERROR( |
| TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result)); |
| } else if (PyList_Check(arg)) { |
| TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence( |
| arg, kList, kListEnd, include_tensor_ranks_only, result)); |
| } else if (tensorflow::swig::IsTuple(arg)) { |
| TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence( |
| arg, kTuple, kTupleEnd, include_tensor_ranks_only, result)); |
| } else if (tensorflow::swig::IsMapping(arg)) { |
| tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg)); |
| if (PyList_Sort(keys.get()) == -1) { |
| return tensorflow::errors::Internal("Unable to sort keys"); |
| } |
| |
| absl::StrAppend(&result->str, kDict); |
| int len = PyList_Size(keys.get()); |
| |
| for (int i = 0; i < len; i++) { |
| PyObject* key = PyList_GetItem(keys.get(), i); |
| TF_RETURN_IF_ERROR( |
| TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result)); |
| tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key)); |
| TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper( |
| value.get(), include_tensor_ranks_only, result)); |
| } |
| } else if (tensorflow::swig::IsCompositeTensor(arg)) { |
| absl::StrAppend(&result->str, kCompositeTensor); |
| |
| // Add the typespec to the list of objects. (Do *not* use a weakref, |
| // since the type spec is often a temporary object.) |
| PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec")); |
| if (type_spec == nullptr) { |
| return tensorflow::errors::InvalidArgument( |
| "Error while reading CompositeTensor._type_spec."); |
| } |
| result->objects.push_back(type_spec); |
| } else if (tensorflow::swig::IsTypeSpec(arg)) { |
| // Add the typespec (not a weakref) in case it's a temporary object. |
| absl::StrAppend(&result->str, kRaw); |
| Py_INCREF(arg); |
| result->objects.push_back(arg); |
| } else if (tensorflow::swig::IsAttrs(arg)) { |
| absl::StrAppend(&result->str, kAttrs); |
| tensorflow::Safe_PyObjectPtr attrs( |
| PyObject_GetAttrString(arg, "__attrs_attrs__")); |
| tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get())); |
| for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item; |
| item.reset(PyIter_Next(iter.get()))) { |
| tensorflow::Safe_PyObjectPtr name( |
| PyObject_GetAttrString(item.get(), "name")); |
| tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get())); |
| TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper( |
| attr_arg.get(), include_tensor_ranks_only, result)); |
| } |
| absl::StrAppend(&result->str, kAttrsEnd); |
| } else { |
| PyObject* object = PyWeakref_NewRef(arg, nullptr); |
| |
| if (object == nullptr) { |
| PyErr_Clear(); |
| |
| object = arg; |
| Py_INCREF(object); |
| } |
| |
| absl::StrAppend(&result->str, kRaw); |
| result->objects.push_back(object); |
| } |
| |
| return tensorflow::Status::OK(); |
| } |
| |
| } // namespace |
| |
| // `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes |
| // are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes |
| // are used for both performance reasons, as much TensorFlow code specializes |
| // on known shapes to produce slimmer graphs, and correctness, as some |
| // high-level APIs require shapes to be fully-known. |
| // |
| // `include_tensor_ranks_only` allows caching on arguments excluding shape info, |
| // so that a slow path using relaxed shape can rely on a cache key that excludes |
| // shapes. |
| PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) { |
| EncodeResult result; |
| const auto status = |
| TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result); |
| if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
| return nullptr; |
| } |
| |
| return result.ToPyTuple(); |
| } |
| |
| // A method prints incoming messages directly to Python's |
| // stdout using Python's C API. This is necessary in Jupyter notebooks |
| // and colabs where messages to the C stdout don't go to the notebook |
| // cell outputs, but calls to Python's stdout do. |
| void PrintToPythonStdout(const char* msg) { |
| if (Py_IsInitialized()) { |
| PyGILState_STATE py_threadstate; |
| py_threadstate = PyGILState_Ensure(); |
| |
| string string_msg = msg; |
| // PySys_WriteStdout truncates strings over 1000 bytes, so |
| // we write the message in chunks small enough to not be truncated. |
| int CHUNK_SIZE = 900; |
| auto len = string_msg.length(); |
| for (int i = 0; i < len; i += CHUNK_SIZE) { |
| PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str()); |
| } |
| |
| // Force flushing to make sure print newlines aren't interleaved in |
| // some colab environments |
| PyRun_SimpleString("import sys; sys.stdout.flush()"); |
| |
| PyGILState_Release(py_threadstate); |
| } |
| } |
| |
| // Register PrintToPythonStdout as a log listener, to allow |
| // printing in colabs and jupyter notebooks to work. |
| void TFE_Py_EnableInteractivePythonLogging() { |
| static bool enabled_interactive_logging = false; |
| if (!enabled_interactive_logging) { |
| enabled_interactive_logging = true; |
| TF_RegisterLogListener(PrintToPythonStdout); |
| } |
| } |
| |
| namespace { |
| // weak reference to Python Context object currently active |
| PyObject* weak_eager_context = nullptr; |
| } // namespace |
| |
| PyObject* TFE_Py_SetEagerContext(PyObject* py_context) { |
| Py_XDECREF(weak_eager_context); |
| weak_eager_context = PyWeakref_NewRef(py_context, nullptr); |
| if (weak_eager_context == nullptr) { |
| return nullptr; |
| } |
| Py_RETURN_NONE; |
| } |
| |
| PyObject* GetPyEagerContext() { |
| if (weak_eager_context == nullptr) { |
| PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set"); |
| return nullptr; |
| } |
| PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context); |
| if (py_context == Py_None) { |
| PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed"); |
| return nullptr; |
| } |
| Py_INCREF(py_context); |
| return py_context; |
| } |