Delete TOSA gather op legalizations temporarily.

Updating op signature in LLVM, new lowerings to be restored afterward.

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
Change-Id: I0f87ef6d4aea4d28305b79bc437a743eeb2b506c
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
index 2c3939d..fdb64eb 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
@@ -705,18 +705,6 @@
 
 // -----
 
-// CHECK-LABEL: test_gather
-// CHECK: tosa.const
-// CHECK: tosa.gather
-func @test_gather(%arg0: tensor<13x21x3xi32>) -> tensor<26x21x3xi32> {
-  %2 = "tf.Const"()  {value = dense<0> : tensor<i32>}  : () -> tensor<i32>
-  %3 = "tf.Const"()  {value = dense<[2, 2, 7, 6, 6, 1, 5, 4, 2, 11, 10, 11, 7, 7, 5, 3, 12, 7, 11, 0, 9, 5, 4, 12, 1, 9]> : tensor<26xi32>}  : () -> tensor<26xi32>
-  %4 = "tf.GatherV2"(%arg0, %3, %2)  {batch_dims = 0 : i64}  : (tensor<13x21x3xi32>, tensor<26xi32>, tensor<i32>) -> tensor<26x21x3xi32>
-  return %4 : tensor<26x21x3xi32>
-}
-
-// -----
-
 // CHECK-LABEL: test_space_to_batch
 // CHECK-DAG: "tosa.const"() {value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]>
 // CHECK-DAG: "tosa.const"() {value = dense<[2, 0, 1, 3]>
diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
index 535f1dd..a8cea22 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
@@ -81,13 +81,6 @@
   return %0 : tensor<13x21x3xf32>
 }
 
-// CHECK-LABEL: test_gather
-// CHECK: tosa.gather
-func @test_gather(%arg0: tensor<100x25xf32>, %arg1: tensor<1x20xi32>) -> tensor<20x25x3xf32> {
-  %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<100x25xf32>, tensor<1x20xi32>) -> tensor<20x25x3xf32>
-  return %0 : tensor<20x25x3xf32>
-}
-
 // -----
 
 // CHECK-LABEL: test_sub
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
index 7ea350d..40c0219 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -117,8 +117,6 @@
 DECL_CONVERT_OP(Pad);
 DECL_CONVERT_OP(ResizeBilinear);
 DECL_CONVERT_OP(ResizeNearestNeighbor);
-DECL_CONVERT_OP(Gather);
-DECL_CONVERT_OP(GatherV2);
 DECL_CONVERT_OP(SelectV2);
 DECL_CONVERT_OP(SpaceToDepth);
 DECL_CONVERT_OP(DepthToSpace);
@@ -1719,55 +1717,6 @@
   return success();
 }
 
-LogicalResult ConvertTFGatherOp::matchAndRewrite(
-    Operation* op, PatternRewriter& rewriter) const {
-  auto tf_gather_op = cast<TF::GatherOp>(op);
-
-  RankedTensorType output_type =
-      tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
-  if (!output_type) return failure();
-
-  IntegerAttr axis_attr = rewriter.getI32IntegerAttr(0);
-
-  // TODO: batchdim_attr handling to be implemented with a revised
-  // defintion of the TOSA operator.
-  rewriter.replaceOpWithNewOp<tosa::GatherOp>(
-      op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
-      axis_attr);
-
-  return success();
-}
-
-LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
-    Operation* op, PatternRewriter& rewriter) const {
-  auto tf_gather_op = cast<TF::GatherV2Op>(op);
-
-  RankedTensorType output_type =
-      tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
-  if (!output_type) return failure();
-
-  // Axis is a tensor in TF. Convert to I64Attr for TOSA
-  ElementsAttr axis_elem;
-  if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
-    return failure();
-  assert(axis_elem.getType().getRank() == 0 && "expected 0D tensor");
-
-  IntegerAttr batchdim_attr;
-  {
-    auto tmpAttr = tf_gather_op.batch_dimsAttr();
-    if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
-    batchdim_attr = tmpAttr;
-  }
-
-  // TODO: batchdim_attr handling to be implemented with a revised
-  // defintion of the TOSA operator.
-  rewriter.replaceOpWithNewOp<tosa::GatherOp>(
-      op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
-      rewriter.getI32IntegerAttr(axis_elem.getValue<IntegerAttr>({}).getInt()));
-
-  return success();
-}
-
 LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
     Operation* op, PatternRewriter& rewriter) const {
   auto tf_sel_op = cast<TF::SelectV2Op>(op);
@@ -2097,8 +2046,6 @@
   patterns.insert<ConvertTFPadOp>(ctx);
   patterns.insert<ConvertTFResizeBilinearOp>(ctx);
   patterns.insert<ConvertTFResizeNearestNeighborOp>(ctx);
-  patterns.insert<ConvertTFGatherOp>(ctx);
-  patterns.insert<ConvertTFGatherV2Op>(ctx);
   patterns.insert<ConvertTFSelectV2Op>(ctx);
   patterns.insert<ConvertTFSpaceToDepthOp>(ctx);
   patterns.insert<ConvertTFDepthToSpaceOp>(ctx);
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index fd27be5..4f252ff 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -46,10 +46,3 @@
 // Ternary ops patterns.
 //===----------------------------------------------------------------------===//
 
-def : Pat<(TFL_GatherOp $params,
-                        $indices,
-                        $axis,
-                        ConstantAttr<I32Attr, "0">:$batch_dims),
-          (Tosa_GatherOp $indices,
-                           $params,
-                           $axis)>;