blob: 25b5d9474c88b2753c236fd9432694f3b3f8c062 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/util/util.h"
#include <functional>
#include <memory>
#include <unordered_map>
#include <vector>
#include "absl/memory/memory.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
namespace swig {
std::unordered_map<string, PyObject*>* PythonTypesMap() {
static auto* m = new std::unordered_map<string, PyObject*>();
return m;
}
PyObject* GetRegisteredType(const string& key) {
auto* m = PythonTypesMap();
auto it = m->find(key);
if (it == m->end()) return nullptr;
return it->second;
}
PyObject* RegisterType(PyObject* type_name, PyObject* type) {
if (!PyType_Check(type)) {
PyErr_SetString(PyExc_TypeError,
tensorflow::strings::StrCat("Expecting a type, got ",
Py_TYPE(type)->tp_name)
.c_str());
return nullptr;
}
string key;
if (PyBytes_Check(type_name)) {
key = PyBytes_AsString(type_name);
}
#if PY_MAJOR_VERSION >= 3
if (PyUnicode_Check(type_name)) {
key = PyUnicode_AsUTF8(type_name);
}
#endif
if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
"Type already registered for ", key)
.c_str());
return nullptr;
}
Py_INCREF(type);
PythonTypesMap()->emplace(key, type);
Py_RETURN_NONE;
}
namespace {
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
bool IsString(PyObject* o) {
return PyBytes_Check(o) ||
#if PY_MAJOR_VERSION < 3
PyString_Check(o) ||
#endif
PyUnicode_Check(o);
}
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
// we only support new-style classes.
StringPiece GetClassName(PyObject* o) {
// __class__ is equivalent to type() for new style classes.
// type() is equivalent to PyObject_Type()
// (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
// PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
// we don't need here.
PyTypeObject* type = o->ob_type;
// __name__ is the value of `tp_name` after the last '.'
// (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
StringPiece name(type->tp_name);
size_t pos = name.rfind('.');
if (pos != StringPiece::npos) {
name.remove_prefix(pos + 1);
}
return name;
}
string PyObjectToString(PyObject* o) {
if (o == nullptr) {
return "<null object>";
}
PyObject* str = PyObject_Str(o);
if (str) {
#if PY_MAJOR_VERSION < 3
string s(PyString_AS_STRING(str));
#else
string s(PyUnicode_AsUTF8(str));
#endif
Py_DECREF(str);
return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
} else {
return "<failed to execute str() on object>";
}
}
class CachedTypeCheck {
public:
explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
: ternary_predicate_(std::move(ternary_predicate)) {}
~CachedTypeCheck() {
mutex_lock l(type_to_sequence_map_mu_);
for (const auto& pair : type_to_sequence_map_) {
Py_DECREF(pair.first);
}
}
// Caches successful executions of the one-argument (PyObject*) callable
// "ternary_predicate" based on the type of "o". -1 from the callable
// indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
// does not match the predicate, and 1 indicates that it does. Used to avoid
// calling back into Python for expensive isinstance checks.
int CachedLookup(PyObject* o) {
// Try not to return to Python - see if the type has already been seen
// before.
auto* type = Py_TYPE(o);
{
tf_shared_lock l(type_to_sequence_map_mu_);
auto it = type_to_sequence_map_.find(type);
if (it != type_to_sequence_map_.end()) {
return it->second;
}
}
int check_result = ternary_predicate_(o);
if (check_result == -1) {
return -1; // Type check error, not cached.
}
// NOTE: This is never decref'd as long as the object lives, which is likely
// forever, but we don't want the type to get deleted as long as it is in
// the map. This should not be too much of a leak, as there should only be a
// relatively small number of types in the map, and an even smaller number
// that are eligible for decref. As a precaution, we limit the size of the
// map to 1024.
{
mutex_lock l(type_to_sequence_map_mu_);
if (type_to_sequence_map_.size() < kMaxItemsInCache) {
Py_INCREF(type);
auto insert_result = type_to_sequence_map_.insert({type, check_result});
if (!insert_result.second) {
// The type was added to the cache by a concurrent thread after we
// looked it up above.
Py_DECREF(type);
}
}
}
return check_result;
}
private:
std::function<int(PyObject*)> ternary_predicate_;
mutex type_to_sequence_map_mu_;
std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
GUARDED_BY(type_to_sequence_map_mu_);
};
// Returns 1 if 'obj' is an instance of 'type_name'
// Returns 0 otherwise.
// Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
PyObject* type_obj = GetRegisteredType(type_name);
if (TF_PREDICT_FALSE(type_obj == nullptr)) {
PyErr_SetString(PyExc_RuntimeError,
tensorflow::strings::StrCat(
type_name,
" type has not been set. "
"Please register the type with the identifier \"",
type_name, "\" using RegisterType.")
.c_str());
return -1;
}
return PyObject_IsInstance(obj, type_obj);
}
// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "Mapping");
});
if (PyDict_Check(o)) return true;
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsMappingViewHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "MappingView");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
if (cls) {
return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
}
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an object of type IndexedSlices.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsIndexedSlicesHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a Tensor.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsTensorHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "Tensor");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a ResourceVariable.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsResourceVariableHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a ResourceVariable.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsVariableHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "Variable");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsSequenceHelper(PyObject* o) {
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
if (IsMappingViewHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, "
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
// Don't cache a failed is_instance check.
if (is_instance == -1) return -1;
return static_cast<int>(is_instance != 0 && !IsString(to_check));
});
return check_cache->CachedLookup(o);
}
// ValueIterator interface
class ValueIterator {
public:
virtual ~ValueIterator() {}
virtual Safe_PyObjectPtr next() = 0;
bool valid() const { return is_valid_; }
protected:
void invalidate() { is_valid_ = false; }
private:
bool is_valid_ = true;
};
using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
// Iterate through dictionaries in a deterministic order by sorting the
// keys. Notice this means that we ignore the original order of
// `OrderedDict` instances. This is intentional, to avoid potential
// bugs caused by mixing ordered and plain dicts (e.g., flattening
// a dict but using a corresponding `OrderedDict` to pack it back).
class DictValueIterator : public ValueIterator {
public:
explicit DictValueIterator(PyObject* dict)
: dict_(dict), keys_(PyDict_Keys(dict)) {
if (PyList_Sort(keys_.get()) == -1) {
invalidate();
} else {
iter_.reset(PyObject_GetIter(keys_.get()));
}
}
Safe_PyObjectPtr next() override {
Safe_PyObjectPtr result;
Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
if (key) {
// PyDict_GetItem returns a borrowed reference.
PyObject* elem = PyDict_GetItem(dict_, key.get());
if (elem) {
Py_INCREF(elem);
result.reset(elem);
} else {
PyErr_SetString(PyExc_RuntimeError,
"Dictionary was modified during iteration over it");
}
}
return result;
}
private:
PyObject* dict_;
Safe_PyObjectPtr keys_;
Safe_PyObjectPtr iter_;
};
// Iterate over mapping objects by sorting the keys first
class MappingValueIterator : public ValueIterator {
public:
explicit MappingValueIterator(PyObject* mapping)
: mapping_(mapping), keys_(MappingKeys(mapping)) {
if (!keys_ || PyList_Sort(keys_.get()) == -1) {
invalidate();
} else {
iter_.reset(PyObject_GetIter(keys_.get()));
}
}
Safe_PyObjectPtr next() override {
Safe_PyObjectPtr result;
Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
if (key) {
// Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
PyObject* elem = PyObject_GetItem(mapping_, key.get());
if (elem) {
result.reset(elem);
} else {
PyErr_SetString(PyExc_RuntimeError,
"Mapping was modified during iteration over it");
}
}
return result;
}
private:
PyObject* mapping_;
Safe_PyObjectPtr keys_;
Safe_PyObjectPtr iter_;
};
// Iterate over a sequence, by index.
class SequenceValueIterator : public ValueIterator {
public:
explicit SequenceValueIterator(PyObject* iterable)
: seq_(PySequence_Fast(iterable, "")),
size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
index_(0) {}
Safe_PyObjectPtr next() override {
Safe_PyObjectPtr result;
if (index_ < size_) {
// PySequence_Fast_GET_ITEM returns a borrowed reference.
PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
++index_;
if (elem) {
Py_INCREF(elem);
result.reset(elem);
}
}
return result;
}
private:
Safe_PyObjectPtr seq_;
const Py_ssize_t size_;
Py_ssize_t index_;
};
// Iterator that just returns a single python object.
class SingleValueIterator : public ValueIterator {
public:
explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
Safe_PyObjectPtr next() override { return std::move(x_); }
private:
Safe_PyObjectPtr x_;
};
// Returns nullptr (to raise an exception) when next() is called. Caller
// should have already called PyErr_SetString.
class ErrorValueIterator : public ValueIterator {
public:
ErrorValueIterator() {}
Safe_PyObjectPtr next() override { return nullptr; }
};
class AttrsValueIterator : public ValueIterator {
public:
explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
Py_INCREF(nested);
cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
if (cls_) {
attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
if (attrs_) {
iter_.reset(PyObject_GetIter(attrs_.get()));
}
}
if (!iter_ || PyErr_Occurred()) invalidate();
}
Safe_PyObjectPtr next() override {
Safe_PyObjectPtr result;
Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
if (item) {
Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
result.reset(PyObject_GetAttr(nested_.get(), name.get()));
}
return result;
}
private:
Safe_PyObjectPtr nested_;
Safe_PyObjectPtr cls_;
Safe_PyObjectPtr attrs_;
Safe_PyObjectPtr iter_;
};
bool IsSparseTensorValueType(PyObject* o) {
PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}
return PyObject_TypeCheck(
o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
}
// Returns 1 if `o` is an instance of CompositeTensor.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
bool IsCompositeTensorHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
bool IsTypeSpecHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
int is_tensor_spec = IsInstanceOfRegisteredType(to_check, "TensorSpec");
if ((is_type_spec == -1) || (is_tensor_spec == -1)) return -1;
return static_cast<int>(is_type_spec && !is_tensor_spec);
});
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
// (non-TensorSpec) TypeSpec.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsSequenceOrCompositeHelper(PyObject* o) {
int is_sequence = IsSequenceHelper(o);
int is_composite = IsCompositeTensorHelper(o);
int is_type_spec = IsTypeSpecHelper(o);
if ((is_sequence == -1) || (is_composite == -1) || (is_type_spec == -1)) {
return -1;
}
return is_sequence || is_composite || is_type_spec;
}
int IsSequenceForDataHelper(PyObject* o) {
return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
!IsSparseTensorValueType(o);
}
ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested);
} else if (IsAttrsHelper(nested)) {
return absl::make_unique<AttrsValueIterator>(nested);
} else {
return absl::make_unique<SequenceValueIterator>(nested);
}
}
// Similar to above, just specialized for the functions in the data package.
ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
if (PyDict_Check(nested)) {
return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested);
} else if (IsAttrsHelper(nested)) {
return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
return absl::make_unique<SingleValueIterator>(nested);
} else {
return absl::make_unique<SequenceValueIterator>(nested);
}
}
// Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
if (IsCompositeTensor(nested)) {
Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
if (PyErr_Occurred() || !spec) {
return absl::make_unique<ErrorValueIterator>();
}
static char to_components[] = "_to_components";
static char argspec[] = "(O)";
Safe_PyObjectPtr components(
PyObject_CallMethod(spec.get(), to_components, argspec, nested));
if (PyErr_Occurred() || components == nullptr) {
return absl::make_unique<ErrorValueIterator>();
}
return absl::make_unique<SingleValueIterator>(components.get());
}
if (IsTypeSpec(nested)) {
Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
if (PyErr_Occurred() || specs == nullptr) {
return absl::make_unique<ErrorValueIterator>();
}
return absl::make_unique<SingleValueIterator>(specs.get());
}
return GetValueIterator(nested);
}
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function<int(PyObject*)>& is_sequence_helper,
const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
if (!is_seq) {
return PyList_Append(list, nested) != -1;
}
ValueIteratorPtr iter = value_iterator_getter(nested);
if (!iter->valid()) return false;
for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
}
}
return true;
}
// Sets error using keys of 'dict1' and 'dict2'.
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
Safe_PyObjectPtr k1(MappingKeys(dict1));
if (PyErr_Occurred() || k1.get() == nullptr) {
*error_msg =
("The two dictionaries don't have the same set of keys. Failed to "
"fetch keys.");
return;
}
Safe_PyObjectPtr k2(MappingKeys(dict2));
if (PyErr_Occurred() || k2.get() == nullptr) {
*error_msg =
("The two dictionaries don't have the same set of keys. Failed to "
"fetch keys.");
return;
}
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
"First structure has keys ",
PyObjectToString(k1.get()), ", while second structure has keys ",
PyObjectToString(k2.get()));
}
// Returns true iff there were no "internal" errors. In other words,
// errors that has nothing to do with structure checking.
// If an "internal" error occurred, the appropriate Python error will be
// set and the caller can propage it directly to the user.
//
// Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
// be empty.
// Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
// with appropriate error and sets `is_type_error` to true iff
// the error to be raised should be TypeError.
bool AssertSameStructureHelper(
PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
bool* is_type_error,
const std::function<int(PyObject*)>& is_sequence_helper,
const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
bool check_composite_tensor_type_spec) {
DCHECK(error_msg);
DCHECK(is_type_error);
const bool is_seq1 = is_sequence_helper(o1);
const bool is_seq2 = is_sequence_helper(o2);
if (PyErr_Occurred()) return false;
if (is_seq1 != is_seq2) {
string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"Substructure \"", seq_str, "\" is a sequence, while substructure \"",
non_seq_str, "\" is not");
return true;
}
// Got to objects that are considered non-sequences. Note that in tf.data
// use case lists and sparse_tensors are not considered sequences. So finished
// checking, structures are the same.
if (!is_seq1) return true;
if (check_types) {
const PyTypeObject* type1 = o1->ob_type;
const PyTypeObject* type2 = o2->ob_type;
// We treat two different namedtuples with identical name and fields
// as having the same type.
const PyObject* o1_tuple = IsNamedtuple(o1, true);
if (o1_tuple == nullptr) return false;
const PyObject* o2_tuple = IsNamedtuple(o2, true);
if (o2_tuple == nullptr) {
Py_DECREF(o1_tuple);
return false;
}
bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
Py_DECREF(o1_tuple);
Py_DECREF(o2_tuple);
if (both_tuples) {
const PyObject* same_tuples = SameNamedtuples(o1, o2);
if (same_tuples == nullptr) return false;
bool not_same_tuples = same_tuples != Py_True;
Py_DECREF(same_tuples);
if (not_same_tuples) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "
"First structure ",
PyObjectToString(o1), " has type ", type1->tp_name,
", while second structure ", PyObjectToString(o2), " has type ",
type2->tp_name);
return true;
}
} else if (type1 != type2
/* If both sequences are list types, don't complain. This allows
one to be a list subclass (e.g. _ListWrapper used for
automatic dependency tracking.) */
&& !(PyList_Check(o1) && PyList_Check(o2))
/* Two mapping types will also compare equal, making _DictWrapper
and dict compare equal. */
&& !(IsMappingHelper(o1) && IsMappingHelper(o2))
/* For CompositeTensor & TypeSpec, we check below. */
&& !(check_composite_tensor_type_spec &&
(IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
(IsTypeSpec(o1) || IsTypeSpec(o2)))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "
"First structure ",
PyObjectToString(o1), " has type ", type1->tp_name,
", while second structure ", PyObjectToString(o2), " has type ",
type2->tp_name);
return true;
}
if (PyDict_Check(o1) && PyDict_Check(o2)) {
if (PyDict_Size(o1) != PyDict_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
PyObject* key;
Py_ssize_t pos = 0;
while (PyDict_Next(o1, &pos, &key, nullptr)) {
if (PyDict_GetItem(o2, key) == nullptr) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
}
} else if (IsMappingHelper(o1)) {
// Fallback for custom mapping types. Instead of using PyDict methods
// which stay in C, we call iter(o1).
if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
}
Safe_PyObjectPtr iter(PyObject_GetIter(o1));
PyObject* key;
while ((key = PyIter_Next(iter.get())) != nullptr) {
if (!PyMapping_HasKey(o2, key)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
Py_DECREF(key);
return true;
}
Py_DECREF(key);
}
}
}
if (check_composite_tensor_type_spec &&
(IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
Safe_PyObjectPtr owned_type_spec_1;
PyObject* type_spec_1 = o1;
if (IsCompositeTensor(o1)) {
owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
type_spec_1 = owned_type_spec_1.get();
}
Safe_PyObjectPtr owned_type_spec_2;
PyObject* type_spec_2 = o2;
if (IsCompositeTensor(o2)) {
owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
type_spec_2 = owned_type_spec_2.get();
}
// Two composite tensors are considered to have the same structure if
// there is some type spec that is compatible with both of them. Thus,
// we use most_specific_compatible_type(), and check if it raises an
// exception. We do *not* use is_compatible_with, since that would
// prevent us from e.g. using a cond statement where the two sides have
// different shapes.
static char compatible_type[] = "most_specific_compatible_type";
static char argspec[] = "(O)";
Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
type_spec_1, compatible_type, argspec, type_spec_2));
if (PyErr_Occurred() || struct_compatible == nullptr) {
PyErr_Clear();
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"Incompatible CompositeTensor TypeSpecs: ",
PyObjectToString(type_spec_1), " vs. ",
PyObjectToString(type_spec_2));
return true;
}
}
ValueIteratorPtr iter1 = value_iterator_getter(o1);
ValueIteratorPtr iter2 = value_iterator_getter(o2);
if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
Safe_PyObjectPtr v1 = iter1->next();
Safe_PyObjectPtr v2 = iter2->next();
if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
bool no_internal_errors = AssertSameStructureHelper(
v1.get(), v2.get(), check_types, error_msg, is_type_error,
is_sequence_helper, value_iterator_getter,
check_composite_tensor_type_spec);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
} else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two structures don't have the same number of elements. ",
"First structure: ", PyObjectToString(o1),
". Second structure: ", PyObjectToString(o2));
return true;
}
}
}
} // namespace
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
bool IsResourceVariable(PyObject* o) {
return IsResourceVariableHelper(o) == 1;
}
bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
// and while we're at it give them consistent behavior by making sure the
// returned value is a list.
//
// As with PyMapping_Keys, returns a new reference.
//
// On failure, returns nullptr.
PyObject* MappingKeys(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
return PyMapping_Keys(o);
#else
static char key_method_name[] = "keys";
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
if (PyErr_Occurred() || raw_result.get() == nullptr) {
return nullptr;
}
return PySequence_Fast(
raw_result.get(),
"The '.keys()' method of a custom mapping returned a non-sequence.");
#endif
}
PyObject* Flatten(PyObject* nested, bool expand_composites) {
PyObject* list = PyList_New(0);
const std::function<int(PyObject*)>& is_sequence_helper =
expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
expand_composites ? GetValueIteratorForComposite : GetValueIterator;
if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
return list;
} else {
Py_DECREF(list);
return nullptr;
}
}
bool IsSequenceOrComposite(PyObject* o) {
return IsSequenceOrCompositeHelper(o) == 1;
}
bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
PyObject* FlattenForData(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceForDataHelper,
GetValueIteratorForData)) {
return list;
} else {
Py_DECREF(list);
return nullptr;
}
}
PyObject* IsNamedtuple(PyObject* o, bool strict) {
// Must be subclass of tuple
if (!PyTuple_Check(o)) {
Py_RETURN_FALSE;
}
// If strict, o.__class__.__base__ must be tuple
if (strict) {
PyObject* klass = PyObject_GetAttrString(o, "__class__");
if (klass == nullptr) return nullptr;
PyObject* base = PyObject_GetAttrString(klass, "__base__");
Py_DECREF(klass);
if (base == nullptr) return nullptr;
const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
// built-in object types are singletons
bool tuple_base = base_type == &PyTuple_Type;
Py_DECREF(base);
if (!tuple_base) {
Py_RETURN_FALSE;
}
}
// o must have attribute '_fields' and every element in
// '_fields' must be a string.
int has_fields = PyObject_HasAttrString(o, "_fields");
if (!has_fields) {
Py_RETURN_FALSE;
}
Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
return nullptr;
}
Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
for (Py_ssize_t i = 0; i < s; ++i) {
// PySequence_Fast_GET_ITEM returns borrowed ref
PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
if (!IsString(elem)) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
if (f1 == nullptr || f2 == nullptr) {
PyErr_SetString(
PyExc_RuntimeError,
"Expected namedtuple-like objects (that have _fields attr)");
return nullptr;
}
if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
Py_RETURN_FALSE;
}
if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
bool expand_composites) {
const std::function<int(PyObject*)>& is_sequence_helper =
expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
expand_composites ? GetValueIteratorForComposite : GetValueIterator;
const bool check_composite_tensor_type_spec = expand_composites;
string error_msg;
bool is_type_error = false;
AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
is_sequence_helper, get_value_iterator,
check_composite_tensor_type_spec);
if (PyErr_Occurred()) {
// Don't hide Python exceptions while checking (e.g. errors fetching keys
// from custom mappings).
return nullptr;
}
if (!error_msg.empty()) {
PyErr_SetString(
is_type_error ? PyExc_TypeError : PyExc_ValueError,
tensorflow::strings::StrCat(
"The two structures don't have the same nested structure.\n\n",
"First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
.c_str());
return nullptr;
}
Py_RETURN_NONE;
}
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types) {
string error_msg;
bool is_type_error = false;
AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
IsSequenceForDataHelper, GetValueIterator, false);
if (PyErr_Occurred()) {
// Don't hide Python exceptions while checking (e.g. errors fetching keys
// from custom mappings).
return nullptr;
}
if (!error_msg.empty()) {
PyErr_SetString(
is_type_error ? PyExc_TypeError : PyExc_ValueError,
tensorflow::strings::StrCat(
"The two structures don't have the same nested structure.\n\n",
"First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
.c_str());
return nullptr;
}
Py_RETURN_NONE;
}
} // namespace swig
} // namespace tensorflow