| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Shapes & broadcasting for RaggedTensors.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| from tensorflow.python.framework import constant_op |
| from tensorflow.python.framework import dtypes |
| from tensorflow.python.framework import ops |
| from tensorflow.python.framework import tensor_shape |
| from tensorflow.python.framework import tensor_util |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import control_flow_ops |
| from tensorflow.python.ops import math_ops |
| from tensorflow.python.ops.ragged import ragged_array_ops |
| from tensorflow.python.ops.ragged import ragged_config |
| from tensorflow.python.ops.ragged import ragged_tensor |
| from tensorflow.python.ops.ragged import ragged_util |
| |
| |
| class RaggedTensorDynamicShape(object): |
| """A collection of tensors encoding the shape of a potentially ragged tensor. |
| |
| Each `RaggedTensorDynamicShape` consists of an ordered list of dimension |
| sizes. There are two dimension types: |
| |
| * "Uniform dimensions" are dimensions where all slices have the same |
| length. `RaggedTensorDynamicShape` records the size of each uniform |
| dimension using a single scalar integer. |
| |
| * "Ragged dimensions" are dimensions whose slices may have different |
| lengths. `RaggedTensorDynamicShape` records the size of each ragged |
| dimension using an integer vector containing the slice lengths for all |
| the slices across that dimension. |
| |
| Furthermore, there are two ways a dimension might be encoded: |
| |
| * "Partitioned dimensions" are dimensions that are encoded using a |
| `RowPartition`. The outermostmost partitioned dimension must be uniform. |
| |
| * "Inner dimensions" are dimensions that are encoded using a |
| `RaggedTensor`'s `flat_values`. Inner dimensions are always uniform. |
| |
| The sizes of partitioned dimensions are recorded using `partitioned_dim_sizes` |
| and `inner_dim_sizes`: |
| |
| * `partitioned_dim_sizes` is a list of tensors (one for each partitioned |
| dimension). |
| |
| * For uniform dimensions, the tensor is an integer scalar specifying the |
| size of all slices across that dimension. |
| * For ragged dimensions, the tensor is an integer vector specifying the |
| size of each slice across that dimension. |
| |
| * `inner_dim_sizes` is a single integer vector, where each element |
| specifies the size of a single inner dimension. |
| |
| Examples: |
| |
| Tensor | Ragged | Partitioned Dim Sizes | Inner Dim |
| : Rank : : Sizes |
| ------------------------------ | ------ | ---------------------- | ---------- |
| `[[1, 2, 3], [4, 5, 6]]` | 0 | | `2, 3` |
| `[[1, 2], [], [3, 4, 5]]` | 1 | `3, (2, 0, 3)` | |
| `[[[1, 2], [3, 4]], [[5, 6]]]` | 1 | `2, (2, 1)` | 2 |
| `[[[1, 2], [3]], [[4, 5]]]` | 2 | `2, (2, 1), (2, 1, 2)` | |
| """ |
| |
| def __init__(self, partitioned_dim_sizes, inner_dim_sizes, |
| dim_size_dtype=None): |
| """Creates a RaggedTensorDynamicShape. |
| |
| Args: |
| partitioned_dim_sizes: A `list` of 0-D or 1-D integer `Tensor`, one for |
| each partitioned dimension. If dimension `d` is uniform, then |
| `partitioned_dim_sizes[d]` must be an integer scalar, specifying the |
| size of all slices across dimension `d`. If dimension `d` is ragged, |
| then `partitioned_dim_sizes[d]` must be an integer vector, specifying |
| the size of each slice across dimension `d`. |
| inner_dim_sizes: A 1-D integer `Tensor`, whose length is equal to the |
| number of inner dimensions. `inner_dim_sizes[n]` is the size of all |
| slices across the `n`th inner dimension (which is the |
| `(len(partitioned_dim_sizes)+n)`th dimension in the overall tensor. |
| dim_size_dtype: dtype for dimension sizes. If not specified, then it |
| is chosen based on the dtypes of `partitioned_dim_sizes` and |
| `inner_dim_sizes`. |
| """ |
| assert isinstance(partitioned_dim_sizes, (list, tuple)) |
| |
| with ops.name_scope(None, 'RaggedTensorDynamicShape', |
| (partitioned_dim_sizes, inner_dim_sizes)): |
| partitioned_dim_sizes = tuple( |
| ops.convert_to_tensor(size, name='partitioned_dimension_size_%d' % i) |
| for (i, size) in enumerate(partitioned_dim_sizes)) |
| inner_dim_sizes = ops.convert_to_tensor( |
| inner_dim_sizes, name='inner_dim_sizes') |
| |
| # Validate shapes. |
| if partitioned_dim_sizes: |
| for axis, dimension_size in enumerate(partitioned_dim_sizes): |
| if dimension_size.shape.ndims is None: |
| raise ValueError( |
| 'rank of partitioned_dim_sizes[%d] is unknown' % axis) |
| dimension_size.shape.with_rank_at_most(1) |
| if partitioned_dim_sizes[0].shape.ndims == 1: |
| raise ValueError('outermost partitioned dimension must be uniform') |
| inner_dim_sizes.shape.assert_has_rank(1) |
| |
| # Convert dimension size tensors to a single dtype. |
| if dim_size_dtype is None: |
| dim_size_dtypes = set( |
| p.dtype for p in partitioned_dim_sizes if p.shape.ndims == 1) |
| if not dim_size_dtypes: |
| dim_size_dtype = dtypes.int64 |
| elif len(dim_size_dtypes) == 1: |
| dim_size_dtype = dim_size_dtypes.pop() |
| else: |
| if not ragged_config.auto_cast_partition_dtype(): |
| raise ValueError('partitioned_dim_sizes must have matching dtypes') |
| dim_size_dtype = dtypes.int64 |
| partitioned_dim_sizes = tuple(math_ops.cast(p, dim_size_dtype) |
| for p in partitioned_dim_sizes) |
| inner_dim_sizes = math_ops.cast(inner_dim_sizes, dim_size_dtype) |
| |
| self._partitioned_dim_sizes = partitioned_dim_sizes |
| self._inner_dim_sizes = inner_dim_sizes |
| |
| def __repr__(self): |
| return ('RaggedTensorDynamicShape' |
| '(partitioned_dim_sizes=%r, inner_dim_sizes=%r)' % |
| (self._partitioned_dim_sizes, self._inner_dim_sizes)) |
| |
| @staticmethod |
| def from_dim_sizes(dim_sizes): |
| """Constructs a ragged shape from a list of dimension sizes. |
| |
| This list contains a single tensor for each dimension, where the tensor |
| is a scalar if the dimension is uniform, or a vector if the dimension is |
| ragged. |
| |
| Args: |
| dim_sizes: List of int32 or int64 scalars or vectors. |
| |
| Returns: |
| A RaggedTensorDynamicShape. |
| """ |
| with ops.name_scope(None, 'RaggedTensorDynamicShapeFromDimensionSizes', |
| [dim_sizes]): |
| dim_sizes = tuple( |
| ops.convert_to_tensor(size, preferred_dtype=dtypes.int64, |
| name='dim_sizes') for size in dim_sizes) |
| # Split the dimensions into partitioned & inner dimensions. |
| inner_split = 0 |
| for dim, dim_size in enumerate(dim_sizes): |
| if dim_size.shape.ndims == 1: |
| inner_split = dim + 1 |
| elif dim_size.shape.ndims != 0: |
| raise ValueError('Each dim_size must be a scalar or a vector') |
| return RaggedTensorDynamicShape(dim_sizes[:inner_split], |
| dim_sizes[inner_split:]) |
| |
| @classmethod |
| def from_tensor(cls, rt_input, dim_size_dtype=None): |
| """Constructs a ragged shape for a potentially ragged tensor.""" |
| with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]): |
| rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) |
| if not ragged_tensor.is_ragged(rt_input): |
| return cls([], array_ops.shape(rt_input)) |
| else: |
| partitioned_dim_sizes = [rt_input.nrows()] |
| rt = rt_input |
| while ragged_tensor.is_ragged(rt): |
| if rt.uniform_row_length is None: |
| partitioned_dim_sizes.append(rt.row_lengths()) |
| else: |
| partitioned_dim_sizes.append(rt.uniform_row_length) |
| rt = rt.values |
| |
| return RaggedTensorDynamicShape( |
| tuple(partitioned_dim_sizes), |
| array_ops.shape(rt_input.flat_values)[1:], |
| dim_size_dtype=dim_size_dtype) |
| |
| def dimension_size(self, axis): |
| """Returns the size of slices across the specified dimension.""" |
| if not isinstance(axis, int): |
| raise TypeError('axis must be an integer') |
| partitioned_ndims = len(self._partitioned_dim_sizes) |
| if axis < partitioned_ndims: |
| return self._partitioned_dim_sizes[axis] |
| else: |
| return self._inner_dim_sizes[axis - partitioned_ndims] |
| |
| def is_ragged(self, axis): |
| """Returns true if the indicated dimension is ragged.""" |
| if not isinstance(axis, int): |
| raise TypeError('axis must be an integer') |
| rank = self.rank |
| if axis < 0: |
| raise ValueError('Negative axis values are not supported') |
| elif rank is not None and axis >= rank: |
| raise ValueError('Expected axis=%s < rank=%s' % (axis, rank)) |
| else: |
| return (axis > 0 and axis < len(self._partitioned_dim_sizes) and |
| self._partitioned_dim_sizes[axis].shape.ndims == 1) |
| |
| @property |
| def rank(self): |
| """The number of dimensions in this shape, or None if unknown.""" |
| inner_ndims = tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) |
| if inner_ndims is None: |
| return None |
| else: |
| return len(self._partitioned_dim_sizes) + inner_ndims |
| |
| @property |
| def partitioned_dim_sizes(self): |
| """The partitioned dimension sizes for this shape. |
| |
| Returns: |
| A `list` of 0-D or 1-D integer `Tensor`. |
| """ |
| return self._partitioned_dim_sizes |
| |
| @property |
| def inner_dim_sizes(self): |
| """The inner dimension sizes for this shape. |
| |
| Returns: |
| A 1-D integer `Tensor`. |
| """ |
| return self._inner_dim_sizes |
| |
| @property |
| def num_partitioned_dimensions(self): |
| """The number of partitioned dimensions in this shape.""" |
| return len(self._partitioned_dim_sizes) |
| |
| @property |
| def num_inner_dimensions(self): |
| """The number of inner dimensions, or `None` if not statically known.""" |
| return tensor_shape.dimension_value(self._inner_dim_sizes.shape[0]) |
| |
| @property |
| def dim_size_dtype(self): |
| """DType used by this shape for dimension sizes.""" |
| return self._inner_dim_sizes.dtype |
| |
| def broadcast_to_rank(self, rank): |
| """Adds leading size-1 dimensions to broadcast `self` to the given rank. |
| |
| E.g., if `shape1` is `[3, (D2), 4]`, then `shape1.broadcast_to_rank(5)` |
| is `[1, 1, 3, (D2), 4]`. |
| |
| Args: |
| rank: The rank for the returned shape. |
| |
| Returns: |
| A RaggedTensorDynamicShape with `rank` dimensions, whose inner dimensions |
| have the same size as `self` and whose outer dimensions have size `1`. |
| |
| Raises: |
| ValueError: If `self.rank` is unknown or greater than `rank`. |
| """ |
| if self.rank is None: |
| raise ValueError('Unable to broadcast: self.rank is unknown') |
| dims_to_add = rank - self.rank |
| if dims_to_add < 0: |
| raise ValueError('Unable to broadcast: rank=%d must be greater than ' |
| 'self.rank=%d.' % (rank, self.rank)) |
| elif dims_to_add == 0: |
| return self |
| elif self._partitioned_dim_sizes: |
| partitioned_dims = (1,) * dims_to_add + self._partitioned_dim_sizes |
| return RaggedTensorDynamicShape(partitioned_dims, self._inner_dim_sizes) |
| else: |
| inner_dims = array_ops.concat( |
| [array_ops.ones([dims_to_add], self.dim_size_dtype), |
| self.inner_dim_sizes], |
| axis=0) |
| return RaggedTensorDynamicShape([], inner_dims) |
| |
| def broadcast_dimension(self, axis, lengths): |
| """Returns a shape that is broadcast-compatible with self & lengths. |
| |
| * If dimension[axis] is uniform and lengths is a scalar, the check |
| that either lengths==1 or axis==1 or lengths==axis, and tile |
| dimension[axis] with tf.where(lengths==axis, 1, axis) repeats. |
| |
| * If dimension[axis] is uniform and lengths is a vector, then check |
| that dimension[axis]==1, and raggedly tile dimension[axis] with |
| lengths repeats. (we can skip tiling if we statically know that |
| slice_lengths == 1??) |
| |
| * If dimension[axis] is ragged and lengths is a scalar, then check |
| that lengths==1. |
| |
| * If dimension[axis] is ragged and lengths is a vector, then check |
| that self.dimension_size(axis) == lengths. |
| |
| Args: |
| axis: `int`. The dimension to broadcast. |
| lengths: 0-D or 1-D integer `Tensor`. |
| |
| Returns: |
| A `RaggedTensorDynamicShape`. |
| """ |
| lengths = ragged_util.convert_to_int_tensor( |
| lengths, name='lengths', dtype=self.dim_size_dtype) |
| # Check whether lengths is a scalar (for uniform dimensions) or |
| # vector (for ragged dimensions). |
| if lengths.shape.ndims is None: |
| raise ValueError('lengths must have a known rank.') |
| elif lengths.shape.ndims > 1: |
| raise ValueError('lengths must be a scalar or vector') |
| else: |
| lengths_is_scalar = (lengths.shape.ndims == 0) |
| |
| # Verify that the shapes are compatible. |
| if self.is_ragged(axis): |
| if lengths_is_scalar: |
| condition = math_ops.equal(lengths, 1) |
| else: |
| condition = math_ops.reduce_all( |
| math_ops.equal(lengths, self.dimension_size(axis))) |
| else: |
| axis_dim_size = self.dimension_size(axis) |
| if lengths_is_scalar: |
| condition = ( |
| math_ops.equal(lengths, 1) | math_ops.equal(axis_dim_size, 1) |
| | math_ops.equal(axis_dim_size, lengths)) |
| else: |
| condition = math_ops.equal(axis_dim_size, 1) |
| broadcast_err = [ |
| 'Unable to broadcast: dimension size mismatch in dimension', axis, |
| 'lengths=', lengths, 'dim_size=', |
| self.dimension_size(axis) |
| ] |
| broadcast_check = control_flow_ops.Assert( |
| condition, data=broadcast_err, summarize=10) |
| |
| with ops.control_dependencies([broadcast_check]): |
| # Partitioned dimensions: |
| if axis < self.num_partitioned_dimensions: |
| if self.is_ragged(axis): |
| # Use an identity op to make sure the check actually gets run. |
| return RaggedTensorDynamicShape( |
| self._partitioned_dim_sizes, |
| array_ops.identity(self.inner_dim_sizes)) |
| else: |
| return self._broadcast_uniform_partitioned_dimension(axis, lengths) |
| |
| # Inner dimensions: |
| else: |
| if lengths_is_scalar: |
| return self._broadcast_inner_dimension_to_uniform(axis, lengths) |
| else: |
| if axis == 0: |
| raise ValueError('Unable to broadcast: ' |
| 'outermost dimension must be uniform.') |
| return self._broadcast_inner_dimension_to_ragged(axis, lengths) |
| |
| def num_slices_in_dimension(self, axis): |
| """Returns the total number of slices across the indicated dimension.""" |
| if axis < 0: |
| return constant_op.constant(1, dtype=self.dim_size_dtype) |
| elif self.is_ragged(axis): |
| return math_ops.reduce_sum(self._partitioned_dim_sizes[axis]) |
| else: |
| return self.dimension_size(axis) * self.num_slices_in_dimension(axis - 1) |
| |
| def _broadcast_uniform_partitioned_dimension(self, axis, lengths): |
| """Broadcasts the partitioned dimension `axis` to match `lengths`.""" |
| axis_dim_size = self.dimension_size(axis) |
| partitioned_sizes = list(self._partitioned_dim_sizes[:axis]) |
| |
| if lengths.shape.ndims == 0: |
| lengths = array_ops.where( |
| math_ops.equal(axis_dim_size, 1), lengths, axis_dim_size) |
| repeats = array_ops.where(math_ops.equal(axis_dim_size, 1), lengths, 1) |
| splits = array_ops.stack([0, self.num_slices_in_dimension(axis)]) |
| else: |
| splits = math_ops.range( |
| array_ops.size(lengths, out_type=self.dim_size_dtype) + 1) |
| repeats = lengths |
| |
| partitioned_sizes.append(lengths) |
| |
| for dim_size in self._partitioned_dim_sizes[axis + 1:]: |
| if dim_size.shape.ndims == 0: |
| partitioned_sizes.append(dim_size) |
| splits *= dim_size |
| else: |
| partitioned_sizes.append( |
| ragged_util.repeat_ranges(dim_size, splits, repeats)) |
| splits = array_ops.gather( |
| ragged_util.lengths_to_splits(dim_size), splits) |
| inner_sizes = self._inner_dim_sizes |
| return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) |
| |
| def _broadcast_inner_dimension_to_uniform(self, axis, length): |
| """Broadcasts the inner dimension `axis` to match `lengths`.""" |
| dim_size = self.dimension_size(axis) |
| axis_in_inner_dims = axis - self.num_partitioned_dimensions |
| partitioned_sizes = self._partitioned_dim_sizes |
| inner_sizes = array_ops.concat([ |
| self._inner_dim_sizes[:axis_in_inner_dims], |
| [array_ops.where(math_ops.equal(dim_size, 1), length, dim_size)], |
| self._inner_dim_sizes[axis_in_inner_dims + 1:] |
| ], |
| axis=0) |
| return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) |
| |
| def _broadcast_inner_dimension_to_ragged(self, axis, lengths): |
| axis_in_inner_dims = axis - self.num_partitioned_dimensions |
| partitioned_sizes = ( |
| self._partitioned_dim_sizes + tuple([ |
| self._inner_dim_sizes[i] for i in range(axis_in_inner_dims) |
| ]) + (lengths,)) |
| inner_sizes = self._inner_dim_sizes[axis_in_inner_dims + 1:] |
| return RaggedTensorDynamicShape(partitioned_sizes, inner_sizes) |
| |
| def with_dim_size_dtype(self, dtype): |
| if dtype not in (dtypes.int32, dtypes.int64): |
| raise ValueError('dtype must be int32 or int64') |
| if self.dim_size_dtype == dtype: |
| return self |
| return RaggedTensorDynamicShape( |
| [math_ops.cast(p, dtype) for p in self._partitioned_dim_sizes], |
| math_ops.cast(self._inner_dim_sizes, dtype)) |
| |
| |
| def broadcast_dynamic_shape(shape_x, shape_y): |
| """Returns the shape formed by broadcasting two shapes to be compatible. |
| |
| Args: |
| shape_x: A `RaggedTensorDynamicShape` |
| shape_y: A `RaggedTensorDynamicShape` |
| |
| Returns: |
| A `RaggedTensorDynamicShape`. |
| Raises: |
| ValueError: If `shape_x` and `shape_y` are not broadcast-compatible. |
| """ |
| if not isinstance(shape_x, RaggedTensorDynamicShape): |
| raise TypeError('shape_x must be a RaggedTensorDynamicShape') |
| if not isinstance(shape_y, RaggedTensorDynamicShape): |
| raise TypeError('shape_y must be a RaggedTensorDynamicShape') |
| |
| # Broadcast both shapes to have the same rank. |
| if shape_x.rank is None or shape_y.rank is None: |
| raise ValueError('Unable to broadcast: unknown rank') |
| broadcast_rank = max(shape_x.rank, shape_y.rank) |
| shape_x = shape_x.broadcast_to_rank(broadcast_rank) |
| shape_y = shape_y.broadcast_to_rank(broadcast_rank) |
| |
| # Broadcast dimensions one at a time, starting from the outermost dimension. |
| for axis in range(broadcast_rank): |
| shape_x = shape_x.broadcast_dimension(axis, shape_y.dimension_size(axis)) |
| shape_y = shape_y.broadcast_dimension(axis, shape_x.dimension_size(axis)) |
| |
| return shape_x |
| |
| |
| def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True): |
| """Broadcasts a potentially ragged tensor to a ragged shape. |
| |
| Tiles `rt_input` as necessary to match the given shape. |
| |
| Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`. |
| |
| Args: |
| rt_input: The potentially ragged tensor to broadcast. |
| shape: A `RaggedTensorDynamicShape` |
| broadcast_inner_dimensions: If false, then inner dimensions will not be |
| tiled. |
| |
| Returns: |
| A potentially ragged tensor whose values are taken from |
| `rt_input`, and whose shape matches `shape`. |
| """ |
| if not isinstance(shape, RaggedTensorDynamicShape): |
| raise TypeError('shape must be a RaggedTensorDynamicShape') |
| rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input) |
| |
| # Broadcasting to a uniform shape. |
| if shape.num_partitioned_dimensions == 0: |
| return _broadcast_to_uniform_shape(rt_input, shape, |
| broadcast_inner_dimensions) |
| else: |
| return _broadcast_to_ragged_shape(rt_input, shape, |
| broadcast_inner_dimensions) |
| |
| |
| def _broadcast_to_uniform_shape(rt_input, shape, broadcast_inner_dimensions): |
| """Broadcasts rt_input to the uniform shape `shape`.""" |
| if isinstance(rt_input, ragged_tensor.RaggedTensor): |
| raise ValueError('Incompatible with shape: ragged rank mismatch') |
| if broadcast_inner_dimensions: |
| return array_ops.broadcast_to(rt_input, shape.inner_dim_sizes) |
| else: |
| return rt_input |
| |
| |
| def _broadcast_to_ragged_shape(rt_input, dst_shape, broadcast_inner_dimensions): |
| """Broadcasts rt_input to the ragged shape `dst_shape`.""" |
| # Check that rt_input and dst_shape have the same row_splits dtype. |
| if (isinstance(rt_input, ragged_tensor.RaggedTensor) and |
| rt_input.row_splits.dtype != dst_shape.dim_size_dtype): |
| if not ragged_config.auto_cast_partition_dtype(): |
| raise ValueError('rt_input and dst_shape have different row_split ' |
| 'dtypes; use RaggedTensor.with_row_splits_dtype() or ' |
| 'RaggedTensorDynamicShape.with_dim_size_dtype() to ' |
| 'convert to a compatible dtype.') |
| rt_input = rt_input.with_row_splits_dtype(dtypes.int64) |
| dst_shape = dst_shape.with_dim_size_dtype(dtypes.int64) |
| |
| # dst_shape's rank and ragged_rank must be greater than or equal to rt_input's |
| if rt_input.shape.ndims is None or dst_shape.rank is None: |
| raise ValueError('Unable to broadcast: unknown rank') |
| if rt_input.shape.ndims > dst_shape.rank: |
| raise ValueError('Incompatible with shape: rank mismatch') |
| if (isinstance(rt_input, ragged_tensor.RaggedTensor) and |
| rt_input.ragged_rank >= dst_shape.num_partitioned_dimensions): |
| raise ValueError('Incompatible with shape: ragged rank mismatch') |
| |
| src_shape = RaggedTensorDynamicShape.from_tensor(rt_input) |
| src_shape = src_shape.broadcast_to_rank(dst_shape.rank) |
| |
| # Add dimensions to rt_input so its rank and ragged_rank matches dst_shape. |
| if dst_shape.rank > rt_input.shape.ndims: |
| if rt_input.shape.ndims < dst_shape.num_inner_dimensions + 1: |
| rt_input = array_ops.reshape( |
| rt_input, array_ops.concat([[-1], dst_shape.inner_dim_sizes], axis=0)) |
| for _ in range(dst_shape.rank - rt_input.shape.ndims): |
| if ragged_tensor.is_ragged(rt_input): |
| nrows = rt_input.nrows() |
| else: |
| nrows = array_ops.shape(rt_input, |
| out_type=dst_shape.dim_size_dtype)[0] |
| rt_input = ragged_tensor.RaggedTensor.from_row_lengths(rt_input, [nrows], |
| validate=False) |
| |
| # Add ragged dimensions to match dst_shape. |
| if ragged_tensor.is_ragged(rt_input): |
| inner_rank_diff = ( |
| rt_input.flat_values.shape.ndims - 1 - dst_shape.num_inner_dimensions) |
| if inner_rank_diff > 0: |
| rt_input = rt_input.with_flat_values( |
| ragged_tensor.RaggedTensor.from_tensor( |
| rt_input.flat_values, ragged_rank=inner_rank_diff, |
| row_splits_dtype=dst_shape.dim_size_dtype)) |
| else: |
| rt_input = ragged_tensor.RaggedTensor.from_tensor( |
| rt_input, ragged_rank=dst_shape.num_partitioned_dimensions - 1, |
| row_splits_dtype=dst_shape.dim_size_dtype) |
| |
| # Do broadcasting for any dimensions that will remain uniform. We can do |
| # these all at once, since they're independent of one another. |
| multiples = [1] * dst_shape.rank |
| for axis in range(dst_shape.num_partitioned_dimensions): |
| if not src_shape.is_ragged(axis) and not dst_shape.is_ragged(axis): |
| src_size = src_shape.dimension_size(axis) |
| dst_size = dst_shape.dimension_size(axis) |
| if ((tensor_util.constant_value(src_size) in (1, None)) and |
| (tensor_util.constant_value(dst_size) != 1)): |
| multiples[axis] = array_ops.where( |
| math_ops.equal(src_size, 1), dst_size, 1) |
| if not all(isinstance(v, int) and v == 1 for v in multiples): |
| multiples = array_ops.stack(multiples, axis=0) |
| rt_input = ragged_array_ops.tile(rt_input, multiples) |
| |
| if broadcast_inner_dimensions: |
| new_shape = array_ops.broadcast_dynamic_shape( |
| array_ops.shape( |
| rt_input.flat_values, out_type=dst_shape.dim_size_dtype), |
| array_ops.concat([[1], dst_shape.inner_dim_sizes], axis=0)) |
| rt_input = rt_input.with_flat_values( |
| array_ops.broadcast_to(rt_input.flat_values, new_shape)) |
| |
| # Do broadcasting for dimensions that become ragged. We must do these from |
| # outermost to innermost. |
| for axis in range(dst_shape.num_partitioned_dimensions): |
| if not src_shape.is_ragged(axis) and dst_shape.is_ragged(axis): |
| dst_size = dst_shape.dimension_size(axis) |
| rt_input = _ragged_tile_axis(rt_input, axis, dst_size, |
| dst_shape.dim_size_dtype) |
| |
| return rt_input |
| |
| |
| def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype): |
| """Tile a dimension of a RaggedTensor to match a ragged shape.""" |
| assert axis > 0 # Outermost dimension may not be ragged. |
| |
| if not ragged_tensor.is_ragged(rt_input): |
| rt_input = ragged_tensor.RaggedTensor.from_tensor( |
| rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype) |
| |
| if axis > 1: |
| return rt_input.with_values( |
| _ragged_tile_axis(rt_input.values, axis - 1, repeats, |
| row_splits_dtype)) |
| else: |
| src_row_splits = rt_input.nested_row_splits |
| src_row_lengths = rt_input.nested_row_lengths() |
| splits = src_row_splits[0] |
| |
| dst_row_lengths = [repeats] |
| for i in range(1, len(src_row_lengths)): |
| dst_row_lengths.append( |
| ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats)) |
| splits = array_ops.gather(src_row_splits[i], splits) |
| dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits, |
| repeats) |
| return ragged_tensor.RaggedTensor.from_nested_row_lengths( |
| dst_values, dst_row_lengths, validate=False) |