Fix segfault if sequence raises an exception when nest.flatten accesses an element.

PiperOrigin-RevId: 269665689
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 6ec4a5d..000320b 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -55,6 +55,15 @@
     return len(self._wrapped)
 
 
+class _CustomSequenceThatRaisesException(collections.Sequence):
+
+  def __len__(self):
+    return 1
+
+  def __getitem__(self, item):
+    raise ValueError("Cannot get item: %s" % item)
+
+
 class NestTest(parameterized.TestCase, test.TestCase):
 
   PointXY = collections.namedtuple("Point", ["x", "y"])  # pylint: disable=invalid-name
@@ -1209,6 +1218,11 @@
     with self.assertRaises(error_type):
       nest.map_structure_with_tuple_paths(lambda path, *s: 0, s1, s2)
 
+  def testFlattenCustomSequenceThatRaisesException(self):  # b/140746865
+    seq = _CustomSequenceThatRaisesException()
+    with self.assertRaisesRegexp(ValueError, "Cannot get item"):
+      nest.flatten(seq)
+
 
 class NestBenchmark(test.Benchmark):
 
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index e8a9e5e..25b5d94 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -407,7 +407,7 @@
  public:
   explicit SequenceValueIterator(PyObject* iterable)
       : seq_(PySequence_Fast(iterable, "")),
-        size_(PySequence_Fast_GET_SIZE(seq_.get())),
+        size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
         index_(0) {}
 
   Safe_PyObjectPtr next() override {
@@ -416,8 +416,10 @@
       // PySequence_Fast_GET_ITEM returns a borrowed reference.
       PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
       ++index_;
-      Py_INCREF(elem);
-      result.reset(elem);
+      if (elem) {
+        Py_INCREF(elem);
+        result.reset(elem);
+      }
     }
 
     return result;