blob: 79f471b38696370d460ae882af4c66aec6c39ec1 [file] [log] [blame]
// RUN: tf-opt %s -tf-functional-control-flow-to-cfg -split-input-file -verify-diagnostics | FileCheck %s
func @testIf1Then(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK-LABEL: func @testIf1Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>)
func @testIf1Result(tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
// CHECK: cond_br %0, ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %1 = call @testIf1Then(%arg1, %arg2)
// CHECK: br ^bb3(%1 : tensor<*xf32>)
// CHECK: ^bb2:
// CHECK: %2 = call @testIf1Else(%arg1, %arg2)
// CHECK: br ^bb3(%2 : tensor<*xf32>)
// CHECK: ^bb3(%3: tensor<*xf32>):
return %1 : tensor<*xf32>
// CHECK: return %3 : tensor<*xf32>
}
func @testIf3Then(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
func @testIf3Else(tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK-LABEL: func @testIf3Result(%arg0: tensor<i1>, %arg1: tensor<*xf32>)
func @testIf3Result(tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>) {
^bb0(%arg0: tensor<i1>, %arg1: tensor<*xf32>):
%1:3 = "tf.If"(%arg0, %arg1) {
then_branch = @testIf3Then, else_branch = @testIf3Else, is_stateless = false
} : (tensor<i1>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
// CHECK: cond_br %0, ^bb1, ^bb2
// CHECK: ^bb1:
// CHECK: %1:3 = call @testIf3Then(%arg1)
// CHECK: br ^bb3(%1#0, %1#1, %1#2 : tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK: ^bb2:
// CHECK: %2:3 = call @testIf3Else(%arg1)
// CHECK: br ^bb3(%2#0, %2#1, %2#2 : tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>)
// CHECK: ^bb3(%3: tensor<*xf32>, %4: tensor<*xi8>, %5: tensor<*xbf16>):
return %1#0, %1#1, %1#2 : tensor<*xf32>, tensor<*xi8>, tensor<*xbf16>
// CHECK: return %3, %4, %5
}
// -----
func @testIf1Then(tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
func @testIf1Else(tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
// CHECK-LABEL: func @testIf1Casts(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>)
func @testIf1Casts(tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32> {
^bb0(%arg0: tensor<i1>, %arg1: tensor<2x2xf32>, %arg2: tensor<*xf32>):
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<i1>, tensor<2x2xf32>, tensor<*xf32>) -> tensor<2x?xf32>
// CHECK: %0 = extract_element %arg0[] : tensor<i1>
// CHECK: cond_br %0, ^bb1, ^bb2
// CHECK:^bb1: // pred: ^bb0
// CHECK: %1 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<2x?xf32>
// CHECK: %2 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x2xf32>
// CHECK: %3 = call @testIf1Then(%1, %2) : (tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: %4 = tensor_cast %3 : tensor<2x2xf32> to tensor<2x?xf32>
// CHECK: br ^bb3(%4 : tensor<2x?xf32>)
// CHECK:^bb2: // pred: ^bb0
// CHECK: %5 = tensor_cast %arg1 : tensor<2x2xf32> to tensor<*xf32>
// CHECK: %6 = tensor_cast %arg2 : tensor<*xf32> to tensor<2x?xf32>
// CHECK: %7 = call @testIf1Else(%5, %6) : (tensor<*xf32>, tensor<2x?xf32>) -> tensor<*xf32>
// CHECK: %8 = tensor_cast %7 : tensor<*xf32> to tensor<2x?xf32>
// CHECK: br ^bb3(%8 : tensor<2x?xf32>)
// CHECK:^bb3(%9: tensor<2x?xf32>): // 2 preds: ^bb1, ^bb2
%2 = "tf.Add"(%1, %1) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
// CHECK: %10 = "tf.Add"(%9, %9) : (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
return %2 : tensor<2x?xf32>
// CHECK: return %10 : tensor<2x?xf32>
}
// -----
// If with a 4xi1 condition.
func @testIf1Then(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func @testIf1Else(tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
func @testIf1x4(tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> {
^bb0(%arg0: tensor<4xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>):
// expected-error @+1 {{only supports zero-D bool tensors now}}
%1 = "tf.If"(%arg0, %arg1, %arg2) {
then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false
} : (tensor<4xi1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
// -----
func @testWhile2Cond(tensor<*xf32>, tensor<*xf32>) -> (tensor<i1>)
func @testWhile2Body(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK-LABEL: func @testWhile2Result(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>)
func @testWhile2Result(tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>):
%1:2 = "tf.While"(%arg0, %arg1) {
cond = @testWhile2Cond, body = @testWhile2Body
} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: br ^bb1(%arg0, %arg1 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb1(%0: tensor<*xf32>, %1: tensor<*xf32>):
// CHECK: %2 = call @testWhile2Cond(%0, %1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: cond_br %3, ^bb2(%0, %1 : tensor<*xf32>, tensor<*xf32>), ^bb3(%0, %1 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb2(%4: tensor<*xf32>, %5: tensor<*xf32>):
// CHECK: %6:2 = call @testWhile2Body(%4, %5) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: br ^bb1(%6#0, %6#1 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb3(%7: tensor<*xf32>, %8: tensor<*xf32>):
return %1#0, %1#1 : tensor<*xf32>, tensor<*xf32>
// CHECK: return %7, %8 : tensor<*xf32>, tensor<*xf32>
}
func @testWhile0Cond() -> (tensor<i1>)
func @testWhile0Body() -> ()
// CHECK-LABEL: func @testWhile0Result() {
func @testWhile0Result() {
^bb0:
"tf.While"() { cond = @testWhile0Cond, body = @testWhile0Body } : () -> ()
// CHECK: br ^bb1
// CHECK: ^bb1:
// CHECK: %0 = call @testWhile0Cond() : () -> tensor<i1>
// CHECK: %1 = extract_element %0[] : tensor<i1>
// CHECK: cond_br %1, ^bb2, ^bb3
// CHECK: ^bb2:
// CHECK: call @testWhile0Body() : () -> ()
// CHECK: br ^bb1
// CHECK: ^bb3:
return
// CHECK: return
}
// CHECK-LABEL: func @testComplexWhile1Result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func @testComplexWhile1Result(tensor<*xf32>) -> (tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>):
br ^bb1(%arg0, %arg0 : tensor<*xf32>, tensor<*xf32>)
^bb1(%0: tensor<*xf32>, %1: tensor<*xf32>):
%2 = addf %0, %1 : tensor<*xf32>
%3:2 = "tf.While"(%0, %2) {
cond = @testWhile2Cond, body = @testWhile2Body
} : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: br ^bb2(%0, %2 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb2(%3: tensor<*xf32>, %4: tensor<*xf32>):
// CHECK: %5 = call @testWhile2Cond(%3, %4) : (tensor<*xf32>, tensor<*xf32>) -> tensor<i1>
// CHECK: %6 = extract_element %5[] : tensor<i1>
// CHECK: cond_br %6, ^bb3(%3, %4 : tensor<*xf32>, tensor<*xf32>), ^bb4(%3, %4 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb3(%7: tensor<*xf32>, %8: tensor<*xf32>):
// CHECK: %9:2 = call @testWhile2Body(%7, %8) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
// CHECK: br ^bb2(%9#0, %9#1 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb4(%10: tensor<*xf32>, %11: tensor<*xf32>):
// CHECK: br ^bb5(%11, %2 : tensor<*xf32>, tensor<*xf32>)
br ^bb2(%3#1, %2 : tensor<*xf32>, tensor<*xf32>)
// CHECK: ^bb5(%12: tensor<*xf32>, %13: tensor<*xf32>):
^bb2(%4: tensor<*xf32>, %5: tensor<*xf32>):
%6 = subf %0, %1 : tensor<*xf32>
return %6 : tensor<*xf32>
// CHECK: return %14 : tensor<*xf32>
}
// -----
func @testWhileCond(tensor<?x3xf32>) -> (tensor<i1>)
func @testWhileBody(tensor<*xf32>) -> (tensor<?x?xf32>)
// CHECK-LABEL: func @testWhileCasts(%arg0: tensor<1x3xf32>)
func @testWhileCasts(%arg0: tensor<1x3xf32>) -> (tensor<?x?xf32>) {
%0 = "tf.While"(%arg0) {
cond = @testWhileCond, body = @testWhileBody
} : (tensor<1x3xf32>) -> (tensor<?x?xf32>)
// CHECK: %0 = tensor_cast %arg0 : tensor<1x3xf32> to tensor<?x3xf32>
// CHECK: br ^bb1(%0 : tensor<?x3xf32>)
// CHECK: ^bb1(%1: tensor<?x3xf32>):
// CHECK: %2 = call @testWhileCond(%1) : (tensor<?x3xf32>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: %4 = tensor_cast %1 : tensor<?x3xf32> to tensor<*xf32>
// CHECK: cond_br %3, ^bb2(%4 : tensor<*xf32>), ^bb3(%4 : tensor<*xf32>)
// CHECK: ^bb2(%5: tensor<*xf32>):
// CHECK: %6 = call @testWhileBody(%5) : (tensor<*xf32>) -> tensor<?x?xf32>
// CHECK: %7 = tensor_cast %6 : tensor<?x?xf32> to tensor<?x3xf32>
// CHECK: br ^bb1(%7 : tensor<?x3xf32>)
// CHECK: ^bb3(%8: tensor<*xf32>):
// CHECK: %9 = tensor_cast %8 : tensor<*xf32> to tensor<?x?xf32>
return %0 : tensor<?x?xf32>
// CHECK: return %9 : tensor<?x?xf32>
}
// -----