Fix Reference Type Part 2
PiperOrigin-RevId: 433272635
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index 9fa76a1..b76b713 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -245,8 +245,10 @@
if all(
isinstance(other, Reference) and self.identifier == other.identifier
for other in types):
- return Reference(self.base.most_specific_common_supertype(
- [other.base for other in types]), self.identifier)
+ base_supertype = self.base.most_specific_common_supertype(
+ [other.base for other in types])
+ if base_supertype is not None:
+ return Reference(base_supertype, self.identifier)
return None
def __eq__(self, other: Any) -> bool:
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index 814e99f..af6545a 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -169,30 +169,45 @@
self.assertEqual(dict_a, dict_c)
self.assertNotEqual(dict_a, dict_b)
- def testReferenceType(self):
+ def testReferenceSubtype(self):
class MockSubtypeOf2(default_types.Generic):
def is_subtype_of(self, other):
return other._object == 2
- def most_specific_common_supertype(self, types):
- return self if all(self._object == other._object
- for other in types) else MockSubtypeOf2(2)
-
original = default_types.Reference(MockSubtypeOf2(3), 1)
clone = default_types.Reference(MockSubtypeOf2(3), 1)
different_id = default_types.Reference(MockSubtypeOf2(3), 2)
supertype = default_types.Reference(MockSubtypeOf2(2), 1)
+ different_type = default_types.Generic(1)
self.assertEqual(original, clone)
self.assertFalse(original.is_subtype_of(different_id))
self.assertTrue(original.is_subtype_of(supertype))
self.assertFalse(supertype.is_subtype_of(original))
+ self.assertFalse(original.is_subtype_of(different_type))
+
+ def testReferenceSupertype(self):
+
+ class Mock2AsTopType(default_types.Generic):
+
+ def most_specific_common_supertype(self, types):
+ if not all(isinstance(other, Mock2AsTopType) for other in types):
+ return None
+ return self if all(self._object == other._object
+ for other in types) else Mock2AsTopType(2)
+
+ original = default_types.Reference(Mock2AsTopType(3), 1)
+ clone = default_types.Reference(Mock2AsTopType(3), 1)
+ different_id = default_types.Reference(Mock2AsTopType(3), 2)
+ supertype = default_types.Reference(Mock2AsTopType(2), 1)
+ different_type = default_types.Generic(1)
self.assertEqual(supertype.most_specific_common_supertype([]), supertype)
self.assertEqual(original.most_specific_common_supertype([clone]), original)
self.assertIsNone(original.most_specific_common_supertype([different_id]))
+ self.assertIsNone(original.most_specific_common_supertype([different_type]))
if __name__ == '__main__':
test.main()