blob: 68c232010daac6eb5ca05fc911290cef9463ed9d [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TensorSpec class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.util import _pywrap_utils
from tensorflow.python.util.tf_export import tf_export
class DenseSpec(type_spec.TypeSpec):
"""Describes a dense object with shape, dtype, and name."""
__slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
_component_specs = property(lambda self: self)
def __init__(self, shape, dtype=dtypes.float32, name=None):
"""Creates a TensorSpec.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
name: Optional name for the Tensor.
Raises:
TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
not convertible to a `tf.DType`.
"""
self._shape = tensor_shape.TensorShape(shape)
try:
self._shape_tuple = tuple(self.shape.as_list())
except ValueError:
self._shape_tuple = None
self._dtype = dtypes.as_dtype(dtype)
self._name = name
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
return self._shape
@property
def dtype(self):
"""Returns the `dtype` of elements in the tensor."""
return self._dtype
@property
def name(self):
"""Returns the (optionally provided) name of the described tensor."""
return self._name
def is_compatible_with(self, spec_or_value):
return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and
self._dtype.is_compatible_with(spec_or_value.dtype) and
self._shape.is_compatible_with(spec_or_value.shape))
def __repr__(self):
return "{}(shape={}, dtype={}, name={})".format(
type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
def __hash__(self):
return hash((self._shape_tuple, self.dtype))
def __eq__(self, other):
# pylint: disable=protected-access
return (type(self) is type(other) and
self._shape_tuple == other._shape_tuple
and self._dtype == other._dtype
and self._name == other._name)
def __ne__(self, other):
return not self == other
def most_specific_compatible_type(self, other):
if (type(self) is not type(other)) or (self._dtype != other.dtype):
raise ValueError("Types are not compatible: %r vs %r" % (self, other))
shape = self._shape.most_specific_compatible_shape(other.shape)
name = self._name if self._name == other.name else None
return type(self)(shape, self._dtype, name)
def _serialize(self):
return (self._shape, self._dtype, self._name)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._shape
def _to_legacy_output_classes(self):
return self.value_type
@tf_export("TensorSpec")
@type_spec.register("tf.TensorSpec")
class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
by some TensorFlow APIs.
"""
__slots__ = []
def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation
"""Returns True if spec_or_tensor is compatible with this TensorSpec.
Two tensors are considered compatible if they have the same dtype
and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
Args:
spec_or_tensor: A tf.TensorSpec or a tf.Tensor
Returns:
True if spec_or_tensor is compatible with self.
"""
return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
@classmethod
def from_spec(cls, spec, name=None):
"""Returns a `TensorSpec` with the same shape and dtype as `spec`.
>>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
>>> tf.TensorSpec.from_spec(spec, "NewName")
TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')
Args:
spec: The `TypeSpec` used to create the new `TensorSpec`.
name: The name for the new `TensorSpec`. Defaults to `spec.name`.
"""
return cls(spec.shape, spec.dtype, name or spec.name)
@classmethod
def from_tensor(cls, tensor, name=None):
"""Returns a `TensorSpec` that describes `tensor`.
>>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
TensorSpec(shape=(3,), dtype=tf.int32, name=None)
Args:
tensor: The `tf.Tensor` that should be described.
name: A name for the `TensorSpec`. Defaults to `tensor.op.name`.
Returns:
A `TensorSpec` that describes `tensor`.
"""
if isinstance(tensor, ops.EagerTensor):
return TensorSpec(tensor.shape, tensor.dtype, name)
elif isinstance(tensor, ops.Tensor):
return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
else:
raise ValueError("`tensor` should be a tf.Tensor")
@property
def value_type(self):
"""The Python type for values that are compatible with this TypeSpec."""
return ops.Tensor
def _to_components(self, value):
try:
value = ops.convert_to_tensor(value, self._dtype)
except (TypeError, ValueError):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
if not value.shape.is_compatible_with(self._shape):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return value
def _from_components(self, components):
return components
def _from_compatible_tensor_list(self, tensor_list):
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
# op here and return that, instead of mutating the input's shape using
# `Tensor.set_shape()`. However, that would add extra ops, which could
# impact performance. When this bug is resolved, we should be able to add
# the `ensure_shape()` ops and optimize them away using contextual shape
# information.
assert len(tensor_list) == 1
tensor_list[0].set_shape(self._shape)
return tensor_list[0]
def _to_batchable_tensor_list(self, value, batched=False):
if batched and self._shape.merge_with(value.shape).ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return self._to_components(value)
def _batch(self, batch_size):
return TensorSpec(
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
self._dtype)
def _unbatch(self):
if self._shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return TensorSpec(self._shape[1:], self._dtype)
# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
class BoundedTensorSpec(TensorSpec):
"""A `TensorSpec` that specifies minimum and maximum values.
Example usage:
```python
spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
```
Bounds are meant to be inclusive. This is especially important for
integer types. The following spec will be satisfied by tensors
with values in the set {0, 1, 2}:
```python
spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
```
"""
__slots__ = ("_minimum", "_maximum")
def __init__(self, shape, dtype, minimum, maximum, name=None):
"""Initializes a new `BoundedTensorSpec`.
Args:
shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
dtype: Value convertible to `tf.DType`. The type of the tensor values.
minimum: Number or sequence specifying the minimum element bounds
(inclusive). Must be broadcastable to `shape`.
maximum: Number or sequence specifying the maximum element bounds
(inclusive). Must be broadcastable to `shape`.
name: Optional string containing a semantic name for the corresponding
array. Defaults to `None`.
Raises:
ValueError: If `minimum` or `maximum` are not provided or not
broadcastable to `shape`.
TypeError: If the shape is not an iterable or if the `dtype` is an invalid
numpy dtype.
"""
super(BoundedTensorSpec, self).__init__(shape, dtype, name)
if minimum is None or maximum is None:
raise ValueError("minimum and maximum must be provided; but saw "
"'%s' and '%s'" % (minimum, maximum))
try:
minimum_shape = np.shape(minimum)
common_shapes.broadcast_shape(
tensor_shape.TensorShape(minimum_shape), self.shape)
except ValueError as exception:
raise ValueError("minimum is not compatible with shape. "
"Message: {!r}.".format(exception))
try:
maximum_shape = np.shape(maximum)
common_shapes.broadcast_shape(
tensor_shape.TensorShape(maximum_shape), self.shape)
except ValueError as exception:
raise ValueError("maximum is not compatible with shape. "
"Message: {!r}.".format(exception))
self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype)
self._minimum.setflags(write=False)
self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype)
self._maximum.setflags(write=False)
@classmethod
def from_spec(cls, spec):
"""Returns a `TensorSpec` with the same shape and dtype as `spec`.
If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
`spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
`spec.dtype.min` and `spec.dtype.max`.
>>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
>>> BoundedTensorSpec.from_spec(spec)
BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
minimum=array(-2147483648, dtype=int32),
maximum=array(2147483647, dtype=int32))
Args:
spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
"""
dtype = dtypes.as_dtype(spec.dtype)
minimum = getattr(spec, "minimum", dtype.min)
maximum = getattr(spec, "maximum", dtype.max)
return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
@property
def minimum(self):
"""Returns a NumPy array specifying the minimum bounds (inclusive)."""
return self._minimum
@property
def maximum(self):
"""Returns a NumPy array specifying the maximum bounds (inclusive)."""
return self._maximum
def __repr__(self):
s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
return s.format(self.shape, repr(self.dtype), repr(self.name),
repr(self.minimum), repr(self.maximum))
def __eq__(self, other):
tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
np.allclose(self.maximum, other.maximum))
def __hash__(self):
return hash((self._shape_tuple, self.dtype))
def __reduce__(self):
return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
self._maximum, self._name)
def _serialize(self):
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
# Note: we do not include Tensor names when constructing TypeSpecs.
type_spec.register_type_spec_from_value_converter(
ops.Tensor,
lambda tensor: TensorSpec(tensor.shape, tensor.dtype))
type_spec.register_type_spec_from_value_converter(
np.ndarray,
lambda array: TensorSpec(array.shape, array.dtype))