Fix Reference Type Part 2

PiperOrigin-RevId: 433272635
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index 9fa76a1..b76b713 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -245,8 +245,10 @@
     if all(
         isinstance(other, Reference) and self.identifier == other.identifier
         for other in types):
-      return Reference(self.base.most_specific_common_supertype(
-          [other.base for other in types]), self.identifier)
+      base_supertype = self.base.most_specific_common_supertype(
+          [other.base for other in types])
+      if base_supertype is not None:
+        return Reference(base_supertype, self.identifier)
     return None
 
   def __eq__(self, other: Any) -> bool:
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index 814e99f..af6545a 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -169,30 +169,45 @@
     self.assertEqual(dict_a, dict_c)
     self.assertNotEqual(dict_a, dict_b)
 
-  def testReferenceType(self):
+  def testReferenceSubtype(self):
 
     class MockSubtypeOf2(default_types.Generic):
 
       def is_subtype_of(self, other):
         return other._object == 2
 
-      def most_specific_common_supertype(self, types):
-        return self if all(self._object == other._object
-                           for other in types) else MockSubtypeOf2(2)
-
     original = default_types.Reference(MockSubtypeOf2(3), 1)
     clone = default_types.Reference(MockSubtypeOf2(3), 1)
     different_id = default_types.Reference(MockSubtypeOf2(3), 2)
     supertype = default_types.Reference(MockSubtypeOf2(2), 1)
+    different_type = default_types.Generic(1)
 
     self.assertEqual(original, clone)
     self.assertFalse(original.is_subtype_of(different_id))
     self.assertTrue(original.is_subtype_of(supertype))
     self.assertFalse(supertype.is_subtype_of(original))
+    self.assertFalse(original.is_subtype_of(different_type))
+
+  def testReferenceSupertype(self):
+
+    class Mock2AsTopType(default_types.Generic):
+
+      def most_specific_common_supertype(self, types):
+        if not all(isinstance(other, Mock2AsTopType) for other in types):
+          return None
+        return self if all(self._object == other._object
+                           for other in types) else Mock2AsTopType(2)
+
+    original = default_types.Reference(Mock2AsTopType(3), 1)
+    clone = default_types.Reference(Mock2AsTopType(3), 1)
+    different_id = default_types.Reference(Mock2AsTopType(3), 2)
+    supertype = default_types.Reference(Mock2AsTopType(2), 1)
+    different_type = default_types.Generic(1)
 
     self.assertEqual(supertype.most_specific_common_supertype([]), supertype)
     self.assertEqual(original.most_specific_common_supertype([clone]), original)
     self.assertIsNone(original.most_specific_common_supertype([different_id]))
+    self.assertIsNone(original.most_specific_common_supertype([different_type]))
 
 if __name__ == '__main__':
   test.main()