Delete TOSA gather op legalizations temporarily.
Updating op signature in LLVM, new lowerings to be restored afterward.
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
Change-Id: I0f87ef6d4aea4d28305b79bc437a743eeb2b506c
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
index 2c3939d..fdb64eb 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir
@@ -705,18 +705,6 @@
// -----
-// CHECK-LABEL: test_gather
-// CHECK: tosa.const
-// CHECK: tosa.gather
-func @test_gather(%arg0: tensor<13x21x3xi32>) -> tensor<26x21x3xi32> {
- %2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
- %3 = "tf.Const"() {value = dense<[2, 2, 7, 6, 6, 1, 5, 4, 2, 11, 10, 11, 7, 7, 5, 3, 12, 7, 11, 0, 9, 5, 4, 12, 1, 9]> : tensor<26xi32>} : () -> tensor<26xi32>
- %4 = "tf.GatherV2"(%arg0, %3, %2) {batch_dims = 0 : i64} : (tensor<13x21x3xi32>, tensor<26xi32>, tensor<i32>) -> tensor<26x21x3xi32>
- return %4 : tensor<26x21x3xi32>
-}
-
-// -----
-
// CHECK-LABEL: test_space_to_batch
// CHECK-DAG: "tosa.const"() {value = dense<{{\[}}[0, 0], [0, 1], [0, 0]]>
// CHECK-DAG: "tosa.const"() {value = dense<[2, 0, 1, 3]>
diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
index 535f1dd..a8cea22 100644
--- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
+++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir
@@ -81,13 +81,6 @@
return %0 : tensor<13x21x3xf32>
}
-// CHECK-LABEL: test_gather
-// CHECK: tosa.gather
-func @test_gather(%arg0: tensor<100x25xf32>, %arg1: tensor<1x20xi32>) -> tensor<20x25x3xf32> {
- %0 = "tfl.gather"(%arg0, %arg1) {axis = 0 : i32} : (tensor<100x25xf32>, tensor<1x20xi32>) -> tensor<20x25x3xf32>
- return %0 : tensor<20x25x3xf32>
-}
-
// -----
// CHECK-LABEL: test_sub
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
index 7ea350d..40c0219 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -117,8 +117,6 @@
DECL_CONVERT_OP(Pad);
DECL_CONVERT_OP(ResizeBilinear);
DECL_CONVERT_OP(ResizeNearestNeighbor);
-DECL_CONVERT_OP(Gather);
-DECL_CONVERT_OP(GatherV2);
DECL_CONVERT_OP(SelectV2);
DECL_CONVERT_OP(SpaceToDepth);
DECL_CONVERT_OP(DepthToSpace);
@@ -1719,55 +1717,6 @@
return success();
}
-LogicalResult ConvertTFGatherOp::matchAndRewrite(
- Operation* op, PatternRewriter& rewriter) const {
- auto tf_gather_op = cast<TF::GatherOp>(op);
-
- RankedTensorType output_type =
- tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
- if (!output_type) return failure();
-
- IntegerAttr axis_attr = rewriter.getI32IntegerAttr(0);
-
- // TODO: batchdim_attr handling to be implemented with a revised
- // defintion of the TOSA operator.
- rewriter.replaceOpWithNewOp<tosa::GatherOp>(
- op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
- axis_attr);
-
- return success();
-}
-
-LogicalResult ConvertTFGatherV2Op::matchAndRewrite(
- Operation* op, PatternRewriter& rewriter) const {
- auto tf_gather_op = cast<TF::GatherV2Op>(op);
-
- RankedTensorType output_type =
- tf_gather_op.getResult().getType().dyn_cast<RankedTensorType>();
- if (!output_type) return failure();
-
- // Axis is a tensor in TF. Convert to I64Attr for TOSA
- ElementsAttr axis_elem;
- if (!matchPattern(tf_gather_op.axis(), m_Constant(&axis_elem)))
- return failure();
- assert(axis_elem.getType().getRank() == 0 && "expected 0D tensor");
-
- IntegerAttr batchdim_attr;
- {
- auto tmpAttr = tf_gather_op.batch_dimsAttr();
- if (!tmpAttr) tmpAttr = rewriter.getI64IntegerAttr(0);
- batchdim_attr = tmpAttr;
- }
-
- // TODO: batchdim_attr handling to be implemented with a revised
- // defintion of the TOSA operator.
- rewriter.replaceOpWithNewOp<tosa::GatherOp>(
- op, output_type, tf_gather_op.params(), tf_gather_op.indices(),
- rewriter.getI32IntegerAttr(axis_elem.getValue<IntegerAttr>({}).getInt()));
-
- return success();
-}
-
LogicalResult ConvertTFSelectV2Op::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_sel_op = cast<TF::SelectV2Op>(op);
@@ -2097,8 +2046,6 @@
patterns.insert<ConvertTFPadOp>(ctx);
patterns.insert<ConvertTFResizeBilinearOp>(ctx);
patterns.insert<ConvertTFResizeNearestNeighborOp>(ctx);
- patterns.insert<ConvertTFGatherOp>(ctx);
- patterns.insert<ConvertTFGatherV2Op>(ctx);
patterns.insert<ConvertTFSelectV2Op>(ctx);
patterns.insert<ConvertTFSpaceToDepthOp>(ctx);
patterns.insert<ConvertTFDepthToSpaceOp>(ctx);
diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
index fd27be5..4f252ff 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
+++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td
@@ -46,10 +46,3 @@
// Ternary ops patterns.
//===----------------------------------------------------------------------===//
-def : Pat<(TFL_GatherOp $params,
- $indices,
- $axis,
- ConstantAttr<I32Attr, "0">:$batch_dims),
- (Tosa_GatherOp $indices,
- $params,
- $axis)>;