blob: 52d62123a199c758bcb1bba74b50ed2f81533ed0 [file] [log] [blame]
// RUN: tf-opt %s -tfl-optimize-functional-ops -split-input-file | FileCheck %s
// CHECK-LABEL: main
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// CHECK: %[[INPUT0:.*]] = "tf.Placeholder.input"
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[INPUT1:.*]] = "tf.Placeholder.input"
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = constant dense<true> : tensor<i1>
// CHECK: "tf.Add"(%[[INPUT0]], %[[INPUT1]])
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %3 : tensor<f32>
}
func private @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func private @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify handling of nested If ops to inline.
// CHECK-LABEL: main
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
// CHECK: %[[INPUT0:.*]] = "tf.Placeholder.input"
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
// CHECK: %[[INPUT1:.*]] = "tf.Placeholder.input"
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = constant dense<true> : tensor<i1>
// CHECK: "tf.Multiply"(%[[INPUT1]], %[[INPUT0]])
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @addormul, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %3 : tensor<f32>
}
func private @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = constant dense<false> : tensor<i1>
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
func private @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func private @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func private @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify unused if with functions without side-effects is removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%cst_0 = constant dense<1.000000e+00> : tensor<f32>
%cst_1 = constant dense<0.000000e+00> : tensor<8xf32>
%cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32>
%0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor<f32>) -> tensor<3x15x14x3xf32>
%1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1>
%2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor<i1>
%3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1,
else_branch = @_functionalize_if_else_branch_00, is_stateless = false,
then_branch = @_functionalize_if_then_branch_00} :
(tensor<i1>, tensor<i1>, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<i1>
%4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func private @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func private @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<true> : tensor<i1>
return %cst : tensor<i1>
}
// CHECK-NOT: tf.If
// CHECK: return
// -----
// Verify unused if with function with side-effects is not removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%cst_0 = constant dense<1.000000e+00> : tensor<f32>
%cst_1 = constant dense<0.000000e+00> : tensor<8xf32>
%cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32>
%0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor<f32>) -> tensor<3x15x14x3xf32>
%1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1>
%2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor<i1>
%3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1,
else_branch = @_functionalize_if_else_branch_01, is_stateless = false,
then_branch = @_functionalize_if_then_branch_01} :
(tensor<i1>, tensor<i1>, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<i1>
%4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func private @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func private @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK: tf.If
// CHECK: return
// -----
// Verify unused if with function with side-effects is removed if op says
// stateless.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%cst_0 = constant dense<1.000000e+00> : tensor<f32>
%cst_1 = constant dense<0.000000e+00> : tensor<8xf32>
%cst_2 = constant dense<0.000000e+00> : tensor<8x3x3x3xf32>
%0 = "tfl.sub"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x15x14x3xf32>, tensor<f32>) -> tensor<3x15x14x3xf32>
%1 = "tfl.greater_equal"(%arg0, %0) : (tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<3x15x14x3xi1>
%2 = "tf.All"(%1, %cst) {Tidx = i32, device = "/device:CPU:0", keep_dims = false} : (tensor<3x15x14x3xi1>, tensor<4xi32>) -> tensor<i1>
%3 = "tf.If"(%2, %2, %arg0, %0) {Tcond = i1,
else_branch = @_functionalize_if_else_branch_02, is_stateless = true,
then_branch = @_functionalize_if_then_branch_02} :
(tensor<i1>, tensor<i1>, tensor<3x15x14x3xf32>, tensor<3x15x14x3xf32>) -> tensor<i1>
%4 = "tfl.conv_2d"(%arg0, %cst_2, %cst_1) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<3x15x14x3xf32>, tensor<8x3x3x3xf32>, tensor<8xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func private @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func private @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-NOT: tf.If
// CHECK: return