blob: 76b637bc260566b716cf1313fc31bcf3f941b7c4 [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import copy
import json
import os
import pickle
from absl.testing import parameterized
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.layers import core as non_keras_core
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import data_structures
from tensorflow.python.training.tracking import tracking
from tensorflow.python.training.tracking import util
from tensorflow.python.util import nest
from tensorflow.python.util import serialization
class ListTests(test.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.l = [1]
json.dumps(obj.l, default=serialization.get_json_type)
def testNotTrackable(self):
class NotTrackable(object):
pass
with self.assertRaises(ValueError):
data_structures.List([NotTrackable()])
def testCallNotImplemented(self):
with self.assertRaisesRegex(TypeError, "not callable"):
data_structures.List()(1.) # pylint: disable=not-callable
def testNoPop(self):
with self.assertRaises(AttributeError):
data_structures.List().pop()
def testNesting(self):
with context.graph_mode():
inner = data_structures.List()
outer = data_structures.List([inner])
inner.append(non_keras_core.Dense(1))
inner[0](array_ops.ones([2, 3]))
self.assertEqual(2, len(outer.variables))
self.assertIsInstance(
outer.variables[0],
resource_variable_ops.ResourceVariable)
def testNonLayerVariables(self):
v = resource_variable_ops.ResourceVariable([1.])
l = data_structures.List([v])
self.assertTrue(l.trainable)
self.assertEqual([], l.layers)
self.assertEqual([v], l.variables)
self.assertEqual([v], l.trainable_weights)
self.assertEqual([], l.non_trainable_variables)
l.trainable = False
self.assertEqual([v], l.variables)
self.assertEqual([], l.trainable_variables)
self.assertEqual([v], l.non_trainable_variables)
l.trainable = True
v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
l.append(v2)
self.assertEqual([v, v2], l.weights)
self.assertEqual([v], l.trainable_weights)
self.assertEqual([v2], l.non_trainable_weights)
def testCopy(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
v3 = resource_variable_ops.ResourceVariable(1.)
l1 = data_structures.List([v1, v2])
l2 = l1.copy()
l2.append(v3)
self.assertEqual(list(l1), [v1, v2])
self.assertEqual(list(l2), [v1, v2, v3])
def testSlicing(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
v3 = resource_variable_ops.ResourceVariable(1.)
v4 = resource_variable_ops.ResourceVariable(1.)
l = data_structures.List([v1, v2, v3, v4])
self.assertEqual(l[1:], [v2, v3, v4])
self.assertEqual(l[1:-1], [v2, v3])
self.assertEqual(l[:-1], [v1, v2, v3])
def testHash(self):
has_sequences = {data_structures.List(), data_structures.List()}
self.assertEqual(2, len(has_sequences))
self.assertNotIn(data_structures.List(), has_sequences)
def testIMul_zero(self):
l = data_structures.List([])
with self.assertRaisesRegex(ValueError, "List only supports append"):
l *= 0
def testIMul(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures.List([v])
l *= 2
self.assertEqual(list(l), [v] * 2)
def testMul(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures.List([v, v, v])
self.assertEqual(list(l * 2), [v, v, v] * 2)
def testRMul(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures.List([v, v, v])
self.assertEqual(list(2 * l), [v, v, v] * 2)
class ListWrapperTest(test.TestCase):
IGNORED = ("__new__", "__init__", "__subclasshook__", "__getattribute__")
def test_overrides_all_list_methods(self):
not_overridden = []
for name in dir(list):
if name in ListWrapperTest.IGNORED:
continue
list_method = getattr(list, name)
if not callable(list_method):
continue
object_method = getattr(object, name, None)
if object_method is not None and object_method == list_method:
# Skip methods that aren't overridden from object.
continue
if list_method == getattr(data_structures.ListWrapper, name):
not_overridden.append(name)
if not_overridden:
self.fail("ListWrapper does not override %s" % (not_overridden))
def testPickle(self):
original = data_structures.ListWrapper([1, 2])
serialized = pickle.dumps(original)
del original
deserialized = pickle.loads(serialized)
self.assertEqual([1, 2], deserialized)
def testSameStructure(self):
l = [1]
nest.assert_same_structure(l, data_structures.ListWrapper(copy.copy(l)))
def testMutateWithoutTrackableComponents(self):
m = module.Module()
m.l = [1, 2]
m.l.insert(0, 0)
self.assertEqual(m.l, [0, 1, 2])
self.assertEqual(m.l._trackable_children(), {})
def testFunctionCaching(self):
@def_function.function
def f(list_input):
return list_input[0] + constant_op.constant(1.)
first_trace = f.get_concrete_function([constant_op.constant(2.)])
second_trace = f.get_concrete_function(
data_structures.ListWrapper([constant_op.constant(3.)]))
self.assertIs(first_trace, second_trace)
def testListWrapperBasic(self):
# ListWrapper, unlike List, compares like the built-in list type (since it
# is used to automatically replace lists).
a = tracking.AutoTrackable()
b = tracking.AutoTrackable()
self.assertEqual([a, a],
[a, a])
self.assertEqual(data_structures.ListWrapper([a, a]),
data_structures.ListWrapper([a, a]))
self.assertEqual([a, a],
data_structures.ListWrapper([a, a]))
self.assertEqual(data_structures.ListWrapper([a, a]),
[a, a])
self.assertNotEqual([a, a],
[b, a])
self.assertNotEqual(data_structures.ListWrapper([a, a]),
data_structures.ListWrapper([b, a]))
self.assertNotEqual([a, a],
data_structures.ListWrapper([b, a]))
self.assertLess([a], [a, b])
self.assertLess(data_structures.ListWrapper([a]),
data_structures.ListWrapper([a, b]))
self.assertLessEqual([a], [a, b])
self.assertLessEqual(data_structures.ListWrapper([a]),
data_structures.ListWrapper([a, b]))
self.assertGreater([a, b], [a])
self.assertGreater(data_structures.ListWrapper([a, b]),
data_structures.ListWrapper([a]))
self.assertGreaterEqual([a, b], [a])
self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
data_structures.ListWrapper([a]))
self.assertEqual([a], data_structures.ListWrapper([a]))
self.assertEqual([a], list(data_structures.List([a])))
self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
self.assertIsInstance(data_structures.ListWrapper([a]), list)
self.assertEqual(
tensor_shape.TensorShape([None, 2]).as_list(),
(data_structures.ListWrapper([None])
+ tensor_shape.TensorShape([2])).as_list())
def testAcceptsNonTrackableContent(self):
l = data_structures.ListWrapper([1, 2, 3])
self.assertEqual(l, [1, 2, 3])
def testWrapperChangesList(self):
l = []
l_wrapper = data_structures.ListWrapper(l)
l_wrapper.append(1)
self.assertEqual([1], l)
def testListChangesWrapper(self):
l = []
l_wrapper = data_structures.ListWrapper(l)
l.append(1)
self.assertEqual([1], l_wrapper)
def testNotHashable(self):
with self.assertRaises(TypeError):
hash(data_structures.ListWrapper()) # pylint: disable=no-value-for-parameter
def testDelItem(self):
l = data_structures.ListWrapper([1, 2, 3, [4]])
del l[0]
self.assertEqual(l, [2, 3, [4]])
self.assertUnableToSave(l, "Unable to save .*__delitem__")
def testDelSlice(self):
l = data_structures.ListWrapper([1, 2, 3, [4]])
del l[2:3]
self.assertEqual(l, [1, 2, [4]])
self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testSetSlice_canSaveForNonTrackableItems(self):
l = data_structures.ListWrapper([1, 2, 3, 4])
l[:] = 2, 8, 9, 0
self.assertEqual(l, [2, 8, 9, 0])
l._maybe_initialize_trackable() # pylint: disable=protected-access
self.assertEqual(len(l._trackable_children()), 0) # pylint: disable=protected-access
def testSetSlice_cannotSaveIfTrackableModified(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
l = data_structures.ListWrapper([1, 2, v1, v2])
l[:] = 2, 8, 9, v2
self.assertEqual(l, [2, 8, 9, v2])
self.assertUnableToSave(l, "Unable to save .*__setslice__")
def testSetSlice_truncate(self):
l = data_structures.ListWrapper([1, 2, 3, 4])
l[:] = []
self.assertEqual(l, [])
def testSetSlice_extend(self):
l = data_structures.ListWrapper([1, 2, 3, 4])
l[2:] = 1, 2, 3, 4
self.assertEqual(l, [1, 2, 1, 2, 3, 4])
def testIMulNegative(self):
l = data_structures.ListWrapper([1, 2, 3, [4]])
l *= -1
self.assertEqual(l, [1, 2, 3, [4]] * -1)
self.assertUnableToSave(l, "Unable to save")
def testIMulPositive(self):
v = variables.Variable(1.)
l = data_structures.ListWrapper([1, 2, 3, 4, v])
self.assertDictEqual({"4": v}, l._trackable_children())
root = util.Checkpoint(l=l)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
path = root.save(prefix)
v.assign(5.)
l *= 2
self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
self.assertDictEqual({"4": v, "9": v}, l._trackable_children())
root.restore(path)
self.assertAllClose(1., v.numpy())
def testSort(self):
l = data_structures.ListWrapper([[1], [2], [3], [4]])
l.sort()
self.assertAllEqual(l, [[1], [2], [3], [4]])
# Regardless of being a no-op for the input list, we still refuse to save.
# This is intentional since otherwise we would end up with a hard to debug
# case for users (e.g. sometimes sort on a ListWrapper is trackable and
# other times it is not).
self.assertUnableToSave(l, "Unable to save .*sort")
def assertUnableToSave(self, l, msg):
l._maybe_initialize_trackable() # pylint: disable=protected-access
with self.assertRaisesRegex(ValueError, msg):
return l._trackable_children() # pylint: disable=protected-access
class MappingTests(test.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.d = {"a": 2}
json.dumps(obj.d, default=serialization.get_json_type)
def testNoOverwrite(self):
mapping = data_structures.Mapping()
original = data_structures.List()
mapping["a"] = original
with self.assertRaises(ValueError):
mapping["a"] = data_structures.List()
self.assertIs(original, mapping["a"])
with self.assertRaises(AttributeError):
del mapping["a"] # pylint: disable=unsupported-delete-operation
mapping.update(b=data_structures.Mapping())
with self.assertRaises(ValueError):
mapping.update({"b": data_structures.Mapping()})
def testNonStringKeys(self):
mapping = data_structures.Mapping()
with self.assertRaises(TypeError):
mapping[1] = data_structures.List()
def testHashing(self):
has_mappings = set([data_structures.Mapping(),
data_structures.Mapping()])
self.assertEqual(2, len(has_mappings))
self.assertNotIn(data_structures.Mapping(), has_mappings)
# In contrast to Mapping, dict wrappers are not hashable
a = tracking.AutoTrackable()
a.d = {}
self.assertEqual({}, a.d)
self.assertFalse({} != a.d) # pylint: disable=g-explicit-bool-comparison
self.assertNotEqual({1: 2}, a.d)
with self.assertRaisesRegex(TypeError, "unhashable"):
set([a.d])
def testListShallowCopy(self):
root = tracking.AutoTrackable()
orig_list = [[1.]]
root.a = orig_list
copied = copy.copy(root.a)
self.assertAllEqual([[1.]], copied)
self.assertIsNot(root.a, copied)
self.assertIs(root.a[0], copied[0])
# Dirtiness should be inherited
util.list_objects(root.a)
orig_list.append(1.)
with self.assertRaises(ValueError):
util.list_objects(root.a)
with self.assertRaises(ValueError):
util.list_objects(copy.copy(root.a))
def testListDeepCopy(self):
root = tracking.AutoTrackable()
orig_list = [[1.]]
root.a = orig_list
copied = copy.deepcopy(root.a)
self.assertAllEqual([[1.]], copied)
self.assertIsNot(root.a, copied)
self.assertIsNot(root.a[0], copied[0])
# Dirtiness should be inherited
util.list_objects(root.a)
orig_list.append(1.)
with self.assertRaises(ValueError):
util.list_objects(root.a)
with self.assertRaises(ValueError):
util.list_objects(copy.deepcopy(root.a))
def testDictShallowCopy(self):
root = tracking.AutoTrackable()
orig_dict = {"a": [1.]}
root.a = orig_dict
copied = copy.copy(root.a)
self.assertAllEqual([1.], copied["a"])
self.assertIsNot(root.a, copied)
self.assertIs(root.a["a"], copied["a"])
copied = root.a.copy()
self.assertAllEqual([1.], copied["a"])
self.assertIsNot(root.a, copied)
self.assertIs(root.a["a"], copied["a"])
# Dirtiness should be inherited
util.list_objects(root.a)
orig_dict["b"] = []
with self.assertRaises(ValueError):
util.list_objects(root.a)
with self.assertRaises(ValueError):
util.list_objects(copy.copy(root.a))
def testDictDeepCopy(self):
root = tracking.AutoTrackable()
orig_dict = {"a": [1.]}
root.a = orig_dict
copied = copy.deepcopy(root.a)
self.assertAllEqual([1.], copied["a"])
self.assertIsNot(root.a, copied)
self.assertIsNot(root.a["a"], copied["a"])
# Dirtiness should be inherited
util.list_objects(root.a)
orig_dict["b"] = []
with self.assertRaises(ValueError):
util.list_objects(root.a)
with self.assertRaises(ValueError):
util.list_objects(copy.deepcopy(root.a))
def testShallowCopyTrackable(self):
original = tracking.AutoTrackable()
original_sub = tracking.AutoTrackable()
original.a = [[1.]]
original.b = {"a": original_sub}
shallow_copied = copy.copy(original)
self.assertIs(original_sub, shallow_copied.b["a"])
self.assertIsNot(original, shallow_copied)
self.assertEqual([[1.]], shallow_copied.a)
shallow_deps = util.list_objects(shallow_copied)
self.assertIn(shallow_copied.a, shallow_deps)
self.assertIn(shallow_copied.b, shallow_deps)
self.assertIn(shallow_copied.b["a"], shallow_deps)
def testDeepCopyTrackable(self):
original = tracking.AutoTrackable()
original_sub = tracking.AutoTrackable()
original.a = [[1.]]
original.b = {"a": original_sub}
self.assertIsInstance(original.b, dict)
deep_copied = copy.deepcopy(original)
self.assertIsInstance(deep_copied.b, dict)
self.assertIsNot(original, deep_copied)
self.assertIsNot(original_sub, deep_copied.b["a"])
self.assertEqual([[1.]], deep_copied.a)
self.assertIsInstance(deep_copied.b["a"], tracking.AutoTrackable)
deps = util.list_objects(deep_copied)
self.assertIn(deep_copied.a, deps)
self.assertIn(deep_copied.b, deps)
self.assertIn(deep_copied.b["a"], deps)
self.assertNotIn(original_sub, deps)
def testConstructableFromSequence(self):
result = data_structures._DictWrapper([(1, 2), (3, 4)])
self.assertIsInstance(result, dict)
self.assertEqual({1: 2, 3: 4}, result)
def testPickle(self):
original = data_structures._DictWrapper(dict(a=1, b=2))
serialized = pickle.dumps(original)
del original
deserialized = pickle.loads(serialized)
self.assertEqual(dict(a=1, b=2), deserialized)
def testListAddOrder(self):
self.assertEqual([1., 2.],
data_structures.ListWrapper([1.])
+ data_structures.ListWrapper([2.]))
self.assertEqual([1., 2.],
data_structures.ListWrapper([1.])
+ [2.])
self.assertEqual([1., 2.],
[1.]
+ data_structures.ListWrapper([2.]))
def testSameStructure(self):
d = {1: "a"}
nest.assert_same_structure(d, data_structures._DictWrapper(d.copy()))
def testFunctionCaching(self):
@def_function.function
def f(dict_input):
return dict_input["x"] + constant_op.constant(1.)
first_trace = f.get_concrete_function({"x": constant_op.constant(2.)})
second_trace = f.get_concrete_function(
data_structures._DictWrapper({"x": constant_op.constant(3.)}))
self.assertIs(first_trace, second_trace)
class TupleTests(test.TestCase, parameterized.TestCase):
def testJSONSerialization(self):
obj = tracking.AutoTrackable()
obj.l = (1,)
json.dumps(obj.l, default=serialization.get_json_type)
def testNonLayerVariables(self):
v = resource_variable_ops.ResourceVariable([1.])
l = data_structures._TupleWrapper((v,))
self.assertEqual([], l.layers)
self.assertEqual([v], l.variables)
self.assertEqual([v], l.trainable_weights)
self.assertEqual([], l.non_trainable_variables)
def testCopy(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
l1 = data_structures._TupleWrapper((v1, v2))
l2 = copy.copy(l1)
self.assertEqual(l1, (v1, v2))
self.assertEqual(l2, (v1, v2))
self.assertIs(l1[0], l2[0])
l2_deep = copy.deepcopy(l1)
self.assertIsNot(l1[0], l2_deep[0])
with self.assertRaises(AttributeError):
l2.append(v1)
def testSlicing(self):
v1 = resource_variable_ops.ResourceVariable(1.)
v2 = resource_variable_ops.ResourceVariable(1.)
v3 = resource_variable_ops.ResourceVariable(1.)
v4 = resource_variable_ops.ResourceVariable(1.)
l = data_structures._TupleWrapper((v1, v2, v3, v4))
self.assertEqual(l[1:], (v2, v3, v4))
self.assertEqual(l[1:-1], (v2, v3))
self.assertEqual(l[:-1], (v1, v2, v3))
def testHash(self):
has_sequences = set([data_structures._TupleWrapper(),
data_structures._TupleWrapper()])
self.assertLen(has_sequences, 1)
self.assertIn(data_structures._TupleWrapper(), has_sequences)
def testIMul_zero(self):
l = data_structures._TupleWrapper((1,))
l *= 0
self.assertEqual((), l)
def testIMul(self):
# Note: tuple behavior differs from list behavior. Lists are mutated by
# imul/iadd, tuples assign a new object to the left hand side of the
# expression.
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures._TupleWrapper((v,))
original = l
l *= 2
self.assertEqual(l, (v,) * 2)
self.assertNotEqual(original, (v,) * 2)
def testIAdd(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures._TupleWrapper((v,))
original = l
l += (1,)
self.assertEqual(l, (v, 1))
self.assertNotEqual(original, (v, 1))
self.assertEqual(original, (v,))
def testMul(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures._TupleWrapper((v, v, v))
self.assertEqual(l * 2, (v, v, v) * 2)
def testRMul(self):
v = resource_variable_ops.ResourceVariable(1.)
l = data_structures._TupleWrapper((v, v, v))
self.assertEqual(2 * l, (v, v, v) * 2)
def testPickle(self):
original = data_structures._TupleWrapper((1, 2))
serialized = pickle.dumps(original)
del original
deserialized = pickle.loads(serialized)
self.assertEqual((1, 2), deserialized)
def testNamedTuple(self):
named = collections.namedtuple("Named", ("x", "y"))
v = variables.Variable(2)
nt = named(x=v, y=2)
m = module.Module()
m.nt = nt
self.assertIs(v, m.nt.x)
self.assertIs(v, m.nt[0])
self.assertIs(
v, m._trackable_children()["nt"]._trackable_children()["x"])
self.assertEqual(2, m.nt.y)
def testNamedTupleConflictingAttributes(self):
named = collections.namedtuple("Named", ("x", "weights"))
v = variables.Variable(2)
nt = named(x=v, weights=3)
m = module.Module()
m.nt = nt
self.assertEqual(3, m.nt.weights)
def testNamedSubclassing(self):
named = collections.namedtuple("Named", ("x", "y"))
v = variables.Variable(2)
class NamedSubclass(named):
def __new__(cls, x, y):
del y # unused
return super(NamedSubclass, cls).__new__(cls, x, 3)
@property
def summed(self):
return self.x + self.y
nt = NamedSubclass(x=v, y=2)
m = module.Module()
m.nt = nt
self.assertEqual(3, m.nt.y)
self.assertIs(v, m.nt.x)
self.assertIn(v,
m._trackable_children()["nt"]._trackable_children().values())
self.assertIn("x", m.nt._trackable_children())
self.assertIn("0", m.nt._trackable_children())
self.assertEqual(5, self.evaluate(m.nt.summed))
def testUnnamedSubclassing(self):
v = variables.Variable(2)
class UnnamedSubclass(tuple):
@property
def summed(self):
return self[0] + self[1]
unt = UnnamedSubclass([v, 2])
m = module.Module()
m.unt = unt
self.assertIn("0", m.unt._trackable_children())
self.assertLen(m.unt._trackable_children(), 1)
self.assertEqual(4, self.evaluate(m.unt.summed))
nest.assert_same_structure(
[m.unt], nest.map_structure(lambda x: x, [m.unt]))
def testNamedtupleSubclassWithCustomNew(self):
class SubclassWithDifferentArgs(collections.namedtuple("A", ["x"])):
def __new__(cls):
return super(SubclassWithDifferentArgs, cls).__new__(cls, [])
nt = SubclassWithDifferentArgs()
m = module.Module()
m.nt = nt
m.nt.x.append(variables.Variable(1.))
prefix = os.path.join(self.get_temp_dir(), "ckpt")
ckpt = util.Checkpoint(m=m)
with self.assertRaises(ValueError):
ckpt.save(prefix)
def testSameStructure(self):
t = (variables.Variable(1.),)
m = module.Module()
m.t = t
nest.assert_same_structure(t, m.t)
nest.assert_same_structure(m.t, t)
nt_type = collections.namedtuple("nt", ["x", "y"])
nt = nt_type(x=1, y=2)
m.nt = nt
nest.assert_same_structure(m.nt, nt)
with self.assertRaises(TypeError): # pylint: disable=g-error-prone-assert-raises
nest.assert_same_structure(m.nt, m.t)
def testFlatten(self):
t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,))))
self.assertEqual([1, 2], nest.flatten(t))
self.assertEqual(
nest.flatten_with_tuple_paths((1, (2,))),
nest.flatten_with_tuple_paths(t))
self.assertEqual((3, (4,)),
nest.pack_sequence_as(t, [3, 4]))
nt_type = collections.namedtuple("nt", ["x", "y"])
nt = nt_type(1., 2.)
wrapped_nt = data_structures._TupleWrapper(nt)
self.assertEqual(
nest.flatten_with_tuple_paths(nt),
nest.flatten_with_tuple_paths(wrapped_nt))
self.assertEqual((3, 4,),
nest.pack_sequence_as(wrapped_nt, [3, 4]))
self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x)
def testFunctionCaching(self):
@def_function.function
def f(tuple_input):
return tuple_input[0] + constant_op.constant(1.)
first_trace = f.get_concrete_function((constant_op.constant(2.),))
second_trace = f.get_concrete_function(
data_structures._TupleWrapper((constant_op.constant(3.),)))
self.assertIs(first_trace, second_trace)
def testPythonMapImpl(self):
t = data_structures._TupleWrapper((1, data_structures._TupleWrapper((2,))))
self.assertEqual(
(4, (5,)),
nest.map_structure_up_to((None, (None,)), lambda x: x + 3, t,
check_types=True))
nest.assert_shallow_structure((None, None), t)
def testDatasetMap(self):
dataset = dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3]))
dataset = dataset.map(lambda x: data_structures._TupleWrapper((x,)))
for index, element in enumerate(dataset):
self.assertEqual((index + 1,), self.evaluate(element))
def testDatasetMapNamed(self):
nt_type = collections.namedtuple("A", ["x"])
dataset = dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3]))
dataset = dataset.map(lambda x: data_structures._TupleWrapper(nt_type(x)))
for index, element in enumerate(dataset):
self.assertEqual((index + 1,), self.evaluate(element))
def testLoopAssignedModule(self):
m = module.Module()
m.s = (m,)
self.assertLen(m._trackable_children(), 1)
self.assertIn("s", m._trackable_children())
self.assertIs(m.s, m._trackable_children()["s"])
self.assertEqual((), m.trainable_variables)
if __name__ == "__main__":
test.main()