Reverts "Implement Serialization for Reference and Dict types"
PiperOrigin-RevId: 454219175
diff --git a/tensorflow/core/function/trace_type/BUILD b/tensorflow/core/function/trace_type/BUILD
index a9e57d6..088616a 100644
--- a/tensorflow/core/function/trace_type/BUILD
+++ b/tensorflow/core/function/trace_type/BUILD
@@ -136,9 +136,6 @@
"default_types.proto",
],
cc_api_version = 2,
- protodeps = [
- ":serialization_proto",
- ],
visibility = ["//tensorflow:internal"],
)
diff --git a/tensorflow/core/function/trace_type/default_types.proto b/tensorflow/core/function/trace_type/default_types.proto
index ee46e74..e7e88aa 100644
--- a/tensorflow/core/function/trace_type/default_types.proto
+++ b/tensorflow/core/function/trace_type/default_types.proto
@@ -2,8 +2,6 @@
package tensorflow.core.function.trace_type.default_types;
-import "tensorflow/core/function/trace_type/serialization.proto";
-
// Represents a serialized Literal type.
message SerializedLiteral {
oneof value {
@@ -13,17 +11,3 @@
string str_value = 4;
}
}
-
-// Represents a serialized Dict type.
-message SerializedDict {
- repeated SerializedLiteral keys = 1;
- repeated tensorflow.core.function.trace_type.serialization.SerializedTraceType
- values = 2;
-}
-
-// Represents a serialized Reference type.
-message SerializedReference {
- optional SerializedLiteral identifier = 1;
- optional tensorflow.core.function.trace_type.serialization.SerializedTraceType
- base = 2;
-}
diff --git a/tensorflow/core/function/trace_type/default_types.py b/tensorflow/core/function/trace_type/default_types.py
index f437def..ab7adef 100644
--- a/tensorflow/core/function/trace_type/default_types.py
+++ b/tensorflow/core/function/trace_type/default_types.py
@@ -28,7 +28,7 @@
"""Represents a Literal type like bool, int or string."""
def __init__(self, value: Any):
- self.value = value
+ self._value = value
self._value_hash = hash(value)
def is_subtype_of(self, other: trace.TraceType) -> bool:
@@ -59,35 +59,35 @@
raise ValueError("Malformed Literal proto can not be deserialized")
def to_proto(self) -> default_types_pb2.SerializedLiteral:
- if isinstance(self.value, bool):
- return default_types_pb2.SerializedLiteral(bool_value=self.value)
+ if isinstance(self._value, bool):
+ return default_types_pb2.SerializedLiteral(bool_value=self._value)
- if isinstance(self.value, int):
- return default_types_pb2.SerializedLiteral(int_value=self.value)
+ if isinstance(self._value, int):
+ return default_types_pb2.SerializedLiteral(int_value=self._value)
- if isinstance(self.value, float):
- return default_types_pb2.SerializedLiteral(float_value=self.value)
+ if isinstance(self._value, float):
+ return default_types_pb2.SerializedLiteral(float_value=self._value)
- if isinstance(self.value, str):
- return default_types_pb2.SerializedLiteral(str_value=self.value)
+ if isinstance(self._value, str):
+ return default_types_pb2.SerializedLiteral(str_value=self._value)
raise ValueError("Can not serialize Literal of type " +
- type(self.value).__name__)
+ type(self._value).__name__)
def _placeholder_value(self) -> Any:
- return self.value
+ return self._value
def __eq__(self, other) -> bool:
if not isinstance(other, trace.TraceType):
return NotImplemented
- return isinstance(other, Literal) and self.value == other.value
+ return isinstance(other, Literal) and self._value == other._value
def __hash__(self) -> int:
return self._value_hash
def __repr__(self):
- return f"{self.__class__.__name__}(value={self.value!r})"
+ return f"{self.__class__.__name__}(value={self._value!r})"
class Weakref(trace.TraceType):
@@ -275,7 +275,7 @@
"""
-class Dict(trace.TraceType, serialization.Serializable):
+class Dict(trace.TraceType):
"""Represents a dictionary of TraceType objects.
Attributes:
@@ -321,22 +321,6 @@
return Dict(new_mapping)
- @classmethod
- def type_proto(cls) -> Type[default_types_pb2.SerializedDict]:
- return default_types_pb2.SerializedDict
-
- @classmethod
- def from_proto(cls, proto: default_types_pb2.SerializedDict) -> "Dict":
- return Dict({
- Literal.from_proto(k).value: serialization.deserialize(v)
- for k, v in zip(proto.keys, proto.values)
- })
-
- def to_proto(self) -> default_types_pb2.SerializedDict:
- return default_types_pb2.SerializedDict(
- keys=[Literal(k).to_proto() for k in self.mapping.keys()],
- values=[serialization.serialize(v) for v in self.mapping.values()])
-
def _placeholder_value(self) -> Any:
return {
key: value._placeholder_value() # pylint: disable=protected-access
@@ -359,7 +343,7 @@
return f"{self.__class__.__name__}(mapping={self.mapping!r})"
-class Reference(trace.TraceType, serialization.Serializable):
+class Reference(trace.TraceType):
"""Represents a resource with an identifier.
Resource identifiers are useful to denote identical resources, that is,
@@ -388,22 +372,6 @@
return Reference(base_supertype, self.identifier)
return None
- @classmethod
- def type_proto(cls) -> Type[default_types_pb2.SerializedReference]:
- return default_types_pb2.SerializedReference
-
- @classmethod
- def from_proto(cls,
- proto: default_types_pb2.SerializedReference) -> "Reference":
- return Reference(
- serialization.deserialize(proto.base),
- Literal.from_proto(proto.identifier).value)
-
- def to_proto(self) -> default_types_pb2.SerializedReference:
- return default_types_pb2.SerializedReference(
- identifier=Literal(self.identifier).to_proto(),
- base=serialization.serialize(self.base))
-
def _placeholder_value(self) -> Any:
return self.base._placeholder_value() # pylint: disable=protected-access
diff --git a/tensorflow/core/function/trace_type/default_types_test.py b/tensorflow/core/function/trace_type/default_types_test.py
index 66ef271..240627c 100644
--- a/tensorflow/core/function/trace_type/default_types_test.py
+++ b/tensorflow/core/function/trace_type/default_types_test.py
@@ -181,17 +181,6 @@
'c': MockSupertypes2With3(3)
}))
- def testDictSerialization(self):
- dict_original = default_types.Dict({
- 'a': default_types.Literal(1),
- 'b': default_types.Literal(2),
- 'c': default_types.Literal(3)
- })
-
- self.assertEqual(
- serialization.deserialize(serialization.serialize(dict_original)),
- dict_original)
-
def testListTupleInequality(self):
literal = default_types.Literal
@@ -242,11 +231,6 @@
self.assertIsNone(original.most_specific_common_supertype([different_id]))
self.assertIsNone(original.most_specific_common_supertype([different_type]))
- def testReferencetSerialization(self):
- ref_original = default_types.Reference(default_types.Literal(3), 1)
- self.assertEqual(
- serialization.deserialize(serialization.serialize(ref_original)),
- ref_original)
if __name__ == '__main__':
test.main()