Fix segfault if sequence raises an exception when nest.flatten accesses an element.
PiperOrigin-RevId: 269665689
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 6ec4a5d..000320b 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -55,6 +55,15 @@
return len(self._wrapped)
+class _CustomSequenceThatRaisesException(collections.Sequence):
+
+ def __len__(self):
+ return 1
+
+ def __getitem__(self, item):
+ raise ValueError("Cannot get item: %s" % item)
+
+
class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
@@ -1209,6 +1218,11 @@
with self.assertRaises(error_type):
nest.map_structure_with_tuple_paths(lambda path, *s: 0, s1, s2)
+ def testFlattenCustomSequenceThatRaisesException(self): # b/140746865
+ seq = _CustomSequenceThatRaisesException()
+ with self.assertRaisesRegexp(ValueError, "Cannot get item"):
+ nest.flatten(seq)
+
class NestBenchmark(test.Benchmark):
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index e8a9e5e..25b5d94 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -407,7 +407,7 @@
public:
explicit SequenceValueIterator(PyObject* iterable)
: seq_(PySequence_Fast(iterable, "")),
- size_(PySequence_Fast_GET_SIZE(seq_.get())),
+ size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
index_(0) {}
Safe_PyObjectPtr next() override {
@@ -416,8 +416,10 @@
// PySequence_Fast_GET_ITEM returns a borrowed reference.
PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
++index_;
- Py_INCREF(elem);
- result.reset(elem);
+ if (elem) {
+ Py_INCREF(elem);
+ result.reset(elem);
+ }
}
return result;