Support int16 index in TF StridedSlice
PiperOrigin-RevId: 424147336
Change-Id: I14cf188c52e68c10d977786e84ebc9f5cd3766fc
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 2cdb128..10e95b9 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -16985,15 +16985,15 @@
let arguments = (ins
TF_Tensor:$input,
- Arg<TF_I32OrI64Tensor, [{`begin[k]` specifies the offset into the `k`th range specification.
+ Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`begin[k]` specifies the offset into the `k`th range specification.
The exact dimension this corresponds to will be determined by context.
Out-of-bounds values will be silently clamped. If the `k`th bit of
`begin_mask` then `begin[k]` is ignored and the full range of the
appropriate dimension is used instead. Negative values causes indexing
to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`.}]>:$begin,
- Arg<TF_I32OrI64Tensor, [{`end[i]` is like `begin` with the exception that `end_mask` is
+ Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`end[i]` is like `begin` with the exception that `end_mask` is
used to determine full ranges.}]>:$end,
- Arg<TF_I32OrI64Tensor, [{`strides[i]` specifies the increment in the `i`th specification
+ Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`strides[i]` specifies the increment in the `i`th specification
after extracting a given element. Negative indices will reverse
the original order. Out or range values are
clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0`}]>:$strides,
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 8c76caf..e6c7d99 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1728,7 +1728,7 @@
.Input("strides: Index")
.Output("output: T")
.Attr("T: type")
- .Attr("Index: {int32, int64}")
+ .Attr("Index: {int16, int32, int64}")
.Attr("begin_mask: int = 0")
.Attr("end_mask: int = 0")
.Attr("ellipsis_mask: int = 0")
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index ca1b90a..5342b26 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -261,8 +261,10 @@
TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
} else if (strides_tensor.dtype() == DT_INT64) {
TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec));
+ } else if (strides_tensor.dtype() == DT_INT16) {
+ TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec));
} else {
- LOG(FATAL) << "begin must be either int32 or int64";
+ LOG(FATAL) << "begin must be either int16, int32 or int64";
}
// Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
index 521692b..fa629af 100644
--- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
@@ -839,6 +839,30 @@
_ = checker2[mask]
_ = checker2[ops.convert_to_tensor(mask)]
+ def test_int16_indices(self):
+
+ def _int16(i):
+ return constant_op.constant(i, dtype=dtypes.int16)
+
+ for tensor_type in STRIDED_SLICE_TYPES:
+ with self.subTest(tensor_type=tensor_type, use_gpu=True):
+ checker = StridedSliceChecker(
+ self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type)
+
+ with self.assertRaises(Exception):
+ _ = checker[::_int16(1), ::_int16(5), ::_int16(2)]
+
+ with self.assertRaises(Exception):
+ _ = checker[_int16(1)::1, :, :]
+
+ with self.assertRaises(Exception):
+ _ = checker[:, _int16(1):_int16(5):-1, :]
+
+ _ = checker[_int16(1):_int16(5):_int16(2), 1:2, :]
+ _ = checker[_int16(1):_int16(5):_int16(2),
+ _int16(0):_int16(5):_int16(1),
+ _int16(0):_int16(4):_int16(2)]
+
class StridedSliceShapeTest(test_util.TensorFlowTestCase):
"""Test the shape inference of StridedSliceShapes."""
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index b344f0c..90ba02a 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -887,8 +887,8 @@
"tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
"indices")
-_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
- dtypes.int64_ref)
+_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref,
+ dtypes.int64, dtypes.int64_ref)
def _check_index(idx):
@@ -1045,6 +1045,8 @@
if begin:
packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
stack(strides))
+ # TODO(mdan): Instead of implicitly casting, it's better to enforce the
+ # same dtypes.
if (packed_begin.dtype == dtypes.int64 or
packed_end.dtype == dtypes.int64 or
packed_strides.dtype == dtypes.int64):
@@ -1054,6 +1056,15 @@
packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
if packed_strides.dtype != dtypes.int64:
packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
+ elif (packed_begin.dtype == dtypes.int16 and
+ packed_end.dtype == dtypes.int16 and
+ packed_strides.dtype == dtypes.int16):
+ if packed_begin.dtype != dtypes.int16:
+ packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16)
+ if packed_end.dtype != dtypes.int16:
+ packed_end = gen_math_ops.cast(packed_end, dtypes.int16)
+ if packed_strides.dtype != dtypes.int16:
+ packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16)
else:
var_empty = constant([], dtype=dtypes.int32)
packed_begin = packed_end = packed_strides = var_empty