blob: 52f5586e96a0fc25954831eaf925987905783e50 [file] [log] [blame]
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for default_types."""
import collections
from tensorflow.core.function.trace_type import default_types
from tensorflow.core.function.trace_type import serialization
from tensorflow.python.platform import test
from tensorflow.python.types import trace
class MockSupertypes2With3(trace.TraceType):
def __init__(self, obj):
self._object = obj
def is_subtype_of(self, other):
return self._object == 2 and other._object == 3
def most_specific_common_supertype(self, others):
if not others:
return self
if self._object == 2 and isinstance(others[0]._object, int):
return MockSupertypes2With3(3)
else:
return None
def __eq__(self, other) -> bool:
return isinstance(other, type(self)) and self._object == other._object
def __hash__(self) -> int:
return self._object_hash
class Mock2AsTopType(MockSupertypes2With3):
def is_subtype_of(self, other):
return other._object == 2
def most_specific_common_supertype(self, others):
if not all(isinstance(other, Mock2AsTopType) for other in others):
return None
return self if all(self._object == other._object
for other in others) else Mock2AsTopType(2)
class DefaultTypesTest(test.TestCase):
def testLiteralSupertypes(self):
literal_a = default_types.Literal(1)
literal_b = default_types.Literal(2)
literal_c = default_types.Literal(1)
self.assertEqual(literal_a, literal_a.most_specific_common_supertype([]))
self.assertEqual(literal_a,
literal_a.most_specific_common_supertype([literal_a]))
self.assertEqual(literal_a,
literal_a.most_specific_common_supertype([literal_c]))
self.assertIsNone(literal_a.most_specific_common_supertype([literal_b]))
def testLiteralSerialization(self):
literal_bool = default_types.Literal(True)
literal_int = default_types.Literal(1)
literal_float = default_types.Literal(1.2)
literal_str = default_types.Literal('a')
self.assertEqual(
serialization.deserialize(serialization.serialize(literal_bool)),
literal_bool)
self.assertEqual(
serialization.deserialize(serialization.serialize(literal_int)),
literal_int)
self.assertEqual(
serialization.deserialize(serialization.serialize(literal_float)),
literal_float)
self.assertEqual(
serialization.deserialize(serialization.serialize(literal_str)),
literal_str)
def testListSupertype(self):
list_a = default_types.List(
MockSupertypes2With3(1), MockSupertypes2With3(2),
MockSupertypes2With3(3))
list_b = default_types.List(
MockSupertypes2With3(2), MockSupertypes2With3(2),
MockSupertypes2With3(2))
self.assertEqual(list_a, list_a.most_specific_common_supertype([]))
self.assertIsNone(list_a.most_specific_common_supertype([list_b]))
self.assertEqual(
list_b.most_specific_common_supertype([list_a]),
default_types.List(
MockSupertypes2With3(3), MockSupertypes2With3(3),
MockSupertypes2With3(3)))
def testListSerialization(self):
list_original = default_types.List(
default_types.Literal(1), default_types.Literal(2),
default_types.Literal(3))
self.assertEqual(
serialization.deserialize(serialization.serialize(list_original)),
list_original)
def testTupleSupertype(self):
tuple_a = default_types.Tuple(
MockSupertypes2With3(1), MockSupertypes2With3(2),
MockSupertypes2With3(3))
tuple_b = default_types.Tuple(
MockSupertypes2With3(2), MockSupertypes2With3(2),
MockSupertypes2With3(2))
self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
self.assertEqual(
tuple_b.most_specific_common_supertype([tuple_a]),
default_types.Tuple(
MockSupertypes2With3(3), MockSupertypes2With3(3),
MockSupertypes2With3(3)))
def testTupleSerialization(self):
tuple_original = default_types.Tuple(
default_types.Literal(1), default_types.Literal(2),
default_types.Literal(3))
self.assertEqual(
serialization.deserialize(serialization.serialize(tuple_original)),
tuple_original)
def testNamedTupleSupertype(self):
named_tuple_type = collections.namedtuple('MyNamedTuple', 'x y z')
tuple_a = default_types.NamedTuple(
named_tuple_type, (MockSupertypes2With3(1), MockSupertypes2With3(2),
MockSupertypes2With3(3)))
tuple_b = default_types.NamedTuple(
named_tuple_type, (MockSupertypes2With3(2), MockSupertypes2With3(2),
MockSupertypes2With3(2)))
self.assertEqual(tuple_a, tuple_a.most_specific_common_supertype([]))
self.assertIsNone(tuple_a.most_specific_common_supertype([tuple_b]))
self.assertEqual(
tuple_b.most_specific_common_supertype([tuple_a]),
default_types.NamedTuple(
named_tuple_type, (MockSupertypes2With3(3), MockSupertypes2With3(3),
MockSupertypes2With3(3))))
def testDictTypeSubtype(self):
dict_type = default_types.Dict
dict_a = dict_type({
'a': Mock2AsTopType(1),
'b': Mock2AsTopType(1),
'c': Mock2AsTopType(1)
})
dict_b = dict_type({
'a': Mock2AsTopType(2),
'b': Mock2AsTopType(2),
'c': Mock2AsTopType(2)
})
dict_c = dict_type({'a': Mock2AsTopType(1), 'b': Mock2AsTopType(1)})
self.assertTrue(dict_a.is_subtype_of(dict_b))
self.assertFalse(dict_c.is_subtype_of(dict_b))
self.assertFalse(dict_c.is_subtype_of(dict_a))
def testDictTypeSupertype(self):
dict_type = default_types.Dict
dict_a = dict_type({
'a': MockSupertypes2With3(1),
'b': MockSupertypes2With3(2),
'c': MockSupertypes2With3(3)
})
dict_b = dict_type({
'a': MockSupertypes2With3(2),
'b': MockSupertypes2With3(2),
'c': MockSupertypes2With3(2)
})
self.assertEqual(dict_a, dict_a.most_specific_common_supertype([]))
self.assertIsNone(dict_a.most_specific_common_supertype([dict_b]))
self.assertEqual(
dict_b.most_specific_common_supertype([dict_a]),
dict_type({
'a': MockSupertypes2With3(3),
'b': MockSupertypes2With3(3),
'c': MockSupertypes2With3(3)
}))
def testDictSerialization(self):
dict_original = default_types.Dict({
'a': default_types.Literal(1),
'b': default_types.Literal(2),
'c': default_types.Literal(3)
})
self.assertEqual(
serialization.deserialize(serialization.serialize(dict_original)),
dict_original)
def testListTupleInequality(self):
literal = default_types.Literal
list_a = default_types.List(literal(1), literal(2), literal(3))
list_b = default_types.List(literal(1), literal(2), literal(3))
tuple_a = default_types.Tuple(literal(1), literal(2), literal(3))
tuple_b = default_types.Tuple(literal(1), literal(2), literal(3))
self.assertEqual(list_a, list_b)
self.assertEqual(tuple_a, tuple_b)
self.assertNotEqual(list_a, tuple_a)
self.assertNotEqual(tuple_a, list_a)
def testDictTypeEquality(self):
dict_type = default_types.Dict
literal = default_types.Literal
dict_a = dict_type({literal(1): literal(2), literal(3): literal(4)})
dict_b = dict_type({literal(1): literal(2)})
dict_c = dict_type({literal(3): literal(4), literal(1): literal(2)})
self.assertEqual(dict_a, dict_c)
self.assertNotEqual(dict_a, dict_b)
def testReferenceSubtype(self):
original = default_types.Reference(Mock2AsTopType(3), 1)
clone = default_types.Reference(Mock2AsTopType(3), 1)
different_id = default_types.Reference(Mock2AsTopType(3), 2)
supertype = default_types.Reference(Mock2AsTopType(2), 1)
different_type = default_types.Literal(1)
self.assertEqual(original, clone)
self.assertFalse(original.is_subtype_of(different_id))
self.assertTrue(original.is_subtype_of(supertype))
self.assertFalse(supertype.is_subtype_of(original))
self.assertFalse(original.is_subtype_of(different_type))
def testReferenceSupertype(self):
original = default_types.Reference(Mock2AsTopType(3), 1)
clone = default_types.Reference(Mock2AsTopType(3), 1)
different_id = default_types.Reference(Mock2AsTopType(3), 2)
supertype = default_types.Reference(Mock2AsTopType(2), 1)
different_type = default_types.Literal(1)
self.assertEqual(supertype.most_specific_common_supertype([]), supertype)
self.assertEqual(original.most_specific_common_supertype([clone]), original)
self.assertIsNone(original.most_specific_common_supertype([different_id]))
self.assertIsNone(original.most_specific_common_supertype([different_type]))
def testReferencetSerialization(self):
ref_original = default_types.Reference(default_types.Literal(3), 1)
self.assertEqual(
serialization.deserialize(serialization.serialize(ref_original)),
ref_original)
if __name__ == '__main__':
test.main()