Introduce TensorArray most_specific_common_supertype
PiperOrigin-RevId: 432999844
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 4be41b0..4c67242 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -1362,6 +1362,41 @@
self._dynamic_size = dynamic_size
self._infer_shape = infer_shape
+ def is_subtype_of(self, other):
+ # pylint: disable=protected-access
+ return (isinstance(other, TensorArraySpec) and
+ self._dtype == other._dtype and
+ self._dynamic_size == other._dynamic_size)
+
+ def most_specific_common_supertype(self, others):
+ """Returns the most specific supertype of `self` and `others`.
+
+ Args:
+ others: A Sequence of `TypeSpec`.
+
+ Returns `None` if a supertype does not exist.
+ """
+ # pylint: disable=protected-access
+ if not all(isinstance(other, TensorArraySpec) for other in others):
+ return False
+
+ common_shape = self._element_shape.most_specific_common_supertype(
+ other._element_shape for other in others)
+ if common_shape is None:
+ return None
+
+ if not all(self._dtype == other._dtype for other in others):
+ return None
+
+ if not all(self._dynamic_size == other._dynamic_size for other in others):
+ return None
+
+ infer_shape = self._infer_shape and all(
+ other._infer_shape for other in others)
+
+ return TensorArraySpec(common_shape, self._dtype, self._dynamic_size,
+ infer_shape)
+
def is_compatible_with(self, other):
# pylint: disable=protected-access
if not isinstance(other, type_spec.TypeSpec):
@@ -1373,7 +1408,9 @@
self._element_shape.is_compatible_with(other._element_shape) and
self._dynamic_size == other._dynamic_size)
+ # TODO(b/221472813): Migrate logic to most_specific_common_supertype.
def most_specific_compatible_type(self, other):
+ """Deprecated. Use most_specific_common_supertype instead."""
# pylint: disable=protected-access
if not self.is_compatible_with(other):
raise ValueError(f"Type `{self}` is not compatible with `{other}`.")