blob: 80b69336bdcd88a75b57a0e3f0769bfe9458f9ce [file] [log] [blame]
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Type specifications for TensorFlow APIs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import re
import numpy as np
import six
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Use LazyLoader to avoid circular dependencies.
tensor_spec = LazyLoader(
"tensor_spec", globals(),
"tensorflow.python.framework.tensor_spec")
ops = LazyLoader(
"ops", globals(),
"tensorflow.python.framework.ops")
@tf_export("TypeSpec", v1=["TypeSpec", "data.experimental.Structure"])
@six.add_metaclass(abc.ABCMeta)
class TypeSpec(object):
"""Specifies a TensorFlow value type.
A `tf.TypeSpec` provides metadata describing an object accepted or returned
by TensorFlow APIs. Concrete subclasses, such as `tf.TensorSpec` and
`tf.RaggedTensorSpec`, are used to describe different value types.
For example, `tf.function`'s `input_signature` argument accepts a list
(or nested structure) of `TypeSpec`s.
Creating new subclasses of `TypeSpec` (outside of TensorFlow core) is not
currently supported. In particular, we may make breaking changes to the
private methods and properties defined by this base class.
Example:
>>> spec = tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32)
>>> @tf.function(input_signature=[spec])
... def double(x):
... return x * 2
>>> print(double(tf.ragged.constant([[1, 2], [3]])))
<tf.RaggedTensor [[2, 4], [6]]>
"""
# === Subclassing ===
#
# Each `TypeSpec` subclass must define:
#
# * A "component encoding" for values.
# * A "serialization" for types.
#
# The component encoding for a value is a nested structure of `tf.Tensor`
# or `CompositeTensor` that can be used by the `TypeSpec` to reconstruct
# the value. Each individual `TypeSpec` must use the same nested structure
# for all values -- this structure is defined by the `component_specs`
# attribute. Decomposing values into components, and reconstructing them
# from those components, should be inexpensive. In particular, it should
# *not* require any TensorFlow ops.
#
# The serialization for a `TypeSpec` is a nested tuple of values that can
# be used to reconstruct the `TypeSpec`. See the documentation for
# `_serialize()` for more information.
__slots__ = []
@abc.abstractproperty
def value_type(self):
"""The Python type for values that are compatible with this TypeSpec.
In particular, all values that are compatible with this TypeSpec must be an
instance of this type.
"""
raise NotImplementedError("%s.value_type" % type(self).__name__)
def is_compatible_with(self, spec_or_value):
"""Returns true if `spec_or_value` is compatible with this TypeSpec."""
# === Subclassing ===
# If not overridden by subclasses, the default behavior is to convert
# `spec_or_value` to a `TypeSpec` (if it isn't already); and then to
# consider two `TypeSpec`s compatible if they have the same type, and
# the values returned by `_serialize` are compatible (where
# `tf.TensorShape`, `tf.TensorSpec`, and `tf.DType` are checked for
# compatibility using their `is_compatible_with` method; and all other
# types are considered compatible if they are equal).
if not isinstance(spec_or_value, TypeSpec):
spec_or_value = type_spec_from_value(spec_or_value)
if type(self) is not type(spec_or_value):
return False
return self.__is_compatible(self._serialize(),
spec_or_value._serialize()) # pylint: disable=protected-access
def most_specific_compatible_type(self, other):
"""Returns the most specific TypeSpec compatible with `self` and `other`.
Args:
other: A `TypeSpec`.
Raises:
ValueError: If there is no TypeSpec that is compatible with both `self`
and `other`.
"""
# === Subclassing ===
# If not overridden by a subclass, the default behavior is to raise a
# `ValueError` if `self` and `other` have different types, or if their type
# serializations differ by anything other than `TensorShape`s. Otherwise,
# the two type serializations are combined (using
# `most_specific_compatible_shape` to combine `TensorShape`s), and the
# result is used to construct and return a new `TypeSpec`.
if type(self) is not type(other):
raise ValueError("No TypeSpec is compatible with both %s and %s" %
(self, other))
merged = self.__most_specific_compatible_type_serialization(
self._serialize(), other._serialize()) # pylint: disable=protected-access
return self._deserialize(merged)
def _with_tensor_ranks_only(self):
"""Returns a TypeSpec compatible with `self`, with tensor shapes relaxed.
Returns:
A `TypeSpec` that is compatible with `self`, where any `TensorShape`
information has been relaxed to include only tensor rank (and not
the dimension sizes for individual axes).
"""
# === Subclassing ===
# If not overridden by a subclass, the default behavior is to serialize
# this TypeSpec, relax any TensorSpec or TensorShape values, and
# deserialize the result.
def relax(value):
if isinstance(value, TypeSpec):
return value._with_tensor_ranks_only() # pylint: disable=protected-access
elif (isinstance(value, tensor_shape.TensorShape) and
value.rank is not None):
return tensor_shape.TensorShape([None] * value.rank)
else:
return value
return self._deserialize(nest.map_structure(relax, self._serialize()))
# === Component encoding for values ===
@abc.abstractmethod
def _to_components(self, value):
"""Encodes `value` as a nested structure of `Tensor` or `CompositeTensor`.
Args:
value: A value compatible with this `TypeSpec`. (Caller is responsible
for ensuring compatibility.)
Returns:
A nested structure of `tf.Tensor` or `tf.CompositeTensor` compatible with
`self._component_specs`, which can be used to reconstruct `value`.
"""
# === Subclassing ===
# This method must be inexpensive (do not call TF ops).
raise NotImplementedError("%s._to_components()" % type(self).__name__)
@abc.abstractmethod
def _from_components(self, components):
"""Reconstructs a value from a nested structure of Tensor/CompositeTensor.
Args:
components: A nested structure of `tf.Tensor` or `tf.CompositeTensor`,
compatible with `self._component_specs`. (Caller is responsible for
ensuring compatibility.)
Returns:
A value that is compatible with this `TypeSpec`.
"""
# === Subclassing ===
# This method must be inexpensive (do not call TF ops).
raise NotImplementedError("%s._from_components()" % type(self).__name__)
@abc.abstractproperty
def _component_specs(self):
"""A nested structure of TypeSpecs for this type's components.
Returns:
A nested structure describing the component encodings that are returned
by this TypeSpec's `_to_components` method. In particular, for a
TypeSpec `spec` and a compatible value `value`:
```
nest.map_structure(lambda t, c: assert t.is_compatible_with(c),
spec._component_specs, spec._to_components(value))
```
"""
raise NotImplementedError("%s._component_specs()" % type(self).__name__)
# === Tensor list encoding for values ===
def _to_tensor_list(self, value):
"""Encodes `value` as a flat list of `tf.Tensor`.
By default, this just flattens `self._to_components(value)` using
`nest.flatten`. However, subclasses may override this to return a
different tensor encoding for values. In particular, some subclasses
of `BatchableTypeSpec` override this method to return a "boxed" encoding
for values, which then can be batched or unbatched. See
`BatchableTypeSpec` for more details.
Args:
value: A value with compatible this `TypeSpec`. (Caller is responsible
for ensuring compatibility.)
Returns:
A list of `tf.Tensor`, compatible with `self._flat_tensor_specs`, which
can be used to reconstruct `value`.
"""
return nest.flatten(self._to_components(value), expand_composites=True)
def _from_tensor_list(self, tensor_list):
"""Reconstructs a value from a flat list of `tf.Tensor`.
Args:
tensor_list: A flat list of `tf.Tensor`, compatible with
`self._flat_tensor_specs`.
Returns:
A value that is compatible with this `TypeSpec`.
Raises:
ValueError: If `tensor_list` is not compatible with
`self._flat_tensor_specs`.
"""
self.__check_tensor_list(tensor_list)
return self._from_compatible_tensor_list(tensor_list)
def _from_compatible_tensor_list(self, tensor_list):
"""Reconstructs a value from a compatible flat list of `tf.Tensor`.
Args:
tensor_list: A flat list of `tf.Tensor`, compatible with
`self._flat_tensor_specs`. (Caller is responsible for ensuring
compatibility.)
Returns:
A value that is compatible with this `TypeSpec`.
"""
return self._from_components(nest.pack_sequence_as(
self._component_specs, tensor_list, expand_composites=True))
@property
def _flat_tensor_specs(self):
"""A list of TensorSpecs compatible with self._to_tensor_list(v)."""
return nest.flatten(self._component_specs, expand_composites=True)
# === Serialization for types ===
@abc.abstractmethod
def _serialize(self):
"""Returns a nested tuple containing the state of this TypeSpec.
The serialization may contain the following value types: boolean,
integer, string, float, None, `TensorSpec`, `tf.TensorShape`, `tf.DType`,
`np.ndarray`, `TypeSpec`, and nested tuples, namedtuples, dicts, and
OrderedDicts of any of the above.
This method is used to provide default definitions for: equality
testing (__eq__, __ne__), hashing (__hash__), pickling (__reduce__),
string representation (__repr__), `self.is_compatible_with()`,
`self.most_specific_compatible_type()`, and protobuf serialization
(e.g. TensorInfo and StructuredValue).
"""
raise NotImplementedError("%s._serialize()" % type(self).__name__)
@classmethod
def _deserialize(cls, serialization):
"""Reconstructs a TypeSpec from a value returned by `serialize`."""
return cls(*serialization)
# === Operators ===
def __eq__(self, other):
# pylint: disable=protected-access
return (type(other) is type(self) and
self.__get_cmp_key() == other.__get_cmp_key())
def __ne__(self, other):
return not self == other
def __hash__(self):
return hash(self.__get_cmp_key())
def __reduce__(self):
return type(self), self._serialize()
def __repr__(self):
return "%s%r" % (type(self).__name__, self._serialize())
# === Legacy Output ===
# TODO(b/133606651) Document and/or deprecate the legacy_output methods.
# (These are used by tf.data.)
def _to_legacy_output_types(self):
raise NotImplementedError("%s._to_legacy_output_types()" %
type(self).__name__)
def _to_legacy_output_shapes(self):
raise NotImplementedError("%s._to_legacy_output_shapes()" %
type(self).__name__)
def _to_legacy_output_classes(self):
return self.value_type
# === Private Helper Methods ===
def __check_tensor_list(self, tensor_list):
expected = self._flat_tensor_specs
specs = [type_spec_from_value(t) for t in tensor_list]
if len(specs) != len(expected):
raise ValueError("Incompatible input: wrong number of tensors")
for i, (s1, s2) in enumerate(zip(specs, expected)):
if not s1.is_compatible_with(s2):
raise ValueError("Incompatible input: tensor %d (%s) is incompatible "
"with %s" % (i, tensor_list[i], s2))
def __get_cmp_key(self):
"""Returns a hashable eq-comparable key for `self`."""
# TODO(b/133606651): Decide whether to cache this value.
return (type(self), self.__make_cmp_key(self._serialize()))
def __make_cmp_key(self, value):
"""Converts `value` to a hashable key."""
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
return value
if isinstance(value, compat.bytes_or_text_types):
return value
if value is None:
return value
if isinstance(value, dict):
return tuple([
tuple([self.__make_cmp_key(key),
self.__make_cmp_key(value[key])])
for key in sorted(value.keys())
])
if isinstance(value, tuple):
return tuple([self.__make_cmp_key(v) for v in value])
if isinstance(value, list):
return (list, tuple([self.__make_cmp_key(v) for v in value]))
if isinstance(value, tensor_shape.TensorShape):
if value.ndims is None:
# Note: we include a type object in the tuple, to ensure we can't get
# false-positive matches (since users can't include type objects).
return (tensor_shape.TensorShape, None)
return (tensor_shape.TensorShape, tuple(value.as_list()))
if isinstance(value, np.ndarray):
return (np.ndarray, value.shape,
TypeSpec.__nested_list_to_tuple(value.tolist()))
raise ValueError("Unsupported value type %s returned by "
"%s._serialize" %
(type(value).__name__, type(self).__name__))
@staticmethod
def __nested_list_to_tuple(value):
"""Converts a nested list to a corresponding nested tuple."""
if isinstance(value, list):
return tuple(TypeSpec.__nested_list_to_tuple(v) for v in value)
return value
@staticmethod
def __is_compatible(a, b):
"""Returns true if the given type serializations compatible."""
if isinstance(a, TypeSpec):
return a.is_compatible_with(b)
if type(a) is not type(b):
return False
if isinstance(a, (list, tuple)):
return (len(a) == len(b) and
all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
if isinstance(a, dict):
return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all(
TypeSpec.__is_compatible(a[k], b[k]) for k in a.keys()))
if isinstance(a, (tensor_shape.TensorShape, dtypes.DType)):
return a.is_compatible_with(b)
return a == b
@staticmethod
def __most_specific_compatible_type_serialization(a, b):
"""Helper for most_specific_compatible_type.
Combines two type serializations as follows:
* If they are both tuples of the same length, then recursively combine
the respective tuple elements.
* If they are both dicts with the same keys, then recursively combine
the respective dict elements.
* If they are both TypeSpecs, then combine using
TypeSpec.most_specific_compatible_type.
* If they are both TensorShapes, then combine using
TensorShape.most_specific_compatible_shape.
* If they are both TensorSpecs with the same dtype, then combine using
TensorShape.most_specific_compatible_shape to combine shapes.
* If they are equal, then return a.
* If none of the above, then raise a ValueError.
Args:
a: A serialized TypeSpec or nested component from a serialized TypeSpec.
b: A serialized TypeSpec or nested component from a serialized TypeSpec.
Returns:
A value with the same type and structure as `a` and `b`.
Raises:
ValueError: If `a` and `b` are incompatible.
"""
if type(a) is not type(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
if isinstance(a, (list, tuple)):
if len(a) != len(b):
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return tuple(TypeSpec.__most_specific_compatible_type_serialization(x, y)
for (x, y) in zip(a, b))
if isinstance(a, collections.OrderedDict):
a_keys, b_keys = a.keys(), b.keys()
if len(a) != len(b) or a_keys != b_keys:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return collections.OrderedDict([
(k,
TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k]))
for k in a_keys
])
if isinstance(a, dict):
a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
if len(a) != len(b) or a_keys != b_keys:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return {
k: TypeSpec.__most_specific_compatible_type_serialization(a[k], b[k])
for k in a_keys
}
if isinstance(a, tensor_shape.TensorShape):
return a.most_specific_compatible_shape(b)
if isinstance(a, list):
raise AssertionError("_serialize() should not return list values.")
if isinstance(a, TypeSpec):
return a.most_specific_compatible_type(b)
if a != b:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
return a
class BatchableTypeSpec(TypeSpec):
"""TypeSpec with a batchable tensor encoding.
The batchable tensor encoding is a list of `tf.Tensor`s that supports
batching and unbatching. In particular, stacking (or unstacking)
values with the same `TypeSpec` must be equivalent to stacking (or
unstacking) each of their tensor lists. Unlike the component encoding
(returned by `self._to_components)`, the batchable tensor encoding
may require using encoding/decoding ops.
If a subclass's batchable tensor encoding is not simply a flattened version
of the component encoding, then the subclass must override `_to_tensor_list`,
`_from_tensor_list`, and _flat_tensor_specs`.
"""
__slots__ = []
@abc.abstractmethod
def _batch(self, batch_size):
"""Returns a TypeSpec representing a batch of objects with this TypeSpec.
Args:
batch_size: An `int` representing the number of elements in a batch,
or `None` if the batch size may vary.
Returns:
A `TypeSpec` representing a batch of objects with this TypeSpec.
"""
raise NotImplementedError("%s._batch" % type(self).__name__)
@abc.abstractmethod
def _unbatch(self):
"""Returns a TypeSpec representing a single element this TypeSpec.
Returns:
A `TypeSpec` representing a single element of objects with this TypeSpec.
"""
raise NotImplementedError("%s._unbatch" % type(self).__name__)
def _to_batched_tensor_list(self, value):
"""Returns a tensor list encoding for value with rank>0."""
tensor_list = self._to_tensor_list(value)
if any(t.shape.ndims == 0 for t in tensor_list):
raise ValueError("Value %s has insufficient rank for batching." % value)
return tensor_list
@tf_export("type_spec_from_value")
def type_spec_from_value(value):
"""Returns a `tf.TypeSpec` that represents the given `value`.
Examples:
>>> tf.type_spec_from_value(tf.constant([1, 2, 3]))
TensorSpec(shape=(3,), dtype=tf.int32, name=None)
>>> tf.type_spec_from_value(np.array([4.0, 5.0], np.float64))
TensorSpec(shape=(2,), dtype=tf.float64, name=None)
>>> tf.type_spec_from_value(tf.ragged.constant([[1, 2], [3, 4, 5]]))
RaggedTensorSpec(TensorShape([2, None]), tf.int32, 1, tf.int64)
>>> example_input = tf.ragged.constant([[1, 2], [3]])
>>> @tf.function(input_signature=[tf.type_spec_from_value(example_input)])
... def f(x):
... return tf.reduce_sum(x, axis=1)
Args:
value: A value that can be accepted or returned by TensorFlow APIs.
Accepted types for `value` include `tf.Tensor`, any value that can be
converted to `tf.Tensor` using `tf.convert_to_tensor`, and any subclass
of `CompositeTensor` (such as `tf.RaggedTensor`).
Returns:
A `TypeSpec` that is compatible with `value`.
Raises:
TypeError: If a TypeSpec cannot be built for `value`, because its type
is not supported.
"""
spec = _type_spec_from_value(value)
if spec is not None:
return spec
# Fallback: try converting value to a tensor.
try:
tensor = ops.convert_to_tensor(value)
spec = _type_spec_from_value(tensor)
if spec is not None:
return spec
except (ValueError, TypeError) as e:
logging.vlog(
3, "Failed to convert %r to tensor: %s" % (type(value).__name__, e))
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(value, type(value).__name__))
def _type_spec_from_value(value):
"""Returns a `TypeSpec` that represents the given `value`."""
if isinstance(value, ops.Tensor):
# Note: we do not include Tensor names when constructing TypeSpecs.
return tensor_spec.TensorSpec(value.shape, value.dtype)
if isinstance(value, composite_tensor.CompositeTensor):
return value._type_spec # pylint: disable=protected-access
# If `value` is a list and all of its elements can be represented by the same
# batchable type spec, then we can represent the entire list using a single
# type spec that captures the type accurately (unlike the `convert_to_tensor`
# fallback).
if isinstance(value, list) and value:
subspecs = [_type_spec_from_value(v) for v in value]
if isinstance(subspecs[0], BatchableTypeSpec):
merged_subspec = subspecs[0]
try:
for subspec in subspecs[1:]:
merged_subspec = merged_subspec.most_specific_compatible_type(subspec)
return merged_subspec._batch(len(subspecs)) # pylint: disable=protected-access
except (ValueError, TypeError):
pass # incompatible subspecs
for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
type_object, converter_fn, allow_subclass = entry
if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck
(allow_subclass and isinstance(value, type_object))):
return converter_fn(value)
return None
_TYPE_CONVERSION_FUNCTION_REGISTRY = []
def register_type_spec_from_value_converter(type_object, converter_fn,
allow_subclass=False):
"""Registers a function for converting values with a given type to TypeSpecs.
If multiple registered `type_object`s match a value, then the most recent
registration takes precedence. Custom converters should not be defined for
`CompositeTensor`s; use `CompositeTensor._type_spec` instead.
Args:
type_object: A Python `type` object representing the type of values
accepted by `converter_fn`.
converter_fn: A function that takes one argument (an instance of the
type represented by `type_object`) and returns a `TypeSpec`.
allow_subclass: If true, then use `isinstance(value, type_object)` to
check for matches. If false, then use `type(value) is type_object`.
"""
_, type_object = tf_decorator.unwrap(type_object)
_TYPE_CONVERSION_FUNCTION_REGISTRY.append(
(type_object, converter_fn, allow_subclass))
_pywrap_utils.RegisterType("TypeSpec", TypeSpec)
_TYPE_SPEC_TO_NAME = {}
_NAME_TO_TYPE_SPEC = {}
# Regular expression for valid TypeSpec names.
_REGISTERED_NAME_RE = re.compile(r"^(\w+\.)+\w+$")
# TODO(b/173744905) tf_export this as "tf.register_type_spec". (And add a
# usage example to the docstring, once the API is public.)
#
# TODO(b/173744905) Update this decorator to apply to ExtensionType rather than
# TypeSpec (once we do refactoring to move to_components/from_components from
# TypeSpec to ExtensionType).
def register(name):
"""Decorator used to register a globally unique name for a TypeSpec subclass.
Args:
name: The name of the type spec. Must be globally unique. Must have
the form `"{project_name}.{type_name}"`. E.g. `"my_project.MyTypeSpec"`.
Returns:
A class decorator that registers the decorated class with the given name.
"""
if not isinstance(name, str):
raise TypeError("Expected `name` to be a string; got %r" % (name,))
if not _REGISTERED_NAME_RE.match(name):
raise ValueError(
"Registered name must have the form '{project_name}.{type_name}' "
"(e.g. 'my_project.MyTypeSpec'); got %r." % name)
def decorator_fn(cls):
if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
if cls in _TYPE_SPEC_TO_NAME:
raise ValueError("Class %s.%s has already been registered with name %s."
% (cls.__module__, cls.__name__,
_TYPE_SPEC_TO_NAME[cls]))
if name in _NAME_TO_TYPE_SPEC:
raise ValueError("Name %s has already been registered for class %s.%s."
% (name, _NAME_TO_TYPE_SPEC[name].__module__,
_NAME_TO_TYPE_SPEC[name].__name__))
_TYPE_SPEC_TO_NAME[cls] = name
_NAME_TO_TYPE_SPEC[name] = cls
return cls
return decorator_fn
# TODO(edloper) tf_export this as "tf.get_type_spec_name" (or some similar name)
def get_name(cls):
"""Returns the registered name for TypeSpec `cls`."""
if not (isinstance(cls, type) and issubclass(cls, TypeSpec)):
raise TypeError("Expected `cls` to be a TypeSpec; got %r" % (cls,))
if cls not in _TYPE_SPEC_TO_NAME:
raise ValueError("TypeSpec %s.%s has not been registered." %
(cls.__module__, cls.__name__))
return _TYPE_SPEC_TO_NAME[cls]
# TODO(edloper) tf_export this as "tf.lookup_type_spec" (or some similar name)
def lookup(name):
"""Returns the TypeSpec that has been registered with name `name`."""
if not isinstance(name, str):
raise TypeError("Expected `name` to be a string; got %r" % (name,))
if name not in _NAME_TO_TYPE_SPEC:
raise ValueError("No TypeSpec has been registered with name %r" % (name,))
return _NAME_TO_TYPE_SPEC[name]