Have an explicit registration mechanism for Serializables
PiperOrigin-RevId: 459370338
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
index a9e57d6..c1f5747 100644
--- a/tensorflow/core/function/trace_type/BUILD
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -19,6 +19,7 @@
visibility = ["//tensorflow:internal"],
deps = [
":default_types",
+ ":serialization",
":util",
"//tensorflow/python/types",
],
diff --git a/tensorflow/core/function/trace_type/__init__.py b/tensorflow/core/function/trace_type/__init__.py
index 93eed8d..472f512 100644
--- a/tensorflow/core/function/trace_type/__init__.py
+++ b/tensorflow/core/function/trace_type/__init__.py
@@ -25,8 +25,11 @@
Other implementations of TraceType include tf.TypeSpec and its subclasses.
"""
-
+from tensorflow.core.function.trace_type.serialization import deserialize
+from tensorflow.core.function.trace_type.serialization import register_serializable
+from tensorflow.core.function.trace_type.serialization import Serializable
+from tensorflow.core.function.trace_type.serialization import serialize
+from tensorflow.core.function.trace_type.serialization import SerializedTraceType
from tensorflow.core.function.trace_type.trace_type_builder import from_object
from tensorflow.core.function.trace_type.trace_type_builder import InternalTracingContext
from tensorflow.core.function.trace_type.trace_type_builder import WeakrefDeletionObserver
-
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index cafab0c..99e12dd 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -603,3 +603,11 @@
def __repr__(self):
return (f"{self.__class__.__name__}(base={self.base!r}, "
f"identifier={self.identifier!r})")
+
+serialization.register_serializable(Literal)
+serialization.register_serializable(Tuple)
+serialization.register_serializable(List)
+serialization.register_serializable(NamedTuple)
+serialization.register_serializable(Attrs)
+serialization.register_serializable(Dict)
+serialization.register_serializable(Reference)
diff --git a/tensorflow/core/function/trace_type/serialization.py b/tensorflow/core/function/trace_type/serialization.py
index 7a8cb3b..a1943e8 100644
--- a/tensorflow/core/function/trace_type/serialization.py
+++ b/tensorflow/core/function/trace_type/serialization.py
@@ -28,18 +28,6 @@
class Serializable(metaclass=abc.ABCMeta):
"""TraceTypes implementing this additional interface are portable."""
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__(**kwargs)
- if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
- raise ValueError(
- "Existing Python class " +
- PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
- " already has " + cls.experimental_type_proto().__name__ +
- " as its associated proto representation. Please ensure " +
- cls.__name__ + " has a unique proto representation.")
-
- PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
-
@classmethod
@abc.abstractmethod
def experimental_type_proto(cls) -> Type[message.Message]:
@@ -58,6 +46,25 @@
raise NotImplementedError
+def register_serializable(cls: Type[Serializable]):
+ """Registers a Python class to support serialization.
+
+ Only register standard TF types. Custom types should NOT be registered.
+
+ Args:
+ cls: Python class to register.
+ """
+ if cls.experimental_type_proto() in PROTO_CLASS_TO_PY_CLASS:
+ raise ValueError(
+ "Existing Python class " +
+ PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()].__name__ +
+ " already has " + cls.experimental_type_proto().__name__ +
+ " as its associated proto representation. Please ensure " +
+ cls.__name__ + " has a unique proto representation.")
+
+ PROTO_CLASS_TO_PY_CLASS[cls.experimental_type_proto()] = cls
+
+
def serialize(to_serialize: Serializable) -> SerializedTraceType:
"""Converts Serializable to a proto SerializedTraceType."""
diff --git a/tensorflow/core/function/trace_type/serialization_test.py b/tensorflow/core/function/trace_type/serialization_test.py
index 4afbd01..ab7a229 100644
--- a/tensorflow/core/function/trace_type/serialization_test.py
+++ b/tensorflow/core/function/trace_type/serialization_test.py
@@ -39,6 +39,9 @@
return proto
+serialization.register_serializable(MyCustomClass)
+
+
class MyCompositeClass(serialization.Serializable):
def __init__(self, *elements):
@@ -62,6 +65,9 @@
return proto
+serialization.register_serializable(MyCompositeClass)
+
+
class SerializeTest(test.TestCase):
def testCustomClassSerialization(self):
@@ -124,25 +130,26 @@
self.assertEqual(deserialized.elements[2].name, "name_3")
def testNonUniqueProto(self):
+ class ClassThatReusesProto(serialization.Serializable):
+
+ @classmethod
+ def experimental_type_proto(cls):
+ return serialization_test_pb2.MyCustomRepresentation
+
+ @classmethod
+ def experimental_from_proto(cls, proto):
+ raise NotImplementedError
+
+ def experimental_as_proto(self):
+ raise NotImplementedError
+
with self.assertRaisesRegex(
ValueError,
("Existing Python class MyCustomClass already has "
"MyCustomRepresentation as its associated proto representation. "
"Please ensure ClassThatReusesProto has a unique proto representation."
)):
-
- class ClassThatReusesProto(serialization.Serializable): # pylint: disable=unused-variable
-
- @classmethod
- def experimental_type_proto(cls):
- return serialization_test_pb2.MyCustomRepresentation
-
- @classmethod
- def experimental_from_proto(cls, proto):
- raise NotImplementedError
-
- def experimental_as_proto(self):
- raise NotImplementedError
+ serialization.register_serializable(ClassThatReusesProto)
def testWrongProto(self):