Have an explicit registration mechanism for Serializables

PiperOrigin-RevId: 459370338
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
index a9e57d6..c1f5747 100644
--- a/tensorflow/core/function/trace_type/BUILD
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -19,6 +19,7 @@
     visibility = ["//tensorflow:internal"],
     deps = [
         ":default_types",
+        ":serialization",
         ":util",
         "//tensorflow/python/types",
     ],
diff --git a/tensorflow/core/function/trace_type/__init__.py b/tensorflow/core/function/trace_type/__init__.py
index 93eed8d..472f512 100644
--- a/tensorflow/core/function/trace_type/__init__.py
+++ b/tensorflow/core/function/trace_type/__init__.py
@@ -25,8 +25,11 @@
 Other implementations of TraceType include tf.TypeSpec and its subclasses.
 """
 
-
+from tensorflow.core.function.trace_type.serialization import deserialize
+from tensorflow.core.function.trace_type.serialization import register_serializable
+from tensorflow.core.function.trace_type.serialization import Serializable
+from tensorflow.core.function.trace_type.serialization import serialize
+from tensorflow.core.function.trace_type.serialization import SerializedTraceType
 from tensorflow.core.function.trace_type.trace_type_builder import from_object
 from tensorflow.core.function.trace_type.trace_type_builder import InternalTracingContext
 from tensorflow.core.function.trace_type.trace_type_builder import WeakrefDeletionObserver
-
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index cafab0c..99e12dd 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -603,3 +603,11 @@
   def __repr__(self):
     return (f"{self.__class__.__name__}(base={self.base!r}, "
             f"identifier={self.identifier!r})")
+
+serialization.register_serializable(Literal)
+serialization.register_serializable(Tuple)
+serialization.register_serializable(List)
+serialization.register_serializable(NamedTuple)
+serialization.register_serializable(Attrs)
+serialization.register_serializable(Dict)
+serialization.register_serializable(Reference)
diff --git a/tensorflow/core/function/trace_type/serialization.py b/tensorflow/core/function/trace_type/serialization.py
index 7a8cb3b..a1943e8 100644
--- a/tensorflow/core/function/trace_type/serialization.py
+++ b/tensorflow/core/function/trace_type/serialization.py
@@ -28,18 +28,6 @@
 class Serializable(metaclass=abc.ABCMeta):
   """TraceTypes implementing this additional interface are portable."""
 
-  def __init_subclass__(cls, **kwargs):
-    super().__init_subclass__(**kwargs)
-    if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
-      raise ValueError(
-          "Existing Python class " +
-          PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
-          " already has " + cls.experimental_type_proto().__name__ +
-          " as its associated proto representation. Please ensure " +
-          cls.__name__ + " has a unique proto representation.")
-
-    PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
-
   @classmethod
   @abc.abstractmethod
   def experimental_type_proto(cls) -> Type[message.Message]:
@@ -58,6 +46,25 @@
     raise NotImplementedError
 
 
+def register_serializable(cls: Type[Serializable]):
+  """Registers a Python class to support serialization.
+
+  Only register standard TF types. Custom types should NOT be registered.
+
+  Args:
+    cls: Python class to register.
+  """
+  if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
+    raise ValueError(
+        "Existing Python class " +
+        PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
+        " already has " + cls.experimental_type_proto().__name__ +
+        " as its associated proto representation. Please ensure " +
+        cls.__name__ + " has a unique proto representation.")
+
+  PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
+
+
 def serialize(to_serialize: Serializable) -> SerializedTraceType:
   """Converts Serializable to a proto SerializedTraceType."""
 
diff --git a/tensorflow/core/function/trace_type/serialization_test.py b/tensorflow/core/function/trace_type/serialization_test.py
index 4afbd01..ab7a229 100644
--- a/tensorflow/core/function/trace_type/serialization_test.py
+++ b/tensorflow/core/function/trace_type/serialization_test.py
@@ -39,6 +39,9 @@
     return proto
 
 
+serialization.register_serializable(MyCustomClass)
+
+
 class MyCompositeClass(serialization.Serializable):
 
   def __init__(self, *elements):
@@ -62,6 +65,9 @@
     return proto
 
 
+serialization.register_serializable(MyCompositeClass)
+
+
 class SerializeTest(test.TestCase):
 
   def testCustomClassSerialization(self):
@@ -124,25 +130,26 @@
     self.assertEqual(deserialized.elements[2].name, "name_3")
 
   def testNonUniqueProto(self):
+    class ClassThatReusesProto(serialization.Serializable):
+
+      @classmethod
+      def experimental_type_proto(cls):
+        return serialization_test_pb2.MyCustomRepresentation
+
+      @classmethod
+      def experimental_from_proto(cls, proto):
+        raise NotImplementedError
+
+      def experimental_as_proto(self):
+        raise NotImplementedError
+
     with self.assertRaisesRegex(
         ValueError,
         ("Existing Python class MyCustomClass already has "
          "MyCustomRepresentation as its associated proto representation. "
          "Please ensure ClassThatReusesProto has a unique proto representation."
         )):
-
-      class ClassThatReusesProto(serialization.Serializable):  # pylint: disable=unused-variable
-
-        @classmethod
-        def experimental_type_proto(cls):
-          return serialization_test_pb2.MyCustomRepresentation
-
-        @classmethod
-        def experimental_from_proto(cls, proto):
-          raise NotImplementedError
-
-        def experimental_as_proto(self):
-          raise NotImplementedError
+      serialization.register_serializable(ClassThatReusesProto)
 
   def testWrongProto(self):