`tf.data.experimental.dense_to_ragged_batch` to output variable ragged rank.
This means:
* For input tensors with shapes known statically, the output will still be a `tf.Tensor`.
* For input tensors with a `None` shape in the i-th axis, the output will be a `tf.RaggedTensor` with ragged rank of `i`, where it is the higher axis with `None` shape.
PiperOrigin-RevId: 303704654
Change-Id: Id4d232688d7a5e4ee0dbca6093743d27432e5de8
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
index 8fb92ec..e66f401 100644
--- a/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_ragged_batch_test.py
@@ -45,11 +45,27 @@
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x], x))
-def _make_matrix_ds(nrows):
+def _make_matrix_ds1(nrows):
"""Create a test dataset with matrix elements (of varying size)."""
return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([x, 2], x))
+def _make_matrix_ds2(nrows):
+ """Create a test dataset with matrix elements (of varying size)."""
+ return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, x], x))
+
+
+def _make_matrix_ds_fully_defined(nrows):
+ """Create a test dataset with matrix elements (of varying size)."""
+ return _make_scalar_ds(nrows).map(lambda x: array_ops.fill([2, 3], x))
+
+
+def _make_5dtensor_ds(nrows):
+ """Create a test dataset with matrix elements (of varying size)."""
+ return _make_scalar_ds(nrows).map(
+ lambda x: array_ops.fill([2, x, 3, 2*x, 4], x))
+
+
def _make_ragged_ds(nrows):
"""Create a test dataset with RaggedTensor elements (of varying size)."""
values = [[[i] * (i % 3) for i in range(j)] * (j % 3) for j in range(nrows)]
@@ -64,6 +80,8 @@
'shape=[]': ops.convert_to_tensor(x),
'shape=[x]': math_ops.range(x),
'shape=[x, 2]': array_ops.fill([x, 2], x),
+ 'shape=[2, x]': array_ops.fill([2, x], x),
+ 'shape=[2, x, 3, 2x, 4]': array_ops.fill([2, x, 3, 2*x, 4], x)
}
return _make_scalar_ds(nrows).map(transform)
@@ -88,8 +106,9 @@
test_base.default_test_combinations(),
combinations.combine(
make_dataset=[
- _make_scalar_ds, _make_vector_ds, _make_matrix_ds,
- _make_ragged_ds, _make_dict_ds, _make_tuple_ds,
+ _make_scalar_ds, _make_vector_ds, _make_matrix_ds1,
+ _make_matrix_ds2, _make_ragged_ds, _make_5dtensor_ds,
+ _make_dict_ds, _make_tuple_ds, _make_matrix_ds_fully_defined,
],
nrows=[0, 20, 23],
batch_size=[4],
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
index a2d3e00..398ec98 100644
--- a/tensorflow/python/data/experimental/ops/batching.py
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -50,11 +50,23 @@
batch from being produced.
Unlike `tf.data.Dataset.batch`, the input elements to be batched may have
- different shapes, and each batch will be encoded as a `tf.RaggedTensor`.
+ different shapes:
+
+ * If an input element is a `tf.Tensor` whose static `tf.TensorShape` is
+ fully defined, then it is batched as normal.
+ * If an input element is a `tf.Tensor` whose static `tf.TensorShape` contains
+ one or more axes with unknown size (i.e., `shape[i]=None`), then the output
+ will contain a `tf.RaggedTensor` that is ragged up to any of such
+ dimensions.
+ * If an input element is a `tf.RaggedTensor` or any other type, then it is
+ batched as normal.
+
Example:
>>> dataset = tf.data.Dataset.from_tensor_slices(np.arange(6))
>>> dataset = dataset.map(lambda x: tf.range(x))
+ >>> dataset.element_spec.shape
+ TensorShape([None])
>>> dataset = dataset.apply(
... tf.data.experimental.dense_to_ragged_batch(batch_size=2))
>>> for batch in dataset:
@@ -385,32 +397,44 @@
any new ragged tensors. Existing `tf.RaggedTensor` elements do *not*
have their row_splits dtype changed.
"""
-
# Replace each TensorSpec in the input dataset's structure with a
# corresponding RaggedTensorSpec.
def to_ragged_spec(spec):
- if isinstance(spec, tensor_spec.TensorSpec) and spec.shape.ndims != 0:
+ """Returns the new spec based on RaggedTensors."""
+ if (not isinstance(spec, tensor_spec.TensorSpec) or
+ spec.shape.rank is None or
+ spec.shape.is_fully_defined()):
+ return spec
+ else:
+ ragged_rank = max([
+ axis for (axis, size) in enumerate(spec.shape.as_list())
+ if size is None
+ ])
return ragged_tensor.RaggedTensorSpec(
shape=spec.shape,
dtype=spec.dtype,
- ragged_rank=0,
+ ragged_rank=ragged_rank,
row_splits_dtype=row_splits_dtype)
- else:
- return spec
self._structure = nest.map_structure(to_ragged_spec,
input_dataset.element_spec)
# Replace each tf.Tensor value in the input dataset with a variant-encoded
- # RaggedTensor. Since we're updating the corresponding structure to be
+ # RaggedTensor. Since we're updating the corresponding structure to be
# a RaggedTensorSpec, this variant-encoded tensor will be decoded with
# RaggedTensorSpec._from_tensor_list.
def to_ragged_variant(value):
- if isinstance(value, ops.Tensor) and value.shape.ndims != 0:
- spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
- return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
- else:
+ """Re-encode Tensors as RaggedTensors."""
+ if (not isinstance(value, ops.Tensor) or
+ value.shape.rank is None or
+ value.shape.is_fully_defined()):
return value
+ else:
+ spec = to_ragged_spec(tensor_spec.TensorSpec.from_tensor(value))
+ if spec._ragged_rank > 0: # pylint: disable=protected-access
+ value = ragged_tensor.RaggedTensor.from_tensor(
+ value, ragged_rank=spec._ragged_rank) # pylint: disable=protected-access
+ return spec._to_tensor_list(value)[0] # pylint: disable=protected-access
# Tuples are automatically unpacked by `dataset.map` so we repack them.
if dataset_ops._should_unpack_args(input_dataset.element_spec): # pylint: disable=protected-access