| // RUN: tf-tfrt-opt -split-input-file -tf-cpurt-pipeline %s | FileCheck %s |
| |
| // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> |
| |
| // CHECK-LABEL: @tanh_lower_and_fuse |
| // CHECK-SAME: %[[ARG:.*]]: memref<?x32xf32> |
| func @tanh_lower_and_fuse(%arg0: tensor<?x32xf32>) -> tensor<?x32xf32> { |
| // CHECK: %[[C0:.*]] = constant 0 : index |
| // CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] |
| // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DIM]]) : memref<?x32xf32> |
| |
| // CHECK: linalg.generic |
| // CHECK-SAME: indexing_maps = [#map, #map] |
| // CHECK-SAME: iterator_types = ["parallel", "parallel"] |
| // CHECK-SAME: ins(%[[ARG]] : memref<?x32xf32>) |
| // CHECK-SAME: outs(%[[MEMREF]] : memref<?x32xf32>) |
| // CHECK: tanh |
| // CHECK-NEXT: tanh |
| |
| // CHECK: return %[[MEMREF]] |
| %0 = "tf.Tanh"(%arg0): (tensor<?x32xf32>) -> tensor<?x32xf32> |
| %1 = "tf.Tanh"(%0): (tensor<?x32xf32>) -> tensor<?x32xf32> |
| return %1 : tensor<?x32xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #map = affine_map<(d0, d1) -> (d0, d1)> |
| |
| // CHECK-LABEL: @sigmoid_dynamic_dim |
| func @sigmoid_dynamic_dim(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> { |
| // CHECK: linalg.generic |
| // CHECK-SAME: indexing_maps = [#map, #map] |
| // CHECK-SAME: iterator_types = ["parallel", "parallel"] |
| %0 = "tf.Sigmoid"(%arg0) : (tensor<?x1xf32>) -> tensor<?x1xf32> |
| return %0 : tensor<?x1xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #map0 = affine_map<(d0) -> ()> |
| // CHECK: #map1 = affine_map<(d0) -> (d0)> |
| |
| // CHECK-LABEL: @add_scalar_with_vec |
| func @add_scalar_with_vec(%arg0: tensor<f32>, |
| %arg1: tensor<?xf32>) -> tensor<?xf32> { |
| // CHECK: linalg.generic |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.AddV2"(%arg0, %arg1): (tensor<f32>, tensor<?xf32>) -> tensor<?xf32> |
| return %0 : tensor<?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #map = affine_map<(d0) -> (d0)> |
| |
| // CHECK-LABEL: @add_vec_vec |
| func @add_vec_vec( |
| %arg0: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg1: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>} |
| ) -> tensor<?xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.AddV2"(%arg0, %arg1): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| return %0 : tensor<?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #map = affine_map<(d0) -> (d0)> |
| |
| // CHECK-LABEL: @add_vec_vec_vec |
| func @add_vec_vec_vec( |
| %arg0: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg1: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg2: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>} |
| ) -> tensor<?xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.AddV2"(%arg0, %arg1): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| %1 = "tf.AddV2"(%0, %arg2): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| return %1 : tensor<?xf32> |
| } |
| |
| // ----- |
| |
| // Verify that symbolic shape optimization can move all the broadcasts up, and |
| // progressively remove all shape constraints and replace mhlo broadcasts with |
| // linalg.generic operations that in the end all are fused together. |
| |
| // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, 0)> |
| // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d2)> |
| // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> |
| |
| // CHECK: compute_with_bcast |
| func @compute_with_bcast( |
| %arg0: tensor<1x?x1xf32> |
| {cpurt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, |
| %arg1: tensor<512xf32>, |
| %arg2: tensor<1x?x512xf32> |
| {cpurt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, |
| %arg3: tensor<1x?x1xf32> |
| {cpurt.symbolic_shape = dense<[1, -2, 1]> : tensor<3xi64>}, |
| %arg4: tensor<512xf32> |
| ) -> tensor<?x?x512xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK: addf |
| // CHECK-NEXT: math.rsqrt |
| // CHECK-NEXT: mulf |
| // CHECK-NEXT: mulf |
| // CHECK-NEXT: subf |
| // CHECK-NEXT: mulf |
| // CHECK-NEXT: addf |
| // CHECK-NEXT: linalg.yield |
| // CHECK-NOT: linalg.generic |
| %c = "tf.Const"() {value = dense<9.99999996E-13> |
| : tensor<f32>} : () -> tensor<f32> |
| %0 = "tf.AddV2"(%arg0, %c) |
| : (tensor<1x?x1xf32>, tensor<f32>) -> tensor<?x?x1xf32> |
| %1 = "tf.Rsqrt"(%0) |
| : (tensor<?x?x1xf32>) -> tensor<?x?x1xf32> |
| %2 = "tf.Mul"(%1, %arg1) |
| : (tensor<?x?x1xf32>, tensor<512xf32>) -> tensor<?x?x512xf32> |
| %3 = "tf.Mul"(%2, %arg2) |
| : (tensor<?x?x512xf32>, tensor<1x?x512xf32>) -> tensor<?x?x512xf32> |
| %4 = "tf.Mul"(%2, %arg3) |
| : (tensor<?x?x512xf32>, tensor<1x?x1xf32>) -> tensor<?x?x512xf32> |
| %5 = "tf.Sub"(%arg4, %4) |
| : (tensor<512xf32>, tensor<?x?x512xf32>) -> tensor<?x?x512xf32> |
| %6 = "tf.AddV2"(%3, %5) |
| : (tensor<?x?x512xf32>, tensor<?x?x512xf32>) -> tensor<?x?x512xf32> |
| return %6 : tensor<?x?x512xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: add_vec_vec_vec_vec |
| func @add_vec_vec_vec_vec( |
| %arg0: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg1: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg2: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>}, |
| %arg3: tensor<?xf32> {cpurt.symbolic_shape = dense<-2>: tensor<1xi64>} |
| ) -> tensor<?xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.AddV2"(%arg0, %arg1): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| %1 = "tf.AddV2"(%0, %arg2): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| %2 = "tf.AddV2"(%1, %arg3): (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
| return %2 : tensor<?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: add_vec_tensor_tensor |
| func @add_vec_tensor_tensor( |
| %arg0: tensor<512xf32>, |
| %arg1: tensor<1x?x512xf32> |
| {cpurt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>}, |
| %arg2: tensor<1x?x512xf32> |
| {cpurt.symbolic_shape = dense<[1, -2, 512]> : tensor<3xi64>} |
| ) -> tensor<1x?x512xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.AddV2"(%arg0, %arg1) |
| : (tensor<512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> |
| %1 = "tf.AddV2"(%arg2, %0) |
| : (tensor<1x?x512xf32>, tensor<1x?x512xf32>) -> tensor<1x?x512xf32> |
| return %1 : tensor<1x?x512xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tf_binary_with_bcast |
| func @tf_binary_with_bcast(%arg0: tensor<?x1xf32>, |
| %arg1: tensor<?x4xf32>) -> tensor<?x4xf32> { |
| // CHECK-NOT: shape. |
| // CHECK: %[[LHS:.*]] = memref.reinterpret_cast |
| // CHECK: %[[RHS:.*]] = memref.reinterpret_cast |
| // CHECK: linalg.generic {{.*}} ins(%[[LHS]], %[[RHS]] : |
| // CHECK: mulf |
| %0 = "tf.Mul"(%arg0, %arg1) |
| : (tensor<?x1xf32>, tensor<?x4xf32>) -> tensor<?x4xf32> |
| return %0 : tensor<?x4xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tf_binary_with_bcast_and_fusion |
| // CHECK-SAME: %[[ARG0:.*]]: memref<?x4xf32>, |
| // CHECK-SAME: %[[ARG1:.*]]: memref<4xf32>, |
| // CHECK-SAME: %[[ARG2:.*]]: memref<4xf32> |
| func @tf_binary_with_bcast_and_fusion(%arg0: tensor<?x4xf32>, |
| %arg1: tensor<4xf32>, |
| %arg2: tensor<4xf32>) -> tensor<?x4xf32> { |
| // CHECK: linalg.generic |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]], %[[ARG2]] : {{.*}}) |
| // CHECK: math.log1p |
| // CHECK-NEXT: subf |
| // CHECK-NEXT: mulf |
| // CHECK-NEXT: linalg.yield |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.Log1p"(%arg0) |
| : (tensor<?x4xf32>) -> tensor<?x4xf32> |
| %1 = "tf.Sub"(%0, %arg1) |
| : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> |
| %2 = "tf.Mul"(%1, %arg2) |
| : (tensor<?x4xf32>, tensor<4xf32>) -> tensor<?x4xf32> |
| return %2 : tensor<?x4xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #[[MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> |
| |
| // CHECK: tf_binary_with_bcast_symbolic_shapes |
| func @tf_binary_with_bcast_symbolic_shapes( |
| %arg0: tensor<?xf32> {cpurt.symbolic_shape = dense<[ -3]>: tensor<1xi64>}, |
| %arg1: tensor<?x?xf32> {cpurt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, |
| %arg2: tensor<?x?xf32> {cpurt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>}, |
| %arg3: tensor<?x?xf32> {cpurt.symbolic_shape = dense<[-2,-3]>: tensor<2xi64>} |
| ) -> tensor<?x?xf32> { |
| // CHECK-NOT: memref.reinterpret_cast |
| // CHECK: linalg.generic |
| // CHECK: log1p |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK: addf |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.Log1p"(%arg0) |
| : (tensor<?xf32>) -> tensor<?xf32> |
| %1 = "tf.AddV2"(%0, %arg1) |
| : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> |
| %2 = "tf.AddV2"(%1, %arg2) |
| : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> |
| %3 = "tf.AddV2"(%2, %arg3) |
| : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> |
| return %3 : tensor<?x?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @tf_lower_matmul |
| // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xf32>, |
| // CHECK-SAME: %[[ARG1:.*]]: memref<?x?xf32> |
| func @tf_lower_matmul(%arg0: tensor<?x?xf32>, |
| %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> { |
| // CHECK-NOT: linalg.copy |
| // CHECK: %[[DIM_M:.*]] = memref.dim %[[ARG0]], %c0 : memref<?x?xf32> |
| // CHECK: %[[DIM_N:.*]] = memref.dim %[[ARG1]], %c1 : memref<?x?xf32> |
| // CHECK-NOT: linalg.copy |
| // Tiling for register reuse. |
| // CHECK: scf.for %[[M:.*]] = %c0 to %[[DIM_M]] step %c[[MR:[0-9]+]] |
| // CHECK: scf.for %[[N:.*]] = %c0 to %[[DIM_N]] step %c[[NR:[0-9]+]] |
| // Unrolled tile matmul in vector dialect goes here. It is too large to match. |
| // CHECK: scf.yield %[[TILE:.*]] : vector<[[MR]]x[[NR]]xf32> |
| %0 = "tf.MatMul"(%arg0, %arg1) { transpose_a = false, transpose_b = false} |
| : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> |
| return %0 : tensor<?x?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @cast_sub |
| func @cast_sub(%arg0: tensor<?x32xi16>, %arg1: tensor<?x?x32xf16>) |
| -> tensor<?x?x32xf16> { |
| // CHECK: linalg.generic |
| // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref<?x?x32xf16>) |
| // CHECK-SAME: { |
| // CHECK: ^bb0(%[[LHS:.*]]: f16, %[[RHS:.*]]: i16, %{{.*}}: f16): |
| // CHECK: %[[RHS_CASTED:.*]] = sitofp %[[RHS]] : i16 to f16 |
| // CHECK: %[[RESULT:.*]] = subf %[[LHS]], %[[RHS_CASTED]] : f16 |
| // CHECK: linalg.yield %[[RESULT]] : f16 |
| // CHECK: } |
| // CHECK: return %[[RESULT_BUF]] : memref<?x?x32xf16> |
| %0 = "tf.Cast"(%arg0) : (tensor<?x32xi16>) -> tensor<?x32xf16> |
| %1 = "tf.Sub"(%arg1, %0) : (tensor<?x?x32xf16>, tensor<?x32xf16>) |
| -> tensor<?x?x32xf16> |
| return %1 : tensor<?x?x32xf16> |
| } |
| |
| // ----- |
| |
| // CHECK: #map0 = affine_map<(d0, d1) -> (d1, d0)> |
| // CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1)> |
| |
| // CHECK-LABEL: @tf_transpose_const_perm |
| func @tf_transpose_const_perm(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { |
| // CHECK: %[[OUT:.*]] = memref.alloc() : memref<3x2xf32> |
| // CHECK: linalg.generic {indexing_maps = [#map0, #map1] |
| // CHECK-SAME: ins(%arg0 : memref<2x3xf32>) |
| // CHECK-SAME: outs(%[[OUT]] : memref<3x2xf32>) |
| %0 = "tf.Const"() { value = dense<[1, 0]> : tensor<2xi32> } |
| : () -> tensor<2xi32> |
| %1 = "tf.Transpose"(%arg0, %0) |
| : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> |
| return %1 : tensor<3x2xf32> |
| } |
| |
| // ----- |
| |
| // CHECK: #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> |
| // CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> |
| |
| // CHECK-LABEL: @tf_transpose_after_transpose |
| func @tf_transpose_after_transpose(%arg0: tensor<?x?x?xf32>) |
| -> tensor<?x?x?xf32> { |
| // CHECK: %[[OUT:.*]] = memref.alloc |
| // CHECK: linalg.generic {indexing_maps = [#map0, #map1] |
| // CHECK-SAME: ins(%arg0 : memref<?x?x?xf32>) |
| // CHECK-SAME: outs(%[[OUT]] : memref<?x?x?xf32>) |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.Const"() { value = dense<[0, 2, 1]> : tensor<3xi32> } |
| : () -> tensor<3xi32> |
| %1 = "tf.Const"() { value = dense<[2, 1, 0]> : tensor<3xi32> } |
| : () -> tensor<3xi32> |
| %2 = "tf.Transpose"(%arg0, %0) |
| : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> |
| %3 = "tf.Transpose"(%2, %1) |
| : (tensor<?x?x?xf32>, tensor<3xi32>) -> tensor<?x?x?xf32> |
| return %3 : tensor<?x?x?xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @bias_add_and_relu |
| // CHECK-SAME: %[[ARG0:.*]]: memref<?x32xf32> |
| // CHECK-SAME: %[[ARG1:.*]]: memref<32xf32> |
| func @bias_add_and_relu(%arg0: tensor<?x32xf32>, |
| %arg1: tensor<32xf32>) -> tensor<?x32xf32> { |
| // CHECK: linalg.generic |
| // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : {{.*}}) |
| // CHECK: addf |
| // CHECK: maxf |
| // CHECK-NEXT: linalg.yield |
| // CHECK-NOT: linalg.generic |
| %0 = "tf.BiasAdd"(%arg0, %arg1) |
| : (tensor<?x32xf32>, tensor<32xf32>) -> tensor<?x32xf32> |
| %1 = "tf.Relu"(%0): (tensor<?x32xf32>) -> tensor<?x32xf32> |
| return %1 : tensor<?x32xf32> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @sub_sub |
| func @sub_sub(%arg0: tensor<?x32xf16>, %arg1: tensor<?x32xf16>, %arg2: tensor<?x?x32xf16>) -> tensor<?x?x32xf16> { |
| // CHECK: linalg.generic |
| // CHECK-SAME: outs(%[[RESULT_BUF:.*]] : memref<?x?x32xf16>) |
| // CHECK: ^bb0(%[[A:.*]]: f16, %[[B:.*]]: f16, %[[C:.*]]: f16, %{{.*}}: f16): |
| // CHECK: %[[TMP:.*]] = subf %[[B]], %[[C]] |
| // CHECK: %[[RESULT:.*]] = subf %[[A]], %[[TMP]] |
| // CHECK: linalg.yield %[[RESULT]] |
| // CHECK: return %[[RESULT_BUF]] : memref<?x?x32xf16> |
| %0 = "tf.Sub"(%arg0, %arg1) : (tensor<?x32xf16>, tensor<?x32xf16>) -> tensor<?x32xf16> |
| %1 = "tf.Sub"(%arg2, %0) : (tensor<?x?x32xf16>, tensor<?x32xf16>) -> tensor<?x?x32xf16> |
| return %1 : tensor<?x?x32xf16> |
| } |
| |
| // ----- |
| |
| // CHECK-LABEL: @strided_slice_1d_to_0d |
| func @strided_slice_1d_to_0d(%arg0: tensor<3xi32>) -> tensor<i32> { |
| %cst_0 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> |
| %cst_1 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> |
| // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1xi32> |
| // CHECK: %[[SUBVIEW:.*]] = memref.subview %arg0[0] [1] [1] |
| // CHECK-SAME: : memref<3xi32> to memref<1xi32> |
| // CHECK: linalg.copy(%[[SUBVIEW]], %[[ALLOC]]) |
| // CHECK: %[[RET:.*]] = memref.collapse_shape %[[ALLOC]] |
| // CHECK: return %[[RET]] |
| %0 = "tf.StridedSlice"(%arg0, %cst_1, %cst_0, %cst_0) |
| { |
| begin_mask = 0 : i64, |
| ellipsis_mask = 0 : i64, |
| end_mask = 0 : i64, |
| new_axis_mask = 0 : i64, |
| shrink_axis_mask = 1 : i64 |
| } : (tensor<3xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) |
| -> tensor<i32> |
| return %0 : tensor<i32> |
| } |
| |
| // ----- |
| |
| // CHECK: memref.global "private" constant @__constant_2xi32 : memref<2xi32> = dense<[0, 1]> |
| // CHECK-LABEL: @constant_folding |
| func @constant_folding() -> tensor<2xi32> { |
| %0 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> |
| %1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32> |
| // CHECK: %[[CONST:.*]] = memref.get_global @__constant_2xi32 : memref<2xi32> |
| // CHECK: return %[[CONST]] |
| %2 = "tf.Pack"(%0, %1) {axis = 0 : i64} |
| : (tensor<i32>, tensor<i32>) -> tensor<2xi32> |
| return %2 : tensor<2xi32> |
| } |