blob: 16054a5614a1f9db0a709d56aa41d3821aa83870 [file] [log] [blame]
// RUN: tf-opt -split-input-file -verify-diagnostics %s | FileCheck %s --dump-input-on-failure
// Unary math ops
// -----
// CHECK-LABEL: testCos
func @testCos(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.cos"(%arg0)
%0 = "tfl.cos"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid Cos input
func @testCosWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.cos' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.cos"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testExp
func @testExp(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.exp"(%arg0)
%0 = "tfl.exp"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testFloor
func @testFloor(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.floor"(%arg0)
%0 = "tfl.floor"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testGather
func @testGather(%arg0 : tensor<?xf32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xf32>,tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testGather
func @testGather(%arg0 : tensor<2xf32>, %arg1 : tensor<2xi32>) -> tensor<2xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<2xf32>,tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}
// ----
// CHECK-LABEL: testGatherUnknownRank
func @testGatherUnknownRank(%arg0 : tensor<*xf32>, %arg1 : tensor<1xi32>) -> tensor<*xf32> {
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<*xf32>,tensor<1xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testGatherUnsupportedType(%arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xf32> {
// expected-error @+1 {{op failed to verify that params and output must have same element type}}
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<?xi32>,tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testGatherUnsupportedRank(%arg0 : tensor<f32>, %arg1 : tensor<1xi32>) -> tensor<?xf32> {
// expected-error @+1 {{op failed to verify that operand 0 is 1-D}}
%0 = "tfl.gather"(%arg0, %arg1) {axis = 1 : i32}: (tensor<f32>,tensor<1xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testAbs
func @testAbs(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.abs"(%arg0)
%0 = "tfl.abs"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testAddN
func @testAddN(tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>):
// CHECK: "tfl.add_n"(%arg0, %arg1, %arg2)
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f32>, tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid AddN
func @testAddNWrongOperandResultType(tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16> {
^bb0(%arg0: tensor<? x f16>, %arg1: tensor<? x f16>, %arg2: tensor<? x f16>):
// expected-error @+1 {{'tfl.add_n' op operand #0 must be tensor of 32-bit float or 32-bit integer or QI16 type or QUI16 type values}}
%0 = "tfl.add_n"(%arg0, %arg1, %arg2): (tensor<? x f16>, tensor<? x f16>, tensor<? x f16>) -> tensor<? x f16>
return %0 : tensor<? x f16>
}
// -----
// CHECK-LABEL: testLog
func @testLog(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.log"(%arg0)
%0 = "tfl.log"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testNeg
func @testNeg(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.neg"(%arg0)
%0 = "tfl.neg"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testRsqrt
func @testRsqrt(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.rsqrt"(%arg0)
%0 = "tfl.rsqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testSin
func @testSin(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.sin"(%arg0)
%0 = "tfl.sin"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
// test invalid Sin input
func @testSinWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.sin' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.sin"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// test invalid Sqrt input
func @testSqrtWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{tfl.sqrt' op operand #0 must be tensor of floating-point values}}
%0 = "tfl.sqrt"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
// test invalid Square input
func @testSquareWithWrongInputType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{tfl.square' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type values}}
%0 = "tfl.square"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testSqrt
func @testSqrt(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.sqrt"(%arg0)
%0 = "tfl.sqrt"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testSquare
func @testSquare(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.square"(%arg0)
%0 = "tfl.square"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
func @testQuantizedSquare(tensor<? x !quant.uniform<u8:f32, 0.1>>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>):
%0 = "tfl.square"(%arg0): (tensor<? x !quant.uniform<u8:f32, 0.1>>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
func @testQuantizedResizeNearestNeighbor(tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
^bb0(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>, %arg1: tensor<? x i32>):
%0 = "tfl.resize_nearest_neighbor"(%arg0, %arg1) { align_corners = false } : (tensor<? x !quant.uniform<u8:f32, 0.1>>, tensor<? x i32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
// CHECK-LABEL: testTanh
func @testTanh(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.tanh"(%arg0)
%0 = "tfl.tanh"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testTanhWithQI8
func @testTanhWithQI8(%arg0: tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
%0 = "tfl.tanh"(%arg0): (tensor<? x !quant.uniform<i8:f32, 0.1>>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}
// CHECK-LABEL: testTanhWithQUI8
func @testTanhWithQUI8(%arg0: tensor<? x !quant.uniform<u8:f32, 0.1>>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
%0 = "tfl.tanh"(%arg0): (tensor<? x !quant.uniform<u8:f32, 0.1>>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
// CHECK-LABEL: testZerosLike
func @testZerosLike(tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: "tfl.zeros_like"(%arg0)
%0 = "tfl.zeros_like"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testDequantize
func @testDequantize(tensor<? x i32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x i32>):
// CHECK: "tfl.dequantize"(%arg0) : (tensor<?xi32>) -> tensor<?xf32>
%0 = "tfl.dequantize"(%arg0): (tensor<? x i32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// CHECK-LABEL: testLogicalNot
func @testLogicalNot(tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>):
// CHECK: "tfl.logical_not"(%arg0)
%0 = "tfl.logical_not"(%arg0): (tensor<? x i1>) -> tensor<? x i1>
return %0 : tensor<? x i1>
}
// -----
func @testLogicalNotWrongOperandType(tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_not' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_not"(%arg0) : (tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// Binary math ops
// -----
// CHECK-LABEL: testAdd
func @testAdd(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// TODO(jpienaar): Enable specifying label of enum for parsing.
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testSub
func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.mul %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testDiv
func @testDiv(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.div %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testLess
func @testLess(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: "tfl.less"(%arg0, %arg1)
%0 = "tfl.less"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
// CHECK-LABEL: testFloorDivI32
func @testFloorDivI32(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.floor_div %arg0, %arg1
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testFloorDivF32
func @testFloorDivF32(tensor<? x f32>, tensor<? x f32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: tfl.floor_div %arg0, %arg1
%0 = tfl.floor_div %arg0, %arg1 : tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
func @testFloorDivF32(%arg0: tensor<2 x f32>, %arg1: tensor<2 x i32>) -> tensor<2 x f32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.floor_div"(%arg0, %arg1) : (tensor<2 x f32>, tensor<2 x i32>) -> tensor<2 x f32>
return %0#0 : tensor<2 x f32>
}
// -----
// CHECK-LABEL: testFloorMod
func @testFloorMod(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>) -> tensor<? x i32> {
%0 = "tfl.floor_mod"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// CHECK-LABEL: testPow
func @testPow(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// CHECK: tfl.pow %arg0, %arg1
%0 = tfl.pow %arg0, %arg1 : tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testConv2D
func @testConv2D(tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> {
^bb0(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>):
// CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2)
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
func @testConv2DNoBias(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: none) -> tensor<256x30x30x16xf32> {
// CHECK: "tfl.conv_2d"(%arg0, %arg1, %arg2)
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, fused_activation_function = "RELU6"} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, none) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
// CHECK-LABEL: testFakeQuant
func @testFakeQuant(tensor<? x f32>, f32, f32) -> tensor<? x f32> {
^bb0(%arg0: tensor<? x f32>, %arg1: f32, %arg2: f32):
// CHECK: %0 = "tfl.fake_quant"(%arg0) {minmax = [], narrow_range = true, num_bits = 2 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.fake_quant"(%arg0) {minmax = [], num_bits = 2 : i32, narrow_range = true} : (tensor<? x f32>) -> tensor<? x f32>
// CHECK: %1 = "tfl.fake_quant"(%0) {minmax = [3.000000e-01, 1.400000e+00], narrow_range = false, num_bits = 6 : i32} : (tensor<?xf32>) -> tensor<?xf32>
%1 = "tfl.fake_quant"(%0) {num_bits = 6 : i32, narrow_range = false, minmax = [0.3, 1.4]} : (tensor<? x f32>) -> tensor<? x f32>
return %1 : tensor<? x f32>
}
// CHECK-LABEL: testQuantize
func @testQuantize(tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>> {
^bb0(%arg0: tensor<? x f32>):
// CHECK: %0 = "tfl.quantize"(%arg0) {qtype = tensor<?x!quant.uniform<u8:f32, 1.000000e-01:128>>}
%0 = "tfl.quantize"(%arg0) {qtype = tensor<? x !quant.uniform<u8:f32, 0.1:128>>} : (tensor<? x f32>) -> tensor<? x !quant.uniform<u8:f32, 0.1:128>>
return %0 : tensor<? x !quant.uniform<u8:f32, 0.1:128>>
}
// CHECK-LABEL: testLogicalAnd
func @testLogicalAnd(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
// CHECK: tfl.logical_and %arg0, %arg1
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
func @testLogicalAndWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_and' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_and"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testLogicalOr
func @testLogicalOr(tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x i1>, %arg1: tensor<? x i1>):
// CHECK: tfl.logical_or %arg0, %arg1
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i1>, tensor<? x i1>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
func @testLogicalOrWrongOperandType(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
// expected-error @+1 {{'tfl.logical_or' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.logical_or"(%arg0, %arg1) : (tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32>
return %0 : tensor<? x i32>
}
// -----
// CHECK-LABEL: testEluF32
func @testEluF32(%arg0: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.elu"(%arg0)
%0 = "tfl.elu"(%arg0): (tensor<? x f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testTileF32
func @testTileF32(%arg0: tensor<4 x 1 x f32>, %arg1: tensor<4 x i32>) -> tensor<? x f32> {
// CHECK: "tfl.tile"(%arg0, %arg1)
%0 = "tfl.tile"(%arg0, %arg1): (tensor<4 x 1 x f32>, tensor<4 x i32>) -> tensor<? x f32>
return %0 : tensor<? x f32>
}
// -----
func @testEluI32(%arg0: tensor<? x i32>) -> tensor<? x i32> {
// expected-error @+1 {{operand #0 must be tensor of floating-point values}}
%0 = "tfl.elu"(%arg0): (tensor<? x i32>) -> tensor<? x i32>
return %0#0 : tensor<? x i32>
}
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> (tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>) {
// CHECK: "NONE"
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<4xi32>
// CHECK: "RELU"
%1 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU"} : tensor<4xi32>
// CHECK: "RELU_N1_TO_1"
%2 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU_N1_TO_1"} : tensor<4xi32>
// CHECK: "RELU6"
%3 = tfl.add %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<4xi32>
// CHECK: "TANH"
%4 = tfl.add %arg0, %arg1 {fused_activation_function = "TANH"} : tensor<4xi32>
// CHECK: "SIGN_BIT"
%5 = tfl.add %arg0, %arg1 {fused_activation_function = "SIGN_BIT"} : tensor<4xi32>
return %0, %1, %2, %3, %4, %5: tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>, tensor<4xi32>
}
// -----
func @testFusedActiviationFunction(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// expected-error @+1 {{attribute 'fused_activation_function' failed to satisfy constraint: fused activation enum}}
%0 = tfl.add %arg0, %arg1 {fused_activation_function = "Relu6"} : tensor<4xi32>
return %0: tensor<4xi32>
}
// -----
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> (tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>) {
// CHECK: "SAME"
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
// CHECK: "VALID"
%1 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0, %1 : tensor<256x30x30x16xf32>, tensor<256x30x30x16xf32>
}
// -----
func @testPadding(%arg0: tensor<256x32x32x3xf32>, %arg1: tensor<3x3x3x16xf32>, %arg2: tensor<16xf32>) -> tensor<256x30x30x16xf32> {
// expected-error @+1 {{attribute 'padding' failed to satisfy constraint: padding enum}}
%0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SOMETHING", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>, tensor<3x3x3x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32>
return %0 : tensor<256x30x30x16xf32>
}
// -----
// CHECK-LABEL: testMaxPool2D
func @testMaxPool2D(tensor<256x32x32x3xf32>) -> tensor<?xf32> {
^bb0(%arg0: tensor<256x32x32x3xf32>):
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testMaxPool2DQuantized
func @testMaxPool2DQuantized(tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>> {
^bb0(%arg0: tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>):
// CHECK: "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<256x32x32x3x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<?x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<?x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
// test invalid MaxPool2D
func @testMaxPool2DWrongOperandResultType(tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32> {
^bb0(%arg0: tensor<1x7x7x16xi32>):
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16xi32>) -> tensor<1x7x7x16xi32>
return %0 : tensor<1x7x7x16xi32>
}
// -----
// test invalid MaxPool2D
func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>> {
^bb0(%arg0: tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>):
// expected-error @+1 {{failed to verify that MaxPool2D operand and result types match specified constraints}}
%0 = "tfl.max_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "RELU6", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>) -> tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
return %0 : tensor<1x7x7x16x!quant.uniform<i9:f32, 0.1:128>>
}
// -----
// CHECK-LABEL: testLogistic
func @testLogistic(tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16> {
^bb0(%arg0: tensor<1x2x3x4x5xbf16>):
// CHECK: "tfl.logistic"(%arg0)
%0 = "tfl.logistic"(%arg0): (tensor<1x2x3x4x5xbf16>) -> tensor<1x2x3x4x5xbf16>
return %0 : tensor<1x2x3x4x5xbf16>
}
// -----
// test invalid Logistic input
func @testLogisticWithWrongInputType(tensor<?xi32>) -> tensor<?xi32> {
^bb0(%arg0: tensor<?xi32>):
// expected-error @+1 {{tfl.logistic' op operand #0 must be tensor of floating-point or QI8 type or QUI8 type or QI16 type or QUI16 type values}}
%0 = "tfl.logistic"(%arg0): (tensor<?xi32>) -> tensor<?xi32>
return %0#0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceRnn
func @testUnidirectionalSequenceRnn(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_rnn"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstm
func @testUnidirectionalSequenceLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr
func @testUnidirectionalSequenceLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test invalid none type applied to a tensor type arg
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test violation of projection weight and projection bias pred op trait
func @testUnidirectionalSequenceLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.unidirectional_sequence_lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testLstm
func @testLstm(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testLstmWithNoneTypeAndOverrideAttr
func @testLstmWithNoneTypeAndOverrideAttr(%arg0: tensor<? x f32>, %arg1: none, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23)
// CHECK-NEXT: {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL"} : (tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test invalid none type applied to a tensor type arg
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: none, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op operand #2 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test violation of projection weight and projection bias pred op trait
func @testLstmWithInvalidNoneType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: none, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op failed to verify that either projection weight must be specified or both projection weight and projection bias must not be specified}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {fused_activation_function = "NONE"} : (tensor<?xf32>, tensor<? x f32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, none, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// test invalid kernel type
func @testLstmWithInvalidKernelType(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>, %arg5: tensor<? x f32>, %arg6: tensor<? x f32>, %arg7: tensor<? x f32>, %arg8: tensor<? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<? x f32>, %arg16: tensor<? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<? x f32>, %arg19: tensor<? x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.lstm' op attribute 'kernel_type' failed to satisfy constraint: lstm kernel type enum case FULL}}
%0 = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) ({}) {cell_clip = 1.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "BASIC"} : (tensor<?xf32>, tensor<? x f32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testReverseV2
func @testReverseV2(%arg0: tensor<1x2x3x4xf32>, %arg1 : tensor<2xi32>) -> tensor<1x2x3x4xf32> {
// CHECK: "tfl.reverse_v2"(%arg0, %arg1)
%0 = "tfl.reverse_v2"(%arg0, %arg1): (tensor<1x2x3x4xf32>, tensor<2xi32>) -> tensor<1x2x3x4xf32>
return %0 : tensor<1x2x3x4xf32>
}
// -----
// test select
// CHECK-LABEL: testSelect
func @testSelect(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// test select with multi-dim inputs
// CHECK-LABEL: testSelectMultiDim
func @testSelectMultiDim(%cond : tensor<?xi1>, %arg0 : tensor<?x4xi32>, %arg1 : tensor<?x4xi32>) -> tensor<?x4xi32> {
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?x4xi32>,tensor<?x4xi32>) -> tensor<?x4xi32>
return %0 : tensor<?x4xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi32>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi32>,tensor<?xi32>,tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testSelectWithUnsupportedShapes(%cond : tensor<2xi1>, %arg0 : tensor<3xi32>, %arg1 : tensor<3xi32>) -> tensor<3xi32> {
// expected-error @+1 {{failed to verify that Select operands meet shape criteria}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<2xi1>,tensor<3xi32>,tensor<3xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @testSelectWithUnsupportedType(%cond : tensor<?xi1>, %arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
// expected-error @+1 {{failed to verify that operands have same element type}}
%0 = "tfl.select"(%cond, %arg0, %arg1): (tensor<?xi1>,tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: topk
func @topk(%arg0: tensor<8xf32>, %arg1: tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>) {
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<8xf32>, tensor<i32>) -> (tensor<?xf32>, tensor<?xi32>)
return %0, %1: tensor<?xf32>, tensor<?xi32>
}
// -----
// CHECK-LABEL: topk
func @topk(%arg0: tensor<*xf32>, %arg1: tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0, %1 = "tfl.topk_v2"(%arg0, %arg1) : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %0, %1: tensor<*xf32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: topk_2
func @topk_2(%arg0: tensor<3x4x8xf32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<3x4x8xf32>, tensor<i32>) -> (tensor<3x4x2xf32>, tensor<3x4x2xi32>)
return %1#0, %1#1: tensor<3x4x2xf32>, tensor<3x4x2xi32>
}
// -----
// CHECK-LABEL: topk_d
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x2xf32>, tensor<?x2xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x2xf32>, tensor<?x2xi32>)
return %1#0, %1#1: tensor<?x2xf32>, tensor<?x2xi32>
}
// -----
// CHECK-LABEL: topk_d
// TODO(jpienaar): This should fail but doesn't as the op definition does not
// include shape verification.
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<?x3xf32>, tensor<?x3xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<?x3xf32>, tensor<?x3xi32>)
return %1#0, %1#1: tensor<?x3xf32>, tensor<?x3xi32>
}
// -----
// CHECK-LABEL: topk_d
func @topk_d(%arg0: tensor<?x8xf32>) -> (tensor<*xf32>, tensor<*xi32>) {
%0 = constant dense<2> : tensor<i32>
%1:2 = "tfl.topk_v2"(%arg0, %0) : (tensor<?x8xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xi32>)
return %1#0, %1#1: tensor<*xf32>, tensor<*xi32>
}
// -----
// CHECK-LABEL: testEqual
func @testEqual(tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1> {
^bb0(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>):
// CHECK: "tfl.equal"(%arg0, %arg1)
%0 = "tfl.equal"(%arg0, %arg1) : (tensor<? x f32>, tensor<? x f32>) -> tensor<? x i1>
return %0#0 : tensor<? x i1>
}
// -----
// CHECK-LABEL: testPad
func @testPad(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test Pad with invalid paddings size
func @testPadWithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 0's rank equals operand 1's size}}
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test Pad with invalid paddings rank
func @testPadWithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
// expected-error @+1 {{'tfl.pad' op failed to verify that operand 1 is 2-D}}
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// CHECK-LABEL: testPadQuantizedU8
func @testPadQuantizedU8(%arg0: tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<u8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<u8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<u8:f32, 0.1>>
}
// CHECK-LABEL: testPadQuantizedI8
func @testPadQuantizedI8(%arg0: tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pad"(%arg0, %arg1)
%0 = "tfl.pad"(%arg0, %arg1) : (tensor<2x1x3x!quant.uniform<i8:f32, 0.1>>, tensor<3x2xi32>) -> tensor<? x !quant.uniform<i8:f32, 0.1>>
return %0#0 : tensor<? x !quant.uniform<i8:f32, 0.1>>
}
// -----
// CHECK-LABEL: testPadV2
func @testPadV2(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
// CHECK: "tfl.padv2"(%arg0, %arg1, %cst)
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid paddings size
func @testPadV2WithInvalidPaddingsDim(tensor<2x1x3xf32>, tensor<2x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<2x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 0's rank equals operand 1's size}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<2x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid paddings rank
func @testPadV2WithInvalidPaddingsRank(tensor<2x1x3xf32>, tensor<1x3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<1x3x2xi32>):
%cst = constant dense<2.0> : tensor<f32>
// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 1 is 2-D}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<1x3x2xi32>, tensor<f32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid constant rank
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<[2.0]> : tensor<1xf32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that operand 2 is 0-D}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<1xf32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
// test PadV2 with invalid constant data type
func @testPadV2WithInvalidConstantScalar(tensor<2x1x3xf32>, tensor<3x2xi32>) -> tensor<? x f32> {
^bb0(%arg0: tensor<2x1x3xf32>, %arg1: tensor<3x2xi32>):
%cst = constant dense<2> : tensor<i32>
//// expected-error @+1 {{'tfl.padv2' op failed to verify that input and constant value operands must have same element type}}
%0 = "tfl.padv2"(%arg0, %arg1, %cst) : (tensor<2x1x3xf32>, tensor<3x2xi32>, tensor<i32>) -> tensor<? x f32>
return %0#0 : tensor<? x f32>
}
// -----
func @packQuantizedU8(%arg0: tensor<2x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<u8:f32, 0.1>>, tensor<2x!quant.uniform<u8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<u8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<u8:f32, 0.1>>
}
func @packQuantizedI8(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2x!quant.uniform<i8:f32, 0.1>>, tensor<2x!quant.uniform<i8:f32, 0.1>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1>>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x4x2xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x4x2xi32>
return %0 : tensor<1x4x2xi32>
}
// -----
func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> {
// CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32}
%0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32>
return %0 : tensor<2x1x4xi32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{input count should match 'values_count' attribute}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 1 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{operands should be of the same type}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 0 : i32, values_count = 2 : i32} : (tensor<1xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}}
%0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
func @unpackQuantized(%arg0: tensor<2x3x!quant.uniform<u8:f32, 0.02>>) -> tensor<2x!quant.uniform<u8:f32, 0.02>> {
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3x!quant.uniform<u8:f32, 0.02>>) -> (tensor<2x!quant.uniform<u8:f32, 0.02>>, tensor<2x!quant.uniform<u8:f32, 0.02>>, tensor<2x!quant.uniform<u8:f32, 0.02>>)
return %0#0 : tensor<2x!quant.uniform<u8:f32, 0.02>>
}
// -----
func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> {
// expected-error @+1 {{output count should match 'num' attribute}}
%0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>)
return %0#0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: testMean
func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false}
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = false}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
// CHECK-LABEL: testMean_true
func @testMean_true(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = true}
%0 = "tfl.mean"(%arg0, %arg1) {keep_dims = true}: (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
func @testMean_missing_keep_dims(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> {
// expected-error @+1 {{'tfl.mean' op requires attribute 'keep_dims'}}
%0 = "tfl.mean"(%arg0, %arg1): (tensor<2x2xf32>, tensor<1xi32>) -> tensor<1x2xf32>
return %0 : tensor<1x2xf32>
}
// -----
// CHECK-LABEL: testBatchToSpaceND
func @testBatchToSpaceND(%arg0 : tensor<4x2x2x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
// CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2)
%0 = "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testSpaceToBatchND
func @testSpaceToBatchND(%arg0 : tensor<1x4x4x3xf32>, %arg1 : tensor<2xi32>, %arg2 : tensor<2x2xi32>) -> tensor<?xf32> {
// CHECK: "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2)
%0 = "tfl.space_to_batch_nd"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testConcat(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> {
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @testConcatQuantized(%arg0: tensor<2x!quant.uniform<i8:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i8:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
func @testConcatInvalidOutputElementalType(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<?xf32> {
// expected-error @+1 {{'tfl.concatenation' op failed to verify that values and output must have same element type}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testConcatInvalidStorageType(%arg0: tensor<2x!quant.uniform<i9:f32, 0.1:128>>, %arg1: tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>> {
// expected-error @+1 {{'tfl.concatenation' op operand #0 must be tensor of 32-bit float or 64-bit integer or 32-bit integer or 16-bit integer or 8-bit integer or QI8 type or QUI8 type or TFLite uint8 type values}}
%0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<2x!quant.uniform<i9:f32, 0.1:128>>, tensor<2x!quant.uniform<i8:f32, 0.1:128>>) -> tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
return %0 : tensor<2x2x!quant.uniform<i8:f32, 0.1:128>>
}
// -----
// CHECK-LABEL: testResizeBilinear
func @testResizeBilinear(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xf32> {
// CHECK: "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false}
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testResizeBilinearInvalidOutputType(%arg0 : tensor<1x100x100x3xf32>, %arg1 : tensor<4xi32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.resize_bilinear' op result #0 must be tensor of 32-bit float or QI8 type or QUI8 type values}}
%0 = "tfl.resize_bilinear"(%arg0, %arg1) {align_corners = false} : (tensor<1x100x100x3xf32>, tensor<4xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
// CHECK-LABEL: testStridedSlice
func @testStridedSlice(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xf32> {
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32>
return %0 : tensor<1x2x2x5xf32>
}
// CHECK-LABEL: testStridedSliceWithQI8
func @testStridedSliceWithQI8(%arg0: tensor<12x2x2x5x!quant.uniform<i8:f32, 0.1>>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform<i8:f32, 0.1>> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform<i8:f32, 0.1>>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform<i8:f32, 0.1>>
return %0 : tensor<1x2x2x5x!quant.uniform<i8:f32, 0.1>>
}
// CHECK-LABEL: testStridedSliceWithQUI8
func @testStridedSliceWithQUI8(%arg0: tensor<12x2x2x5x!quant.uniform<u8:f32, 0.1>>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform<u8:f32, 0.1>> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!quant.uniform<u8:f32, 0.1>>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!quant.uniform<u8:f32, 0.1>>
return %0 : tensor<1x2x2x5x!quant.uniform<u8:f32, 0.1>>
}
// -----
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {
// expected-error @+1 {{op failed to verify that input and output must have same element type}}
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xi32>
return %0 : tensor<1x2x2x5xi32>
}
// -----
// CHECK-LABEL: testOneHot
func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xf32> {
// CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
func @testOneHotWithInvalidOutputType(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<*xi8> {
// expected-error @+1 {{'tfl.one_hot' op result #0 must be tensor of 32-bit float or 32-bit integer or 64-bit integer or 1-bit integer values}}
%0 = "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) {axis = -1 : i32} : (tensor<3xi32>, tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<*xi8>
return %0 : tensor<*xi8>
}
// -----
func @testArgMax(%arg0: tensor<3xi32>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK: "tfl.arg_max"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
%0 = "tfl.arg_max"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
// -----
func @testArgMin(%arg0: tensor<3xi32>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK: "tfl.arg_min"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
%0 = "tfl.arg_min"(%arg0, %arg1) {output_type = 2 : i32} : (tensor<3xi32>, tensor<i32>) -> tensor<i32>
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: testSpaceToDepth
func @testSpaceToDepthF32(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32> {
// CHECK: %[[ARG:.*]]: tensor<1x2x2x1xf32>
// CHECK: "tfl.space_to_depth"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32>
%0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xf32>
return %0 : tensor<1x1x1x4xf32>
}
// -----
func @testSpaceToDepthInvalidOutputType(%arg0: tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32> {
// expected-error @+1 {{'tfl.space_to_depth' op failed to verify that input and output must have same element type}}
%0 = "tfl.space_to_depth"(%arg0) {block_size = 2: i32} : (tensor<1x2x2x1xf32>) -> tensor<1x1x1x4xi32>
return %0 : tensor<1x1x1x4xi32>
}
// -----
func @testRange(%arg0 : tensor<i32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xi32> {
%0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testRangeNonScalarTensorInput(%arg0 : tensor<1xi32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xi32> {
// expected-error @+1 {{op failed to verify that operand 0 is 0-D}}
%0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<1xi32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testRangeOutputTypeMismatch(%arg0 : tensor<i32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<?xf32> {
// expected-error @+1 {{op failed to verify that operands and output must have same element type}}
%0 = "tfl.range"(%arg0, %arg1, %arg2) : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @transpose(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_perm_not_i32(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2xf32>) -> tensor<2x2xi32> {
// expected-error @+1 {{op operand #1 must be tensor of 32-bit integer values}}
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2xf32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_element_type(%arg0 : tensor<2x2xf32>, %arg1 : tensor<2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{input and output must have same element type}}
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xf32>, tensor<2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @transpose_1d_perm(%arg0 : tensor<2x2xi32>, %arg1 : tensor<2x2xi32>) -> tensor<2x2xi32> {
// expected-error @+1 {{op failed to verify that operand 1 is 1-D}}
%0 = "tfl.transpose"(%arg0, %arg1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}
// -----
func @anyWithI64Axis(%arg0: tensor<2x2xi1>, %arg1: tensor<i64>) -> tensor<i1> {
// expected-error @+1 {{tfl.reduce_any' op operand #1 must be tensor of 32-bit integer values}}
%0 = "tfl.reduce_any"(%arg0, %arg1) {keep_dims = false} : (tensor<2x2xi1>, tensor<i64>) -> tensor<i1>
return %0 : tensor<i1>
}
// -----
func @testRoundInvalidInputType(%arg: tensor<?xi32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.round' op operand #0 must be tensor of 32-bit float values}}
%0 = "tfl.round"(%arg) : (tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testSplitWithQuantizedTypes(%arg0 : tensor<i32>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split"(%arg0, %arg1) {num_splits = 1 : i32} : (tensor<i32>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testSplitVWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<1xi32>, %arg2 : tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.split_v"(%arg0, %arg1, %arg2) {num_splits = 1 : i32} : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<1xi32>, tensor<i32>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @whereWithI32Input(%arg0: tensor<3x5xi32>) -> tensor<?x2xi64> {
// expected-error @+1 {{'tfl.where' op operand #0 must be tensor of 1-bit integer values}}
%0 = "tfl.where"(%arg0) : (tensor<3x5xi32>) -> tensor<?x2xi64>
return %0 : tensor<?x2xi64>
}
// -----
func @testMinimumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.minimum"(%arg0, %arg1) : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testMaximumWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>, %arg1 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.maximum"(%arg0, %arg1) : (tensor<10x!quant.uniform<u8:f32, 1.0>>, tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testReluWithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.relu"(%arg0) : (tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testRelu6WithQuantizedTypes(%arg0 : tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>> {
%0 = "tfl.relu6"(%arg0) : (tensor<10x!quant.uniform<u8:f32, 1.0>>) -> tensor<10x!quant.uniform<u8:f32, 1.0>>
return %0 : tensor<10x!quant.uniform<u8:f32, 1.0>>
}
// -----
func @testEmbeddingLookup(%arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
%0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor<?xi32>,tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testEmbeddingLookupInvalidResultType(%arg0 : tensor<?xi32>, %arg1 : tensor<?xf32>) -> tensor<?xi32> {
// expected-error @+1 {{'tfl.embedding_lookup' op result #0 must be tensor of 32-bit float or 8-bit integer or TFLite uint8 type values}}
%0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor<?xi32>,tensor<?xf32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// -----
func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor<?xi32>, %arg1 : tensor<?xi8>) -> tensor<?xf32> {
// expected-error @+1 {{'tfl.embedding_lookup' op failed to verify that value and output must have same element type}}
%0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor<?xi32>,tensor<?xi8>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
return %0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
}
// -----
// CHECK-LABEL: testSvdf
func @testSvdf(%arg0: tensor<? x f32>, %arg1: tensor<? x f32>, %arg2: tensor<? x f32>, %arg3: tensor<? x f32>, %arg4: tensor<? x f32>) -> tensor<? x f32> {
// CHECK: "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
func @testSvdfUnsupportedType(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>, %arg2: tensor<? x i32>, %arg3: tensor<? x i32>, %arg4: tensor<? x i32>) -> tensor<? x f32> {
// expected-error @+1 {{'tfl.svdf' op operand #0 must be tensor of 32-bit float or 8-bit integer values}}
%0 = "tfl.svdf"(%arg0, %arg1, %arg2, %arg3, %arg4) {fused_activation_function = "NONE", rank = 2 : i32} : (tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>) -> tensor<?xf32>
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: testDepthToSpace
func @testDepthToSpaceF32(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32> {
// CHECK: %[[ARG:.*]]: tensor<1x1x1x4xf32>
// CHECK: "tfl.depth_to_space"(%[[ARG]]) {block_size = 2 : i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
%0 = "tfl.depth_to_space"(%arg0) {block_size = 2: i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xf32>
return %0 : tensor<1x2x2x1xf32>
}
// -----
func @testDepthToSpaceInvalidOutputType(%arg0: tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xi32> {
// expected-error @+1 {{'tfl.depth_to_space' op failed to verify that input and output must have same element type}}
%0 = "tfl.depth_to_space"(%arg0) {block_size = 2: i32} : (tensor<1x1x1x4xf32>) -> tensor<1x2x2x1xi32>
return %0 : tensor<1x2x2x1xi32>
}
// -----
func @testSlice(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceBadBeginDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<2xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
// expected-error @+1 {{begin tensor elements size is not equal to input tensor rank}}
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<2xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceBadSizeDimension(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<2xi32>) -> tensor<?x3x5xf32> {
// expected-error @+1 {{size tensor elements size is not equal to input tensor rank}}
%0 = "tfl.slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<2xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceBadBegin(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
%cst = constant dense<[2, -1, 5]> : tensor<3xi32>
// expected-error @+1 {{begin[1] cannot be negative}}
%0 = "tfl.slice"(%arg0, %cst, %arg1) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceNegativeSize(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
%cst = constant dense<[-2, -1, 5]> : tensor<3xi32>
// expected-error @+1 {{size[0] cannot be negative other than -1}}
%0 = "tfl.slice"(%arg0, %arg1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceSizeOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
%cst = constant dense<[2, 1, 5]> : tensor<3xi32>
%cst_1 = constant dense<[0, 1, 1]> : tensor<3xi32>
// expected-error @+1 {{begin[2] + size[2] cannot exceed dimension length: 5}}
%0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSliceBeginOutOfRange(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>) -> tensor<?x3x5xf32> {
%cst = constant dense<[1, 1, 1]> : tensor<3xi32>
%cst_1 = constant dense<[2, 1, 3]> : tensor<3xi32>
// expected-error @+1 {{begin[0] cannot exceed dimension length: 2}}
%0 = "tfl.slice"(%arg0, %cst_1, %cst) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>
}
// -----
func @testSplitOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}}
"tfl.split"(%split_dim, %arg0) {num_splits = 0 : i32} : (tensor<i32>, tensor<16xf32>) -> ()
return
}
// -----
func @testSplitOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op output count should match 'num_splits' attribute}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 4 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%split_dim = constant dense<0> : tensor<2x2xi32>
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<2x2xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
// expected-error @+1 {{'tfl.split' op operand #0 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split"(%split_dim, %arg0) {num_splits = 1 : i32} : (tensor<*xi32>, tensor<16x4x4xf32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitOpWithOutOfRangeSplitDim(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%split_dim = constant dense<1> : tensor<i32>
// expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitOpWithOutOfRangeSplitDimTFLConst(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%split_dim = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitOpWithOutOfRangeSplitDimNegative(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%split_dim = constant dense<-2> : tensor<i32>
// expected-error @+1 {{'tfl.split' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitOpWithUnevenDivision(%arg0 : tensor<16xf32>) -> (tensor<6xf32>, tensor<5xf32>, tensor<5xf32>) {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op 'num_splits' should evenly divide 'split_dim' axis}}
%0, %1, %2 = "tfl.split"(%split_dim, %arg0) {num_splits = 3 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<6xf32>, tensor<5xf32>, tensor<5xf32>)
return %0, %1, %2 : tensor<6xf32>, tensor<5xf32>, tensor<5xf32>
}
// -----
func @testSplitOpWithMismatchTensorTypeSplitDimOut0(%arg0 : tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op output #0 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>)
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
// -----
func @testSplitOpWithMismatchTensorTypeSplitDimOut1(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op output #1 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>)
return %0, %1 : tensor<8xf32>, tensor<4xf32>
}
// -----
func @testSplitOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split' op output #0 should be 'tensor<8x4xf32>'}}
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>)
return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32>
}
// -----
func @testSplitOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split"(%split_dim_0, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
%split_dim_1 = constant dense<1> : tensor<i32>
%2, %3 = "tfl.split"(%split_dim_1, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x4xf32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0, %1, %2, %3 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>
}
// -----
func @testSplitOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) {
%split_dim = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split"(%split_dim, %arg0) {num_splits = 2 : i32} : (tensor<i32>, tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
}
// -----
func @testSplitVOpWithBadNumSplits(%arg0 : tensor<16xf32>) -> () {
%size_splits = constant dense<[]> : tensor<0xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op attribute 'num_splits' failed to satisfy constraint: positive 32-bit integer attribute}}
"tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 0 : i32} : (tensor<16xf32>, tensor<0xi32>, tensor<i32>) -> ()
return
}
// -----
func @testSplitVOpWithMismatchedNumResults(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[4, 4, 4, 4]> : tensor<4xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output count should match 'num_splits' attribute}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 4 : i32} : (tensor<16xf32>, tensor<4xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[[8, 8], [2, 2]]> : tensor<2x2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2x2xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %size_splits: tensor<*xi32>) -> tensor<16x4x4xf32> {
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op operand #1 must be 1D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<*xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstant(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[-2]> : tensor<1xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op elements of 'size_splits' should be greater than or equal to -1}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<i32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstantMultipleNegativeOne(%arg0: tensor<16x4x4xf32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>) {
%size_splits = constant dense<[-1, -1, 14]> : tensor<3xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'size_splits' can only have one -1}}
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 3 : i32} : (tensor<16x4x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>)
return %0, %1, %2 : tensor<1x4x4xf32>, tensor<1x4x4xf32>, tensor<14x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsConstantSum(%arg0: tensor<16x4x4xf32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>) {
%size_splits = constant dense<[-1, 17]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op sum of non-negative elements of 'size_splits' is greater than the dimension size of 'split_dim' axis}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<0x4x4xf32>, tensor<16x4x4xf32>)
return %0, %1 : tensor<0x4x4xf32>, tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSizeSplitsSize(%arg0: tensor<16x4x4xf32>) -> tensor<15x4x4xf32> {
%size_splits = constant dense<[15, 1]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'size_splits' should be 'tensor<1xi32>'}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<2xi32>, tensor<i32>) -> tensor<15x4x4xf32>
return %0 : tensor<15x4x4xf32>
}
// -----
func @testSplitVOpWithBadSplitDimTensorType(%arg0: tensor<16x4x4xf32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[16]> : tensor<1xi32>
%split_dim = constant dense<0> : tensor<2x2xi32>
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<2x2xi32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithBadSplitDimUnrankedTensorType(%arg0: tensor<16x4x4xf32>, %split_dim : tensor<*xi32>) -> tensor<16x4x4xf32> {
%size_splits = constant dense<[16]> : tensor<1xi32>
// expected-error @+1 {{'tfl.split_v' op operand #2 must be 0D tensor of 32-bit integer values}}
%0 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 1 : i32} : (tensor<16x4x4xf32>, tensor<1xi32>, tensor<*xi32>) -> tensor<16x4x4xf32>
return %0 : tensor<16x4x4xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDim(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<1> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDimTFLConst(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = "tfl.pseudo_const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithOutOfRangeSplitDimNegative(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<8xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<-2> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op 'split_dim' should be in [-rank, rank)}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<8xf32>)
return %0, %1 : tensor<8xf32>, tensor<8xf32>
}
// -----
func @testSplitVOpWithMismatchSizeSplitsSum(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 4]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op sum of 'size_splits' should match the dimension size of 'split_dim' axis}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
return %0, %1 : tensor<8xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeSplitDimOut0(%arg0 : tensor<16xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<4xf32>, tensor<4xf32>)
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeSplitDimOut1(%arg0 : tensor<16xf32>) -> (tensor<8xf32>, tensor<4xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #1 should be 'tensor<8xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8xf32>, tensor<4xf32>)
return %0, %1 : tensor<8xf32>, tensor<4xf32>
}
// -----
func @testSplitVOpWithMismatchTensorTypeNonSplitDim(%arg0 : tensor<16x4xf32>) -> (tensor<8x2xf32>, tensor<8x2xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
// expected-error @+1 {{'tfl.split_v' op output #0 should be 'tensor<8x4xf32>'}}
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x2xf32>, tensor<8x2xf32>)
return %0, %1 : tensor<8x2xf32>, tensor<8x2xf32>
}
// -----
func @testSplitVOpWithValidTensorType(%arg0 : tensor<16x4xf32>) -> (tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>) {
%size_splits_0 = constant dense<[8, 8]> : tensor<2xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x4xf32>, tensor<8x4xf32>)
%size_splits_1 = constant dense<[2, 2]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%2, %3 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x2xf32>, tensor<16x2xf32>)
return %0, %1, %2, %3 : tensor<8x4xf32>, tensor<8x4xf32>, tensor<16x2xf32>, tensor<16x2xf32>
}
// -----
func @testSplitVOpWithValidTensorTypeDynamic(%arg0 : tensor<16x?xf32>) -> (tensor<8x?xf32>, tensor<8x?xf32>) {
%size_splits = constant dense<[8, 8]> : tensor<2xi32>
%split_dim = constant dense<0> : tensor<i32>
%0, %1 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 2 : i32} : (tensor<16x?xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<8x?xf32>, tensor<8x?xf32>)
return %0, %1 : tensor<8x?xf32>, tensor<8x?xf32>
}
// -----
func @testSplitVOpWithValidSizeSplitsUneven(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>) {
%size_splits_0 = constant dense<[7, 3, 6]> : tensor<3xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
%size_splits_1 = constant dense<[1, 3]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x1xf32>, tensor<16x3xf32>)
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x1xf32>, tensor<16x3xf32>
}
// -----
func @testSplitVOpWithValidSizeSplitsNegative(%arg0 : tensor<16x4xf32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>) {
%size_splits_0 = constant dense<[7, -1, 6]> : tensor<3xi32>
%split_dim_0 = constant dense<0> : tensor<i32>
%0, %1, %2 = "tfl.split_v"(%arg0, %size_splits_0, %split_dim_0) {num_splits = 3 : i32} : (tensor<16x4xf32>, tensor<3xi32>, tensor<i32>) -> (tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>)
%size_splits_1 = constant dense<[-1, 4]> : tensor<2xi32>
%split_dim_1 = constant dense<1> : tensor<i32>
%3, %4 = "tfl.split_v"(%arg0, %size_splits_1, %split_dim_1) {num_splits = 2 : i32} : (tensor<16x4xf32>, tensor<2xi32>, tensor<i32>) -> (tensor<16x0xf32>, tensor<16x4xf32>)
return %0, %1, %2, %3, %4 : tensor<7x4xf32>, tensor<3x4xf32>, tensor<6x4xf32>, tensor<16x0xf32>, tensor<16x4xf32>
}
// -----
func @testNonMaxSuppressionV4WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV4WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>) -> (tensor<2xi32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v4' op failed to verify that boxes should have dim[1] == 4}}
%0, %1 = "tfl.non_max_suppression_v4"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<i32>)
return %0, %1 : tensor<2xi32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithCorrectBoxShape(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x4xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}
// -----
func @testNonMaxSuppressionV5WithWrongBoxShape(%arg0: tensor<3x2xf32>, %arg1: tensor<3xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>) {
// expected-error @+1 {{'tfl.non_max_suppression_v5' op failed to verify that boxes should have dim[1] == 4}}
%0, %1, %2 = "tfl.non_max_suppression_v5"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tensor<3x2xf32>, tensor<3xf32>, tensor<i32>, tensor<f32>, tensor<f32>, tensor<f32>) -> (tensor<2xi32>, tensor<2xf32>, tensor<i32>)
return %0, %1, %2 : tensor<2xi32>, tensor<2xf32>, tensor<i32>
}