added ones_like to structured_array_ops.
PiperOrigin-RevId: 371349598
Change-Id: If8d2b1bcb9777ca9183294d88f2128267aed53c0
diff --git a/tensorflow/python/ops/structured/structured_array_ops.py b/tensorflow/python/ops/structured/structured_array_ops.py
index a92d979..805669b 100644
--- a/tensorflow/python/ops/structured/structured_array_ops.py
+++ b/tensorflow/python/ops/structured/structured_array_ops.py
@@ -234,6 +234,51 @@
return result
+# pylint: disable=protected-access
+@dispatch.dispatch_for_types(array_ops.ones_like, StructuredTensor)
+def ones_like(tensor, dtype=None, name=None, optimize=True):
+ """Implementation of zeros_like for StructuredTensor for TF v1."""
+ del optimize
+ return ones_like_v2(tensor, dtype=dtype, name=name)
+
+
+# pylint: disable=protected-access
+@dispatch.dispatch_for_types(array_ops.ones_like_v2, StructuredTensor)
+def ones_like_v2(input, dtype=None, name=None): # pylint: disable=redefined-builtin
+ """Replace every object with a zero.
+
+ Example:
+ >>> st = StructuredTensor.from_pyval([{"x":[3]}, {"x":[4,5]}])
+ >>> tf.ones_like(st)
+ <tf.Tensor: shape=(2,), dtype=int32, numpy=array([1.0, 1.0], dtype=float32)>
+ >>> st = StructuredTensor.from_pyval([[{"x":[3]}], [{"x":[4,5]}, {"x":[]}]])
+ >>> tf.ones_like(st, dtype=tf.int32)
+ <tf.RaggedTensor [[1], [1, 1]]>
+
+ Args:
+ input: a structured tensor.
+ dtype: the dtype of the resulting zeros. (default is tf.float32)
+ name: a name for the op.
+ Returns:
+ a tensor of zeros of the same shape.
+ """
+ if dtype is None:
+ dtype = dtypes.float32
+ with ops.name_scope(name, 'ones_like', [input]) as name:
+ if not input._row_partitions:
+ if input._nrows is not None:
+ return array_ops.ones([input._nrows], dtype) # vector.
+ else:
+ return array_ops.ones([], dtype) # scalar.
+ # 2D and up.
+ last_row_partition = input._row_partitions[-1]
+
+ result = ragged_tensor.RaggedTensor._from_nested_row_partitions(
+ array_ops.ones(last_row_partition.nvals(), dtype=dtype),
+ input._row_partitions)
+ return result
+
+
def _expand_dims_impl(st, axis, name=None): # pylint: disable=redefined-builtin
"""Creates a StructuredTensor with a length 1 axis inserted at index `axis`.
diff --git a/tensorflow/python/ops/structured/structured_array_ops_test.py b/tensorflow/python/ops/structured/structured_array_ops_test.py
index 17d3575..98d3590 100644
--- a/tensorflow/python/ops/structured/structured_array_ops_test.py
+++ b/tensorflow/python/ops/structured/structured_array_ops_test.py
@@ -338,6 +338,156 @@
@parameterized.named_parameters([
dict(
+ testcase_name="scalar_int32",
+ row_partitions=None,
+ shape=(),
+ dtype=dtypes.int32,
+ expected=1),
+ dict(
+ testcase_name="scalar_bool",
+ row_partitions=None,
+ shape=(),
+ dtype=dtypes.bool,
+ expected=True),
+ dict(
+ testcase_name="scalar_int64",
+ row_partitions=None,
+ shape=(),
+ dtype=dtypes.int64,
+ expected=1),
+ dict(
+ testcase_name="scalar_float32",
+ row_partitions=None,
+ shape=(),
+ dtype=dtypes.float32,
+ expected=1.0),
+ dict(
+ testcase_name="list_0_int32",
+ row_partitions=None,
+ shape=(0),
+ dtype=dtypes.int32,
+ expected=[]),
+ dict(
+ testcase_name="list_0_0_int32",
+ row_partitions=None,
+ shape=(0, 0),
+ dtype=dtypes.int32,
+ expected=[]),
+ dict(
+ testcase_name="list_int32",
+ row_partitions=None,
+ shape=(7),
+ dtype=dtypes.int32,
+ expected=[1, 1, 1, 1, 1, 1, 1]),
+ dict(
+ testcase_name="list_int64",
+ row_partitions=None,
+ shape=(7),
+ dtype=dtypes.int64,
+ expected=[1, 1, 1, 1, 1, 1, 1]),
+ dict(
+ testcase_name="list_float32",
+ row_partitions=None,
+ shape=(7),
+ dtype=dtypes.float32,
+ expected=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
+ dict(
+ testcase_name="matrix_int32",
+ row_partitions=[[0, 3, 6]],
+ shape=(2, 3),
+ dtype=dtypes.int32,
+ expected=[[1, 1, 1], [1, 1, 1]]),
+ dict(
+ testcase_name="matrix_float64",
+ row_partitions=[[0, 3, 6]],
+ shape=(2, 3),
+ dtype=dtypes.float64,
+ expected=[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
+ dict(
+ testcase_name="tensor_int32",
+ row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
+ shape=(2, 3, 1),
+ dtype=dtypes.int32,
+ expected=[[[1], [1], [1]], [[1], [1], [1]]]),
+ dict(
+ testcase_name="tensor_float32",
+ row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
+ shape=(2, 3, 1),
+ dtype=dtypes.float32,
+ expected=[[[1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0]]]),
+ dict(
+ testcase_name="ragged_1_float32",
+ row_partitions=[[0, 3, 4]],
+ shape=(2, None),
+ dtype=dtypes.float32,
+ expected=[[1.0, 1.0, 1.0], [1.0]]),
+ dict(
+ testcase_name="ragged_2_float32",
+ row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]],
+ shape=(2, None, None),
+ dtype=dtypes.float32,
+ expected=[[[1.0, 1.0], [1.0], [1.0, 1.0]], [[1.0, 1.0]]]),
+ ]) # pyformat: disable
+ def testOnesLikeObject(self, row_partitions, shape, dtype, expected):
+ if row_partitions is not None:
+ row_partitions = [
+ row_partition.RowPartition.from_row_splits(r) for r in row_partitions
+ ]
+ st = StructuredTensor.from_fields({},
+ shape=shape,
+ row_partitions=row_partitions)
+ # NOTE: ones_like is very robust. There aren't arguments that
+ # should cause this operation to fail.
+ actual = array_ops.ones_like(st, dtype)
+ self.assertAllEqual(actual, expected)
+
+ actual2 = array_ops.ones_like_v2(st, dtype)
+ self.assertAllEqual(actual2, expected)
+
+ @parameterized.named_parameters([
+ dict(
+ testcase_name="list_empty_2_1",
+ values=[[{}, {}], [{}]],
+ dtype=dtypes.int32,
+ expected=[[1, 1], [1]]),
+ dict(
+ testcase_name="list_empty_2",
+ values=[{}, {}],
+ dtype=dtypes.int32,
+ expected=[1, 1]),
+ dict(
+ testcase_name="list_empty_1",
+ values=[{}],
+ dtype=dtypes.int32,
+ expected=[1]),
+ dict(
+ testcase_name="list_example_1",
+ values=[{"x": [3]}, {"x": [4, 5]}],
+ dtype=dtypes.int32,
+ expected=[1, 1]),
+ dict(
+ testcase_name="list_example_2",
+ values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
+ dtype=dtypes.float32,
+ expected=[[1.0], [1.0, 1.0]]),
+ dict(
+ testcase_name="list_example_2_None",
+ values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
+ dtype=None,
+ expected=[[1.0], [1.0, 1.0]]),
+ ]) # pyformat: disable
+ def testOnesLikeObjectAlt(self, values, dtype, expected):
+ st = StructuredTensor.from_pyval(values)
+ # NOTE: ones_like is very robust. There aren't arguments that
+ # should cause this operation to fail.
+ actual = array_ops.ones_like(st, dtype)
+ self.assertAllEqual(actual, expected)
+
+ actual2 = array_ops.ones_like_v2(st, dtype)
+ self.assertAllEqual(actual2, expected)
+
+ @parameterized.named_parameters([
+ dict(
testcase_name="list_empty",
values=[[{}], [{}]],
axis=0,