blob: 3100c1589b1c7ba4f521877763fab372fe9cc23c [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <atomic>
#include <cstring>
#include <unordered_map>
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/tape.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
#include "tensorflow/python/util/util.h"
using tensorflow::string;
using tensorflow::strings::Printf;
namespace {
// NOTE: Items are retrieved from and returned to these unique_ptrs, and they
// act as arenas. This is important if the same thread requests 2 items without
// releasing one.
// The following sequence of events on the same thread will still succeed:
// - GetOp <- Returns existing.
// - GetOp <- Allocates and returns a new pointer.
// - ReleaseOp <- Sets the item in the unique_ptr.
// - ReleaseOp <- Sets the item in the unique_ptr, deleting the old one.
// This occurs when a PyFunc kernel is run. This behavior makes it safe in that
// case, as well as the case where python decides to reuse the underlying
// C++ thread in 2 python threads case.
struct OpDeleter {
void operator()(TFE_Op* op) const { TFE_DeleteOp(op); }
};
thread_local std::unordered_map<TFE_Context*,
std::unique_ptr<TFE_Op, OpDeleter>>
thread_local_eager_operation_map; // NOLINT
thread_local std::unique_ptr<TF_Status> thread_local_tf_status = // NOLINT
nullptr;
std::unique_ptr<TFE_Op, OpDeleter> ReleaseThreadLocalOp(TFE_Context* ctx) {
auto it = thread_local_eager_operation_map.find(ctx);
if (it == thread_local_eager_operation_map.end()) {
return nullptr;
}
return std::move(it->second);
}
TFE_Op* GetOp(TFE_Context* ctx, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
auto op = ReleaseThreadLocalOp(ctx);
if (!op) {
op.reset(new TFE_Op{ctx->context->CreateOperation()});
}
status->status = op->operation->Reset(op_or_function_name, raw_device_name);
if (!status->status.ok()) {
op.reset();
}
return op.release();
}
void ReturnOp(TFE_Context* ctx, TFE_Op* op) {
if (op) {
op->operation->Clear();
thread_local_eager_operation_map[ctx].reset(op);
}
}
TF_Status* ReleaseThreadLocalStatus() {
if (thread_local_tf_status == nullptr) {
return nullptr;
}
return thread_local_tf_status.release();
}
struct InputInfo {
InputInfo(int i, bool is_list) : i(i), is_list(is_list) {}
int i;
bool is_list = false;
};
// Takes in output gradients, returns input gradients.
typedef std::function<PyObject*(PyObject*,
const std::vector<tensorflow::int64>&)>
PyBackwardFunction;
using AttrToInputsMap =
tensorflow::gtl::FlatMap<string,
tensorflow::gtl::InlinedVector<InputInfo, 4>>;
tensorflow::gtl::FlatMap<string, AttrToInputsMap*>* GetAllAttrToInputsMaps() {
static auto* all_attr_to_input_maps =
new tensorflow::gtl::FlatMap<string, AttrToInputsMap*>;
return all_attr_to_input_maps;
}
// This function doesn't use a lock, since we depend on the GIL directly.
AttrToInputsMap* GetAttrToInputsMapHoldingGIL(const tensorflow::OpDef& op_def) {
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
DCHECK(PyGILState_Check())
<< "This function needs to hold the GIL when called.";
#endif
auto* all_attr_to_input_maps = GetAllAttrToInputsMaps();
auto* output =
tensorflow::gtl::FindPtrOrNull(*all_attr_to_input_maps, op_def.name());
if (output != nullptr) {
return output;
}
std::unique_ptr<AttrToInputsMap> m(new AttrToInputsMap);
// Store a list of InputIndex -> List of corresponding inputs.
for (int i = 0; i < op_def.input_arg_size(); i++) {
if (!op_def.input_arg(i).type_attr().empty()) {
auto it = m->find(op_def.input_arg(i).type_attr());
if (it == m->end()) {
it = m->insert({op_def.input_arg(i).type_attr(), {}}).first;
}
it->second.emplace_back(i, !op_def.input_arg(i).number_attr().empty());
}
}
auto* retval = m.get();
(*all_attr_to_input_maps)[op_def.name()] = m.release();
return retval;
}
// This function doesn't use a lock, since we depend on the GIL directly.
tensorflow::gtl::FlatMap<
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>*
GetAllAttrToDefaultsMaps() {
static auto* all_attr_to_defaults_maps = new tensorflow::gtl::FlatMap<
string, tensorflow::gtl::FlatMap<string, tensorflow::DataType>*>;
return all_attr_to_defaults_maps;
}
tensorflow::gtl::FlatMap<string, tensorflow::DataType>*
GetAttrToDefaultsMapHoldingGIL(const tensorflow::OpDef& op_def) {
#if PY_MAJOR_VERSION >= 3 && PY_MINOR_VERSION >= 4
DCHECK(PyGILState_Check())
<< "This function needs to hold the GIL when called.";
#endif
auto* all_attr_to_defaults_maps = GetAllAttrToDefaultsMaps();
auto* output =
tensorflow::gtl::FindPtrOrNull(*all_attr_to_defaults_maps, op_def.name());
if (output != nullptr) {
return output;
}
auto* new_map = new tensorflow::gtl::FlatMap<string, tensorflow::DataType>;
for (const auto& attr : op_def.attr()) {
if (attr.type() == "type" && attr.has_default_value()) {
new_map->insert({attr.name(), attr.default_value().type()});
}
}
(*all_attr_to_defaults_maps)[op_def.name()] = new_map;
return new_map;
}
struct FastPathOpExecInfo {
TFE_Context* ctx;
const char* device_name;
bool run_callbacks;
bool run_post_exec_callbacks;
bool run_gradient_callback;
// The op name of the main op being executed.
PyObject* name;
// The op type name of the main op being executed.
PyObject* op_name;
PyObject* callbacks;
// All the args passed into the FastPathOpExecInfo.
PyObject* args;
// DTypes can come from another input that has the same attr. So build that
// map.
const AttrToInputsMap* attr_to_inputs_map;
const tensorflow::gtl::FlatMap<string, tensorflow::DataType>* default_dtypes;
tensorflow::gtl::FlatMap<string, tensorflow::DataType> cached_dtypes;
};
#define PARSE_VALUE(fn_name, type, check_fn, parse_fn) \
bool fn_name(const string& key, PyObject* py_value, TF_Status* status, \
type* value) { \
if (check_fn(py_value)) { \
*value = static_cast<type>(parse_fn(py_value)); \
return true; \
} else { \
TF_SetStatus(status, TF_INVALID_ARGUMENT, \
tensorflow::strings::StrCat( \
"Expecting " #type " value for attr ", key, ", got ", \
py_value->ob_type->tp_name) \
.c_str()); \
return false; \
} \
}
#if PY_MAJOR_VERSION >= 3
PARSE_VALUE(ParseIntValue, int, PyLong_Check, PyLong_AsLong)
PARSE_VALUE(ParseInt64Value, int64_t, PyLong_Check, PyLong_AsLong)
#else
PARSE_VALUE(ParseIntValue, int, PyInt_Check, PyInt_AsLong)
#endif
PARSE_VALUE(ParseFloatValue, float, PyFloat_Check, PyFloat_AsDouble)
#undef PARSE_VALUE
#if PY_MAJOR_VERSION < 3
bool ParseInt64Value(const string& key, PyObject* py_value, TF_Status* status,
int64_t* value) {
if (PyInt_Check(py_value)) {
*value = static_cast<int64_t>(PyInt_AsLong(py_value));
return true;
} else if (PyLong_Check(py_value)) {
*value = static_cast<int64_t>(PyLong_AsLong(py_value));
return true;
}
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting int or long value for attr ", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
#endif
Py_ssize_t TensorShapeNumDims(PyObject* value) {
const auto size = PySequence_Size(value);
if (size == -1) {
// TensorShape.__len__ raises an error in the scenario where the shape is an
// unknown, which needs to be cleared.
// TODO(nareshmodi): ensure that this is actually a TensorShape.
PyErr_Clear();
}
return size;
}
bool IsInteger(PyObject* py_value) {
#if PY_MAJOR_VERSION >= 3
return PyLong_Check(py_value);
#else
return PyInt_Check(py_value) || PyLong_Check(py_value);
#endif
}
// This function considers a Dimension._value of None to be valid, and sets the
// value to be -1 in that case.
bool ParseDimensionValue(const string& key, PyObject* py_value,
TF_Status* status, int64_t* value) {
if (IsInteger(py_value)) {
return ParseInt64Value(key, py_value, status, value);
}
tensorflow::Safe_PyObjectPtr dimension_value(
PyObject_GetAttrString(py_value, "_value"));
if (dimension_value == nullptr) {
PyErr_Clear();
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting a Dimension for attr ", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
if (dimension_value.get() == Py_None) {
*value = -1;
return true;
}
return ParseInt64Value(key, dimension_value.get(), status, value);
}
bool ParseStringValue(const string& key, PyObject* py_value, TF_Status* status,
tensorflow::StringPiece* value) {
if (PyBytes_Check(py_value)) {
Py_ssize_t size = 0;
char* buf = nullptr;
if (PyBytes_AsStringAndSize(py_value, &buf, &size) < 0) return false;
*value = tensorflow::StringPiece(buf, size);
return true;
}
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(py_value)) {
Py_ssize_t size = 0;
const char* buf = PyUnicode_AsUTF8AndSize(py_value, &size);
if (buf == nullptr) return false;
*value = tensorflow::StringPiece(buf, size);
return true;
}
#endif
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting a string value for attr ", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
bool ParseBoolValue(const string& key, PyObject* py_value, TF_Status* status,
unsigned char* value) {
*value = PyObject_IsTrue(py_value);
return true;
}
// The passed in py_value is expected to be an object of the python type
// dtypes.DType or an int.
bool ParseTypeValue(const string& key, PyObject* py_value, TF_Status* status,
int* value) {
if (IsInteger(py_value)) {
return ParseIntValue(key, py_value, status, value);
}
tensorflow::Safe_PyObjectPtr py_type_enum(
PyObject_GetAttrString(py_value, "_type_enum"));
if (py_type_enum == nullptr) {
PyErr_Clear();
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting a DType.dtype for attr ", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
return ParseIntValue(key, py_type_enum.get(), status, value);
}
bool SetOpAttrList(
TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_list,
TF_AttrType type,
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
if (!PySequence_Check(py_list)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat("Expecting sequence value for attr ", key,
", got ", py_list->ob_type->tp_name)
.c_str());
return false;
}
const int num_values = PySequence_Size(py_list);
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = num_values;
#define PARSE_LIST(c_type, parse_fn) \
std::unique_ptr<c_type[]> values(new c_type[num_values]); \
for (int i = 0; i < num_values; ++i) { \
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i)); \
if (!parse_fn(key, py_value.get(), status, &values[i])) return false; \
}
if (type == TF_ATTR_STRING) {
std::unique_ptr<const void*[]> values(new const void*[num_values]);
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
for (int i = 0; i < num_values; ++i) {
tensorflow::StringPiece value;
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
if (!ParseStringValue(key, py_value.get(), status, &value)) return false;
values[i] = value.data();
lengths[i] = value.size();
}
TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
} else if (type == TF_ATTR_INT) {
PARSE_LIST(int64_t, ParseInt64Value);
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_FLOAT) {
PARSE_LIST(float, ParseFloatValue);
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_BOOL) {
PARSE_LIST(unsigned char, ParseBoolValue);
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_TYPE) {
PARSE_LIST(int, ParseTypeValue);
TFE_OpSetAttrTypeList(op, key,
reinterpret_cast<const TF_DataType*>(values.get()),
num_values);
} else if (type == TF_ATTR_SHAPE) {
// Make one pass through the input counting the total number of
// dims across all the input lists.
int total_dims = 0;
for (int i = 0; i < num_values; ++i) {
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
if (py_value.get() != Py_None) {
if (!PySequence_Check(py_value.get())) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Expecting None or sequence value for element", i,
" of attr ", key, ", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
const auto size = TensorShapeNumDims(py_value.get());
if (size >= 0) {
total_dims += size;
}
}
}
// Allocate a buffer that can fit all of the dims together.
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
// Copy the input dims into the buffer and set dims to point to
// the start of each list's dims.
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
std::unique_ptr<int[]> num_dims(new int[num_values]);
int64_t* offset = buffer.get();
for (int i = 0; i < num_values; ++i) {
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
if (py_value.get() == Py_None) {
dims[i] = nullptr;
num_dims[i] = -1;
} else {
const auto size = TensorShapeNumDims(py_value.get());
if (size == -1) {
dims[i] = nullptr;
num_dims[i] = -1;
continue;
}
dims[i] = offset;
num_dims[i] = size;
for (int j = 0; j < size; ++j) {
tensorflow::Safe_PyObjectPtr inner_py_value(
PySequence_ITEM(py_value.get(), j));
if (inner_py_value.get() == Py_None) {
*offset = -1;
} else if (!ParseDimensionValue(key, inner_py_value.get(), status,
offset)) {
return false;
}
++offset;
}
}
}
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
status);
if (!status->status.ok()) return false;
} else if (type == TF_ATTR_FUNC) {
std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
for (int i = 0; i < num_values; ++i) {
tensorflow::Safe_PyObjectPtr py_value(PySequence_ITEM(py_list, i));
// Allow:
// (1) String function name, OR
// (2) A Python object with a .name attribute
// (A crude test for being a
// tensorflow.python.framework.function._DefinedFunction)
// (which is what the various "defun" or "Defun" decorators do).
// And in the future also allow an object that can encapsulate
// the function name and its attribute values.
tensorflow::StringPiece func_name;
if (!ParseStringValue(key, py_value.get(), status, &func_name)) {
PyObject* name_attr = PyObject_GetAttrString(py_value.get(), "name");
if (name_attr == nullptr ||
!ParseStringValue(key, name_attr, status, &func_name)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"unable to set function value attribute from a ",
py_value.get()->ob_type->tp_name,
" object. If you think this is an error, please file an "
"issue at "
"https://github.com/tensorflow/tensorflow/issues/new")
.c_str());
return false;
}
}
funcs[i] = TFE_NewOp(ctx, func_name.data(), status);
if (!status->status.ok()) return false;
}
TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
if (!status->status.ok()) return false;
} else {
TF_SetStatus(status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Attr ", key,
" has unhandled list type ", type)
.c_str());
return false;
}
#undef PARSE_LIST
return true;
}
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
for (const auto& attr : func.attr()) {
if (!status->status.ok()) return nullptr;
SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
if (!status->status.ok()) return nullptr;
}
return func_op;
}
void SetOpAttrListDefault(
TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
const char* key, TF_AttrType type,
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
if (type == TF_ATTR_STRING) {
int num_values = attr.default_value().list().s_size();
std::unique_ptr<const void*[]> values(new const void*[num_values]);
std::unique_ptr<size_t[]> lengths(new size_t[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
const string& v = attr.default_value().list().s(i);
values[i] = v.data();
lengths[i] = v.size();
}
TFE_OpSetAttrStringList(op, key, values.get(), lengths.get(), num_values);
} else if (type == TF_ATTR_INT) {
int num_values = attr.default_value().list().i_size();
std::unique_ptr<int64_t[]> values(new int64_t[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
values[i] = attr.default_value().list().i(i);
}
TFE_OpSetAttrIntList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_FLOAT) {
int num_values = attr.default_value().list().f_size();
std::unique_ptr<float[]> values(new float[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
values[i] = attr.default_value().list().f(i);
}
TFE_OpSetAttrFloatList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_BOOL) {
int num_values = attr.default_value().list().b_size();
std::unique_ptr<unsigned char[]> values(new unsigned char[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
values[i] = attr.default_value().list().b(i);
}
TFE_OpSetAttrBoolList(op, key, values.get(), num_values);
} else if (type == TF_ATTR_TYPE) {
int num_values = attr.default_value().list().type_size();
std::unique_ptr<int[]> values(new int[num_values]);
(*attr_list_sizes)[key] = num_values;
for (int i = 0; i < num_values; i++) {
values[i] = attr.default_value().list().type(i);
}
TFE_OpSetAttrTypeList(op, key,
reinterpret_cast<const TF_DataType*>(values.get()),
attr.default_value().list().type_size());
} else if (type == TF_ATTR_SHAPE) {
int num_values = attr.default_value().list().shape_size();
(*attr_list_sizes)[key] = num_values;
int total_dims = 0;
for (int i = 0; i < num_values; ++i) {
if (!attr.default_value().list().shape(i).unknown_rank()) {
total_dims += attr.default_value().list().shape(i).dim_size();
}
}
// Allocate a buffer that can fit all of the dims together.
std::unique_ptr<int64_t[]> buffer(new int64_t[total_dims]);
// Copy the input dims into the buffer and set dims to point to
// the start of each list's dims.
std::unique_ptr<const int64_t*[]> dims(new const int64_t*[num_values]);
std::unique_ptr<int[]> num_dims(new int[num_values]);
int64_t* offset = buffer.get();
for (int i = 0; i < num_values; ++i) {
const auto& shape = attr.default_value().list().shape(i);
if (shape.unknown_rank()) {
dims[i] = nullptr;
num_dims[i] = -1;
} else {
for (int j = 0; j < shape.dim_size(); j++) {
*offset = shape.dim(j).size();
++offset;
}
}
}
TFE_OpSetAttrShapeList(op, key, dims.get(), num_dims.get(), num_values,
status);
} else if (type == TF_ATTR_FUNC) {
int num_values = attr.default_value().list().func_size();
(*attr_list_sizes)[key] = num_values;
std::unique_ptr<const TFE_Op*[]> funcs(new const TFE_Op*[num_values]);
for (int i = 0; i < num_values; i++) {
funcs[i] = GetFunc(ctx, attr.default_value().list().func(i), status);
}
TFE_OpSetAttrFunctionList(op, key, funcs.get(), num_values);
} else {
TF_SetStatus(status, TF_UNIMPLEMENTED,
"Lists of tensors are not yet implemented for default valued "
"attributes for an operation.");
}
}
bool SetOpAttrScalar(
TFE_Context* ctx, TFE_Op* op, const char* key, PyObject* py_value,
TF_AttrType type,
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
if (type == TF_ATTR_STRING) {
tensorflow::StringPiece value;
if (!ParseStringValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrString(op, key, value.data(), value.size());
} else if (type == TF_ATTR_INT) {
int64_t value;
if (!ParseInt64Value(key, py_value, status, &value)) return false;
TFE_OpSetAttrInt(op, key, value);
// attr_list_sizes is set for all int attributes (since at this point we are
// not aware if that attribute might be used to calculate the size of an
// output list or not).
if (attr_list_sizes != nullptr) (*attr_list_sizes)[key] = value;
} else if (type == TF_ATTR_FLOAT) {
float value;
if (!ParseFloatValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrFloat(op, key, value);
} else if (type == TF_ATTR_BOOL) {
unsigned char value;
if (!ParseBoolValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrBool(op, key, value);
} else if (type == TF_ATTR_TYPE) {
int value;
if (!ParseTypeValue(key, py_value, status, &value)) return false;
TFE_OpSetAttrType(op, key, static_cast<TF_DataType>(value));
} else if (type == TF_ATTR_SHAPE) {
if (py_value == Py_None) {
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
} else {
if (!PySequence_Check(py_value)) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"Expecting None or sequence value for attr", key,
", got ", py_value->ob_type->tp_name)
.c_str());
return false;
}
const auto num_dims = TensorShapeNumDims(py_value);
if (num_dims == -1) {
TFE_OpSetAttrShape(op, key, nullptr, -1, status);
return true;
}
std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
for (int i = 0; i < num_dims; ++i) {
tensorflow::Safe_PyObjectPtr inner_py_value(
PySequence_ITEM(py_value, i));
if (inner_py_value.get() == Py_None) {
dims[i] = -1;
} else if (!ParseDimensionValue(key, inner_py_value.get(), status,
&dims[i])) {
return false;
}
}
TFE_OpSetAttrShape(op, key, dims.get(), num_dims, status);
}
if (!status->status.ok()) return false;
} else if (type == TF_ATTR_FUNC) {
// Allow:
// (1) String function name, OR
// (2) A Python object with a .name attribute
// (A crude test for being a
// tensorflow.python.framework.function._DefinedFunction)
// (which is what the various "defun" or "Defun" decorators do).
// And in the future also allow an object that can encapsulate
// the function name and its attribute values.
tensorflow::StringPiece func_name;
if (!ParseStringValue(key, py_value, status, &func_name)) {
PyObject* name_attr = PyObject_GetAttrString(py_value, "name");
if (name_attr == nullptr ||
!ParseStringValue(key, name_attr, status, &func_name)) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
tensorflow::strings::StrCat(
"unable to set function value attribute from a ",
py_value->ob_type->tp_name,
" object. If you think this is an error, please file an issue "
"at https://github.com/tensorflow/tensorflow/issues/new")
.c_str());
return false;
}
}
TF_SetStatus(status, TF_OK, "");
TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
} else {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
tensorflow::strings::StrCat("Attr ", key, " has unhandled type ", type)
.c_str());
return false;
}
return true;
}
void SetOpAttrScalarDefault(
TFE_Context* ctx, TFE_Op* op, const tensorflow::AttrValue& default_value,
const char* attr_name,
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
SetOpAttrValueScalar(ctx, op, default_value, attr_name, status);
if (default_value.value_case() == tensorflow::AttrValue::kI) {
(*attr_list_sizes)[attr_name] = default_value.i();
}
}
// start_index is the index at which the Tuple/List attrs will start getting
// processed.
void SetOpAttrs(TFE_Context* ctx, TFE_Op* op, PyObject* attrs, int start_index,
TF_Status* out_status) {
if (attrs == Py_None) return;
Py_ssize_t len = PyTuple_GET_SIZE(attrs) - start_index;
if ((len & 1) != 0) {
TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
"Expecting attrs tuple to have even length.");
return;
}
// Parse attrs
for (Py_ssize_t i = 0; i < len; i += 2) {
PyObject* py_key = PyTuple_GET_ITEM(attrs, start_index + i);
PyObject* py_value = PyTuple_GET_ITEM(attrs, start_index + i + 1);
#if PY_MAJOR_VERSION >= 3
const char* key = PyBytes_Check(py_key) ? PyBytes_AsString(py_key)
: PyUnicode_AsUTF8(py_key);
#else
const char* key = PyBytes_AsString(py_key);
#endif
unsigned char is_list = 0;
const TF_AttrType type = TFE_OpGetAttrType(op, key, &is_list, out_status);
if (!out_status->status.ok()) return;
if (is_list != 0) {
if (!SetOpAttrList(ctx, op, key, py_value, type, nullptr, out_status))
return;
} else {
if (!SetOpAttrScalar(ctx, op, key, py_value, type, nullptr, out_status))
return;
}
}
}
// This function will set the op attrs required. If an attr has the value of
// None, then it will read the AttrDef to get the default value and set that
// instead. Any failure in this function will simply fall back to the slow
// path.
void SetOpAttrWithDefaults(
TFE_Context* ctx, TFE_Op* op, const tensorflow::OpDef::AttrDef& attr,
const char* attr_name, PyObject* attr_value,
tensorflow::gtl::FlatMap<string, tensorflow::int64>* attr_list_sizes,
TF_Status* status) {
unsigned char is_list = 0;
const TF_AttrType type = TFE_OpGetAttrType(op, attr_name, &is_list, status);
if (!status->status.ok()) return;
if (attr_value == Py_None) {
if (is_list != 0) {
SetOpAttrListDefault(ctx, op, attr, attr_name, type, attr_list_sizes,
status);
} else {
SetOpAttrScalarDefault(ctx, op, attr.default_value(), attr_name,
attr_list_sizes, status);
}
} else {
if (is_list != 0) {
SetOpAttrList(ctx, op, attr_name, attr_value, type, attr_list_sizes,
status);
} else {
SetOpAttrScalar(ctx, op, attr_name, attr_value, type, attr_list_sizes,
status);
}
}
}
PyObject* GetPythonObjectFromInt(int num) {
#if PY_MAJOR_VERSION >= 3
return PyLong_FromLong(num);
#else
return PyInt_FromLong(num);
#endif
}
// Python subclass of Exception that is created on not ok Status.
tensorflow::mutex exception_class_mutex(tensorflow::LINKER_INITIALIZED);
PyObject* exception_class TF_GUARDED_BY(exception_class_mutex) = nullptr;
// Python subclass of Exception that is created to signal fallback.
PyObject* fallback_exception_class = nullptr;
// Python function that returns input gradients given output gradients.
PyObject* gradient_function = nullptr;
// Python function that returns output gradients given input gradients.
PyObject* forward_gradient_function = nullptr;
static std::atomic<int64_t> _uid;
} // namespace
TF_Status* GetStatus() {
TF_Status* maybe_status = ReleaseThreadLocalStatus();
if (maybe_status) {
TF_SetStatus(maybe_status, TF_OK, "");
return maybe_status;
} else {
return TF_NewStatus();
}
}
void ReturnStatus(TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
thread_local_tf_status.reset(status);
}
void TFE_Py_Execute(TFE_Context* ctx, const char* device_name,
const char* op_name, TFE_InputTensorHandles* inputs,
PyObject* attrs, TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Py_ExecuteCancelable(ctx, device_name, op_name, inputs, attrs,
/*cancellation_manager=*/nullptr, outputs,
out_status);
}
void TFE_Py_ExecuteCancelable(TFE_Context* ctx, const char* device_name,
const char* op_name,
TFE_InputTensorHandles* inputs, PyObject* attrs,
TFE_CancellationManager* cancellation_manager,
TFE_OutputTensorHandles* outputs,
TF_Status* out_status) {
TFE_Op* op = GetOp(ctx, op_name, device_name, out_status);
auto cleaner = tensorflow::gtl::MakeCleanup([ctx, op] { ReturnOp(ctx, op); });
if (!out_status->status.ok()) return;
for (int i = 0; i < inputs->size() && out_status->status.ok(); ++i) {
TFE_OpAddInput(op, inputs->at(i), out_status);
}
if (cancellation_manager && out_status->status.ok()) {
TFE_OpSetCancellationManager(op, cancellation_manager, out_status);
}
if (out_status->status.ok()) {
SetOpAttrs(ctx, op, attrs, 0, out_status);
}
Py_BEGIN_ALLOW_THREADS;
if (out_status->status.ok()) {
int num_outputs = outputs->size();
TFE_Execute(op, outputs->data(), &num_outputs, out_status);
outputs->resize(num_outputs);
}
if (!out_status->status.ok()) {
TF_SetStatus(out_status, TF_GetCode(out_status),
tensorflow::strings::StrCat(TF_Message(out_status),
" [Op:", op_name, "]")
.c_str());
}
Py_END_ALLOW_THREADS;
}
PyObject* TFE_Py_RegisterExceptionClass(PyObject* e) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
Py_DECREF(exception_class);
}
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
exception_class = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterExceptionClass: "
"Registered class should be subclass of Exception.");
return nullptr;
}
Py_INCREF(e);
exception_class = e;
Py_RETURN_NONE;
}
PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) {
if (fallback_exception_class != nullptr) {
Py_DECREF(fallback_exception_class);
}
if (PyObject_IsSubclass(e, PyExc_Exception) <= 0) {
fallback_exception_class = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterFallbackExceptionClass: "
"Registered class should be subclass of Exception.");
return nullptr;
} else {
Py_INCREF(e);
fallback_exception_class = e;
Py_RETURN_NONE;
}
}
PyObject* TFE_Py_RegisterGradientFunction(PyObject* e) {
if (gradient_function != nullptr) {
Py_DECREF(gradient_function);
}
if (!PyCallable_Check(e)) {
gradient_function = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterGradientFunction: "
"Registered object should be function.");
return nullptr;
} else {
Py_INCREF(e);
gradient_function = e;
Py_RETURN_NONE;
}
}
PyObject* TFE_Py_RegisterJVPFunction(PyObject* e) {
if (forward_gradient_function != nullptr) {
Py_DECREF(forward_gradient_function);
}
if (!PyCallable_Check(e)) {
forward_gradient_function = nullptr;
PyErr_SetString(PyExc_TypeError,
"TFE_Py_RegisterJVPFunction: "
"Registered object should be function.");
return nullptr;
} else {
Py_INCREF(e);
forward_gradient_function = e;
Py_RETURN_NONE;
}
}
void RaiseFallbackException(const char* message) {
if (fallback_exception_class != nullptr) {
PyErr_SetString(fallback_exception_class, message);
return;
}
PyErr_SetString(
PyExc_RuntimeError,
tensorflow::strings::StrCat(
"Fallback exception type not set, attempting to fallback due to ",
message)
.data());
}
int MaybeRaiseExceptionFromTFStatus(TF_Status* status, PyObject* exception) {
if (status->status.ok()) return 0;
const char* msg = TF_Message(status);
if (exception == nullptr) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
tensorflow::Safe_PyObjectPtr val(
Py_BuildValue("si", msg, TF_GetCode(status)));
if (PyErr_Occurred()) {
// NOTE: This hides the actual error (i.e. the reason `status` was not
// TF_OK), but there is nothing we can do at this point since we can't
// generate a reasonable error from the status.
// Consider adding a message explaining this.
return -1;
}
PyErr_SetObject(exception_class, val.get());
return -1;
} else {
exception = PyExc_RuntimeError;
}
}
// May be update already set exception.
PyErr_SetString(exception, msg);
return -1;
}
int MaybeRaiseExceptionFromStatus(const tensorflow::Status& status,
PyObject* exception) {
if (status.ok()) return 0;
const char* msg = status.error_message().c_str();
if (exception == nullptr) {
tensorflow::mutex_lock l(exception_class_mutex);
if (exception_class != nullptr) {
tensorflow::Safe_PyObjectPtr val(Py_BuildValue("si", msg, status.code()));
PyErr_SetObject(exception_class, val.get());
return -1;
} else {
exception = PyExc_RuntimeError;
}
}
// May be update already set exception.
PyErr_SetString(exception, msg);
return -1;
}
const char* TFE_GetPythonString(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
if (PyBytes_Check(o)) {
return PyBytes_AsString(o);
} else {
return PyUnicode_AsUTF8(o);
}
#else
return PyBytes_AsString(o);
#endif
}
int64_t get_uid() { return _uid++; }
PyObject* TFE_Py_UID() { return PyLong_FromLongLong(get_uid()); }
void TFE_DeleteContextCapsule(PyObject* context) {
TFE_Context* ctx =
reinterpret_cast<TFE_Context*>(PyCapsule_GetPointer(context, nullptr));
auto op = ReleaseThreadLocalOp(ctx);
op.reset();
TFE_DeleteContext(ctx);
}
static tensorflow::int64 MakeInt(PyObject* integer) {
#if PY_MAJOR_VERSION >= 3
return PyLong_AsLong(integer);
#else
return PyInt_AsLong(integer);
#endif
}
static tensorflow::int64 FastTensorId(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
return PyEagerTensor_ID(tensor);
}
PyObject* id_field = PyObject_GetAttrString(tensor, "_id");
if (id_field == nullptr) {
return -1;
}
tensorflow::int64 id = MakeInt(id_field);
Py_DECREF(id_field);
return id;
}
static tensorflow::DataType FastTensorDtype(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
return PyEagerTensor_Dtype(tensor);
}
PyObject* dtype_field = PyObject_GetAttrString(tensor, "dtype");
if (dtype_field == nullptr) {
return tensorflow::DT_INVALID;
}
PyObject* enum_field = PyObject_GetAttrString(dtype_field, "_type_enum");
Py_DECREF(dtype_field);
if (dtype_field == nullptr) {
return tensorflow::DT_INVALID;
}
tensorflow::int64 id = MakeInt(enum_field);
Py_DECREF(enum_field);
return static_cast<tensorflow::DataType>(id);
}
class PyTapeTensor {
public:
PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
const tensorflow::TensorShape& shape)
: id_(id), dtype_(dtype), shape_(shape) {}
PyTapeTensor(tensorflow::int64 id, tensorflow::DataType dtype,
PyObject* shape)
: id_(id), dtype_(dtype), shape_(shape) {
Py_INCREF(absl::get<1>(shape_));
}
PyTapeTensor(const PyTapeTensor& other) {
id_ = other.id_;
dtype_ = other.dtype_;
shape_ = other.shape_;
if (shape_.index() == 1) {
Py_INCREF(absl::get<1>(shape_));
}
}
~PyTapeTensor() {
if (shape_.index() == 1) {
Py_DECREF(absl::get<1>(shape_));
}
}
PyObject* GetShape() const;
PyObject* GetPyDType() const { return PyLong_FromLong(dtype_); }
tensorflow::int64 GetID() const { return id_; }
tensorflow::DataType GetDType() const { return dtype_; }
PyObject* OnesLike() const;
PyObject* ZerosLike() const;
private:
tensorflow::int64 id_;
tensorflow::DataType dtype_;
// Note that if shape_.index() == 1, meaning shape_ contains a PyObject, that
// PyObject is the tensor itself. This is used to support tf.shape(tensor) for
// partially-defined shapes and tf.zeros_like(tensor) for variant-dtype
// tensors.
absl::variant<tensorflow::TensorShape, PyObject*> shape_;
};
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor);
class PyVSpace : public tensorflow::eager::VSpace<PyObject, PyBackwardFunction,
PyTapeTensor> {
public:
explicit PyVSpace(PyObject* py_vspace) : py_vspace_(py_vspace) {
Py_INCREF(py_vspace_);
}
tensorflow::Status Initialize() {
num_elements_ = PyObject_GetAttrString(py_vspace_, "num_elements_fn");
if (num_elements_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
aggregate_fn_ = PyObject_GetAttrString(py_vspace_, "aggregate_fn");
if (aggregate_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
zeros_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_fn");
if (zeros_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
zeros_like_fn_ = PyObject_GetAttrString(py_vspace_, "zeros_like_fn");
if (zeros_like_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
ones_fn_ = PyObject_GetAttrString(py_vspace_, "ones_fn");
if (ones_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
ones_like_fn_ = PyObject_GetAttrString(py_vspace_, "ones_like_fn");
if (ones_like_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
graph_shape_fn_ = PyObject_GetAttrString(py_vspace_, "graph_shape_fn");
if (graph_shape_fn_ == nullptr) {
return tensorflow::errors::InvalidArgument("invalid vspace");
}
return tensorflow::Status::OK();
}
~PyVSpace() override {
Py_XDECREF(num_elements_);
Py_XDECREF(aggregate_fn_);
Py_XDECREF(zeros_fn_);
Py_XDECREF(zeros_like_fn_);
Py_XDECREF(ones_fn_);
Py_XDECREF(ones_like_fn_);
Py_XDECREF(graph_shape_fn_);
Py_DECREF(py_vspace_);
}
tensorflow::int64 NumElements(PyObject* tensor) const final {
if (EagerTensor_CheckExact(tensor)) {
return PyEagerTensor_NumElements(tensor);
}
PyObject* arglist =
Py_BuildValue("(O)", reinterpret_cast<PyObject*>(tensor));
PyObject* result = PyEval_CallObject(num_elements_, arglist);
Py_DECREF(arglist);
if (result == nullptr) {
// The caller detects whether a python exception has been raised.
return -1;
}
tensorflow::int64 r = MakeInt(result);
Py_DECREF(result);
return r;
}
PyObject* AggregateGradients(
tensorflow::gtl::ArraySlice<PyObject*> gradient_tensors) const final {
PyObject* list = PyList_New(gradient_tensors.size());
for (int i = 0; i < gradient_tensors.size(); ++i) {
// Note: stealing a reference to the gradient tensors.
CHECK(gradient_tensors[i] != nullptr);
CHECK(gradient_tensors[i] != Py_None);
PyList_SET_ITEM(list, i,
reinterpret_cast<PyObject*>(gradient_tensors[i]));
}
PyObject* arglist = Py_BuildValue("(O)", list);
CHECK(arglist != nullptr);
PyObject* result = PyEval_CallObject(aggregate_fn_, arglist);
Py_DECREF(arglist);
Py_DECREF(list);
return result;
}
tensorflow::int64 TensorId(PyObject* tensor) const final {
return FastTensorId(tensor);
}
void MarkAsResult(PyObject* gradient) const final { Py_INCREF(gradient); }
PyObject* Ones(PyObject* shape, PyObject* dtype) const {
if (PyErr_Occurred()) {
return nullptr;
}
PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
PyObject* result = PyEval_CallObject(ones_fn_, arg_list);
Py_DECREF(arg_list);
return result;
}
PyObject* OnesLike(PyObject* tensor) const {
if (PyErr_Occurred()) {
return nullptr;
}
return PyObject_CallFunctionObjArgs(ones_like_fn_, tensor, NULL);
}
PyObject* Zeros(PyObject* shape, PyObject* dtype) const {
if (PyErr_Occurred()) {
return nullptr;
}
PyObject* arg_list = Py_BuildValue("OO", shape, dtype);
PyObject* result = PyEval_CallObject(zeros_fn_, arg_list);
Py_DECREF(arg_list);
return result;
}
PyObject* ZerosLike(PyObject* tensor) const {
if (PyErr_Occurred()) {
return nullptr;
}
return PyObject_CallFunctionObjArgs(zeros_like_fn_, tensor, NULL);
}
PyObject* GraphShape(PyObject* tensor) const {
PyObject* arg_list = Py_BuildValue("(O)", tensor);
PyObject* result = PyEval_CallObject(graph_shape_fn_, arg_list);
Py_DECREF(arg_list);
return result;
}
tensorflow::Status CallBackwardFunction(
PyBackwardFunction* backward_function,
const std::vector<tensorflow::int64>& unneeded_gradients,
tensorflow::gtl::ArraySlice<PyObject*> output_gradients,
std::vector<PyObject*>* result) const final {
PyObject* grads = PyTuple_New(output_gradients.size());
for (int i = 0; i < output_gradients.size(); ++i) {
if (output_gradients[i] == nullptr) {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(grads, i, Py_None);
} else {
PyTuple_SET_ITEM(grads, i,
reinterpret_cast<PyObject*>(output_gradients[i]));
}
}
PyObject* py_result = (*backward_function)(grads, unneeded_gradients);
Py_DECREF(grads);
if (py_result == nullptr) {
return tensorflow::errors::Internal("gradient function threw exceptions");
}
result->clear();
PyObject* seq =
PySequence_Fast(py_result, "expected a sequence of gradients");
if (seq == nullptr) {
return tensorflow::errors::InvalidArgument(
"gradient function did not return a list");
}
int len = PySequence_Fast_GET_SIZE(seq);
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
VLOG(1) << "Gradient length is " << len;
result->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = seq_array[i];
if (item == Py_None) {
result->push_back(nullptr);
} else {
Py_INCREF(item);
result->push_back(item);
}
}
Py_DECREF(seq);
Py_DECREF(py_result);
return tensorflow::Status::OK();
}
void DeleteGradient(PyObject* tensor) const final { Py_XDECREF(tensor); }
PyTapeTensor TapeTensorFromGradient(PyObject* tensor) const final {
return TapeTensorFromTensor(tensor);
}
private:
PyObject* py_vspace_;
PyObject* num_elements_;
PyObject* aggregate_fn_;
PyObject* zeros_fn_;
PyObject* zeros_like_fn_;
PyObject* ones_fn_;
PyObject* ones_like_fn_;
PyObject* graph_shape_fn_;
};
PyVSpace* py_vspace = nullptr;
bool HasAccumulator();
PyObject* TFE_Py_RegisterVSpace(PyObject* e) {
if (py_vspace != nullptr) {
if (HasAccumulator()) {
// Accumulators reference py_vspace, so we can't swap it out while one is
// active. This is unlikely to ever happen.
MaybeRaiseExceptionFromStatus(
tensorflow::errors::Internal(
"Can't change the vspace implementation while a "
"forward accumulator is active."),
nullptr);
}
delete py_vspace;
}
py_vspace = new PyVSpace(e);
auto status = py_vspace->Initialize();
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
delete py_vspace;
return nullptr;
}
Py_RETURN_NONE;
}
PyObject* PyTapeTensor::GetShape() const {
if (shape_.index() == 0) {
auto& shape = absl::get<0>(shape_);
PyObject* py_shape = PyTuple_New(shape.dims());
for (int i = 0; i < shape.dims(); ++i) {
PyTuple_SET_ITEM(py_shape, i, PyLong_FromLong(shape.dim_size(i)));
}
return py_shape;
}
return py_vspace->GraphShape(absl::get<1>(shape_));
}
PyObject* PyTapeTensor::OnesLike() const {
if (shape_.index() == 1) {
PyObject* tensor = absl::get<1>(shape_);
return py_vspace->OnesLike(tensor);
}
PyObject* py_shape = GetShape();
PyObject* py_dtype = GetPyDType();
PyObject* result = py_vspace->Ones(py_shape, py_dtype);
Py_DECREF(py_dtype);
Py_DECREF(py_shape);
return result;
}
PyObject* PyTapeTensor::ZerosLike() const {
if (shape_.index() == 1) {
PyObject* tensor = absl::get<1>(shape_);
return py_vspace->ZerosLike(tensor);
}
PyObject* py_shape = GetShape();
PyObject* py_dtype = GetPyDType();
PyObject* result = py_vspace->Zeros(py_shape, py_dtype);
Py_DECREF(py_dtype);
Py_DECREF(py_shape);
return result;
}
// Keeps track of all variables that have been accessed during execution.
class VariableWatcher {
public:
VariableWatcher() {}
~VariableWatcher() {
for (const IdAndVariable& v : watched_variables_) {
Py_DECREF(v.variable);
}
}
tensorflow::int64 WatchVariable(PyObject* v) {
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(v, "handle"));
if (handle == nullptr) {
return -1;
}
tensorflow::int64 id = FastTensorId(handle.get());
tensorflow::mutex_lock l(watched_variables_mu_);
auto insert_result = watched_variables_.emplace(id, v);
if (insert_result.second) {
// Only increment the reference count if we aren't already watching this
// variable.
Py_INCREF(v);
}
return id;
}
PyObject* GetVariablesAsPyTuple() {
tensorflow::mutex_lock l(watched_variables_mu_);
PyObject* result = PyTuple_New(watched_variables_.size());
Py_ssize_t pos = 0;
for (const IdAndVariable& id_and_variable : watched_variables_) {
PyTuple_SET_ITEM(result, pos++, id_and_variable.variable);
Py_INCREF(id_and_variable.variable);
}
return result;
}
private:
// We store an IdAndVariable in the map since the map needs to be locked
// during insert, but should not call back into python during insert to avoid
// deadlocking with the GIL.
struct IdAndVariable {
tensorflow::int64 id;
PyObject* variable;
IdAndVariable(tensorflow::int64 id, PyObject* variable)
: id(id), variable(variable) {}
};
struct CompareById {
bool operator()(const IdAndVariable& lhs, const IdAndVariable& rhs) const {
return lhs.id < rhs.id;
}
};
tensorflow::mutex watched_variables_mu_;
std::set<IdAndVariable, CompareById> watched_variables_
TF_GUARDED_BY(watched_variables_mu_);
};
class GradientTape
: public tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor> {
public:
explicit GradientTape(bool persistent, bool watch_accessed_variables)
: tensorflow::eager::GradientTape<PyObject, PyBackwardFunction,
PyTapeTensor>(persistent),
watch_accessed_variables_(watch_accessed_variables) {}
virtual ~GradientTape() {}
void VariableAccessed(PyObject* v) {
if (watch_accessed_variables_) {
WatchVariable(v);
}
}
void WatchVariable(PyObject* v) {
tensorflow::int64 id = variable_watcher_.WatchVariable(v);
if (!PyErr_Occurred()) {
this->Watch(id);
}
}
PyObject* GetVariablesAsPyTuple() {
return variable_watcher_.GetVariablesAsPyTuple();
}
private:
bool watch_accessed_variables_;
VariableWatcher variable_watcher_;
};
typedef tensorflow::eager::ForwardAccumulator<PyObject, PyBackwardFunction,
PyTapeTensor>
ForwardAccumulator;
// Incremented when a GradientTape or accumulator is newly added to a set, and
// used to enforce an ordering between them.
std::atomic_uint_fast64_t tape_nesting_id_counter(0);
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
GradientTape* tape;
// A nesting order between GradientTapes and ForwardAccumulators, used to
// ensure that GradientTapes do not watch the products of outer
// ForwardAccumulators.
tensorflow::int64 nesting_id;
} TFE_Py_Tape;
static void TFE_Py_Tape_Delete(PyObject* tape) {
delete reinterpret_cast<TFE_Py_Tape*>(tape)->tape;
Py_TYPE(tape)->tp_free(tape);
}
static PyTypeObject TFE_Py_Tape_Type = {
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.Tape", /* tp_name */
sizeof(TFE_Py_Tape), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_Tape_Delete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"TFE_Py_Tape objects", /* tp_doc */
};
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
ForwardAccumulator* accumulator;
// A nesting order between GradientTapes and ForwardAccumulators, used to
// ensure that GradientTapes do not watch the products of outer
// ForwardAccumulators.
tensorflow::int64 nesting_id;
} TFE_Py_ForwardAccumulator;
static void TFE_Py_ForwardAccumulatorDelete(PyObject* accumulator) {
delete reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)->accumulator;
Py_TYPE(accumulator)->tp_free(accumulator);
}
static PyTypeObject TFE_Py_ForwardAccumulator_Type = {
PyVarObject_HEAD_INIT(nullptr, 0) "ForwardAccumulator", /* tp_name */
sizeof(TFE_Py_ForwardAccumulator), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_ForwardAccumulatorDelete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"TFE_Py_ForwardAccumulator objects", /* tp_doc */
};
typedef struct {
PyObject_HEAD
/* Type-specific fields go here. */
VariableWatcher* variable_watcher;
} TFE_Py_VariableWatcher;
static void TFE_Py_VariableWatcher_Delete(PyObject* variable_watcher) {
delete reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
->variable_watcher;
Py_TYPE(variable_watcher)->tp_free(variable_watcher);
}
static PyTypeObject TFE_Py_VariableWatcher_Type = {
PyVarObject_HEAD_INIT(nullptr, 0) "tfe.VariableWatcher", /* tp_name */
sizeof(TFE_Py_VariableWatcher), /* tp_basicsize */
0, /* tp_itemsize */
&TFE_Py_VariableWatcher_Delete, /* tp_dealloc */
0, /* tp_print */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
nullptr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
nullptr, /* tp_hash */
nullptr, /* tp_call */
nullptr, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"TFE_Py_VariableWatcher objects", /* tp_doc */
};
// Note: in the current design no mutex is needed here because of the python
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
thread_local std::unique_ptr<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>
tape_set = nullptr;
if (tape_set == nullptr) {
tape_set.reset(new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>);
}
return tape_set.get();
}
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>*
GetVariableWatcherSet() {
thread_local std::unique_ptr<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>
variable_watcher_set = nullptr;
if (variable_watcher_set == nullptr) {
variable_watcher_set.reset(
new tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>);
}
return variable_watcher_set.get();
}
// A linked hash set, where iteration is in insertion order.
//
// Nested accumulators rely on op recording happening in insertion order, so an
// unordered data structure like CompactPointerSet is not suitable. Outer
// accumulators need to observe operations first so they know to watch the inner
// accumulator's jvp computation.
//
// Not thread safe.
class AccumulatorSet {
public:
// Returns true if `element` was newly inserted, false if it already exists.
bool insert(TFE_Py_ForwardAccumulator* element) {
if (map_.find(element) != map_.end()) {
return false;
}
ListType::iterator it = ordered_.insert(ordered_.end(), element);
map_.insert(std::make_pair(element, it));
return true;
}
void erase(TFE_Py_ForwardAccumulator* element) {
MapType::iterator existing = map_.find(element);
if (existing == map_.end()) {
return;
}
ListType::iterator list_position = existing->second;
map_.erase(existing);
ordered_.erase(list_position);
}
bool empty() const { return ordered_.empty(); }
size_t size() const { return ordered_.size(); }
private:
typedef std::list<TFE_Py_ForwardAccumulator*> ListType;
typedef tensorflow::gtl::FlatMap<TFE_Py_ForwardAccumulator*,
ListType::iterator>
MapType;
public:
typedef ListType::const_iterator const_iterator;
typedef ListType::const_reverse_iterator const_reverse_iterator;
const_iterator begin() const { return ordered_.begin(); }
const_iterator end() const { return ordered_.end(); }
const_reverse_iterator rbegin() const { return ordered_.rbegin(); }
const_reverse_iterator rend() const { return ordered_.rend(); }
private:
MapType map_;
ListType ordered_;
};
AccumulatorSet* GetAccumulatorSet() {
thread_local std::unique_ptr<AccumulatorSet> accumulator_set{nullptr};
if (accumulator_set == nullptr) {
accumulator_set.reset(new AccumulatorSet);
}
return accumulator_set.get();
}
inline bool HasAccumulator() { return !GetAccumulatorSet()->empty(); }
inline bool HasGradientTape() { return !GetTapeSet()->empty(); }
inline bool HasAccumulatorOrTape() {
return HasGradientTape() || HasAccumulator();
}
// A safe copy of a set, used for tapes and accumulators. The copy is not
// affected by other python threads changing the set of active tapes.
template <typename ContainerType>
class SafeSetCopy {
public:
explicit SafeSetCopy(const ContainerType& to_copy) : set_copy_(to_copy) {
for (auto* member : set_copy_) {
Py_INCREF(member);
}
}
~SafeSetCopy() {
for (auto* member : set_copy_) {
Py_DECREF(member);
}
}
typename ContainerType::const_iterator begin() const {
return set_copy_.begin();
}
typename ContainerType::const_iterator end() const { return set_copy_.end(); }
bool empty() const { return set_copy_.empty(); }
size_t size() const { return set_copy_.size(); }
protected:
ContainerType set_copy_;
};
class SafeTapeSet
: public SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>> {
public:
SafeTapeSet()
: SafeSetCopy<tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>>(
*GetTapeSet()) {}
};
class SafeAccumulatorSet : public SafeSetCopy<AccumulatorSet> {
public:
SafeAccumulatorSet() : SafeSetCopy<AccumulatorSet>(*GetAccumulatorSet()) {}
typename AccumulatorSet::const_reverse_iterator rbegin() const {
return set_copy_.rbegin();
}
typename AccumulatorSet::const_reverse_iterator rend() const {
return set_copy_.rend();
}
};
class SafeVariableWatcherSet
: public SafeSetCopy<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>> {
public:
SafeVariableWatcherSet()
: SafeSetCopy<
tensorflow::gtl::CompactPointerSet<TFE_Py_VariableWatcher*>>(
*GetVariableWatcherSet()) {}
};
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
}
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
void TFE_Py_TapeSetRestartOnThread() { *ThreadTapeIsStopped() = false; }
PyObject* TFE_Py_TapeSetIsStopped() {
if (*ThreadTapeIsStopped()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
PyObject* TFE_Py_TapeSetNew(PyObject* persistent,
PyObject* watch_accessed_variables) {
TFE_Py_Tape_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_Tape_Type) < 0) return nullptr;
TFE_Py_Tape* tape = PyObject_NEW(TFE_Py_Tape, &TFE_Py_Tape_Type);
tape->tape = new GradientTape(persistent == Py_True,
watch_accessed_variables == Py_True);
Py_INCREF(tape);
tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
GetTapeSet()->insert(tape);
return reinterpret_cast<PyObject*>(tape);
}
void TFE_Py_TapeSetAdd(PyObject* tape) {
Py_INCREF(tape);
TFE_Py_Tape* tfe_tape = reinterpret_cast<TFE_Py_Tape*>(tape);
if (!GetTapeSet()->insert(tfe_tape).second) {
// Already exists in the tape set.
Py_DECREF(tape);
} else {
tfe_tape->nesting_id = tape_nesting_id_counter.fetch_add(1);
}
}
PyObject* TFE_Py_TapeSetIsEmpty() {
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
}
void TFE_Py_TapeSetRemove(PyObject* tape) {
auto* stack = GetTapeSet();
stack->erase(reinterpret_cast<TFE_Py_Tape*>(tape));
// We kept a reference to the tape in the set to ensure it wouldn't get
// deleted under us; cleaning it up here.
Py_DECREF(tape);
}
static std::vector<tensorflow::int64> MakeIntList(PyObject* list) {
if (list == Py_None) {
return {};
}
PyObject* seq = PySequence_Fast(list, "expected a sequence");
if (seq == nullptr) {
return {};
}
int len = PySequence_Size(list);
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
std::vector<tensorflow::int64> tensor_ids;
tensor_ids.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = seq_array[i];
#if PY_MAJOR_VERSION >= 3
if (PyLong_Check(item)) {
#else
if (PyLong_Check(item) || PyInt_Check(item)) {
#endif
tensorflow::int64 id = MakeInt(item);
tensor_ids.push_back(id);
} else {
tensor_ids.push_back(-1);
}
}
Py_DECREF(seq);
return tensor_ids;
}
// Fill `tensor_ids` and `dtypes` from `tensors`, none of which may be
// null. Returns true on success and false on a Python exception.
bool TensorShapesAndDtypes(PyObject* tensors,
std::vector<tensorflow::int64>* tensor_ids,
std::vector<tensorflow::DataType>* dtypes) {
tensorflow::Safe_PyObjectPtr seq(
PySequence_Fast(tensors, "expected a sequence"));
if (seq == nullptr) {
return false;
}
int len = PySequence_Fast_GET_SIZE(seq.get());
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
tensor_ids->reserve(len);
dtypes->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = seq_array[i];
tensor_ids->push_back(FastTensorId(item));
dtypes->push_back(FastTensorDtype(item));
}
return true;
}
bool TapeCouldPossiblyRecord(PyObject* tensors) {
if (tensors == Py_None) {
return false;
}
if (*ThreadTapeIsStopped()) {
return false;
}
if (!HasAccumulatorOrTape()) {
return false;
}
return true;
}
bool CouldBackprop() { return !*ThreadTapeIsStopped() && HasGradientTape(); }
bool CouldForwardprop() { return !*ThreadTapeIsStopped() && HasAccumulator(); }
PyObject* TFE_Py_TapeSetShouldRecordBackprop(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors) || !CouldBackprop()) {
Py_RETURN_FALSE;
}
// TODO(apassos) consider not building a list and changing the API to check
// each tensor individually.
std::vector<tensorflow::int64> tensor_ids;
std::vector<tensorflow::DataType> dtypes;
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
return nullptr;
}
auto tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
Py_RETURN_TRUE;
}
}
Py_RETURN_FALSE;
}
PyObject* TFE_Py_ForwardAccumulatorPushState() {
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PushState();
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_ForwardAccumulatorPopState() {
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
accumulator->accumulator->PopState();
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_TapeSetPossibleGradientTypes(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors)) {
return GetPythonObjectFromInt(0);
}
std::vector<tensorflow::int64> tensor_ids;
std::vector<tensorflow::DataType> dtypes;
if (!TensorShapesAndDtypes(tensors, &tensor_ids, &dtypes)) {
return nullptr;
}
// If there is a persistent tape watching, or if there are multiple tapes
// watching, we'll return immediately indicating that higher-order tape
// gradients are possible.
bool some_tape_watching = false;
if (CouldBackprop()) {
auto tape_set = *GetTapeSet();
for (TFE_Py_Tape* tape : tape_set) {
if (tape->tape->ShouldRecord(tensor_ids, dtypes)) {
if (tape->tape->IsPersistent() || some_tape_watching) {
// Either this is the second tape watching, or this tape is
// persistent: higher-order gradients are possible.
return GetPythonObjectFromInt(2);
}
some_tape_watching = true;
}
}
}
if (CouldForwardprop()) {
auto forward_accumulators = *GetAccumulatorSet();
for (TFE_Py_ForwardAccumulator* accumulator : forward_accumulators) {
if (accumulator->accumulator->ShouldRecord(tensor_ids, dtypes)) {
if (some_tape_watching) {
// This is the second tape watching: higher-order gradients are
// possible. Note that there's no equivalent of persistence for
// forward-mode.
return GetPythonObjectFromInt(2);
}
some_tape_watching = true;
}
}
}
if (some_tape_watching) {
// There's exactly one non-persistent tape. The user can request first-order
// gradients but won't be able to get higher-order tape gradients.
return GetPythonObjectFromInt(1);
} else {
// There are no tapes. The user can't request tape gradients.
return GetPythonObjectFromInt(0);
}
}
void TFE_Py_TapeWatch(PyObject* tape, PyObject* tensor) {
if (!CouldBackprop()) {
return;
}
tensorflow::int64 tensor_id = FastTensorId(tensor);
if (PyErr_Occurred()) {
return;
}
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->Watch(tensor_id);
}
bool ListContainsNone(PyObject* list) {
if (list == Py_None) return true;
tensorflow::Safe_PyObjectPtr seq(
PySequence_Fast(list, "expected a sequence"));
if (seq == nullptr) {
return false;
}
int len = PySequence_Size(list);
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
for (int i = 0; i < len; ++i) {
PyObject* item = seq_array[i];
if (item == Py_None) return true;
}
return false;
}
static PyTapeTensor TapeTensorFromTensor(PyObject* tensor) {
if (EagerTensor_CheckExact(tensor)) {
TFE_TensorHandle* t = EagerTensor_Handle(tensor);
tensorflow::int64 id = PyEagerTensor_ID(tensor);
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(t->handle->DataType());
if (dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}
tensorflow::TensorShape tensor_shape;
int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims);
if (status.ok()) {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size);
if (!status.ok()) break;
tensor_shape.AddDim(dim_size);
}
}
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
tensorflow::TensorShape({}));
} else {
return PyTapeTensor(id, dtype, tensor_shape);
}
}
tensorflow::int64 id = FastTensorId(tensor);
if (PyErr_Occurred()) {
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
tensorflow::TensorShape({}));
}
PyObject* dtype_object = PyObject_GetAttrString(tensor, "dtype");
PyObject* dtype_enum = PyObject_GetAttrString(dtype_object, "_type_enum");
Py_DECREF(dtype_object);
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum));
Py_DECREF(dtype_enum);
if (PyErr_Occurred()) {
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
tensorflow::TensorShape({}));
}
static char _shape_tuple[] = "_shape_tuple";
tensorflow::Safe_PyObjectPtr shape_tuple(
PyObject_CallMethod(tensor, _shape_tuple, nullptr));
if (PyErr_Occurred()) {
return PyTapeTensor(id, static_cast<tensorflow::DataType>(0),
tensorflow::TensorShape({}));
}
if (ListContainsNone(shape_tuple.get()) || dtype == tensorflow::DT_VARIANT) {
return PyTapeTensor(id, dtype, tensor);
}
auto l = MakeIntList(shape_tuple.get());
// Replace -1, which represents accidental Nones which can occur in graph mode
// and can cause errors in shape construction with 0s.
for (auto& c : l) {
if (c < 0) {
c = 0;
}
}
tensorflow::TensorShape shape(l);
return PyTapeTensor(id, dtype, shape);
}
// Populates output_info from output_seq, which must come from PySequence_Fast.
//
// Does not take ownership of output_seq. Returns true on success and false if a
// Python exception has been set.
bool TapeTensorsFromTensorSequence(PyObject* output_seq,
std::vector<PyTapeTensor>* output_info) {
Py_ssize_t output_len = PySequence_Fast_GET_SIZE(output_seq);
PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
output_info->reserve(output_len);
for (Py_ssize_t i = 0; i < output_len; ++i) {
output_info->push_back(TapeTensorFromTensor(output_seq_array[i]));
if (PyErr_Occurred() != nullptr) {
return false;
}
}
return true;
}
std::vector<tensorflow::int64> MakeTensorIDList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return {};
}
int len = PySequence_Fast_GET_SIZE(seq);
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
std::vector<tensorflow::int64> list;
list.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* tensor = seq_array[i];
list.push_back(FastTensorId(tensor));
if (PyErr_Occurred()) {
Py_DECREF(seq);
return list;
}
}
Py_DECREF(seq);
return list;
}
void TFE_Py_TapeVariableAccessed(PyObject* variable) {
if (!CouldBackprop()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
tape->tape->VariableAccessed(variable);
}
}
void TFE_Py_TapeWatchVariable(PyObject* tape, PyObject* variable) {
if (!CouldBackprop()) {
return;
}
reinterpret_cast<TFE_Py_Tape*>(tape)->tape->WatchVariable(variable);
}
PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) {
return reinterpret_cast<TFE_Py_Tape*>(tape)->tape->GetVariablesAsPyTuple();
}
PyObject* TFE_Py_VariableWatcherNew() {
TFE_Py_VariableWatcher_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_VariableWatcher_Type) < 0) return nullptr;
TFE_Py_VariableWatcher* variable_watcher =
PyObject_NEW(TFE_Py_VariableWatcher, &TFE_Py_VariableWatcher_Type);
variable_watcher->variable_watcher = new VariableWatcher();
Py_INCREF(variable_watcher);
GetVariableWatcherSet()->insert(variable_watcher);
return reinterpret_cast<PyObject*>(variable_watcher);
}
void TFE_Py_VariableWatcherRemove(PyObject* variable_watcher) {
auto* stack = GetVariableWatcherSet();
stack->erase(reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher));
// We kept a reference to the variable watcher in the set to ensure it
// wouldn't get deleted under us; cleaning it up here.
Py_DECREF(variable_watcher);
}
void TFE_Py_VariableWatcherVariableAccessed(PyObject* variable) {
for (TFE_Py_VariableWatcher* variable_watcher : SafeVariableWatcherSet()) {
variable_watcher->variable_watcher->WatchVariable(variable);
}
}
PyObject* TFE_Py_VariableWatcherWatchedVariables(PyObject* variable_watcher) {
return reinterpret_cast<TFE_Py_VariableWatcher*>(variable_watcher)
->variable_watcher->GetVariablesAsPyTuple();
}
namespace {
std::vector<tensorflow::DataType> MakeTensorDtypeList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return {};
}
int len = PySequence_Fast_GET_SIZE(seq);
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
std::vector<tensorflow::DataType> list;
list.reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* tensor = seq_array[i];
list.push_back(FastTensorDtype(tensor));
}
Py_DECREF(seq);
return list;
}
PyObject* ForwardAccumulatorDeleteGradient(PyObject* tensor_id,
PyObject* weak_tensor_ref) {
tensorflow::int64 parsed_tensor_id = MakeInt(tensor_id);
for (TFE_Py_ForwardAccumulator* accumulator : *GetAccumulatorSet()) {
accumulator->accumulator->DeleteGradient(parsed_tensor_id);
}
Py_DECREF(weak_tensor_ref);
Py_DECREF(tensor_id);
Py_INCREF(Py_None);
return Py_None;
}
static PyMethodDef forward_accumulator_delete_gradient_method_def = {
"ForwardAccumulatorDeleteGradient", ForwardAccumulatorDeleteGradient,
METH_O, "ForwardAccumulatorDeleteGradient"};
void RegisterForwardAccumulatorCleanup(PyObject* tensor,
tensorflow::int64 tensor_id) {
tensorflow::Safe_PyObjectPtr callback(
PyCFunction_New(&forward_accumulator_delete_gradient_method_def,
PyLong_FromLong(tensor_id)));
// We need to keep a reference to the weakref active if we want our callback
// called. The callback itself now owns the weakref object and the tensor ID
// object.
PyWeakref_NewRef(tensor, callback.get());
}
void TapeSetRecordBackprop(
const string& op_type, const std::vector<PyTapeTensor>& output_info,
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
tensorflow::uint64 max_gradient_tape_id) {
if (!CouldBackprop()) {
return;
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
if (tape->nesting_id < max_gradient_tape_id) {
tape->tape->RecordOperation(op_type, output_info, input_ids, input_dtypes,
backward_function_getter,
backward_function_killer);
}
}
}
bool TapeSetRecordForwardprop(
const string& op_type, PyObject* output_seq,
const std::vector<PyTapeTensor>& output_info, PyObject* input_tensors,
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
const tensorflow::eager::ForwardFunction<PyObject>* forward_function,
PyObject* forwardprop_output_indices,
tensorflow::uint64* max_gradient_tape_id) {
*max_gradient_tape_id = std::numeric_limits<tensorflow::uint64>::max();
if (!CouldForwardprop()) {
return true;
}
auto accumulator_set = SafeAccumulatorSet();
tensorflow::Safe_PyObjectPtr input_seq(
PySequence_Fast(input_tensors, "expected a sequence of tensors"));
if (input_seq == nullptr || PyErr_Occurred()) return false;
Py_ssize_t input_len = PySequence_Fast_GET_SIZE(input_seq.get());
PyObject** output_seq_array = PySequence_Fast_ITEMS(output_seq);
for (int i = 0; i < output_info.size(); ++i) {
RegisterForwardAccumulatorCleanup(output_seq_array[i],
output_info[i].GetID());
}
if (forwardprop_output_indices != nullptr &&
forwardprop_output_indices != Py_None) {
tensorflow::Safe_PyObjectPtr indices_fast(PySequence_Fast(
forwardprop_output_indices, "Expected a sequence of indices"));
if (indices_fast == nullptr || PyErr_Occurred()) {
return false;
}
if (PySequence_Fast_GET_SIZE(indices_fast.get()) !=
accumulator_set.size()) {
MaybeRaiseExceptionFromStatus(
tensorflow::errors::Internal(
"Accumulators were added or removed from the active set "
"between packing and unpacking."),
nullptr);
}
PyObject** indices_fast_array = PySequence_Fast_ITEMS(indices_fast.get());
Py_ssize_t accumulator_index = 0;
for (AccumulatorSet::const_reverse_iterator it = accumulator_set.rbegin();
it != accumulator_set.rend(); ++it, ++accumulator_index) {
tensorflow::Safe_PyObjectPtr jvp_index_seq(
PySequence_Fast(indices_fast_array[accumulator_index],
"Expected a sequence of jvp indices."));
if (jvp_index_seq == nullptr || PyErr_Occurred()) {
return false;
}
Py_ssize_t num_jvps = PySequence_Fast_GET_SIZE(jvp_index_seq.get());
PyObject** jvp_index_seq_array =
PySequence_Fast_ITEMS(jvp_index_seq.get());
for (Py_ssize_t jvp_index = 0; jvp_index < num_jvps; ++jvp_index) {
PyObject* tuple = jvp_index_seq_array[jvp_index];
tensorflow::int64 primal_tensor_id =
output_info[MakeInt(PyTuple_GetItem(tuple, 0))].GetID();
(*it)->accumulator->Watch(
primal_tensor_id,
output_seq_array[MakeInt(PyTuple_GetItem(tuple, 1))]);
}
}
} else {
std::vector<PyTapeTensor> input_info;
input_info.reserve(input_len);
PyObject** input_seq_array = PySequence_Fast_ITEMS(input_seq.get());
for (Py_ssize_t i = 0; i < input_len; ++i) {
input_info.push_back(TapeTensorFromTensor(input_seq_array[i]));
}
for (TFE_Py_ForwardAccumulator* accumulator : accumulator_set) {
tensorflow::Status status = accumulator->accumulator->Accumulate(
op_type, input_info, output_info, input_ids, input_dtypes,
forward_function, backward_function_getter, backward_function_killer);
if (PyErr_Occurred()) return false; // Don't swallow Python exceptions.
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return false;
}
if (accumulator->accumulator->BusyAccumulating()) {
// Ensure inner accumulators don't see outer accumulators' jvps. This
// mostly happens on its own, with some potentially surprising
// exceptions, so the blanket policy is for consistency.
*max_gradient_tape_id = accumulator->nesting_id;
break;
}
}
}
return true;
}
PyObject* TangentsAsPyTuple(const std::vector<PyObject*>& input_tangents) {
PyObject* py_input_tangents = PyTuple_New(input_tangents.size());
for (int i = 0; i < input_tangents.size(); ++i) {
PyObject* element;
if (input_tangents[i] == nullptr) {
element = Py_None;
} else {
element = input_tangents[i];
}
Py_INCREF(element);
PyTuple_SET_ITEM(py_input_tangents, i, element);
}
return py_input_tangents;
}
tensorflow::Status ParseTangentOutputs(
PyObject* user_output, std::vector<PyObject*>* output_tangents) {
if (user_output == Py_None) {
// No connected gradients.
return tensorflow::Status::OK();
}
tensorflow::Safe_PyObjectPtr fast_result(
PySequence_Fast(user_output, "expected a sequence of forward gradients"));
if (fast_result == nullptr) {
return tensorflow::errors::InvalidArgument(
"forward gradient function did not return a sequence.");
}
int len = PySequence_Fast_GET_SIZE(fast_result.get());
PyObject** fast_result_array = PySequence_Fast_ITEMS(fast_result.get());
output_tangents->reserve(len);
for (int i = 0; i < len; ++i) {
PyObject* item = fast_result_array[i];
if (item == Py_None) {
output_tangents->push_back(nullptr);
} else {
Py_INCREF(item);
output_tangents->push_back(item);
}
}
return tensorflow::Status::OK();
}
// Calls the registered forward_gradient_function, computing `output_tangents`
// from `input_tangents`. `output_tangents` must not be null.
//
// `op_name`, `attrs`, `inputs`, and `results` describe the operation for which
// the forward function is being called.
tensorflow::Status CallJVPFunction(PyObject* op_name, PyObject* attrs,
PyObject* inputs, PyObject* results,
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
if (forward_gradient_function == nullptr) {
return tensorflow::errors::Internal(
"No forward gradient function registered.");
}
tensorflow::Safe_PyObjectPtr py_input_tangents(
TangentsAsPyTuple(input_tangents));
// Normalize the input sequence to a tuple so it works with function
// caching; otherwise it may be an opaque _InputList object.
tensorflow::Safe_PyObjectPtr input_tuple(PySequence_Tuple(inputs));
tensorflow::Safe_PyObjectPtr callback_args(
Py_BuildValue("OOOOO", op_name, attrs, input_tuple.get(), results,
py_input_tangents.get()));
tensorflow::Safe_PyObjectPtr py_result(
PyObject_CallObject(forward_gradient_function, callback_args.get()));
if (py_result == nullptr || PyErr_Occurred()) {
return tensorflow::errors::Internal(
"forward gradient function threw exceptions");
}
return ParseTangentOutputs(py_result.get(), output_tangents);
}
// Like CallJVPFunction, but calls a pre-bound forward function.
// These are passed in from a record_gradient argument.
tensorflow::Status CallOpSpecificJVPFunction(
PyObject* op_specific_forward_function,
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
tensorflow::Safe_PyObjectPtr py_input_tangents(
TangentsAsPyTuple(input_tangents));
tensorflow::Safe_PyObjectPtr py_result(PyObject_CallObject(
op_specific_forward_function, py_input_tangents.get()));
if (py_result == nullptr || PyErr_Occurred()) {
return tensorflow::errors::Internal(
"forward gradient function threw exceptions");
}
return ParseTangentOutputs(py_result.get(), output_tangents);
}
bool ParseOpTypeString(PyObject* op_type, string* op_type_string) {
if (PyBytes_Check(op_type)) {
*op_type_string = PyBytes_AsString(op_type);
} else if (PyUnicode_Check(op_type)) {
#if PY_MAJOR_VERSION >= 3
*op_type_string = PyUnicode_AsUTF8(op_type);
#else
PyObject* py_str = PyUnicode_AsUTF8String(op_type);
if (py_str == nullptr) {
return false;
}
*op_type_string = PyBytes_AS_STRING(py_str);
Py_DECREF(py_str);
#endif
} else {
PyErr_SetString(PyExc_RuntimeError, "op_type should be a string.");
return false;
}
return true;
}
bool TapeSetRecordOperation(
PyObject* op_type, PyObject* input_tensors, PyObject* output_tensors,
const std::vector<tensorflow::int64>& input_ids,
const std::vector<tensorflow::DataType>& input_dtypes,
const std::function<PyBackwardFunction*()>& backward_function_getter,
const std::function<void(PyBackwardFunction*)>& backward_function_killer,
const tensorflow::eager::ForwardFunction<PyObject>* forward_function) {
std::vector<PyTapeTensor> output_info;
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
output_tensors, "expected a sequence of integer tensor ids"));
if (PyErr_Occurred() ||
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
return false;
}
string op_type_str;
if (!ParseOpTypeString(op_type, &op_type_str)) {
return false;
}
tensorflow::uint64 max_gradient_tape_id;
if (!TapeSetRecordForwardprop(
op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
input_dtypes, backward_function_getter, backward_function_killer,
forward_function, nullptr /* No special-cased jvps. */,
&max_gradient_tape_id)) {
return false;
}
TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
backward_function_getter, backward_function_killer,
max_gradient_tape_id);
return true;
}
} // namespace
PyObject* TFE_Py_TapeSetRecordOperation(PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensors,
PyObject* backward_function,
PyObject* forward_function) {
if (!HasAccumulatorOrTape() || *ThreadTapeIsStopped()) {
Py_RETURN_NONE;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes =
MakeTensorDtypeList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::function<PyBackwardFunction*()> backward_function_getter(
[backward_function]() {
Py_INCREF(backward_function);
PyBackwardFunction* function = new PyBackwardFunction(
[backward_function](PyObject* out_grads,
const std::vector<tensorflow::int64>& unused) {
return PyObject_CallObject(backward_function, out_grads);
});
return function;
});
std::function<void(PyBackwardFunction*)> backward_function_killer(
[backward_function](PyBackwardFunction* py_backward_function) {
Py_DECREF(backward_function);
delete py_backward_function;
});
if (forward_function == Py_None) {
if (!TapeSetRecordOperation(
op_type, input_tensors, output_tensors, input_ids, input_dtypes,
backward_function_getter, backward_function_killer,
nullptr /* No special-cased forward function */)) {
return nullptr;
}
} else {
tensorflow::eager::ForwardFunction<PyObject> wrapped_forward_function(
[forward_function](const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
return CallOpSpecificJVPFunction(forward_function, input_tangents,
output_tangents);
});
if (!TapeSetRecordOperation(
op_type, input_tensors, output_tensors, input_ids, input_dtypes,
backward_function_getter, backward_function_killer,
&wrapped_forward_function)) {
return nullptr;
}
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_TapeSetRecordOperationForwardprop(
PyObject* op_type, PyObject* output_tensors, PyObject* input_tensors,
PyObject* backward_function, PyObject* forwardprop_output_indices) {
if (!HasAccumulator() || *ThreadTapeIsStopped()) {
Py_RETURN_NONE;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes =
MakeTensorDtypeList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::function<PyBackwardFunction*()> backward_function_getter(
[backward_function]() {
Py_INCREF(backward_function);
PyBackwardFunction* function = new PyBackwardFunction(
[backward_function](PyObject* out_grads,
const std::vector<tensorflow::int64>& unused) {
return PyObject_CallObject(backward_function, out_grads);
});
return function;
});
std::function<void(PyBackwardFunction*)> backward_function_killer(
[backward_function](PyBackwardFunction* py_backward_function) {
Py_DECREF(backward_function);
delete py_backward_function;
});
std::vector<PyTapeTensor> output_info;
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
output_tensors, "expected a sequence of integer tensor ids"));
if (PyErr_Occurred() ||
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
return nullptr;
}
string op_type_str;
if (!ParseOpTypeString(op_type, &op_type_str)) {
return nullptr;
}
tensorflow::uint64 max_gradient_tape_id;
if (!TapeSetRecordForwardprop(
op_type_str, output_seq.get(), output_info, input_tensors, input_ids,
input_dtypes, backward_function_getter, backward_function_killer,
nullptr /* no special-cased forward function */,
forwardprop_output_indices, &max_gradient_tape_id)) {
return nullptr;
}
Py_RETURN_NONE;
}
PyObject* TFE_Py_TapeSetRecordOperationBackprop(PyObject* op_type,
PyObject* output_tensors,
PyObject* input_tensors,
PyObject* backward_function) {
if (!CouldBackprop()) {
Py_RETURN_NONE;
}
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes =
MakeTensorDtypeList(input_tensors);
if (PyErr_Occurred()) return nullptr;
std::function<PyBackwardFunction*()> backward_function_getter(
[backward_function]() {
Py_INCREF(backward_function);
PyBackwardFunction* function = new PyBackwardFunction(
[backward_function](PyObject* out_grads,
const std::vector<tensorflow::int64>& unused) {
return PyObject_CallObject(backward_function, out_grads);
});
return function;
});
std::function<void(PyBackwardFunction*)> backward_function_killer(
[backward_function](PyBackwardFunction* py_backward_function) {
Py_DECREF(backward_function);
delete py_backward_function;
});
std::vector<PyTapeTensor> output_info;
tensorflow::Safe_PyObjectPtr output_seq(PySequence_Fast(
output_tensors, "expected a sequence of integer tensor ids"));
if (PyErr_Occurred() ||
!TapeTensorsFromTensorSequence(output_seq.get(), &output_info)) {
return nullptr;
}
string op_type_str;
if (!ParseOpTypeString(op_type, &op_type_str)) {
return nullptr;
}
TapeSetRecordBackprop(op_type_str, output_info, input_ids, input_dtypes,
backward_function_getter, backward_function_killer,
// No filtering based on relative ordering with forward
// accumulators.
std::numeric_limits<tensorflow::uint64>::max());
Py_RETURN_NONE;
}
void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) {
for (TFE_Py_Tape* tape : *GetTapeSet()) {
tape->tape->DeleteTrace(tensor_id);
}
}
std::vector<PyObject*> MakeTensorList(PyObject* tensors) {
PyObject* seq = PySequence_Fast(tensors, "expected a sequence");
if (seq == nullptr) {
return {};
}
int len = PySequence_Fast_GET_SIZE(seq);
PyObject** seq_array = PySequence_Fast_ITEMS(seq);
std::vector<PyObject*> list(seq_array, seq_array + len);
Py_DECREF(seq);
return list;
}
PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* target,
PyObject* sources, PyObject* output_gradients,
PyObject* sources_raw,
PyObject* unconnected_gradients,
TF_Status* status) {
TFE_Py_Tape* tape_obj = reinterpret_cast<TFE_Py_Tape*>(tape);
if (!tape_obj->tape->IsPersistent()) {
auto* tape_set = GetTapeSet();
if (tape_set->find(tape_obj) != tape_set->end()) {
PyErr_SetString(PyExc_RuntimeError,
"gradient() cannot be invoked within the "
"GradientTape context (i.e., while operations are being "
"recorded). Either move the call to gradient() to be "
"outside the 'with tf.GradientTape' block, or "
"use a persistent tape: "
"'with tf.GradientTape(persistent=true)'");
return nullptr;
}
}
std::vector<tensorflow::int64> target_vec = MakeTensorIDList(target);
if (PyErr_Occurred()) {
return nullptr;
}
std::vector<tensorflow::int64> sources_vec = MakeTensorIDList(sources);
if (PyErr_Occurred()) {
return nullptr;
}
tensorflow::gtl::FlatSet<tensorflow::int64> sources_set(sources_vec.begin(),
sources_vec.end());
tensorflow::Safe_PyObjectPtr seq =
tensorflow::make_safe(PySequence_Fast(target, "expected a sequence"));
int len = PySequence_Fast_GET_SIZE(seq.get());
PyObject** seq_array = PySequence_Fast_ITEMS(seq.get());
std::unordered_map<tensorflow::int64, PyTapeTensor>
source_tensors_that_are_targets;
for (int i = 0; i < len; ++i) {
tensorflow::int64 target_id = target_vec[i];
if (sources_set.find(target_id) != sources_set.end()) {
auto tensor = seq_array[i];
source_tensors_that_are_targets.insert(
std::make_pair(target_id, TapeTensorFromTensor(tensor)));
}
if (PyErr_Occurred()) {
return nullptr;
}
}
if (PyErr_Occurred()) {
return nullptr;
}
std::vector<PyObject*> outgrad_vec;
if (output_gradients != Py_None) {
outgrad_vec = MakeTensorList(output_gradients);
if (PyErr_Occurred()) {
return nullptr;
}
for (PyObject* tensor : outgrad_vec) {
// Calling the backward function will eat a reference to the tensors in
// outgrad_vec, so we need to increase their reference count.
Py_INCREF(tensor);
}
}
std::vector<PyObject*> result;
status->status = tape_obj->tape->ComputeGradient(
*py_vspace, target_vec, sources_vec, source_tensors_that_are_targets,
outgrad_vec, &result);
if (!status->status.ok()) {
if (PyErr_Occurred()) {
// Do not propagate the erroneous status as that would swallow the
// exception which caused the problem.
status->status = tensorflow::Status::OK();
}
return nullptr;
}
bool unconnected_gradients_zero =
strcmp(TFE_GetPythonString(unconnected_gradients), "zero") == 0;
std::vector<PyObject*> sources_obj;
if (unconnected_gradients_zero) {
// Uses the "raw" sources here so it can properly make a zeros tensor even
// if there are resource variables as sources.
sources_obj = MakeTensorList(sources_raw);
}
if (!result.empty()) {
PyObject* py_result = PyList_New(result.size());
tensorflow::gtl::FlatSet<PyObject*> seen_results(result.size());
for (int i = 0; i < result.size(); ++i) {
if (result[i] == nullptr) {
if (unconnected_gradients_zero) {
// generate a zeros tensor in the shape of sources[i]
tensorflow::DataType dtype = FastTensorDtype(sources_obj[i]);
PyTapeTensor tensor =
PyTapeTensor(sources_vec[i], dtype, sources_obj[i]);
result[i] = tensor.ZerosLike();
} else {
Py_INCREF(Py_None);
result[i] = Py_None;
}
} else if (seen_results.find(result[i]) != seen_results.end()) {
Py_INCREF(result[i]);
}
seen_results.insert(result[i]);
PyList_SET_ITEM(py_result, i, reinterpret_cast<PyObject*>(result[i]));
}
return py_result;
}
return PyList_New(0);
}
PyObject* TFE_Py_ForwardAccumulatorNew() {
TFE_Py_ForwardAccumulator_Type.tp_new = PyType_GenericNew;
if (PyType_Ready(&TFE_Py_ForwardAccumulator_Type) < 0) return nullptr;
TFE_Py_ForwardAccumulator* accumulator =
PyObject_NEW(TFE_Py_ForwardAccumulator, &TFE_Py_ForwardAccumulator_Type);
if (py_vspace == nullptr) {
MaybeRaiseExceptionFromStatus(
tensorflow::errors::Internal(
"ForwardAccumulator requires a PyVSpace to be registered."),
nullptr);
}
accumulator->accumulator = new ForwardAccumulator(*py_vspace);
return reinterpret_cast<PyObject*>(accumulator);
}
PyObject* TFE_Py_ForwardAccumulatorSetAdd(PyObject* accumulator) {
TFE_Py_ForwardAccumulator* c_accumulator(
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
c_accumulator->nesting_id = tape_nesting_id_counter.fetch_add(1);
if (GetAccumulatorSet()->insert(c_accumulator)) {
Py_INCREF(accumulator);
Py_RETURN_NONE;
} else {
MaybeRaiseExceptionFromStatus(
tensorflow::errors::Internal(
"A ForwardAccumulator was added to the active set twice."),
nullptr);
return nullptr;
}
}
void TFE_Py_ForwardAccumulatorSetRemove(PyObject* accumulator) {
GetAccumulatorSet()->erase(
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator));
Py_DECREF(accumulator);
}
void TFE_Py_ForwardAccumulatorWatch(PyObject* accumulator, PyObject* tensor,
PyObject* tangent) {
tensorflow::int64 tensor_id = FastTensorId(tensor);
reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
->accumulator->Watch(tensor_id, tangent);
RegisterForwardAccumulatorCleanup(tensor, tensor_id);
}
// Returns a new reference to the JVP Tensor.
PyObject* TFE_Py_ForwardAccumulatorJVP(PyObject* accumulator,
PyObject* tensor) {
PyObject* jvp = reinterpret_cast<TFE_Py_ForwardAccumulator*>(accumulator)
->accumulator->FetchJVP(FastTensorId(tensor));
if (jvp == nullptr) {
jvp = Py_None;
}
Py_INCREF(jvp);
return jvp;
}
PyObject* TFE_Py_PackJVPs(PyObject* tensors) {
if (!TapeCouldPossiblyRecord(tensors)) {
tensorflow::Safe_PyObjectPtr empty_tuple(PyTuple_New(0));
tensorflow::Safe_PyObjectPtr empty_list(PyList_New(0));
return PyTuple_Pack(2, empty_tuple.get(), empty_list.get());
}
auto accumulators = *GetAccumulatorSet();
tensorflow::Safe_PyObjectPtr tensors_fast(
PySequence_Fast(tensors, "Expected a sequence of input Tensors."));
if (tensors_fast == nullptr || PyErr_Occurred()) {
return nullptr;
}
std::vector<tensorflow::int64> augmented_input_ids;
int len = PySequence_Fast_GET_SIZE(tensors_fast.get());
PyObject** tensors_fast_array = PySequence_Fast_ITEMS(tensors_fast.get());
for (Py_ssize_t position = 0; position < len; ++position) {
PyObject* input = tensors_fast_array[position];
if (input == Py_None) {
continue;
}
tensorflow::DataType input_dtype(FastTensorDtype(input));
if (input_dtype == tensorflow::DT_INVALID) {
return nullptr;
}
augmented_input_ids.push_back(FastTensorId(input));
}
if (PyErr_Occurred()) {
return nullptr;
}
// Find the innermost accumulator such that all outer accumulators are
// recording. Any more deeply nested accumulators will not have their JVPs
// saved.
AccumulatorSet::const_iterator innermost_all_recording = accumulators.begin();
for (; innermost_all_recording != accumulators.end();
++innermost_all_recording) {
if ((*innermost_all_recording)->accumulator->BusyAccumulating()) {
break;
}
}
AccumulatorSet::const_reverse_iterator reverse_innermost_all_recording(
innermost_all_recording);
bool saving_jvps = false;
tensorflow::Safe_PyObjectPtr all_indices(PyTuple_New(accumulators.size()));
std::vector<PyObject*> new_tensors;
Py_ssize_t accumulator_index = 0;
// Start with the innermost accumulators to give outer accumulators a chance
// to find their higher-order JVPs.
for (AccumulatorSet::const_reverse_iterator it = accumulators.rbegin();
it != accumulators.rend(); ++it, ++accumulator_index) {
std::vector<tensorflow::int64> new_input_ids;
std::vector<std::pair<tensorflow::int64, tensorflow::int64>>
accumulator_indices;
if (it == reverse_innermost_all_recording) {
saving_jvps = true;
}
if (saving_jvps) {
for (int input_index = 0; input_index < augmented_input_ids.size();
++input_index) {
tensorflow::int64 existing_input = augmented_input_ids[input_index];
PyObject* jvp = (*it)->accumulator->FetchJVP(existing_input);
if (jvp != nullptr) {
new_tensors.push_back(jvp);
new_input_ids.push_back(FastTensorId(jvp));
accumulator_indices.emplace_back(
input_index,
augmented_input_ids.size() + new_input_ids.size() - 1);
}
}
}
tensorflow::Safe_PyObjectPtr accumulator_indices_py(
PyTuple_New(accumulator_indices.size()));
for (int i = 0; i < accumulator_indices.size(); ++i) {
tensorflow::Safe_PyObjectPtr from_index(
GetPythonObjectFromInt(accumulator_indices[i].first));
tensorflow::Safe_PyObjectPtr to_index(
GetPythonObjectFromInt(accumulator_indices[i].second));
PyTuple_SetItem(accumulator_indices_py.get(), i,
PyTuple_Pack(2, from_index.get(), to_index.get()));
}
PyTuple_SetItem(all_indices.get(), accumulator_index,
accumulator_indices_py.release());
augmented_input_ids.insert(augmented_input_ids.end(), new_input_ids.begin(),
new_input_ids.end());
}
tensorflow::Safe_PyObjectPtr new_tensors_py(PyList_New(new_tensors.size()));
for (int i = 0; i < new_tensors.size(); ++i) {
PyObject* jvp = new_tensors[i];
Py_INCREF(jvp);
PyList_SET_ITEM(new_tensors_py.get(), i, jvp);
}
return PyTuple_Pack(2, all_indices.get(), new_tensors_py.get());
}
namespace {
static const int kFastPathExecuteInputStartIndex = 5;
PyObject* GetPythonObjectFromString(tensorflow::StringPiece s) {
#if PY_MAJOR_VERSION >= 3
return PyUnicode_FromStringAndSize(s.data(), s.size());
#else
return PyBytes_FromStringAndSize(s.data(), s.size());
#endif
}
bool CheckResourceVariable(PyObject* item) {
if (tensorflow::swig::IsResourceVariable(item)) {
tensorflow::Safe_PyObjectPtr handle(
PyObject_GetAttrString(item, "_handle"));
return EagerTensor_CheckExact(handle.get());
}
return false;
}
bool IsNumberType(PyObject* item) {
#if PY_MAJOR_VERSION >= 3
return PyFloat_Check(item) || PyLong_Check(item);
#else
return PyFloat_Check(item) || PyInt_Check(item) || PyLong_Check(item);
#endif
}
bool CheckOneInput(PyObject* item) {
if (EagerTensor_CheckExact(item) || CheckResourceVariable(item) ||
PyArray_Check(item) || IsNumberType(item)) {
return true;
}
// Sequences are not properly handled. Sequences with purely python numeric
// types work, but sequences with mixes of EagerTensors and python numeric
// types don't work.
// TODO(nareshmodi): fix
return false;
}
bool CheckInputsOk(PyObject* seq, int start_index,
const tensorflow::OpDef& op_def) {
for (int i = 0; i < op_def.input_arg_size(); i++) {
PyObject* item = PyTuple_GET_ITEM(seq, i + start_index);
if (!op_def.input_arg(i).number_attr().empty() ||
!op_def.input_arg(i).type_list_attr().empty()) {
// This item should be a seq input.
if (!PySequence_Check(item)) {
VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
<< "\", Input \"" << op_def.input_arg(i).name()
<< "\" since we expected a sequence, but got "
<< item->ob_type->tp_name;
return false;
}
tensorflow::Safe_PyObjectPtr fast_item(
PySequence_Fast(item, "Could not parse sequence."));
if (fast_item.get() == nullptr) {
return false;
}
int len = PySequence_Fast_GET_SIZE(fast_item.get());
PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
for (Py_ssize_t j = 0; j < len; j++) {
PyObject* inner_item = fast_item_array[j];
if (!CheckOneInput(inner_item)) {
VLOG(1) << "Falling back to slow path for Op \"" << op_def.name()
<< "\", Input \"" << op_def.input_arg(i).name()
<< "\", Index " << j
<< " since we expected an EagerTensor/ResourceVariable, "
"but got "
<< inner_item->ob_type->tp_name;
return false;
}
}
} else if (!CheckOneInput(item)) {
VLOG(1)
<< "Falling back to slow path for Op \"" << op_def.name()
<< "\", Input \"" << op_def.input_arg(i).name()
<< "\" since we expected an EagerTensor/ResourceVariable, but got "
<< item->ob_type->tp_name;
return false;
}
}
return true;
}
tensorflow::DataType MaybeGetDType(PyObject* item) {
if (EagerTensor_CheckExact(item) || CheckResourceVariable(item)) {
return FastTensorDtype(item);
}
return tensorflow::DT_INVALID;
}
tensorflow::DataType MaybeGetDTypeForAttr(const string& attr,
FastPathOpExecInfo* op_exec_info) {
auto cached_it = op_exec_info->cached_dtypes.find(attr);
if (cached_it != op_exec_info->cached_dtypes.end()) {
return cached_it->second;
}
auto it = op_exec_info->attr_to_inputs_map->find(attr);
if (it == op_exec_info->attr_to_inputs_map->end()) {
// No other inputs - this should never happen.
return tensorflow::DT_INVALID;
}
for (const auto& input_info : it->second) {
PyObject* item = PyTuple_GET_ITEM(
op_exec_info->args, kFastPathExecuteInputStartIndex + input_info.i);
if (input_info.is_list) {
tensorflow::Safe_PyObjectPtr fast_item(
PySequence_Fast(item, "Unable to allocate"));
int len = PySequence_Fast_GET_SIZE(fast_item.get());
PyObject** fast_item_array = PySequence_Fast_ITEMS(fast_item.get());
for (int i = 0; i < len; i++) {
auto dtype = MaybeGetDType(fast_item_array[i]);
if (dtype != tensorflow::DT_INVALID) return dtype;
}
} else {
auto dtype = MaybeGetDType(item);
if (dtype != tensorflow::DT_INVALID) return dtype;
}
}
auto default_it = op_exec_info->default_dtypes->find(attr);
if (default_it != op_exec_info->default_dtypes->end()) {
return default_it->second;
}
return tensorflow::DT_INVALID;
}
PyObject* CopySequenceSettingIndicesToNull(
PyObject* seq, const tensorflow::gtl::FlatSet<int>& indices) {
tensorflow::Safe_PyObjectPtr fast_seq(
PySequence_Fast(seq, "unable to allocate"));
int len = PySequence_Fast_GET_SIZE(fast_seq.get());
PyObject** fast_seq_array = PySequence_Fast_ITEMS(fast_seq.get());
PyObject* result = PyTuple_New(len);
for (int i = 0; i < len; i++) {
PyObject* item;
if (indices.find(i) != indices.end()) {
item = Py_None;
} else {
item = fast_seq_array[i];
}
Py_INCREF(item);
PyTuple_SET_ITEM(result, i, item);
}
return result;
}
PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* results,
PyObject* forward_pass_name_scope = nullptr) {
std::vector<tensorflow::int64> input_ids = MakeTensorIDList(inputs);
if (PyErr_Occurred()) return nullptr;
std::vector<tensorflow::DataType> input_dtypes = MakeTensorDtypeList(inputs);
if (PyErr_Occurred()) return nullptr;
bool should_record = false;
for (TFE_Py_Tape* tape : SafeTapeSet()) {
if (tape->tape->ShouldRecord(input_ids, input_dtypes)) {
should_record = true;
break;
}
}
if (!should_record) {
for (TFE_Py_ForwardAccumulator* accumulator : SafeAccumulatorSet()) {
if (accumulator->accumulator->ShouldRecord(input_ids, input_dtypes)) {
should_record = true;
break;
}
}
}
if (!should_record) Py_RETURN_NONE;
string c_op_name = TFE_GetPythonString(op_name);
PyObject* op_outputs;
bool op_outputs_tuple_created = false;
if (const auto unused_output_indices =
OpGradientUnusedOutputIndices(c_op_name)) {
if (unused_output_indices->empty()) {
op_outputs = Py_None;
} else {
op_outputs_tuple_created = true;
op_outputs =
CopySequenceSettingIndicesToNull(results, *unused_output_indices);
}
} else {
op_outputs = results;
}
PyObject* op_inputs;
bool op_inputs_tuple_created = false;
if (const auto unused_input_indices =
OpGradientUnusedInputIndices(c_op_name)) {
if (unused_input_indices->empty()) {
op_inputs = Py_None;
} else {
op_inputs_tuple_created = true;
op_inputs =
CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
}
} else {
op_inputs = inputs;
}
tensorflow::eager::ForwardFunction<PyObject> py_forward_function(
[op_name, attrs, inputs, results](
const std::vector<PyObject*>& input_tangents,
std::vector<PyObject*>* output_tangents) {
return CallJVPFunction(op_name, attrs, inputs, results, input_tangents,
output_tangents);
});
tensorflow::eager::ForwardFunction<PyObject>* forward_function;
if (c_op_name == "While" || c_op_name == "StatelessWhile" ||
c_op_name == "If" || c_op_name == "StatelessIf") {
// Control flow contains non-hashable attributes. Handling them in Python is
// a headache, so instead we'll stay as close to GradientTape's handling as
// possible (a null forward function means the accumulator forwards to a
// tape).
//
// This is safe to do since we'll only see control flow when graph building,
// in which case we can rely on pruning.
forward_function = nullptr;
} else {
forward_function = &py_forward_function;
}
PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs));
if (!forward_pass_name_scope) forward_pass_name_scope = Py_None;
TapeSetRecordOperation(
op_name, inputs, results, input_ids, input_dtypes,
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope]() {
Py_INCREF(op_name);
Py_INCREF(attrs);
Py_INCREF(num_inputs);
Py_INCREF(op_inputs);
Py_INCREF(op_outputs);
Py_INCREF(forward_pass_name_scope);
PyBackwardFunction* function = new PyBackwardFunction(
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope](
PyObject* output_grads,
const std::vector<tensorflow::int64>& unneeded_gradients) {
if (PyErr_Occurred()) {
return static_cast<PyObject*>(nullptr);
}
tensorflow::Safe_PyObjectPtr skip_input_indices;
if (!unneeded_gradients.empty()) {
skip_input_indices.reset(
PyTuple_New(unneeded_gradients.size()));
for (int i = 0; i < unneeded_gradients.size(); i++) {
PyTuple_SET_ITEM(
skip_input_indices.get(), i,
GetPythonObjectFromInt(unneeded_gradients[i]));
}
} else {
Py_INCREF(Py_None);
skip_input_indices.reset(Py_None);
}
tensorflow::Safe_PyObjectPtr callback_args(Py_BuildValue(
"OOOOOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs,
output_grads, skip_input_indices.get(),
forward_pass_name_scope));
tensorflow::Safe_PyObjectPtr result(
PyObject_CallObject(gradient_function, callback_args.get()));
if (PyErr_Occurred()) return static_cast<PyObject*>(nullptr);
return tensorflow::swig::Flatten(result.get());
});
return function;
},
[op_name, attrs, num_inputs, op_inputs, op_outputs,
forward_pass_name_scope](PyBackwardFunction* backward_function) {
Py_DECREF(op_name);
Py_DECREF(attrs);
Py_DECREF(num_inputs);
Py_DECREF(op_inputs);
Py_DECREF(op_outputs);
Py_DECREF(forward_pass_name_scope);
delete backward_function;
},
forward_function);
Py_DECREF(num_inputs);
if (op_outputs_tuple_created) Py_DECREF(op_outputs);
if (op_inputs_tuple_created) Py_DECREF(op_inputs);
if (PyErr_Occurred()) {
return nullptr;
}
Py_RETURN_NONE;
}
void MaybeNotifyVariableAccessed(PyObject* input) {
DCHECK(CheckResourceVariable(input));
DCHECK(PyObject_HasAttrString(input, "_trainable"));
tensorflow::Safe_PyObjectPtr trainable(
PyObject_GetAttrString(input, "_trainable"));
if (trainable.get() == Py_False) return;
TFE_Py_TapeVariableAccessed(input);
TFE_Py_VariableWatcherVariableAccessed(input);
}
bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
PyObject* input, tensorflow::Safe_PyObjectPtr* output,
TF_Status* status) {
MaybeNotifyVariableAccessed(input);
TFE_Op* op = TFE_NewOp(parent_op_exec_info.ctx, "ReadVariableOp", status);
auto cleaner = tensorflow::gtl::MakeCleanup([op] { TFE_DeleteOp(op); });
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
// Set dtype
DCHECK(PyObject_HasAttrString(input, "_dtype"));
tensorflow::Safe_PyObjectPtr dtype(PyObject_GetAttrString(input, "_dtype"));
int value;
if (!ParseTypeValue("_dtype", dtype.get(), status, &value)) {
return false;
}
TFE_OpSetAttrType(op, "dtype", static_cast<TF_DataType>(value));
TFE_OpSetDevice(op, parent_op_exec_info.device_name, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
// Get handle
tensorflow::Safe_PyObjectPtr handle(PyObject_GetAttrString(input, "_handle"));
if (!EagerTensor_CheckExact(handle.get())) return false;
TFE_OpAddInput(op, EagerTensor_Handle(handle.get()), status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
int num_retvals = 1;
TFE_TensorHandle* output_handle;
TFE_Execute(op, &output_handle, &num_retvals, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) return false;
// Always create the py object (and correctly DECREF it) from the returned
// value, else the data will leak.
output->reset(EagerTensorFromHandle(output_handle));
// TODO(nareshmodi): Should we run post exec callbacks here?
if (parent_op_exec_info.run_gradient_callback) {
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(1));
PyTuple_SET_ITEM(inputs.get(), 0, handle.release());
tensorflow::Safe_PyObjectPtr outputs(PyTuple_New(1));
Py_INCREF(output->get()); // stay alive after since tuple steals.
PyTuple_SET_ITEM(outputs.get(), 0, output->get());
tensorflow::Safe_PyObjectPtr op_string(
GetPythonObjectFromString("ReadVariableOp"));
if (!RecordGradient(op_string.get(), inputs.get(), Py_None,
outputs.get())) {
return false;
}
}
return true;
}
// Supports 3 cases at the moment:
// i) input is an EagerTensor.
// ii) input is a ResourceVariable - in this case, the is_variable param is
// set to true.
// iii) input is an arbitrary python list/tuple (note, this handling doesn't
// support packing).
//
// NOTE: dtype_hint_getter must *always* return a PyObject that can be
// decref'd. So if no hint is found, Py_RETURN_NONE (which correctly
// increfs Py_None).
//
// NOTE: This function sets a python error directly, and returns false.
// TF_Status is only passed since we don't want to have to reallocate it.
bool ConvertToTensor(
const FastPathOpExecInfo& op_exec_info, PyObject* input,
tensorflow::Safe_PyObjectPtr* output_handle,
// This gets a hint for this particular input.
const std::function<tensorflow::DataType()>& dtype_hint_getter,
// This sets the dtype after conversion is complete.
const std::function<void(const tensorflow::DataType dtype)>& dtype_setter,
TF_Status* status) {
if (EagerTensor_CheckExact(input)) {
Py_INCREF(input);
output_handle->reset(input);
return true;
} else if (CheckResourceVariable(input)) {
return ReadVariableOp(op_exec_info, input, output_handle, status);
}
// The hint comes from a supposedly similarly typed tensor.
tensorflow::DataType dtype_hint = dtype_hint_getter();
TFE_TensorHandle* handle = tensorflow::ConvertToEagerTensor(
op_exec_info.ctx, input, dtype_hint, op_exec_info.device_name);
if (handle == nullptr) {
return MaybeRaiseExceptionFromTFStatus(status, nullptr);
}
output_handle->reset(EagerTensorFromHandle(handle));
dtype_setter(
static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(handle)));
return true;
}
// Adds input and type attr to the op, and to the list of flattened
// inputs/attrs.
bool AddInputToOp(FastPathOpExecInfo* op_exec_info, PyObject* input,
const bool add_type_attr,
const tensorflow::OpDef::ArgDef& input_arg,
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_attrs,
std::vector<tensorflow::Safe_PyObjectPtr>* flattened_inputs,
TFE_Op* op, TF_Status* status) {
// py_eager_tensor's ownership is transferred to flattened_inputs if it is
// required, else the object is destroyed and DECREF'd when the object goes
// out of scope in this function.
tensorflow::Safe_PyObjectPtr py_eager_tensor = nullptr;
if (!ConvertToTensor(
*op_exec_info, input, &py_eager_tensor,
[&]() {
if (input_arg.type() != tensorflow::DataType::DT_INVALID) {
return input_arg.type();
}
return MaybeGetDTypeForAttr(input_arg.type_attr(), op_exec_info);
},
[&](const tensorflow::DataType dtype) {
op_exec_info->cached_dtypes[input_arg.type_attr()] = dtype;
},
status)) {
return false;
}
TFE_TensorHandle* input_handle = EagerTensor_Handle(py_eager_tensor.get());
if (add_type_attr && !input_arg.type_attr().empty()) {
auto dtype = TFE_TensorHandleDataType(input_handle);
TFE_OpSetAttrType(op, input_arg.type_attr().data(), dtype);
if (flattened_attrs != nullptr) {
flattened_attrs->emplace_back(
GetPythonObjectFromString(input_arg.type_attr()));
flattened_attrs->emplace_back(PyLong_FromLong(dtype));
}
}
if (flattened_inputs != nullptr) {
flattened_inputs->emplace_back(std::move(py_eager_tensor));
}
TFE_OpAddInput(op, input_handle, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return false;
}
return true;
}
const char* GetDeviceName(PyObject* py_device_name) {
if (py_device_name != Py_None) {
return TFE_GetPythonString(py_device_name);
}
return nullptr;
}
bool RaiseIfNotPySequence(PyObject* seq, const string& attr_name) {
if (!PySequence_Check(seq)) {
PyErr_SetString(PyExc_TypeError,
Printf("expected a sequence for attr %s, got %s instead",
attr_name.data(), seq->ob_type->tp_name)
.data());
return false;
}
if (PyArray_Check(seq) &&
PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)) != 1) {
PyErr_SetString(PyExc_ValueError,
Printf("expected a sequence for attr %s, got an ndarray "
"with rank %d instead",
attr_name.data(),
PyArray_NDIM(reinterpret_cast<PyArrayObject*>(seq)))
.data());
return false;
}
return true;
}
bool RunCallbacks(
const FastPathOpExecInfo& op_exec_info, PyObject* args,
int num_inferred_attrs,
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_inputs,
const std::vector<tensorflow::Safe_PyObjectPtr>& flattened_attrs,
PyObject* flattened_result) {
DCHECK(op_exec_info.run_callbacks);
tensorflow::Safe_PyObjectPtr inputs(PyTuple_New(flattened_inputs.size()));
for (int i = 0; i < flattened_inputs.size(); i++) {
PyObject* input = flattened_inputs[i].get();
Py_INCREF(input);
PyTuple_SET_ITEM(inputs.get(), i, input);
}
int num_non_inferred_attrs = PyTuple_GET_SIZE(args) - num_inferred_attrs;
int num_attrs = flattened_attrs.size() + num_non_inferred_attrs;
tensorflow::Safe_PyObjectPtr attrs(PyTuple_New(num_attrs));
for (int i = 0; i < num_non_inferred_attrs; i++) {
auto* attr = PyTuple_GET_ITEM(args, num_inferred_attrs + i);
Py_INCREF(attr);
PyTuple_SET_ITEM(attrs.get(), i, attr);
}
for (int i = num_non_inferred_attrs; i < num_attrs; i++) {
PyObject* attr_or_name =
flattened_attrs.at(i - num_non_inferred_attrs).get();
Py_INCREF(attr_or_name);
PyTuple_SET_ITEM(attrs.get(), i, attr_or_name);
}
if (op_exec_info.run_gradient_callback) {
if (!RecordGradient(op_exec_info.op_name, inputs.get(), attrs.get(),
flattened_result)) {
return false;
}
}
if (op_exec_info.run_post_exec_callbacks) {
tensorflow::Safe_PyObjectPtr callback_args(
Py_BuildValue("OOOOO", op_exec_info.op_name, inputs.get(), attrs.get(),
flattened_result, op_exec_info.name));
for (Py_ssize_t i = 0; i < PyList_Size(op_exec_info.callbacks); i++) {
PyObject* callback_fn = PyList_GET_ITEM(op_exec_info.callbacks, i);
if (!PyCallable_Check(callback_fn)) {
PyErr_SetString(
PyExc_TypeError,
Printf("expected a function for "
"post execution callback in index %ld, got %s instead",
i, callback_fn->ob_type->tp_name)
.c_str());
return false;
}
PyObject* callback_result =
PyObject_CallObject(callback_fn, callback_args.get());
if (!callback_result) {
return false;
}
Py_DECREF(callback_result);
}
}
return true;
}
} // namespace
PyObject* TFE_Py_FastPathExecute_C(PyObject* args) {
tensorflow::profiler::TraceMe activity(
"TFE_Py_FastPathExecute_C", tensorflow::profiler::TraceMeLevel::kInfo);
Py_ssize_t args_size = PyTuple_GET_SIZE(args);
if (args_size < kFastPathExecuteInputStartIndex) {
PyErr_SetString(
PyExc_ValueError,
Printf("There must be at least %d items in the input tuple.",
kFastPathExecuteInputStartIndex)
.c_str());
return nullptr;
}
FastPathOpExecInfo op_exec_info;
TFE_Context* ctx = reinterpret_cast<TFE_Context*>(
PyCapsule_GetPointer(PyTuple_GET_ITEM(args, 0), nullptr));
op_exec_info.ctx = ctx;
op_exec_info.args = args;
if (ctx == nullptr) {
// The context hasn't been initialized. It will be in the slow path.
RaiseFallbackException(
"This function does not handle the case of the path where "
"all inputs are not already EagerTensors.");
return nullptr;
}
op_exec_info.device_name = GetDeviceName(PyTuple_GET_ITEM(args, 1));
op_exec_info.op_name = PyTuple_GET_ITEM(args, 2);
op_exec_info.name = PyTuple_GET_ITEM(args, 3);
op_exec_info.callbacks = PyTuple_GET_ITEM(args, 4);
// TODO(nareshmodi): Add a benchmark for the fast-path with gradient callbacks
// (similar to benchmark_tf_gradient_function_*). Also consider using an
// InlinedVector for flattened_attrs and flattened_inputs if the benchmarks
// point out problems with heap allocs.
op_exec_info.run_gradient_callback =
!*ThreadTapeIsStopped() && HasAccumulatorOrTape();
op_exec_info.run_post_exec_callbacks =
op_exec_info.callbacks != Py_None &&
PyList_Size(op_exec_info.callbacks) > 0;
op_exec_info.run_callbacks = op_exec_info.run_gradient_callback ||
op_exec_info.run_post_exec_callbacks;
TF_Status* status = GetStatus();
const char* op_name = TFE_GetPythonString(op_exec_info.op_name);
if (op_name == nullptr) {
PyErr_SetString(PyExc_TypeError,
Printf("expected a string for op_name, got %s instead",
op_exec_info.op_name->ob_type->tp_name)
.c_str());
return nullptr;
}
TFE_Op* op = GetOp(ctx, op_name, op_exec_info.device_name, status);
auto cleaner = tensorflow::gtl::MakeCleanup([status, ctx, op] {
ReturnStatus(status);
ReturnOp(ctx, op);
});
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return nullptr;
}
const tensorflow::OpDef* op_def = op->operation->OpDef();
if (op_def == nullptr) return nullptr;
if (args_size < kFastPathExecuteInputStartIndex + op_def->input_arg_size()) {
PyErr_SetString(
PyExc_ValueError,
Printf("Tuple size smaller than intended. Expected to be at least %d, "
"was %ld",
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
args_size)
.c_str());
return nullptr;
}
if (!CheckInputsOk(args, kFastPathExecuteInputStartIndex, *op_def)) {
RaiseFallbackException(
"This function does not handle the case of the path where "
"all inputs are not already EagerTensors.");
return nullptr;
}
op_exec_info.attr_to_inputs_map = GetAttrToInputsMapHoldingGIL(*op_def);
op_exec_info.default_dtypes = GetAttrToDefaultsMapHoldingGIL(*op_def);
// Mapping of attr name to size - used to calculate the number of values
// to be expected by the TFE_Execute run.
tensorflow::gtl::FlatMap<string, tensorflow::int64> attr_list_sizes;
// Set non-inferred attrs, including setting defaults if the attr is passed in
// as None.
for (int i = kFastPathExecuteInputStartIndex + op_def->input_arg_size();
i < args_size; i += 2) {
PyObject* py_attr_name = PyTuple_GET_ITEM(args, i);
const char* attr_name = TFE_GetPythonString(py_attr_name);
PyObject* py_attr_value = PyTuple_GET_ITEM(args, i + 1);
// Not creating an index since most of the time there are not more than a
// few attrs.
// TODO(nareshmodi): Maybe include the index as part of the
// OpRegistrationData.
for (const auto& attr : op_def->attr()) {
if (tensorflow::StringPiece(attr_name) == attr.name()) {
SetOpAttrWithDefaults(ctx, op, attr, attr_name, py_attr_value,
&attr_list_sizes, status);
if (!status->status.ok()) {
VLOG(1) << "Falling back to slow path for Op \"" << op_def->name()
<< "\" since we are unable to set the value for attr \""
<< attr.name() << "\" due to: " << TF_Message(status);
RaiseFallbackException(TF_Message(status));
return nullptr;
}
break;
}
}
}
// Flat attrs and inputs as required by the record_gradient call. The attrs
// here only contain inferred attrs (non-inferred attrs are added directly
// from the input args).
// All items in flattened_attrs and flattened_inputs contain
// Safe_PyObjectPtr - any time something steals a reference to this, it must
// INCREF.
// TODO(nareshmodi): figure out why PyList_New/PyList_Append don't work
// directly.
std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_attrs =
nullptr;
std::unique_ptr<std::vector<tensorflow::Safe_PyObjectPtr>> flattened_inputs =
nullptr;
// TODO(nareshmodi): Encapsulate callbacks information into a struct.
if (op_exec_info.run_callbacks) {
flattened_attrs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
flattened_inputs.reset(new std::vector<tensorflow::Safe_PyObjectPtr>);
}
// Add inferred attrs and inputs.
// The following code might set duplicate type attrs. This will result in
// the CacheKey for the generated AttrBuilder possibly differing from
// those where the type attrs are correctly set. Inconsistent CacheKeys
// for ops means that there might be unnecessarily duplicated kernels.
// TODO(nareshmodi): Fix this.
for (int i = 0; i < op_def->input_arg_size(); i++) {
const auto& input_arg = op_def->input_arg(i);
PyObject* input =
PyTuple_GET_ITEM(args, kFastPathExecuteInputStartIndex + i);
if (!input_arg.number_attr().empty()) {
// The item is a homogeneous list.
if (!RaiseIfNotPySequence(input, input_arg.number_attr())) return nullptr;
tensorflow::Safe_PyObjectPtr fast_input(
PySequence_Fast(input, "Could not parse sequence."));
if (fast_input.get() == nullptr) {
return nullptr;
}
Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
TFE_OpSetAttrInt(op, input_arg.number_attr().data(), len);
if (op_exec_info.run_callbacks) {
flattened_attrs->emplace_back(
GetPythonObjectFromString(input_arg.number_attr()));
flattened_attrs->emplace_back(PyLong_FromLong(len));
}
attr_list_sizes[input_arg.number_attr()] = len;
if (len > 0) {
// First item adds the type attr.
if (!AddInputToOp(&op_exec_info, fast_input_array[0], true, input_arg,
flattened_attrs.get(), flattened_inputs.get(), op,
status)) {
return nullptr;
}
for (Py_ssize_t j = 1; j < len; j++) {
// Since the list is homogeneous, we don't need to re-add the attr.
if (!AddInputToOp(&op_exec_info, fast_input_array[j], false,
input_arg, nullptr /* flattened_attrs */,
flattened_inputs.get(), op, status)) {
return nullptr;
}
}
}
} else if (!input_arg.type_list_attr().empty()) {
// The item is a heterogeneous list.
if (!RaiseIfNotPySequence(input, input_arg.type_list_attr())) {
return nullptr;
}
tensorflow::Safe_PyObjectPtr fast_input(
PySequence_Fast(input, "Could not parse sequence."));
if (fast_input.get() == nullptr) {
return nullptr;
}
const string& attr_name = input_arg.type_list_attr();
Py_ssize_t len = PySequence_Fast_GET_SIZE(fast_input.get());
PyObject** fast_input_array = PySequence_Fast_ITEMS(fast_input.get());
tensorflow::gtl::InlinedVector<TF_DataType, 4> attr_value(len);
PyObject* py_attr_value = nullptr;
if (op_exec_info.run_callbacks) {
py_attr_value = PyTuple_New(len);
}
for (Py_ssize_t j = 0; j < len; j++) {
PyObject* py_input = fast_input_array[j];
tensorflow::Safe_PyObjectPtr py_eager_tensor;
if (!ConvertToTensor(
op_exec_info, py_input, &py_eager_tensor,
[]() { return tensorflow::DT_INVALID; },
[](const tensorflow::DataType dtype) {}, status)) {
return nullptr;
}
TFE_TensorHandle* input_handle =
EagerTensor_Handle(py_eager_tensor.get());
attr_value[j] = TFE_TensorHandleDataType(input_handle);
TFE_OpAddInput(op, input_handle, status);
if (MaybeRaiseExceptionFromTFStatus(status, nullptr)) {
return nullptr;
}
if (op_exec_info.run_callbacks) {
flattened_inputs->emplace_back(std::move(py_eager_tensor));
PyTuple_SET_ITEM(py_attr_value, j, PyLong_FromLong(attr_value[j]));
}
}
if (op_exec_info.run_callbacks) {
flattened_attrs->emplace_back(GetPythonObjectFromString(attr_name));
flattened_attrs->emplace_back(py_attr_value);
}
TFE_OpSetAttrTypeList(op, attr_name.data(), attr_value.data(),
attr_value.size());
attr_list_sizes[attr_name] = len;
} else {
// The item is a single item.
if (!AddInputToOp(&op_exec_info, input, true, input_arg,
flattened_attrs.get(), flattened_inputs.get(), op,
status)) {
return nullptr;
}
}
}
int num_retvals = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
const auto& output_arg = op_def->output_arg(i);
int delta = 1;
if (!output_arg.number_attr().empty()) {
delta = attr_list_sizes[output_arg.number_attr()];
} else if (!output_arg.type_list_attr().empty()) {
delta = attr_list_sizes[output_arg.type_list_attr()];
}
if (delta < 0) {
RaiseFallbackException(
"Attributes suggest that the size of an output list is less than 0");
return nullptr;
}
num_retvals += delta;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals(num_retvals);
Py_BEGIN_ALLOW_THREADS;
TFE_Execute(op, retvals.data(), &num_retvals, status);
Py_END_ALLOW_THREADS;
if (!status->status.ok()) {
// Augment the status with the op_name for easier debugging similar to
// TFE_Py_Execute.
TF_SetStatus(status, TF_GetCode(status),
tensorflow::strings::StrCat(
TF_Message(status),
" [Op:", TFE_GetPythonString(op_exec_info.op_name), "]")
.c_str());
MaybeRaiseExceptionFromTFStatus(status, nullptr);
return nullptr;
}
tensorflow::Safe_PyObjectPtr flat_result(PyList_New(num_retvals));
for (int i = 0; i < num_retvals; ++i) {
PyList_SET_ITEM(flat_result.get(), i, EagerTensorFromHandle(retvals[i]));
}
if (op_exec_info.run_callbacks) {
if (!RunCallbacks(
op_exec_info, args,
kFastPathExecuteInputStartIndex + op_def->input_arg_size(),
*flattened_inputs, *flattened_attrs, flat_result.get())) {
return nullptr;
}
}
// Unflatten results.
if (op_def->output_arg_size() == 0) {
Py_RETURN_NONE;
}
if (op_def->output_arg_size() == 1) {
if (!op_def->output_arg(0).number_attr().empty() ||
!op_def->output_arg(0).type_list_attr().empty()) {
return flat_result.release();
} else {
auto* result = PyList_GET_ITEM(flat_result.get(), 0);
Py_INCREF(result);
return result;
}
}
// Correctly output the results that are made into a namedtuple.
PyObject* result = PyList_New(op_def->output_arg_size());
int flat_result_index = 0;
for (int i = 0; i < op_def->output_arg_size(); i++) {
if (!op_def->output_arg(i).number_attr().empty()) {
int list_length = attr_list_sizes[op_def->output_arg(i).number_attr()];
PyObject* inner_list = PyList_New(list_length);
for (int j = 0; j < list_length; j++) {
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
Py_INCREF(obj);
PyList_SET_ITEM(inner_list, j, obj);
}
PyList_SET_ITEM(result, i, inner_list);
} else if (!op_def->output_arg(i).type_list_attr().empty()) {
int list_length = attr_list_sizes[op_def->output_arg(i).type_list_attr()];
PyObject* inner_list = PyList_New(list_length);
for (int j = 0; j < list_length; j++) {
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
Py_INCREF(obj);
PyList_SET_ITEM(inner_list, j, obj);
}
PyList_SET_ITEM(result, i, inner_list);
} else {
PyObject* obj = PyList_GET_ITEM(flat_result.get(), flat_result_index++);
Py_INCREF(obj);
PyList_SET_ITEM(result, i, obj);
}
}
return result;
}
PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
PyObject* attrs, PyObject* results,
PyObject* forward_pass_name_scope) {
if (*ThreadTapeIsStopped() || !HasAccumulatorOrTape()) {
Py_RETURN_NONE;
}
return RecordGradient(op_name, inputs, attrs, results,
forward_pass_name_scope);
}
namespace {
const char kTensor[] = "T";
const char kList[] = "L";
const char kListEnd[] = "l";
const char kTuple[] = "U";
const char kTupleEnd[] = "u";
const char kDict[] = "D";
const char kRaw[] = "R";
const char kShape[] = "s";
const char kShapeDelim[] = "-";
const char kDType[] = "d";
const char kNone[] = "n";
const char kCompositeTensor[] = "C";
const char kAttrs[] = "A";
const char kAttrsEnd[] = "a";
struct EncodeResult {
string str;
std::vector<PyObject*> objects;
PyObject* ToPyTuple() {
PyObject* result = PyTuple_New(2);
PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str));
if (objects.empty()) {
Py_INCREF(Py_None);
PyTuple_SET_ITEM(result, 1, Py_None);
} else {
PyObject* objects_tuple = PyTuple_New(objects.size());
for (int i = 0; i < objects.size(); i++) {
PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
}
PyTuple_SET_ITEM(result, 1, objects_tuple);
}
return result;
}
};
tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result) {
if (EagerTensor_CheckExact(arg)) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
absl::StrAppend(&result->str, kDType,
static_cast<tensorflow::DataType>(t->handle->DataType()));
absl::StrAppend(&result->str, kShape);
int num_dims;
tensorflow::Status status = t->handle->NumDims(&num_dims);
if (!status.ok()) return status;
if (include_tensor_ranks_only) {
absl::StrAppend(&result->str, num_dims);
} else {
for (int i = 0; i < num_dims; ++i) {
tensorflow::int64 dim_size;
status = t->handle->Dim(i, &dim_size);
if (!status.ok()) return status;
absl::StrAppend(&result->str, dim_size, kShapeDelim);
}
}
return tensorflow::Status::OK();
}
tensorflow::Safe_PyObjectPtr dtype_object(
PyObject_GetAttrString(arg, "dtype"));
if (dtype_object == nullptr) {
return tensorflow::errors::InvalidArgument(
"ops.Tensor object doesn't have dtype() attr.");
}
tensorflow::Safe_PyObjectPtr dtype_enum(
PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
if (dtype_enum == nullptr) {
return tensorflow::errors::InvalidArgument(
"ops.Tensor's dtype object doesn't have _type_enum() attr.");
}
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
absl::StrAppend(&result->str, kDType, dtype);
static char _shape_tuple[] = "_shape_tuple";
tensorflow::Safe_PyObjectPtr shape_tuple(
PyObject_CallMethod(arg, _shape_tuple, nullptr));
if (shape_tuple == nullptr) {
return tensorflow::errors::InvalidArgument(
"ops.Tensor object doesn't have _shape_tuple() method.");
}
if (shape_tuple.get() == Py_None) {
// Unknown shape, encode that directly.
absl::StrAppend(&result->str, kNone);
return tensorflow::Status::OK();
}
absl::StrAppend(&result->str, kShape);
tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
shape_tuple.get(), "shape_tuple didn't return a sequence"));
int len = PySequence_Fast_GET_SIZE(shape_seq.get());
PyObject** shape_seq_array = PySequence_Fast_ITEMS(shape_seq.get());
if (include_tensor_ranks_only) {
absl::StrAppend(&result->str, len);
} else {
for (int i = 0; i < len; ++i) {
PyObject* item = shape_seq_array[i];
if (item == Py_None) {
absl::StrAppend(&result->str, kNone);
} else {
absl::StrAppend(&result->str, MakeInt(item));
}
}
}
return tensorflow::Status::OK();
}
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result);
// This function doesn't set the type of sequence before
tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
const char* end_type,
bool include_tensor_ranks_only,
EncodeResult* result) {
tensorflow::Safe_PyObjectPtr arg_seq(
PySequence_Fast(arg, "unable to create seq from list/tuple"));
absl::StrAppend(&result->str, type);
int len = PySequence_Fast_GET_SIZE(arg_seq.get());
PyObject** arg_seq_array = PySequence_Fast_ITEMS(arg_seq.get());
for (int i = 0; i < len; ++i) {
PyObject* item = arg_seq_array[i];
if (item == Py_None) {
absl::StrAppend(&result->str, kNone);
} else {
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(item, include_tensor_ranks_only, result));
}
}
absl::StrAppend(&result->str, end_type);
return tensorflow::Status::OK();
}
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg,
bool include_tensor_ranks_only,
EncodeResult* result) {
if (tensorflow::swig::IsTensor(arg)) {
absl::StrAppend(&result->str, kTensor);
TF_RETURN_IF_ERROR(
TFE_Py_EncodeTensor(arg, include_tensor_ranks_only, result));
} else if (PyList_Check(arg)) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
arg, kList, kListEnd, include_tensor_ranks_only, result));
} else if (tensorflow::swig::IsTuple(arg)) {
TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(
arg, kTuple, kTupleEnd, include_tensor_ranks_only, result));
} else if (tensorflow::swig::IsMapping(arg)) {
tensorflow::Safe_PyObjectPtr keys(tensorflow::swig::MappingKeys(arg));
if (PyList_Sort(keys.get()) == -1) {
return tensorflow::errors::Internal("Unable to sort keys");
}
absl::StrAppend(&result->str, kDict);
int len = PyList_Size(keys.get());
for (int i = 0; i < len; i++) {
PyObject* key = PyList_GetItem(keys.get(), i);
TF_RETURN_IF_ERROR(
TFE_Py_EncodeArgHelper(key, include_tensor_ranks_only, result));
tensorflow::Safe_PyObjectPtr value(PyObject_GetItem(arg, key));
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
value.get(), include_tensor_ranks_only, result));
}
} else if (tensorflow::swig::IsCompositeTensor(arg)) {
absl::StrAppend(&result->str, kCompositeTensor);
// Add the typespec to the list of objects. (Do *not* use a weakref,
// since the type spec is often a temporary object.)
PyObject* type_spec(PyObject_GetAttrString(arg, "_type_spec"));
if (type_spec == nullptr) {
return tensorflow::errors::InvalidArgument(
"Error while reading CompositeTensor._type_spec.");
}
result->objects.push_back(type_spec);
} else if (tensorflow::swig::IsAttrs(arg)) {
absl::StrAppend(&result->str, kAttrs);
tensorflow::Safe_PyObjectPtr attrs(
PyObject_GetAttrString(arg, "__attrs_attrs__"));
tensorflow::Safe_PyObjectPtr iter(PyObject_GetIter(attrs.get()));
for (tensorflow::Safe_PyObjectPtr item(PyIter_Next(iter.get())); item;
item.reset(PyIter_Next(iter.get()))) {
tensorflow::Safe_PyObjectPtr name(
PyObject_GetAttrString(item.get(), "name"));
tensorflow::Safe_PyObjectPtr attr_arg(PyObject_GetAttr(arg, name.get()));
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(
attr_arg.get(), include_tensor_ranks_only, result));
}
absl::StrAppend(&result->str, kAttrsEnd);
} else {
PyObject* object = PyWeakref_NewRef(arg, nullptr);
if (object == nullptr) {
PyErr_Clear();
object = arg;
Py_INCREF(object);
}
absl::StrAppend(&result->str, kRaw);
result->objects.push_back(object);
}
return tensorflow::Status::OK();
}
} // namespace
// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
// are used for both performance reasons, as much TensorFlow code specializes
// on known shapes to produce slimmer graphs, and correctness, as some
// high-level APIs require shapes to be fully-known.
//
// `include_tensor_ranks_only` allows caching on arguments excluding shape info,
// so that a slow path using relaxed shape can rely on a cache key that excludes
// shapes.
PyObject* TFE_Py_EncodeArg(PyObject* arg, bool include_tensor_ranks_only) {
EncodeResult result;
const auto status =
TFE_Py_EncodeArgHelper(arg, include_tensor_ranks_only, &result);
if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
return nullptr;
}
return result.ToPyTuple();
}
// A method prints incoming messages directly to Python's
// stdout using Python's C API. This is necessary in Jupyter notebooks
// and colabs where messages to the C stdout don't go to the notebook
// cell outputs, but calls to Python's stdout do.
void PrintToPythonStdout(const char* msg) {
if (Py_IsInitialized()) {
PyGILState_STATE py_threadstate;
py_threadstate = PyGILState_Ensure();
string string_msg = msg;
// PySys_WriteStdout truncates strings over 1000 bytes, so
// we write the message in chunks small enough to not be truncated.
int CHUNK_SIZE = 900;
auto len = string_msg.length();
for (int i = 0; i < len; i += CHUNK_SIZE) {
PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
}
// Force flushing to make sure print newlines aren't interleaved in
// some colab environments
PyRun_SimpleString("import sys; sys.stdout.flush()");
PyGILState_Release(py_threadstate);
}
}
// Register PrintToPythonStdout as a log listener, to allow
// printing in colabs and jupyter notebooks to work.
void TFE_Py_EnableInteractivePythonLogging() {
static bool enabled_interactive_logging = false;
if (!enabled_interactive_logging) {
enabled_interactive_logging = true;
TF_RegisterLogListener(PrintToPythonStdout);
}
}
namespace {
// weak reference to Python Context object currently active
PyObject* weak_eager_context = nullptr;
} // namespace
PyObject* TFE_Py_SetEagerContext(PyObject* py_context) {
Py_XDECREF(weak_eager_context);
weak_eager_context = PyWeakref_NewRef(py_context, nullptr);
if (weak_eager_context == nullptr) {
return nullptr;
}
Py_RETURN_NONE;
}
PyObject* GetPyEagerContext() {
if (weak_eager_context == nullptr) {
PyErr_SetString(PyExc_RuntimeError, "Python eager context is not set");
return nullptr;
}
PyObject* py_context = PyWeakref_GET_OBJECT(weak_eager_context);
if (py_context == Py_None) {
PyErr_SetString(PyExc_RuntimeError, "Eager context has been destroyed");
return nullptr;
}
Py_INCREF(py_context);
return py_context;
}