blob: 4c672426d6d6e16c79fe54897d7dedc4af78e125 [file] [log] [blame]
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorArray: a dynamically sized array of Tensors."""
# Mixture of pep8 and non-pep8 names, so disable pylint bad-name
# pylint: disable=g-bad-name
import contextlib
import traceback
import weakref
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
# _GraphTensorArray accesses many of the hidden generated ops, but is in
# fact built to wrap these methods.
# pylint: disable=protected-access
class _GraphTensorArray:
"""Graph-mode implementation of TensorArray."""
def __init__(self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
"""Constructs a graph mode TensorArray.
Args:
dtype: (required) data type of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
Required if handle is not provided.
dynamic_size: (optional) Python bool: If true, writes to the TensorArray
can grow the TensorArray past its initial size. Default: False.
clear_after_read: Boolean (optional, default: True). If True, clear
TensorArray values after reading them. This disables read-many
semantics, but allows early release of memory.
tensor_array_name: (optional) Python string: the name of the TensorArray.
This is used when creating the TensorArray handle. If this value is
set, handle should be None.
handle: (optional) A `Tensor` handle to an existing TensorArray. If this
is set, tensor_array_name should be None. Only supported in graph mode.
flow: (optional) A float `Tensor` scalar coming from an existing
`TensorArray.flow`. Only supported in graph mode.
infer_shape: (optional, default: True) If True, shape inference is
enabled. In this case, all elements must have the same shape.
element_shape: (optional, default: None) A `TensorShape` object specifying
the shape constraints of each of the elements of the TensorArray. Need
not be fully defined.
colocate_with_first_write_call: If `True`, the TensorArray will be
colocated on the same device as the Tensor used on its first write
(write operations include `write`, `unstack`, and `split`). If `False`,
the TensorArray will be placed on the device determined by the device
context available during its initialization.
name: A name for the operation (optional).
Raises:
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
if handle is not None and tensor_array_name:
raise ValueError(
"Cannot provide both `handle` and `tensor_array_name` arguments at "
"the same time.")
if handle is not None and not isinstance(handle, ops.Tensor):
raise TypeError(
f"Expected `handle` to be a Tensor, but got `{handle}` of type "
f"`{type(handle)}` instead.")
if handle is None and size is None:
raise ValueError(
"Argument `size` must be provided if handle is not provided.")
if handle is not None and size is not None:
raise ValueError("Cannot provide both a `handle` and `size` arguments "
"at the same time.")
if handle is not None and element_shape is not None:
raise ValueError(
"Cannot provide both `handle` and `element_shape` arguments "
"at the same time.")
if handle is not None and dynamic_size is not None:
raise ValueError(
"Cannot provide both `handle` and `dynamic_size` arguments "
"at the same time.")
if handle is not None and clear_after_read is not None:
raise ValueError(
"Cannot provide both `handle` and `clear_after_read` arguments "
"at the same time.")
if clear_after_read is None:
clear_after_read = True
self._dynamic_size = dynamic_size or False
self._dtype = dtypes.as_dtype(dtype).base_dtype
# Used to keep track of what tensors the TensorArray should be
# colocated with. We choose to colocate the TensorArray with the
# first tensor written to it.
self._colocate_with_first_write_call = colocate_with_first_write_call
if colocate_with_first_write_call:
self._colocate_with = []
else:
self._colocate_with = None
# Record the current static shape for the array elements. The element
# shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
# shape equality.
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._infer_shape = infer_shape
self._size = size
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
if handle is not None:
self._handle = handle
if flow is None:
raise ValueError("flow must not be None if handle is not None.")
self._flow = flow
else:
# Construct the TensorArray with an empty device. The first
# write into the TensorArray from a Tensor with a set device
# will retroactively set the device value of this op.
def create():
"""Create the TensorArray op."""
return gen_data_flow_ops.tensor_array_v3(
dtype=dtype,
size=size,
element_shape=element_shape,
identical_element_shapes=infer_shape,
dynamic_size=self._dynamic_size,
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name,
name=scope)
if colocate_with_first_write_call:
with ops.device(None), ops.colocate_with(None, ignore_existing=True):
self._handle, self._flow = create()
else:
self._handle, self._flow = create()
@property
def flow(self):
return self._flow
@property
def dtype(self):
return self._dtype
@property
def handle(self):
return self._handle
@property
def element_shape(self):
return self._element_shape[0]
def _check_element_shape(self, shape):
"""Changes the element shape of the array given a shape to merge with.
Args:
shape: A `TensorShape` object to merge with.
Raises:
ValueError: if the provided shape is incompatible with the current
element shape of the `TensorArray`.
"""
if not shape.is_compatible_with(self.element_shape):
raise ValueError("Inconsistent shapes: saw %s but expected %s " %
(shape, self.element_shape))
if self._infer_shape:
self._element_shape[0] = self.element_shape.merge_with(shape)
@contextlib.contextmanager
def _maybe_colocate_with(self, value):
"""Colocate operations with an internal colocation group or `value`.
Args:
value: `Tensor`, the tensor to try to colocate with.
Yields:
Does not yield anything, but the new context is a colocation context.
If no internal colocation group is set, colocate with `value` and set
the internal colocation group to be value.
"""
if not self._colocate_with_first_write_call:
yield
else:
if not self._colocate_with:
self._colocate_with.append(value)
with ops.colocate_with(self._colocate_with[0]):
yield
def identity(self):
"""See TensorArray."""
flow = array_ops.identity(self._flow)
return build_ta_with_new_flow(self, flow)
def grad(self, source, flow=None, name=None):
"""See TensorArray."""
# tensor_array_grad requires a flow input when forward
# TensorArrays are dynamically sized. This forces the creation
# of the grad TensorArray only once the final forward array's size
# is fixed.
if flow is None:
flow = self.flow
with ops.name_scope(name, "TensorArrayGrad", [self._handle]):
with ops.colocate_with(self._handle):
g_handle, unused_flow = gen_data_flow_ops.tensor_array_grad_v3(
handle=self._handle, source=source, flow_in=flow, name=name)
with ops.control_dependencies([g_handle]):
flow = array_ops.identity(flow, name="gradient_flow")
g = TensorArray(
dtype=self._dtype,
handle=g_handle,
flow=flow,
infer_shape=self._infer_shape,
colocate_with_first_write_call=False)
# pylint: disable=protected-access
g._implementation._element_shape = self._element_shape
# pylint: enable=protected-access
return g
def read(self, index, name=None):
"""See TensorArray."""
value = gen_data_flow_ops.tensor_array_read_v3(
handle=self._handle,
index=index,
flow_in=self._flow,
dtype=self._dtype,
name=name)
if self._element_shape:
value.set_shape(self._element_shape[0].dims)
return value
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
self._check_element_shape(value.shape)
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops.tensor_array_write_v3(
handle=self._handle,
index=index,
value=value,
flow_in=self._flow,
name=name)
return build_ta_with_new_flow(self, flow_out)
def stack(self, name=None):
"""See TensorArray."""
with ops.colocate_with(self._handle):
with ops.name_scope(name, "TensorArrayStack", [self._handle]):
value = self.gather(math_ops.range(0, self.size()), name=name)
if (self.element_shape and not self._dynamic_size and
self._size is not None):
value.set_shape([tensor_util.constant_value(self._size)] +
self.element_shape.dims)
return value
def gather(self, indices, name=None):
"""See TensorArray."""
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.unknown_shape(None)
value = gen_data_flow_ops.tensor_array_gather_v3(
handle=self._handle,
indices=indices,
flow_in=self._flow,
dtype=self._dtype,
name=name,
element_shape=element_shape)
if self.element_shape:
value.set_shape([None] + self.element_shape.dims)
return value
def concat(self, name=None):
"""See TensorArray."""
value, _ = gen_data_flow_ops.tensor_array_concat_v3(
handle=self._handle,
flow_in=self._flow,
dtype=self._dtype,
name=name,
element_shape_except0=self.element_shape[1:])
if self.element_shape:
value.set_shape([None] + self.element_shape.dims[1:])
return value
@tf_should_use.should_use_result
def unstack(self, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayUnstack", [self._handle, value]):
num_elements = array_ops.shape(value)[0]
return self.scatter(
indices=math_ops.range(0, num_elements), value=value, name=name)
@tf_should_use.should_use_result
def scatter(self, indices, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if not context.executing_eagerly():
self._check_element_shape(value.shape[1:])
with self._maybe_colocate_with(value):
flow_out = gen_data_flow_ops.tensor_array_scatter_v3(
handle=self._handle,
indices=indices,
value=value,
flow_in=self._flow,
name=name)
return build_ta_with_new_flow(self, flow_out)
@tf_should_use.should_use_result
def split(self, value, lengths, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArraySplit",
[self._handle, value, lengths]):
value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.cast(lengths, dtypes.int64)
if not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None and clengths is not None:
if clengths.shape and clengths.max() == clengths.min():
self._check_element_shape(
tensor_shape.TensorShape([clengths[0]
]).concatenate(value.shape[1:]))
flow_out = gen_data_flow_ops.tensor_array_split_v3(
handle=self._handle,
value=value,
lengths=lengths_64,
flow_in=self._flow,
name=name)
return build_ta_with_new_flow(self, flow_out)
def size(self, name=None):
"""See TensorArray."""
if not self._dynamic_size and self._size is not None:
return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
else:
return gen_data_flow_ops.tensor_array_size_v3(
handle=self._handle, flow_in=self.flow, name=name)
@tf_should_use.should_use_result
def close(self, name=None):
"""See TensorArray."""
return gen_data_flow_ops.tensor_array_close_v3(
handle=self._handle, name=name)
class _GraphTensorArrayV2:
"""Graph-mode implementation of TensorArray backed by TensorLists.
The backing tensor of this TensorArray is a TensorList variant tensor which is
stored in the `flow`. The `handle` is always none here. The reason we use the
`flow` field and not the `handle` field is to ensure backwards compatibility
with legacy control flow.
"""
def __init__(self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
"""Constructs a graph mode TensorArray.
Args:
dtype: (required) data type of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
Required if flow is not provided.
dynamic_size: (optional) Python bool: If true, writes to the TensorArray
can grow the TensorArray past its initial size. Default: False.
clear_after_read: (optional) unused. Not supported in TensorLists.
tensor_array_name: (optional) unused.
handle: (optional) Must always be None.
flow: (optional) A variant `Tensor` scalar for a TensorList.
infer_shape: (optional, default: True) If True, shape inference is
enabled. In this case, all elements must have the same shape.
element_shape: (optional, default: None) A `TensorShape` object specifying
the shape constraints of each of the elements of the TensorArray. Need
not be fully defined.
colocate_with_first_write_call: (optional). unused.
name: (optional) A name for the operation.
Raises:
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
assert handle is None
del handle
del clear_after_read
del tensor_array_name
del colocate_with_first_write_call
self._dynamic_size = dynamic_size
self._size = size
if (flow is not None and
(not isinstance(flow, ops.Tensor) or flow.dtype != dtypes.variant)):
raise TypeError(
f"Expected `flow` to be a variant tensor, but received `{flow.dtype}` "
f"instead.")
if flow is None and size is None:
raise ValueError("Argument `size` must be provided if argument `flow` "
"is not provided.")
if flow is not None and size is not None:
raise ValueError("Cannot provide both `flow` and `size` arguments "
"at the same time.")
if flow is not None and element_shape is not None:
raise ValueError(
"Cannot provide both `flow` and `element_shape` arguments"
"at the same time.")
self._dtype = dtypes.as_dtype(dtype).base_dtype
# Record the current static shape for the array elements. The element
# shape is defined either by `element_shape` or the shape of the tensor
# of the first write. If `infer_shape` is true, all writes checks for
# shape equality.
self._element_shape = [tensor_shape.as_shape(element_shape)]
self._infer_shape = infer_shape
with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
if flow is None:
self._flow = list_ops.tensor_list_reserve(
element_shape=element_shape,
num_elements=size,
element_dtype=dtype,
name=scope)
else:
self._flow = flow
# For backwards compatibility.
self._colocate_with_first_write_call = None
self._colocate_with = None
@property
def flow(self):
return self._flow
@property
def dtype(self):
return self._dtype
@property
def element_shape(self):
return self._element_shape[0]
@property
def handle(self):
# We intentionally do not raise an error so that legacy while_loop does not
# complain.
return None
def _check_element_shape(self, shape):
"""Changes the element shape of the array given a shape to merge with.
Args:
shape: A `TensorShape` object to merge with.
Raises:
ValueError: if the provided shape is incompatible with the current
element shape of the `TensorArray`.
"""
if not shape.is_compatible_with(self.element_shape):
raise ValueError("Inconsistent shapes: saw %s but expected %s " %
(shape, self.element_shape))
if self._infer_shape:
self._element_shape[0] = self.element_shape.merge_with(shape)
def identity(self):
"""See TensorArray."""
flow = array_ops.identity(self._flow)
return build_ta_with_new_flow(self, flow)
def grad(self, source, flow=None, name=None):
"""Not supported."""
raise NotImplementedError()
def read(self, index, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
value = list_ops.tensor_list_get_item(
input_handle=self._flow,
index=index,
element_dtype=self._dtype,
element_shape=self.element_shape,
name=name)
return value
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
self._check_element_shape(value.shape)
flow_out = list_ops.tensor_list_set_item(
input_handle=self._flow,
index=index,
item=value,
resize_if_index_out_of_bounds=self._dynamic_size,
name=name)
return build_ta_with_new_flow(self, flow_out)
def stack(self, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
# TODO(b/139941163): remove constant_value after changing num_elements to regular input
if not self._dynamic_size and self._size is not None:
ta_size = tensor_util.constant_value(self._size)
else:
ta_size = -1
value = list_ops.tensor_list_stack(
input_handle=self._flow,
element_dtype=self._dtype,
num_elements=ta_size,
element_shape=self.element_shape)
return value
def gather(self, indices, name=None):
"""See TensorArray."""
value = list_ops.tensor_list_gather(
input_handle=self._flow,
indices=indices,
element_dtype=self._dtype,
element_shape=self.element_shape,
name=name)
return value
def concat(self, name=None):
"""See TensorArray."""
if self.element_shape:
element_shape = [None] + self.element_shape.dims[1:]
else:
element_shape = None
value = list_ops.tensor_list_concat(
input_handle=self._flow,
element_dtype=self._dtype,
element_shape=element_shape,
name=name)
return value
@tf_should_use.should_use_result
def unstack(self, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
self._check_element_shape(value.shape[1:])
flow_out = list_ops.tensor_list_from_tensor(
tensor=value, element_shape=value.shape[1:])
return build_ta_with_new_flow(self, flow_out)
@tf_should_use.should_use_result
def scatter(self, indices, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayScatter",
[self._flow, value, indices]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
self._check_element_shape(value.shape[1:])
flow_out = list_ops.tensor_list_scatter(
tensor=value,
indices=indices,
element_shape=self.element_shape,
input_handle=self._flow)
return build_ta_with_new_flow(self, flow_out)
@tf_should_use.should_use_result
def split(self, value, lengths, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
lengths_64 = math_ops.cast(lengths, dtypes.int64)
if not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None and clengths is not None:
if clengths.shape and clengths.max() == clengths.min():
self._check_element_shape(
tensor_shape.TensorShape([clengths[0]
]).concatenate(value.shape[1:]))
flow_out = list_ops.tensor_list_split(
tensor=value,
lengths=lengths_64,
element_shape=self.element_shape,
name=name)
return build_ta_with_new_flow(self, flow_out)
def size(self, name=None):
"""See TensorArray."""
if not self._dynamic_size and self._size is not None:
return ops.convert_to_tensor(self._size, dtype=dtypes.int32)
else:
return list_ops.tensor_list_length(input_handle=self._flow, name=name)
def close(self, name=None):
"""See TensorArray."""
return gen_control_flow_ops.no_op(name=name)
# pylint: enable=protected-access
class _EagerTensorArray:
"""Eager-compatible implementation of TensorArray."""
def __init__(self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
"""Constructs a TensorArray compatible with eager execution.
Args:
dtype: (required) data type of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
Required if handle is not provided.
dynamic_size: (optional) Python bool: If true, writes to the TensorArray
can grow the TensorArray past its initial size. Default: False.
clear_after_read: Boolean (optional, default: True). If True, clear
TensorArray values after reading them. This disables read-many
semantics, but allows early release of memory.
tensor_array_name: unused.
handle: unsupported.
flow: unsupported.
infer_shape: used for error checking, same semantics as TensorArray.
element_shape: used for error checking, same semantics as TensorArray.
colocate_with_first_write_call: unsupported.
name: unsupported.
Raises:
ValueError: handle or flow are supplied, or if size is not supplied.
"""
del (flow, tensor_array_name, name) # Unused.
if handle is not None:
raise ValueError("TensorArray handles are not supported when eager "
"execution is enabled.")
if size is None:
raise ValueError("Size must be declared for TensorArrays when eager "
"execution is enabled.")
# These attributes are not meaningful when eager is enabled, but some
# library functions (e.g., those in control_flow_ops.py) access them to
# create new tensor arrays; as such, we define them for the sake of
# compatibility.
self._handle = None
# we assign a dummy value to _flow in case other code assumes it to be
# a Tensor
self._flow = constant_op.constant(0, dtype=dtypes.int32)
self._infer_shape = infer_shape
self._element_shape = tensor_shape.as_shape(element_shape)
self._colocate_with_first_write_call = colocate_with_first_write_call
self._dtype = dtypes.as_dtype(dtype).base_dtype
self._dynamic_size = dynamic_size or False
self._clear_after_read = (True
if clear_after_read is None else clear_after_read)
self._previously_read_indices = []
if isinstance(size, ops.EagerTensor):
size = size.numpy()
self._tensor_array = [None for _ in range(size)]
@property
def flow(self):
"""For compatibility; flows are not meaningful when eager is enabled."""
return self._flow
@property
def dtype(self):
return self._dtype
@property
def handle(self):
"""For compatibility; handles are not meaningful when eager is enabled."""
return self._handle
@property
def element_shape(self):
return self._element_shape
def identity(self):
"""See TensorArray."""
return self.parent()
def grad(self, source, flow=None, name=None):
raise NotImplementedError(
"TensorArray.grad is not supported when executing eagerly; eager's "
"gradient implementation does not use/need this function to compute "
"gradients of operations that use TensorArrays.")
def read(self, index, name=None):
"""See TensorArray."""
del name # not meaningful when executing eagerly.
if isinstance(index, ops.EagerTensor):
index = index.numpy()
if index < 0:
raise errors_impl.OutOfRangeError(
None, None,
"Reading from negative indices (index %d) is not allowed." % index)
if index >= len(self._tensor_array):
raise errors_impl.OutOfRangeError(
None, None, "Tried to read from index %d but array size is: %d " %
(index, len(self._tensor_array)))
tensor = self._tensor_array[index]
if tensor is None:
if index in self._previously_read_indices:
raise errors_impl.InvalidArgumentError(
None, None,
"Could not read index %d twice because it was cleared after "
"a previous read (perhaps try setting clear_after_read = false?)" %
index)
else:
tensor = self._maybe_zero(index)
if self._clear_after_read:
self._tensor_array[index] = None
self._previously_read_indices.append(index)
return tensor
def _write(self, index, value):
"""Writes `value` into index named by `index`.
Args:
index: 0-D. int32 scalar with the index to write to.
value: N-D. Tensor of type `dtype`. The `Tensor` to write to `index`.
Raises:
errors_impl.InvalidArgumentError: `value` dtype does not match dtype.
errors_impl.OutOfRangeError: `index` is out of bounds.
ValueError: shape of `value` is not consistent with inferred shape.
"""
if isinstance(index, ops.EagerTensor):
index = index.numpy()
if index < 0:
raise errors_impl.OutOfRangeError(
None, None,
"Writing to negative indices (index %d) is not allowed." % index)
size = len(self._tensor_array)
if index >= size:
if not self._dynamic_size:
raise errors_impl.OutOfRangeError(
None, None,
"Tried to write to index %d but array is not resizeable and size "
"is: %d " % (index, size))
self._tensor_array.extend(None for _ in range(index - size + 1))
if not isinstance(value, ops.EagerTensor):
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
if self._dtype != value.dtype:
raise errors_impl.InvalidArgumentError(
None, None,
"TensorArray dtype is %s but Op is trying to write dtype %s " %
(self._dtype.name, value.dtype.name))
if not self._element_shape.is_compatible_with(value.shape):
raise ValueError("Incompatible shape for value (%s), expected (%s)" %
(value.shape, self._element_shape))
if self._infer_shape:
self._element_shape = self._element_shape.merge_with(value.shape)
self._tensor_array[index] = value
def write(self, index, value, name=None):
"""See TensorArray."""
del name # not meaningful when executing eagerly.
self._write(index, value)
return self.parent()
def _maybe_zero(self, ix):
val = self._tensor_array[ix]
if val is None:
val = self._tensor_array[ix] = array_ops.zeros(
shape=self._element_shape, dtype=self._dtype)
return val
def stack(self, name=None):
"""See TensorArray."""
if self._tensor_array:
for ix in range(len(self._tensor_array)):
self._maybe_zero(ix)
if not self._tensor_array and self._element_shape.is_fully_defined():
return ops.convert_to_tensor(
np.ndarray([0] + self._element_shape), name=name, dtype=self._dtype)
else:
return ops.convert_to_tensor(
self._tensor_array, name=name, dtype=self._dtype)
def gather(self, indices, name=None):
"""See TensorArray."""
del name # not meaningful when executing eagerly.
if isinstance(indices, ops.EagerTensor):
indices = indices.numpy()
return array_ops.stack([self._maybe_zero(i) for i in indices])
def concat(self, name=None):
"""See TensorArray."""
try:
return array_ops.concat(
[self._maybe_zero(ix) for ix in range(len(self._tensor_array))],
0,
name=name)
except errors_impl.OpError:
# Reproduce a subset of the error-handling for graph-mode TensorArrays.
shapes = [t.shape for t in self._tensor_array]
ndims = [s.ndims for s in shapes]
if 0 in ndims:
idx = ndims.index(0)
raise errors_impl.InvalidArgumentError(
None, None, "Concat saw a scalar shape at index %d but requires "
"at least vectors." % idx)
else:
raise
def unstack(self, value, name=None):
"""See TensorArray."""
tensors = array_ops.unstack(value, name=name)
if len(tensors) > len(self._tensor_array) and not self._dynamic_size:
raise ValueError(
"Cannot unstack %d tensors into a TensorArray of static size %d " %
(len(tensors), len(self._tensor_array)))
self._tensor_array = tensors
return self.parent()
def scatter(self, indices, value, name=None):
"""See TensorArray."""
del name # not meaningful when executing eagerly.
if isinstance(indices, ops.EagerTensor):
indices = indices.numpy()
for index, val in zip(indices, array_ops.unstack(value)):
self._write(index, val) # pylint: disable=protected-access
return self.parent()
def split(self, value, lengths, name=None):
"""See TensorArray."""
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
lengths = ops.convert_to_tensor(lengths)
sum_lengths = math_ops.reduce_sum(lengths)
if lengths.shape.ndims != 1:
raise errors_impl.InvalidArgumentError(
None, None, "Expected lengths to be a vector, received shape: %s " %
lengths.shape.as_list())
elif value.shape.ndims == 0:
raise errors_impl.InvalidArgumentError(
None, None, "Expected value to be at least a vector, "
"but received shape: %s " % value.shape.as_list())
elif sum_lengths.numpy() != value.shape.as_list()[0]:
raise errors_impl.InvalidArgumentError(
None, None, "Expected sum of lengths to be equal to "
"values.shape[0], but sum of lengths is %d and "
"value's shape is: %s " % (sum_lengths.numpy(),
value.shape.as_list()))
elif not self._dynamic_size and lengths.shape[0] != len(self._tensor_array):
raise errors_impl.InvalidArgumentError(
None, None, "TensorArray's size is not equal to the size of "
"lengths (%d vs. %d), and the TensorArray is not marked as "
"dynamically resizeable." %
(len(self._tensor_array), lengths.shape[0]))
else:
self._tensor_array = array_ops.split(value, lengths, name=name)
return self.parent()
def size(self, name=None):
"""See TensorArray."""
del name # not meaningful when executing eagerly.
return constant_op.constant(len(self._tensor_array))
def close(self, name=None):
del name # not meaningful when executing eagerly.
del self._tensor_array[:]
# TensorArray is designed to hide an underlying implementation object
# and as such accesses many of that object's hidden fields.
# pylint: disable=protected-access
# pylint:disable=line-too-long
@tf_export("TensorArray")
class TensorArray:
"""Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
This class is meant to be used with dynamic iteration primitives such as
`while_loop` and `map_fn`. It supports gradient back-propagation via special
"flow" control flow dependencies.
Example 1: Plain reading and writing.
>>> ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True, clear_after_read=False)
>>> ta = ta.write(0, 10)
>>> ta = ta.write(1, 20)
>>> ta = ta.write(2, 30)
>>>
>>> ta.read(0)
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>
>>> ta.read(1)
<tf.Tensor: shape=(), dtype=float32, numpy=20.0>
>>> ta.read(2)
<tf.Tensor: shape=(), dtype=float32, numpy=30.0>
>>> ta.stack()
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([10., 20., 30.],
dtype=float32)>
Example 2: Fibonacci sequence algorithm that writes in a loop then returns.
>>> @tf.function
... def fibonacci(n):
... ta = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
... ta = ta.unstack([0., 1.])
...
... for i in range(2, n):
... ta = ta.write(i, ta.read(i - 1) + ta.read(i - 2))
...
... return ta.stack()
>>>
>>> fibonacci(7)
<tf.Tensor: shape=(7,), dtype=float32,
numpy=array([0., 1., 1., 2., 3., 5., 8.], dtype=float32)>
Example 3: A simple loop interacting with a `tf.Variable`.
>>> v = tf.Variable(1)
>>> @tf.function
... def f(x):
... ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
... for i in tf.range(x):
... v.assign_add(i)
... ta = ta.write(i, v)
... return ta.stack()
>>> f(5)
<tf.Tensor: shape=(5,), dtype=int32, numpy=array([ 1, 2, 4, 7, 11],
dtype=int32)>
"""
def __init__(self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
"""Construct a new TensorArray or wrap an existing TensorArray handle.
A note about the parameter `name`:
The name of the `TensorArray` (even if passed in) is uniquified: each time
a new `TensorArray` is created at runtime it is assigned its own name for
the duration of the run. This avoids name collisions if a `TensorArray`
is created within a `while_loop`.
Args:
dtype: (required) data type of the TensorArray.
size: (optional) int32 scalar `Tensor`: the size of the TensorArray.
Required if handle is not provided.
dynamic_size: (optional) Python bool: If true, writes to the TensorArray
can grow the TensorArray past its initial size. Default: False.
clear_after_read: Boolean (optional, default: True). If True, clear
TensorArray values after reading them. This disables read-many
semantics, but allows early release of memory.
tensor_array_name: (optional) Python string: the name of the TensorArray.
This is used when creating the TensorArray handle. If this value is
set, handle should be None.
handle: (optional) A `Tensor` handle to an existing TensorArray. If this
is set, tensor_array_name should be None. Only supported in graph mode.
flow: (optional) A float `Tensor` scalar coming from an existing
`TensorArray.flow`. Only supported in graph mode.
infer_shape: (optional, default: True) If True, shape inference is
enabled. In this case, all elements must have the same shape.
element_shape: (optional, default: None) A `TensorShape` object specifying
the shape constraints of each of the elements of the TensorArray. Need
not be fully defined.
colocate_with_first_write_call: If `True`, the TensorArray will be
colocated on the same device as the Tensor used on its first write
(write operations include `write`, `unstack`, and `split`). If `False`,
the TensorArray will be placed on the device determined by the device
context available during its initialization.
name: A name for the operation (optional).
Raises:
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
if (context.executing_eagerly() and
(flow is None or flow.dtype != dtypes.variant)):
# It is possible to create a Variant-style TensorArray even in eager mode,
# and this is fine but can have performance implications in eager.
# An example of when this happens is if a tf.function returns a
# TensorArray in its output; its flow variant object is returned to Eager.
# This can be wrapped back up in a Variant-style TensorArray.
implementation = _EagerTensorArray
elif (flow is not None and flow.dtype == dtypes.variant or
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
implementation = _GraphTensorArrayV2
else:
implementation = _GraphTensorArray
self._implementation = implementation(
dtype,
size=size,
dynamic_size=dynamic_size,
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name,
handle=handle,
flow=flow,
infer_shape=infer_shape,
element_shape=element_shape,
colocate_with_first_write_call=colocate_with_first_write_call,
name=name)
self._implementation.parent = weakref.ref(self)
@property
def flow(self):
"""The flow `Tensor` forcing ops leading to this TensorArray state."""
return self._implementation._flow
@property
def dtype(self):
"""The data type of this TensorArray."""
return self._implementation._dtype
@property
def handle(self):
"""The reference to the TensorArray."""
return self._implementation.handle
@property
def element_shape(self):
"""The `tf.TensorShape` of elements in this TensorArray."""
return self._implementation.element_shape
@property
def dynamic_size(self):
"""Python bool; if `True` the TensorArray can grow dynamically."""
return self._implementation._dynamic_size
@property
def _infer_shape(self):
# TODO(slebedev): consider making public or changing TensorArrayStructure
# to access _implementation directly. Note that dynamic_size is also
# only used by TensorArrayStructure.
return self._implementation._infer_shape
def identity(self):
"""Returns a TensorArray with the same content and properties.
Returns:
A new TensorArray object with flow that ensures the control dependencies
from the contexts will become control dependencies for writes, reads, etc.
Use this object for all subsequent operations.
"""
return self._implementation.identity()
def grad(self, source, flow=None, name=None):
return self._implementation.grad(source, flow=flow, name=name)
def read(self, index, name=None):
"""Read the value at location `index` in the TensorArray.
Args:
index: 0-D. int32 tensor with the index to read from.
name: A name for the operation (optional).
Returns:
The tensor at index `index`.
"""
return self._implementation.read(index, name=name)
@tf_should_use.should_use_result(warn_in_eager=True)
def write(self, index, value, name=None):
"""Write `value` into index `index` of the TensorArray.
Args:
index: 0-D. int32 scalar with the index to write to.
value: N-D. Tensor of type `dtype`. The Tensor to write to this index.
name: A name for the operation (optional).
Returns:
A new TensorArray object with flow that ensures the write occurs.
Use this object for all subsequent operations.
Raises:
ValueError: if there are more writers than specified.
"""
return self._implementation.write(index, value, name=name)
def stack(self, name=None):
"""Return the values in the TensorArray as a stacked `Tensor`.
All of the values must have been written and their shapes must all match.
If input shapes have rank-`R`, then output shape will have rank-`(R+1)`.
For example:
>>> ta = tf.TensorArray(tf.int32, size=3)
>>> ta.write(0, tf.constant([1, 2]))
>>> ta.write(1, tf.constant([3, 4]))
>>> ta.write(2, tf.constant([5, 6]))
>>> ta.stack()
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[5, 6]], dtype=int32)>
Args:
name: A name for the operation (optional).
Returns:
All the tensors in the TensorArray stacked into one tensor.
"""
return self._implementation.stack(name=name)
def gather(self, indices, name=None):
"""Return selected values in the TensorArray as a packed `Tensor`.
All of selected values must have been written and their shapes
must all match.
Args:
indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the
`TensorArray` is not dynamic, `max_value=size()`.
name: A name for the operation (optional).
Returns:
The tensors in the `TensorArray` selected by `indices`, packed into one
tensor.
"""
return self._implementation.gather(indices, name=name)
def concat(self, name=None):
"""Return the values in the TensorArray as a concatenated `Tensor`.
All of the values must have been written, their ranks must match, and
and their shapes must all match for all dimensions except the first.
Args:
name: A name for the operation (optional).
Returns:
All the tensors in the TensorArray concatenated into one tensor.
"""
return self._implementation.concat(name=name)
@tf_should_use.should_use_result
def unstack(self, value, name=None):
"""Unstack the values of a `Tensor` in the TensorArray.
If input value shapes have rank-`R`, then the output TensorArray will
contain elements whose shapes are rank-`(R-1)`.
Args:
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unstack.
name: A name for the operation (optional).
Returns:
A new TensorArray object with flow that ensures the unstack occurs.
Use this object for all subsequent operations.
Raises:
ValueError: if the shape inference fails.
"""
return self._implementation.unstack(value, name=name)
@tf_should_use.should_use_result
def scatter(self, indices, value, name=None):
"""Scatter the values of a `Tensor` in specific indices of a `TensorArray`.
Args:
indices: A `1-D` `Tensor` taking values in `[0, max_value)`. If the
`TensorArray` is not dynamic, `max_value=size()`.
value: (N+1)-D. Tensor of type `dtype`. The Tensor to unpack.
name: A name for the operation (optional).
Returns:
A new TensorArray object with flow that ensures the scatter occurs.
Use this object for all subsequent operations.
Raises:
ValueError: if the shape inference fails.
"""
return self._implementation.scatter(indices, value, name=name)
@tf_should_use.should_use_result
def split(self, value, lengths, name=None):
"""Split the values of a `Tensor` into the TensorArray.
Args:
value: (N+1)-D. Tensor of type `dtype`. The Tensor to split.
lengths: 1-D. int32 vector with the lengths to use when splitting `value`
along its first dimension.
name: A name for the operation (optional).
Returns:
A new TensorArray object with flow that ensures the split occurs.
Use this object for all subsequent operations.
Raises:
ValueError: if the shape inference fails.
"""
return self._implementation.split(value, lengths, name=name)
def size(self, name=None):
"""Return the size of the TensorArray."""
return self._implementation.size(name=name)
@tf_should_use.should_use_result
def close(self, name=None):
"""Close the current TensorArray."""
return self._implementation.close(name=name)
def build_ta_with_new_flow(old_ta, flow):
"""Builds a TensorArray with a new `flow` tensor."""
# Sometimes we get old_ta as the implementation, sometimes it's the
# TensorArray wrapper object.
impl = (old_ta._implementation if isinstance(old_ta, TensorArray) else old_ta)
if not context.executing_eagerly():
if (not isinstance(impl, _GraphTensorArrayV2) and
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
raise NotImplementedError("Attempting to build a graph-mode TF2-style "
"TensorArray from either an eager-mode "
"TensorArray or a TF1-style TensorArray. "
"This is not currently supported. You may be "
"attempting to capture a TensorArray "
"inside a tf.function or tf.data map function. "
"Instead, construct a new TensorArray inside "
"the function.")
new_ta = TensorArray(
dtype=impl.dtype,
handle=impl.handle,
flow=flow,
infer_shape=impl._infer_shape,
colocate_with_first_write_call=impl._colocate_with_first_write_call)
new_impl = new_ta._implementation
new_impl._dynamic_size = impl._dynamic_size
new_impl._size = impl._size
new_impl._colocate_with = impl._colocate_with
new_impl._element_shape = impl._element_shape # Share _element_shape.
return new_ta
# pylint: enable=protected-access
def _check_dtypes(value, dtype):
if value.dtype != dtype:
logging.error("Error: Input value {} has dtype {}, but expected dtype {}. "
"This leads to undefined behavior and will be an error "
"in future versions of TensorFlow. Traceback:\n{}".format(
value, str(value.dtype), str(dtype),
"".join(traceback.format_stack())))
@tf_export("TensorArraySpec")
@type_spec.register("tf.TensorArraySpec")
class TensorArraySpec(type_spec.TypeSpec):
"""Type specification for a `tf.TensorArray`."""
__slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
value_type = property(lambda self: TensorArray)
def __init__(self,
element_shape=None,
dtype=dtypes.float32,
dynamic_size=False,
infer_shape=True):
"""Constructs a type specification for a `tf.TensorArray`.
Args:
element_shape: The shape of each element in the `TensorArray`.
dtype: Data type of the `TensorArray`.
dynamic_size: Whether the `TensorArray` can grow past its initial size.
infer_shape: Whether shape inference is enabled.
"""
self._element_shape = tensor_shape.as_shape(element_shape)
self._dtype = dtypes.as_dtype(dtype)
self._dynamic_size = dynamic_size
self._infer_shape = infer_shape
def is_subtype_of(self, other):
# pylint: disable=protected-access
return (isinstance(other, TensorArraySpec) and
self._dtype == other._dtype and
self._dynamic_size == other._dynamic_size)
def most_specific_common_supertype(self, others):
"""Returns the most specific supertype of `self` and `others`.
Args:
others: A Sequence of `TypeSpec`.
Returns `None` if a supertype does not exist.
"""
# pylint: disable=protected-access
if not all(isinstance(other, TensorArraySpec) for other in others):
return False
common_shape = self._element_shape.most_specific_common_supertype(
other._element_shape for other in others)
if common_shape is None:
return None
if not all(self._dtype == other._dtype for other in others):
return None
if not all(self._dynamic_size == other._dynamic_size for other in others):
return None
infer_shape = self._infer_shape and all(
other._infer_shape for other in others)
return TensorArraySpec(common_shape, self._dtype, self._dynamic_size,
infer_shape)
def is_compatible_with(self, other):
# pylint: disable=protected-access
if not isinstance(other, type_spec.TypeSpec):
other = type_spec.type_spec_from_value(other)
# Note: we intentionally exclude infer_shape in this check.
return (isinstance(other, TensorArraySpec) and
self._dtype.is_compatible_with(other._dtype) and
self._element_shape.is_compatible_with(other._element_shape) and
self._dynamic_size == other._dynamic_size)
# TODO(b/221472813): Migrate logic to most_specific_common_supertype.
def most_specific_compatible_type(self, other):
"""Deprecated. Use most_specific_common_supertype instead."""
# pylint: disable=protected-access
if not self.is_compatible_with(other):
raise ValueError(f"Type `{self}` is not compatible with `{other}`.")
infer_shape = self._infer_shape and other._infer_shape
return TensorArraySpec(
self._element_shape.most_specific_compatible_shape(
other._element_shape), self._dtype, self._dynamic_size, infer_shape)
def _serialize(self):
return (self._element_shape, self._dtype, self._dynamic_size,
self._infer_shape)
@property
def _component_specs(self):
return [tensor_spec.TensorSpec([], dtypes.variant)]
def _to_components(self, value):
if not isinstance(value, TensorArray):
raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
type(value)))
if value.flow is not None and value.flow.dtype == dtypes.variant:
return [value.flow]
else:
# Convert to a TF2-style TensorArray.
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
# "implementation / as_variant" arg to TensorArray constructor.
with ops.name_scope("convert_tensor_array"):
flow = list_ops.tensor_list_from_tensor(
tensor=value.stack(), element_shape=value.element_shape)
return [flow]
def _from_components(self, tensor_list):
# This will return a TF2 Graph-style TensorArray because tensor_list[0] is
# a variant object. size == -1 implies unknown size.
ret = TensorArray(
dtype=self._dtype,
flow=tensor_list[0],
dynamic_size=self._dynamic_size,
infer_shape=self._infer_shape)
ret._implementation._element_shape = [self._element_shape] # pylint: disable=protected-access
return ret
@staticmethod
def from_value(value):
if not isinstance(value, TensorArray):
raise TypeError("Expected value to be a TensorArray, but got: `{}`".format(
type(value)))
return TensorArraySpec(
dtype=value.dtype,
element_shape=value.element_shape,
dynamic_size=value.dynamic_size,
infer_shape=value._infer_shape) # pylint: disable=protected-access
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
# Sneak the dynamic_size and infer_shape values into the legacy shape.
return (tensor_shape.TensorShape([self._dynamic_size, self._infer_shape
]).concatenate(self._element_shape))
def _to_legacy_output_classes(self):
return TensorArray
# Register the TypeSpec for TensorArray. If TensorArray is updated to be a
# CompositeTensor, then this registration can be deleted.
type_spec.register_type_spec_from_value_converter(
TensorArray, TensorArraySpec.from_value, allow_subclass=True)