Support int16 index in TF StridedSlice

PiperOrigin-RevId: 424147336
Change-Id: I14cf188c52e68c10d977786e84ebc9f5cd3766fc
diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
index 2cdb128..10e95b9 100644
--- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
+++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td
@@ -16985,15 +16985,15 @@
 
   let arguments = (ins
     TF_Tensor:$input,
-    Arg<TF_I32OrI64Tensor, [{`begin[k]` specifies the offset into the `k`th range specification.
+    Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`begin[k]` specifies the offset into the `k`th range specification.
 The exact dimension this corresponds to will be determined by context.
 Out-of-bounds values will be silently clamped. If the `k`th bit of
 `begin_mask` then `begin[k]` is ignored and the full range of the
 appropriate dimension is used instead. Negative values causes indexing
 to start from the highest element e.g. If `foo==[1,2,3]` then `foo[-1]==3`.}]>:$begin,
-    Arg<TF_I32OrI64Tensor, [{`end[i]` is like `begin` with the exception that `end_mask` is
+    Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`end[i]` is like `begin` with the exception that `end_mask` is
 used to determine full ranges.}]>:$end,
-    Arg<TF_I32OrI64Tensor, [{`strides[i]` specifies the increment in the `i`th specification
+    Arg<TensorOf<[TF_Int16, TF_Int32, TF_Int64]>, [{`strides[i]` specifies the increment in the `i`th specification
 after extracting a given element. Negative indices will reverse
 the original order. Out or range values are
 clamped to `[0,dim[i]) if slice[i]>0` or `[-1,dim[i]-1] if slice[i] < 0`}]>:$strides,
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 8c76caf..e6c7d99 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -1728,7 +1728,7 @@
     .Input("strides: Index")
     .Output("output: T")
     .Attr("T: type")
-    .Attr("Index: {int32, int64}")
+    .Attr("Index: {int16, int32, int64}")
     .Attr("begin_mask: int = 0")
     .Attr("end_mask: int = 0")
     .Attr("ellipsis_mask: int = 0")
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index ca1b90a..5342b26 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -261,8 +261,10 @@
     TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
   } else if (strides_tensor.dtype() == DT_INT64) {
     TF_RETURN_IF_ERROR(BuildDenseSpec<int64_t>(sparse_spec, &dense_spec));
+  } else if (strides_tensor.dtype() == DT_INT16) {
+    TF_RETURN_IF_ERROR(BuildDenseSpec<int16_t>(sparse_spec, &dense_spec));
   } else {
-    LOG(FATAL) << "begin must be either int32 or int64";
+    LOG(FATAL) << "begin must be either int16, int32 or int64";
   }
 
   // Step 3: Make implicit ranges (non-zero begin_masks and end_masks) explicit
diff --git a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
index 521692b..fa629af 100644
--- a/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops/array_ops_test.py
@@ -839,6 +839,30 @@
       _ = checker2[mask]
       _ = checker2[ops.convert_to_tensor(mask)]
 
+  def test_int16_indices(self):
+
+    def _int16(i):
+      return constant_op.constant(i, dtype=dtypes.int16)
+
+    for tensor_type in STRIDED_SLICE_TYPES:
+      with self.subTest(tensor_type=tensor_type, use_gpu=True):
+        checker = StridedSliceChecker(
+            self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type)
+
+        with self.assertRaises(Exception):
+          _ = checker[::_int16(1), ::_int16(5), ::_int16(2)]
+
+        with self.assertRaises(Exception):
+          _ = checker[_int16(1)::1, :, :]
+
+        with self.assertRaises(Exception):
+          _ = checker[:, _int16(1):_int16(5):-1, :]
+
+        _ = checker[_int16(1):_int16(5):_int16(2), 1:2, :]
+        _ = checker[_int16(1):_int16(5):_int16(2),
+                    _int16(0):_int16(5):_int16(1),
+                    _int16(0):_int16(4):_int16(2)]
+
 
 class StridedSliceShapeTest(test_util.TensorFlowTestCase):
   """Test the shape inference of StridedSliceShapes."""
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index b344f0c..90ba02a 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -887,8 +887,8 @@
     "tf.newaxis (`None`) and scalar tf.int32/tf.int64 tensors are valid "
     "indices")
 
-_SUPPORTED_SLICE_DTYPES = (dtypes.int32, dtypes.int32_ref, dtypes.int64,
-                           dtypes.int64_ref)
+_SUPPORTED_SLICE_DTYPES = (dtypes.int16, dtypes.int32, dtypes.int32_ref,
+                           dtypes.int64, dtypes.int64_ref)
 
 
 def _check_index(idx):
@@ -1045,6 +1045,8 @@
     if begin:
       packed_begin, packed_end, packed_strides = (stack(begin), stack(end),
                                                   stack(strides))
+      # TODO(mdan): Instead of implicitly casting, it's better to enforce the
+      # same dtypes.
       if (packed_begin.dtype == dtypes.int64 or
           packed_end.dtype == dtypes.int64 or
           packed_strides.dtype == dtypes.int64):
@@ -1054,6 +1056,15 @@
           packed_end = gen_math_ops.cast(packed_end, dtypes.int64)
         if packed_strides.dtype != dtypes.int64:
           packed_strides = gen_math_ops.cast(packed_strides, dtypes.int64)
+      elif (packed_begin.dtype == dtypes.int16 and
+            packed_end.dtype == dtypes.int16 and
+            packed_strides.dtype == dtypes.int16):
+        if packed_begin.dtype != dtypes.int16:
+          packed_begin = gen_math_ops.cast(packed_begin, dtypes.int16)
+        if packed_end.dtype != dtypes.int16:
+          packed_end = gen_math_ops.cast(packed_end, dtypes.int16)
+        if packed_strides.dtype != dtypes.int16:
+          packed_strides = gen_math_ops.cast(packed_strides, dtypes.int16)
     else:
       var_empty = constant([], dtype=dtypes.int32)
       packed_begin = packed_end = packed_strides = var_empty