blob: 805669b968a3467236a3bf1eda58611836316a1d [file] [log] [blame]
# Lint as python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StructuredTensor array ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Sequence
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
@dispatch.dispatch_for_types(array_ops.expand_dims, StructuredTensor)
@deprecation.deprecated_args(None, 'Use the `axis` argument instead', 'dim')
def expand_dims(input, axis=None, name=None, dim=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
input: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
dim: deprecated: use axis.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
axis = deprecation.deprecated_argument_lookup('axis', axis, 'dim', dim)
return _expand_dims_impl(input, axis, name=name)
@dispatch.dispatch_for_types(array_ops.expand_dims_v2, StructuredTensor)
def expand_dims_v2(input, axis, name=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
input: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
return _expand_dims_impl(input, axis, name=name)
@dispatch.dispatch_for_types(array_ops.gather, StructuredTensor)
def gather(params,
indices,
validate_indices=None,
name=None,
axis=None,
batch_dims=0):
"""tf.gather for structured tensors.
Does not support (yet) checks on illegal axis values, et cetera.
Indices must be a ragged or dense tensor.
Args:
params: a structured tensor to be gathered
indices: a ragged tensor or tensor to gather by.
validate_indices: whether to validate the indices
name: the name of the op(s).
axis: the axis in params to gather on.
batch_dims: the number of batch dimensions.
Returns:
the params reorganized according to indices.
"""
if name is None:
name = 'gather'
with ops.name_scope(name):
if axis is None:
axis = batch_dims
ndims_name = params.shape.rank
axis = array_ops.get_positive_axis(axis, ndims_name)
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices')
def leaf_op(p):
return array_ops.gather(
p,
indices,
validate_indices=validate_indices,
axis=axis,
batch_dims=batch_dims,
name=None)
return _extend_op_single(params, leaf_op)
@dispatch.dispatch_for_types(array_ops.concat, StructuredTensor)
def concat(values, axis, name: str = 'concat'):
"""tf.concat for structured tensors.
Does not support (yet) checks on illegal axis values, et cetera.
Args:
values: a sequence of StructuredTensors.
axis: an axis to concatenate upon.
name: the name of the op(s).
Returns:
the params reorganized according to indices.
"""
if name is None:
name = 'concat'
_assert_concat_compatible_structured_tensors(values)
def leaf_op(values):
return array_ops.concat(values, axis)
# TODO(martinz): handle axis when it is a tensor.
axis = array_ops.get_positive_axis(axis, values[0].rank)
with ops.name_scope(name, 'StructuredConcat', values):
return _extend_op(values, leaf_op)
@dispatch.dispatch_for_types(random_ops.random_shuffle, StructuredTensor)
def random_shuffle(value, seed=None, name=None):
"""Shuffle a structured tensor on the zeroth axis.
Args:
value: a structured tensor of rank at least one.
seed: the seed for shuffling.
name: the name for shuffle.
Returns:
The shuffled structured tensor.
"""
with ops.name_scope(name, 'shuffle', [value, seed]):
if value.rank == 0:
raise ValueError('Cannot shuffle a scalar StructuredTensor')
first_dimension = value.nrows()
index = random_ops.random_shuffle(math_ops.range(first_dimension),
seed=seed)
return gather(value, index, axis=0)
# pylint: disable=protected-access
@dispatch.dispatch_for_types(array_ops.zeros_like, StructuredTensor)
def zeros_like(tensor, dtype=None, name=None, optimize=True):
"""Implementation of zeros_like for StructuredTensor for TF v1."""
del optimize
return zeros_like_v2(tensor, dtype=dtype, name=name)
# pylint: disable=protected-access
@dispatch.dispatch_for_types(array_ops.zeros_like_v2, StructuredTensor)
def zeros_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin
"""Replace every object with a zero.
Example:
>>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
>>> tf.zeros_like(st)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([0.0, 0.0], dtype=float32)>
>>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
>>> tf.zeros_like(st, dtype=tf.int32)
<tf.RaggedTensor [[0], [0, 0]]>
Args:
input: a structured tensor.
dtype: the dtype of the resulting zeros. (default is tf.float32)
name: a name for the op.
Returns:
a tensor of zeros of the same shape.
"""
if dtype is None:
dtype = dtypes.float32
with ops.name_scope(name, 'zeros_like', [input]) as name:
if not input._row_partitions:
if input._nrows is not None:
return array_ops.zeros([input._nrows], dtype) # vector.
else:
return array_ops.zeros([], dtype) # scalar.
# 2D and up.
last_row_partition = input._row_partitions[-1]
result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
array_ops.zeros(last_row_partition.nvals(), dtype=dtype),
input._row_partitions)
return result
# pylint: disable=protected-access
@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor)
def ones_like(tensor, dtype=None, name=None, optimize=True):
"""Implementation of zeros_like for StructuredTensor for TF v1."""
del optimize
return ones_like_v2(tensor, dtype=dtype, name=name)
# pylint: disable=protected-access
@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor)
def ones_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin
"""Replace every object with a zero.
Example:
>>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
>>> tf.ones_like(st)
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)>
>>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
>>> tf.ones_like(st, dtype=tf.int32)
<tf.RaggedTensor [[1], [1, 1]]>
Args:
input: a structured tensor.
dtype: the dtype of the resulting zeros. (default is tf.float32)
name: a name for the op.
Returns:
a tensor of zeros of the same shape.
"""
if dtype is None:
dtype = dtypes.float32
with ops.name_scope(name, 'ones_like', [input]) as name:
if not input._row_partitions:
if input._nrows is not None:
return array_ops.ones([input._nrows], dtype) # vector.
else:
return array_ops.ones([], dtype) # scalar.
# 2D and up.
last_row_partition = input._row_partitions[-1]
result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
array_ops.ones(last_row_partition.nvals(), dtype=dtype),
input._row_partitions)
return result
def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
This is an implementation of tf.expand_dims for StructuredTensor. Note
that the `axis` must be less than or equal to rank.
>>> st = StructuredTensor.from_pyval([[{"x": 1}, {"x": 2}], [{"x": 3}]])
>>> tf.expand_dims(st, 0).to_pyval()
[[[{'x': 1}, {'x': 2}], [{'x': 3}]]]
>>> tf.expand_dims(st, 1).to_pyval()
[[[{'x': 1}, {'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, 2).to_pyval()
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
>>> tf.expand_dims(st, -1).to_pyval() # -1 is the same as 2
[[[{'x': 1}], [{'x': 2}]], [[{'x': 3}]]]
Args:
st: the original StructuredTensor.
axis: the axis to insert the dimension: `-(rank + 1) <= axis <= rank`
name: the name of the op.
Returns:
a new structured tensor with larger rank.
Raises:
an error if `axis < -(rank + 1)` or `rank < axis`.
"""
axis = array_ops.get_positive_axis(
axis, st.rank + 1, axis_name='axis', ndims_name='rank(st)')
with ops.name_scope(name, 'ExpandDims', [st, axis]):
new_fields = {
k: array_ops.expand_dims(v, axis) for (k, v) in st._fields.items()
}
new_shape = st.shape[:axis] + (1,) + st.shape[axis:]
new_row_partitions = _expand_st_row_partitions(st, axis)
new_nrows = st.nrows() if (axis > 0) else 1
return StructuredTensor.from_fields(
new_fields,
shape=new_shape,
row_partitions=new_row_partitions,
nrows=new_nrows)
def _expand_st_row_partitions(st, axis):
"""Create the row_partitions for expand_dims."""
if axis == 0:
if st.shape.rank == 0:
return ()
nvals = st.nrows()
new_partition = RowPartition.from_uniform_row_length(
nvals, nvals, nrows=1, validate=False)
return (new_partition,) + st.row_partitions
elif axis == st.rank:
nvals = (
st.row_partitions[axis - 2].nvals() if (axis - 2 >= 0) else st.nrows())
return st.row_partitions + (RowPartition.from_uniform_row_length(
1, nvals, nrows=nvals, validate=False),)
else:
nvals = (
st.row_partitions[axis - 1].nrows() if (axis - 1 >= 0) else st.nrows())
return st.row_partitions[:axis - 1] + (RowPartition.from_uniform_row_length(
1, nvals, nrows=nvals, validate=False),) + st.row_partitions[axis - 1:]
# TODO(martinz): consider allowing values to be nested.
def _extend_op(values, leaf_op, empty_st_op=None):
"""Extend an op from RaggedTensor and Tensor to StructuredTensor.
Visits all children of the structured tensor, and children of children,
applying leaf_op whenever it reaches a leaf, and empty_st_op whenever
it reaches an internal node without children.
Args:
values: a list of structured tensors, ragged tensors, or tensors. All must
have the same type. If they are structured tensors, they must have the
same paths.
leaf_op: an op for handling non-structured tensor.
empty_st_op: op to create a structured tensor without fields.
Returns:
the result of the extended op (a StructuredTensor, RaggedTensor, or Tensor)
Raises:
ValueError:
If values is not a Sequence or is empty.
"""
if not isinstance(values, Sequence):
raise ValueError('Expected a list')
if not values:
raise ValueError('List cannot be empty')
if empty_st_op is None:
empty_st_op = empty_st_op_like_zeros(leaf_op)
# Use the structure of the first StructuredTensor. They are all assumed to
# be the same.
value = values[0]
if isinstance(value, StructuredTensor):
# TODO(martinz): Calling empty_st_op may add unnecessary ops. Revisit later.
empty_result = empty_st_op(values)
if not value.field_names():
return empty_result
new_fields = {}
for k in value.field_names():
new_fields[k] = _extend_op([v.field_value(k) for v in values], leaf_op,
empty_st_op)
return StructuredTensor.from_fields(new_fields, shape=empty_result.shape)
else:
return leaf_op(values)
def _extend_op_single(value, leaf_op, empty_st_op=None):
"""Extend an op to a value instead of a list of values."""
def to_list_op(element_op):
if element_op is None:
return None
def list_op(values):
[value] = values
return element_op(value)
return list_op
return _extend_op([value], to_list_op(leaf_op), to_list_op(empty_st_op))
def empty_st_op_like_zeros(leaf_op):
def empty_st_op(values):
as_zeros = [
zeros_like_v2(value, dtype=dtypes.int32) for value in values
]
result = leaf_op(as_zeros)
return _structured_tensor_like(result)
return empty_st_op
def _structured_tensor_from_dense_tensor(t):
"""Create a structured tensor with the shape of a dense tensor."""
# Note: If a tensor will have rank 0,
# it either has a fully defined shape or has unknown rank.
if t.shape.is_fully_defined():
return StructuredTensor.from_fields({}, shape=t.shape)
elif t.shape.rank is None:
raise ValueError("Can't build StructuredTensor w/ unknown rank")
elif t.shape.rank == 1:
return StructuredTensor.from_fields({}, shape=t.shape,
nrows=array_ops.shape(t)[0])
else:
rt = ragged_tensor.RaggedTensor.from_tensor(t)
return _structured_tensor_from_row_partitions(t.shape,
rt._nested_row_partitions)
def _structured_tensor_from_row_partitions(shape, row_partitions):
return StructuredTensor.from_fields({},
shape=shape,
row_partitions=row_partitions)
# pylint: disable=protected_access
def _all_nested_row_partitions(rt):
"""Returns all nested row partitions in rt, including for dense dimensions."""
if isinstance(rt, ops.Tensor):
if rt.shape.rank <= 1:
return ()
else:
rt2 = ragged_tensor.RaggedTensor.from_tensor(rt)
return rt2._nested_row_partitions
else:
tail_partitions = _all_nested_row_partitions(rt.flat_values)
head_partitions = rt._nested_row_partitions # pylint: disable=protected_access
return head_partitions + tail_partitions
def _structured_tensor_like(t):
"""Create a StructuredTensor with the shape of a (composite) tensor."""
if isinstance(t, ops.Tensor):
return _structured_tensor_from_dense_tensor(t)
if ragged_tensor.is_ragged(t):
return StructuredTensor.from_fields(
{}, shape=t.get_shape(), row_partitions=_all_nested_row_partitions(t))
# here, it is a StructuredTensor
return StructuredTensor.from_fields({},
shape=t.shape,
row_partitions=t.row_partitions,
nrows=t.nrows())
def _get_all_paths(st):
"""Get all the paths from a StructuredTensor."""
fields = st.field_names()
all_paths = {()}
for k in fields:
v = st.field_value(k)
if isinstance(v, StructuredTensor):
all_paths = all_paths.union([(k,) + p for p in _get_all_paths(v)])
else:
all_paths.add((k,))
return all_paths
def _get_all_ranks(st):
"""Get ranks of all submessages of a StructuredTensor."""
fields = st.field_names()
all_ranks = {(): st.rank}
for k in fields:
v = st.field_value(k)
if isinstance(v, StructuredTensor):
for (k2, v2) in _get_all_ranks(v).items():
all_ranks[(k,) + k2] = v2
return all_ranks
def _assert_all_paths_match(values):
"""Raises an error if the paths are not identical."""
paths = [_get_all_paths(st) for st in values]
path_diff = set()
for other_paths in paths[1:]:
path_diff = path_diff.union(paths[0].symmetric_difference(other_paths))
if path_diff:
raise ValueError(
'Some paths are present in some, but not all, structured tensors: %r' %
(path_diff,))
def _assert_all_ranks_match(values):
"""Raises an error if the ranks of submessages are not identical."""
ranks = [_get_all_ranks(st) for st in values]
for other_ranks in ranks[1:]:
if other_ranks != ranks[0]:
# TODO(martinz): If this becomes common, we can provide more detail.
# e.g.: which path is inconsistent.
raise ValueError('Ranks of sub-message do not match')
def _assert_concat_compatible_structured_tensors(values):
"""Sometimes raises an error if concat doesn't make sense statically on values.
values must be a sequence, and each element in values must be a structured
tensor, and must have the same paths. Additionally, each path that is a
submessage must have the same rank.
These constraints are sufficient for concat on the fields to be the same
as concat on structured tensors. This is meant to capture scenarios like
paths that are not in the first structured tensor, but are in later
structured tensors, which will just be ignored by the recursive algorithm.
If the rank of a submessage was different for two structured tensors,
then that is also a non-sensical merge.
Note that all of these checks are static, as paths and submessage ranks
are known.
Args:
values: a Sequence of StructuredTensors.
Raises:
ValueError: if there is any inconsistency as described above.
"""
if not isinstance(values, Sequence):
raise ValueError('values must be a list of StructuredTensors (not a list)')
if not values:
raise ValueError('values must not be an empty list')
for st in values:
if not isinstance(st, StructuredTensor):
raise ValueError('values must be a list of StructuredTensors')
_assert_all_paths_match(values)
_assert_all_ranks_match(values)