blob: 000320b00580a52ff44f92aa53c066eb23c381ed [file] [log] [blame]
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for utilities working with arbitrarily nested structures."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import time
from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
try:
import attr # pylint:disable=g-import-not-at-top
except ImportError:
attr = None
class _CustomMapping(collections_abc.Mapping):
def __init__(self, *args, **kwargs):
self._wrapped = dict(*args, **kwargs)
def __getitem__(self, key):
return self._wrapped[key]
def __iter__(self):
return iter(self._wrapped)
def __len__(self):
return len(self._wrapped)
class _CustomSequenceThatRaisesException(collections.Sequence):
def __len__(self):
return 1
def __getitem__(self, item):
raise ValueError("Cannot get item: %s" % item)
class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
if attr:
class BadAttr(object):
"""Class that has a non-iterable __attrs_attrs__."""
__attrs_attrs__ = None
@attr.s
class SampleAttr(object):
field1 = attr.ib()
field2 = attr.ib()
@attr.s
class UnsortedSampleAttr(object):
field3 = attr.ib()
field1 = attr.ib()
field2 = attr.ib()
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsFlattenAndPack(self):
if attr is None:
self.skipTest("attr module is unavailable.")
field_values = [1, 2]
sample_attr = NestTest.SampleAttr(*field_values)
self.assertFalse(nest._is_attrs(field_values))
self.assertTrue(nest._is_attrs(sample_attr))
flat = nest.flatten(sample_attr)
self.assertEqual(field_values, flat)
restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
self.assertEqual(restructured_from_flat, sample_attr)
# Check that flatten fails if attributes are not iterable
with self.assertRaisesRegexp(TypeError, "object is not iterable"):
flat = nest.flatten(NestTest.BadAttr())
@parameterized.parameters(
{"values": [1, 2, 3]},
{"values": [{"B": 10, "A": 20}, [1, 2], 3]},
{"values": [(1, 2), [3, 4], 5]},
{"values": [PointXY(1, 2), 3, 4]},
)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAttrsMapStructure(self, values):
if attr is None:
self.skipTest("attr module is unavailable.")
structure = NestTest.UnsortedSampleAttr(*values)
new_structure = nest.map_structure(lambda x: x, structure)
self.assertEqual(structure, new_structure)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
structure = (NestTest.PointXY(x=4, y=2),
((NestTest.PointXY(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
@parameterized.parameters({"mapping_type": collections.OrderedDict},
{"mapping_type": _CustomMapping})
def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
self.assertIsInstance(custom_reconstruction, mapping_type)
self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPackMappingViews(self):
"""`flatten` orders dicts by key, including OrderedDicts."""
ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
# test flattening
ordered_keys_flat = nest.flatten(ordered.keys())
ordered_values_flat = nest.flatten(ordered.values())
ordered_items_flat = nest.flatten(ordered.items())
self.assertEqual([3, 1, 0, 2], ordered_values_flat)
self.assertEqual(["d", "b", "a", "c"], ordered_keys_flat)
self.assertEqual(["d", 3, "b", 1, "a", 0, "c", 2], ordered_items_flat)
# test packing
self.assertEqual([("d", 3), ("b", 1), ("a", 0), ("c", 2)],
nest.pack_sequence_as(ordered.items(), ordered_items_flat))
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack_withDicts(self):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
NestTest.Abc(3, 4), {
"d": _CustomMapping({
41: 4
}),
"c": [
1,
collections.OrderedDict([
("b", 3),
("a", 2),
]),
],
"b": 5
}, 17
]
flattened = nest.flatten(mess)
self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
structure_of_mess = [
14,
NestTest.Abc("a", True),
{
"d": _CustomMapping({
41: 42
}),
"c": [
0,
collections.OrderedDict([
("b", 9),
("a", 8),
]),
],
"b": 3
},
"hi everybody",
]
unflattened = nest.pack_sequence_as(structure_of_mess, flattened)
self.assertEqual(unflattened, mess)
# Check also that the OrderedDict was created, with the correct key order.
unflattened_ordered_dict = unflattened[2]["c"][1]
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
unflattened_custom_mapping = unflattened[2]["d"]
self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
self.assertLen(flattened, 1)
def testFlatten_stringIsNotFlattened(self):
structure = "lots of letters"
flattened = nest.flatten(structure)
self.assertLen(flattened, 1)
unflattened = nest.pack_sequence_as("goodbye", flattened)
self.assertEqual(structure, unflattened)
def testPackSequenceAs_notIterableError(self):
with self.assertRaisesRegexp(TypeError,
"flat_sequence must be a sequence"):
nest.pack_sequence_as("hi", "bye")
def testPackSequenceAs_wrongLengthsError(self):
with self.assertRaisesRegexp(
ValueError,
"Structure had 2 elements, but flat_sequence had 3 elements."):
nest.pack_sequence_as(["hello", "world"],
["and", "goodbye", "again"])
@test_util.assert_no_new_pyobjects_executing_eagerly
def testIsNested(self):
self.assertFalse(nest.is_nested("1234"))
self.assertTrue(nest.is_nested([1, 3, [4, 5]]))
self.assertTrue(nest.is_nested(((7, 8), (5, 6))))
self.assertTrue(nest.is_nested([]))
self.assertTrue(nest.is_nested({"a": 1, "b": 2}))
self.assertTrue(nest.is_nested({"a": 1, "b": 2}.keys()))
self.assertTrue(nest.is_nested({"a": 1, "b": 2}.values()))
self.assertTrue(nest.is_nested({"a": 1, "b": 2}.items()))
self.assertFalse(nest.is_nested(set([1, 2])))
ones = array_ops.ones([2, 3])
self.assertFalse(nest.is_nested(ones))
self.assertFalse(nest.is_nested(math_ops.tanh(ones)))
self.assertFalse(nest.is_nested(np.ones((4, 5))))
@parameterized.parameters({"mapping_type": _CustomMapping},
{"mapping_type": dict})
def testFlattenDictItems(self, mapping_type):
dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)
with self.assertRaises(TypeError):
nest.flatten_dict_items(4)
bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)
another_bad_dictionary = mapping_type({
(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
})
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)
# pylint does not correctly recognize these as class names and
# suggests to use variable style under_score naming.
# pylint: disable=invalid-name
Named0ab = collections.namedtuple("named_0", ("a", "b"))
Named1ab = collections.namedtuple("named_1", ("a", "b"))
SameNameab = collections.namedtuple("same_name", ("a", "b"))
SameNameab2 = collections.namedtuple("same_name", ("a", "b"))
SameNamexy = collections.namedtuple("same_name", ("x", "y"))
SameName1xy = collections.namedtuple("same_name_1", ("x", "y"))
SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y"))
NotSameName = collections.namedtuple("not_same_name", ("a", "b"))
# pylint: enable=invalid-name
class SameNamedType1(SameNameab):
pass
@test_util.assert_no_new_pyobjects_executing_eagerly
def testAssertSameStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
structure_different_num_elements = ("spam", "eggs")
structure_different_nesting = (((1, 2), 3), 4, 5, (6,))
nest.assert_same_structure(structure1, structure2)
nest.assert_same_structure("abc", 1.0)
nest.assert_same_structure("abc", np.array([0, 1]))
nest.assert_same_structure("abc", constant_op.constant([0, 1]))
with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
"More specifically: Substructure "
r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while '
'substructure "type=str str=spam" is not\n'
"Entire first structure:\n"
r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n"
"Entire second structure:\n"
r"\(\., \.\)")):
nest.assert_same_structure(structure1, structure_different_num_elements)
with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
r'is a sequence, while substructure "type=ndarray str=\[0 1\]" '
"is not")):
nest.assert_same_structure([0, 1], np.array([0, 1]))
with self.assertRaisesRegexp(
ValueError,
("The two structures don't have the same nested structure\\.\n\n"
"First structure:.*?\n\n"
"Second structure:.*\n\n"
r'More specifically: Substructure "type=list str=\[0, 1\]" '
'is a sequence, while substructure "type=int str=0" '
"is not")):
nest.assert_same_structure(0, [0, 1])
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1])
with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(structure1, structure_different_nesting)
self.assertRaises(TypeError, nest.assert_same_structure, (0, 1),
NestTest.Named0ab("a", "b"))
nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab("a", "b"))
self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4))
with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure(NestTest.Named0ab(3, 4),
NestTest.Named0ab([3], 4))
with self.assertRaisesRegexp(
ValueError,
("don't have the same nested structure\\.\n\n"
"First structure: .*?\n\nSecond structure: ")):
nest.assert_same_structure([[3], 4], [3, [4]])
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError,
"don't have the same sequence type"):
nest.assert_same_structure(structure1, structure1_list)
nest.assert_same_structure(structure1, structure2, check_types=False)
nest.assert_same_structure(structure1, structure1_list, check_types=False)
with self.assertRaisesRegexp(ValueError,
"don't have the same set of keys"):
nest.assert_same_structure({"a": 1}, {"b": 1})
nest.assert_same_structure(NestTest.SameNameab(0, 1),
NestTest.SameNameab2(2, 3))
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
nest.assert_same_structure(
NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2),
NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4))
expected_message = "The two structures don't have the same.*"
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_same_structure(
NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)),
NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2))
self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3))
self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3))
self.assertRaises(TypeError, nest.assert_same_structure,
NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3))
EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
def testHeterogeneousComparison(self):
nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
structure3 = collections.defaultdict(list)
structure3["a"] = [1, 2, 3, 4]
structure3["b"] = [2, 3, 4, 5]
expected_structure3 = collections.defaultdict(list)
expected_structure3["a"] = [2, 3, 4, 5]
expected_structure3["b"] = [3, 4, 5, 6]
self.assertEqual(expected_structure3,
nest.map_structure(lambda x: x + 1, structure3))
# Empty structures
self.assertEqual((), nest.map_structure(lambda x: x + 1, ()))
self.assertEqual([], nest.map_structure(lambda x: x + 1, []))
self.assertEqual({}, nest.map_structure(lambda x: x + 1, {}))
self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1,
NestTest.EmptyNT()))
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), nest.map_structure(lambda x: x + 1, []))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "at least one structure"):
nest.map_structure(lambda x: x)
with self.assertRaisesRegexp(ValueError, "same number of elements"):
nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, structure1, structure1_list)
nest.map_structure(lambda x, y: None, structure1, structure1_list,
check_types=False)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegexp(ValueError,
"Only valid keyword argument.*foo"):
nest.map_structure(lambda x: None, structure1, foo="a")
with self.assertRaisesRegexp(ValueError,
"Only valid keyword argument.*foo"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructureWithStrings(self):
inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz"))
inp_b = NestTest.ABTuple(a=2, b=(1, 3))
out = nest.map_structure(lambda string, repeats: string * repeats,
inp_a,
inp_b)
self.assertEqual("foofoo", out.a)
self.assertEqual("bar", out.b[0])
self.assertEqual("bazbazbaz", out.b[1])
nt = NestTest.ABTuple(a=("something", "something_else"),
b="yet another thing")
rev_nt = nest.map_structure(lambda x: x[::-1], nt)
# Check the output is the correct structure, and all strings are reversed.
nest.assert_same_structure(nt, rev_nt)
self.assertEqual(nt.a[0][::-1], rev_nt.a[0])
self.assertEqual(nt.a[1][::-1], rev_nt.a[1])
self.assertEqual(nt.b[::-1], rev_nt.b)
@test_util.run_deprecated_v1
def testMapStructureOverPlaceholders(self):
inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))
inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]),
array_ops.placeholder(dtypes.float32, shape=[3, 7]))
output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b)
nest.assert_same_structure(output, inp_a)
self.assertShapeEqual(np.zeros((3, 4)), output[0])
self.assertShapeEqual(np.zeros((3, 7)), output[1])
feed_dict = {
inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)),
inp_b: (np.random.randn(3, 4), np.random.randn(3, 7))
}
with self.cached_session() as sess:
output_np = sess.run(output, feed_dict=feed_dict)
self.assertAllClose(output_np[0],
feed_dict[inp_a][0] + feed_dict[inp_b][0])
self.assertAllClose(output_np[1],
feed_dict[inp_a][1] + feed_dict[inp_b][1])
def testAssertShallowStructure(self):
inp_ab = ["a", "b"]
inp_abc = ["a", "b", "c"]
with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises
ValueError,
nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=len(inp_ab),
shallow_length=len(inp_abc))):
nest.assert_shallow_structure(inp_abc, inp_ab)
inp_ab1 = [(1, 1), (2, 2)]
inp_ab2 = [[1, 1], [2, 2]]
with self.assertRaisesWithLiteralMatch(
TypeError,
nest._STRUCTURES_HAVE_MISMATCHING_TYPES.format(
shallow_type=type(inp_ab2[0]),
input_type=type(inp_ab1[0]))):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False)
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
with self.assertRaisesWithLiteralMatch(
ValueError,
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["d"])):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))])
inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)])
nest.assert_shallow_structure(inp_ab, inp_ba)
# This assertion is expected to pass: two namedtuples with the same
# name and field names are considered to be identical.
inp_shallow = NestTest.SameNameab(1, 2)
inp_deep = NestTest.SameNameab2(1, [1, 2, 3])
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False)
nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True)
def testFlattenUpTo(self):
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
ab_tuple = NestTest.ABTuple
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = [0, 1]
shallow_tree = 9
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'str'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
expected_message = ("If shallow structure is a sequence, input must also "
"be a sequence. Input has type: <(type|class) 'int'>.")
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesRegexp(TypeError, expected_message):
flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree)
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = [(1,), (2,), 3]
shallow_tree = [(1,), (2,)]
expected_message = nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=len(input_tree), shallow_length=len(shallow_tree))
with self.assertRaisesRegexp(ValueError, expected_message): # pylint: disable=g-error-prone-assert-raises
nest.assert_shallow_structure(shallow_tree, input_tree)
def testFlattenWithTuplePathsUpTo(self):
def get_paths_and_values(shallow_tree, input_tree,
check_subtrees_length=True):
path_value_pairs = nest.flatten_with_tuple_paths_up_to(
shallow_tree, input_tree, check_subtrees_length=check_subtrees_length)
paths = [p for p, _ in path_value_pairs]
values = [v for _, v in path_value_pairs]
return paths, values
# Shallow tree ends at scalar.
input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
shallow_tree = [[True, True], [False, True]]
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths,
[(0, 0), (0, 1), (1, 0), (1, 1)])
self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
self.assertEqual(flattened_shallow_tree_paths,
[(0, 0), (0, 1), (1, 0), (1, 1)])
self.assertEqual(flattened_shallow_tree, [True, True, False, True])
# Shallow tree ends at string.
input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
input_tree_flattened_paths = [p for p, _ in
nest.flatten_with_tuple_paths(input_tree)]
input_tree_flattened = nest.flatten(input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[("a", 1), ("b", 2), ("c", 3), ("d", 4)])
self.assertEqual(input_tree_flattened_paths,
[(0, 0, 0), (0, 0, 1),
(0, 1, 0, 0), (0, 1, 0, 1),
(0, 1, 1, 0, 0), (0, 1, 1, 0, 1),
(0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)])
self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
# Make sure dicts are correctly flattened, yielding values, not keys.
input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("b",), ("d", 0), ("d", 1)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[1, {"c": 2}, 3, (4, 5)])
# Namedtuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
input_tree = ab_tuple(a=[0, 1], b=2)
shallow_tree = ab_tuple(a=0, b=1)
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("b",)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[[0, 1], 2])
# Nested dicts, OrderedDicts and namedtuples.
input_tree = collections.OrderedDict(
[("a", ab_tuple(a=[0, {"b": 1}], b=2)),
("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
shallow_tree = input_tree
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a", "a", 0),
("a", "a", 1, "b"),
("a", "b"),
("c", "d"),
("c", "e", "f")])
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",),
("c", "d"),
("c", "e")])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
3,
collections.OrderedDict([("f", 4)])])
shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
(input_tree_flattened_as_shallow_tree_paths,
input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
[("a",), ("c",)])
self.assertEqual(input_tree_flattened_as_shallow_tree,
[ab_tuple(a=[0, {"b": 1}], b=2),
{"d": 3, "e": collections.OrderedDict([("f", 4)])}])
## Shallow non-list edge-case.
# Using iterable elements.
input_tree = ["input_tree"]
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = ["input_tree_0", "input_tree_1"]
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Test case where len(shallow_tree) < len(input_tree)
input_tree = {"a": "A", "b": "B", "c": "C"}
shallow_tree = {"a": 1, "c": 2}
with self.assertRaisesWithLiteralMatch( # pylint: disable=g-error-prone-assert-raises
ValueError,
nest._STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
input_length=len(input_tree),
shallow_length=len(shallow_tree))):
get_paths_and_values(shallow_tree, input_tree)
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree,
check_subtrees_length=False)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)])
self.assertEqual(flattened_input_tree, ["A", "C"])
self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)])
self.assertEqual(flattened_shallow_tree, [1, 2])
# Using non-iterable elements.
input_tree = [0]
shallow_tree = 9
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
input_tree = [0, 1]
shallow_tree = 9
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Both non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = "shallow_tree"
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
# Using non-iterable elements.
input_tree = 0
shallow_tree = 0
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_input_tree_paths, [()])
self.assertEqual(flattened_input_tree, [input_tree])
self.assertEqual(flattened_shallow_tree_paths, [()])
self.assertEqual(flattened_shallow_tree, [shallow_tree])
## Input non-list edge-case.
# Using iterable elements.
input_tree = "input_tree"
shallow_tree = ["shallow_tree"]
with self.assertRaisesWithLiteralMatch(
TypeError,
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = "input_tree"
shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
with self.assertRaisesWithLiteralMatch(
TypeError,
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
# Using non-iterable elements.
input_tree = 0
shallow_tree = [9]
with self.assertRaisesWithLiteralMatch(
TypeError,
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
input_tree = 0
shallow_tree = [9, 8]
with self.assertRaisesWithLiteralMatch(
TypeError,
nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
(flattened_input_tree_paths,
flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
(flattened_shallow_tree_paths,
flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
self.assertEqual(flattened_shallow_tree, shallow_tree)
def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = nest.map_structure_up_to(
name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list, data_list)
self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]])
# Dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)
# Non-equal dicts.
inp_val = dict(a=2, b=3)
inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesWithLiteralMatch(
ValueError,
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
# Dict+custom mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3))
out = nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
self.assertEqual(out["a"], 6)
self.assertEqual(out["b"], 15)
# Non-equal dict/mapping.
inp_val = dict(a=2, b=3)
inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3))
with self.assertRaisesWithLiteralMatch(
ValueError,
nest._SHALLOW_TREE_HAS_INVALID_KEYS.format(["b"])):
nest.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops)
def testGetTraverseShallowStructure(self):
scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []]
scalar_traverse_r = nest.get_traverse_shallow_structure(
lambda s: not isinstance(s, tuple),
scalar_traverse_input)
self.assertEqual(scalar_traverse_r,
[True, True, False, [True, True], {"a": False}, []])
nest.assert_shallow_structure(scalar_traverse_r,
scalar_traverse_input)
structure_traverse_input = [(1, [2]), ([1], 2)]
structure_traverse_r = nest.get_traverse_shallow_structure(
lambda s: (True, False) if isinstance(s, tuple) else True,
structure_traverse_input)
self.assertEqual(structure_traverse_r,
[(True, False), ([True], False)])
nest.assert_shallow_structure(structure_traverse_r,
structure_traverse_input)
with self.assertRaisesRegexp(TypeError, "returned structure"):
nest.get_traverse_shallow_structure(lambda _: [True], 0)
with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"):
nest.get_traverse_shallow_structure(lambda _: 1, [1])
with self.assertRaisesRegexp(
TypeError, "didn't return a depth=1 structure of bools"):
nest.get_traverse_shallow_structure(lambda _: [1], [1])
def testYieldFlatStringPaths(self):
for inputs_expected in ({"inputs": [], "expected": []},
{"inputs": 3, "expected": [()]},
{"inputs": [3], "expected": [(0,)]},
{"inputs": {"a": 3}, "expected": [("a",)]},
{"inputs": {"a": {"b": 4}},
"expected": [("a", "b")]},
{"inputs": [{"a": 2}], "expected": [(0, "a")]},
{"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]},
{"inputs": [{"a": [(23, 42)]}],
"expected": [(0, "a", 0, 0), (0, "a", 0, 1)]},
{"inputs": [{"a": ([23], 42)}],
"expected": [(0, "a", 0, 0), (0, "a", 1)]},
{"inputs": {"a": {"a": 2}, "c": [[[4]]]},
"expected": [("a", "a"), ("c", 0, 0, 0)]},
{"inputs": {"0": [{"1": 23}]},
"expected": [("0", 0, "1")]}):
inputs = inputs_expected["inputs"]
expected = inputs_expected["expected"]
self.assertEqual(list(nest.yield_flat_paths(inputs)), expected)
# We cannot define namedtuples within @parameterized argument lists.
# pylint: disable=invalid-name
Foo = collections.namedtuple("Foo", ["a", "b"])
Bar = collections.namedtuple("Bar", ["c", "d"])
# pylint: enable=invalid-name
@parameterized.parameters([
dict(inputs=[], expected=[]),
dict(inputs=[23, "42"], expected=[("0", 23), ("1", "42")]),
dict(inputs=[[[[108]]]], expected=[("0/0/0/0", 108)]),
dict(inputs=Foo(a=3, b=Bar(c=23, d=42)),
expected=[("a", 3), ("b/c", 23), ("b/d", 42)]),
dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")),
expected=[("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "thing")]),
dict(inputs=Bar(c=42, d=43),
expected=[("c", 42), ("d", 43)]),
dict(inputs=Bar(c=[42], d=43),
expected=[("c/0", 42), ("d", 43)]),
])
def testFlattenWithStringPaths(self, inputs, expected):
self.assertEqual(
nest.flatten_with_joined_string_paths(inputs, separator="/"),
expected)
@parameterized.parameters([
dict(inputs=[], expected=[]),
dict(inputs=[23, "42"], expected=[((0,), 23), ((1,), "42")]),
dict(inputs=[[[[108]]]], expected=[((0, 0, 0, 0), 108)]),
dict(inputs=Foo(a=3, b=Bar(c=23, d=42)),
expected=[(("a",), 3), (("b", "c"), 23), (("b", "d"), 42)]),
dict(inputs=Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="thing")),
expected=[(("a", "c"), 23), (("a", "d"), 42), (("b", "c"), 0),
(("b", "d"), "thing")]),
dict(inputs=Bar(c=42, d=43),
expected=[(("c",), 42), (("d",), 43)]),
dict(inputs=Bar(c=[42], d=43),
expected=[(("c", 0), 42), (("d",), 43)]),
])
def testFlattenWithTuplePaths(self, inputs, expected):
self.assertEqual(nest.flatten_with_tuple_paths(inputs), expected)
@parameterized.named_parameters(
("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))),
("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True,
{"a": ("a", 4), "b": ("b", 6)}),
("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))),
("nested",
{"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True,
{"a": [("a/0", 10), ("a/1", 12)],
"b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]}))
def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected):
def format_sum(path, *values):
return (path, sum(values))
result = nest.map_structure_with_paths(format_sum, s1, s2,
check_types=check_types)
self.assertEqual(expected, result)
@parameterized.named_parameters(
("tuples", (1, 2, 3), (4, 5), ValueError),
("dicts", {"a": 1}, {"b": 2}, ValueError),
("mixed", (1, 2), [3, 4], TypeError),
("nested",
{"a": [2, 3, 4], "b": [1, 3]},
{"b": [5, 6], "a": [8, 9]},
ValueError
))
def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type):
with self.assertRaises(error_type):
nest.map_structure_with_paths(lambda path, *s: 0, s1, s2)
@parameterized.named_parameters([
dict(testcase_name="Tuples", s1=(1, 2), s2=(3, 4),
check_types=True, expected=(((0,), 4), ((1,), 6))),
dict(testcase_name="Dicts", s1={"a": 1, "b": 2}, s2={"b": 4, "a": 3},
check_types=True, expected={"a": (("a",), 4), "b": (("b",), 6)}),
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4],
check_types=False, expected=(((0,), 4), ((1,), 6))),
dict(testcase_name="Nested",
s1={"a": [2, 3], "b": [1, 2, 3]},
s2={"b": [5, 6, 7], "a": [8, 9]},
check_types=True,
expected={"a": [(("a", 0), 10), (("a", 1), 12)],
"b": [(("b", 0), 6), (("b", 1), 8), (("b", 2), 10)]}),
])
def testMapWithTuplePathsCompatibleStructures(
self, s1, s2, check_types, expected):
def path_and_sum(path, *values):
return path, sum(values)
result = nest.map_structure_with_tuple_paths(
path_and_sum, s1, s2, check_types=check_types)
self.assertEqual(expected, result)
@parameterized.named_parameters([
dict(testcase_name="Tuples", s1=(1, 2, 3), s2=(4, 5),
error_type=ValueError),
dict(testcase_name="Dicts", s1={"a": 1}, s2={"b": 2},
error_type=ValueError),
dict(testcase_name="Mixed", s1=(1, 2), s2=[3, 4], error_type=TypeError),
dict(testcase_name="Nested",
s1={"a": [2, 3, 4], "b": [1, 3]},
s2={"b": [5, 6], "a": [8, 9]},
error_type=ValueError)
])
def testMapWithTuplePathsIncompatibleStructures(self, s1, s2, error_type):
with self.assertRaises(error_type):
nest.map_structure_with_tuple_paths(lambda path, *s: 0, s1, s2)
def testFlattenCustomSequenceThatRaisesException(self): # b/140746865
seq = _CustomSequenceThatRaisesException()
with self.assertRaisesRegexp(ValueError, "Cannot get item"):
nest.flatten(seq)
class NestBenchmark(test.Benchmark):
def run_and_report(self, s1, s2, name):
burn_iter, test_iter = 100, 30000
for _ in xrange(burn_iter):
nest.assert_same_structure(s1, s2)
t0 = time.time()
for _ in xrange(test_iter):
nest.assert_same_structure(s1, s2)
t1 = time.time()
self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
name=name)
def benchmark_assert_structure(self):
s1 = (((1, 2), 3), 4, (5, 6))
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
self.run_and_report(s1, s2, "assert_same_structure_6_elem")
s1 = (((1, 2), 3), 4, (5, 6)) * 10
s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10
self.run_and_report(s1, s2, "assert_same_structure_60_elem")
if __name__ == "__main__":
test.main()