| // RUN: tf-mlir-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s |
| |
| // ---------------------- |
| // Test importing If op with tuple args; element-size of tuple == 1 |
| // ---------------------- |
| HloModule tfcompile.0 |
| |
| %then_branch0 { |
| %arg_tuple.7 = (f32[]) parameter(0), metadata={op_name="XLA_Args"} |
| %get-tuple-element.8 = f32[] get-tuple-element(%arg_tuple.7), index=0, metadata={op_name="XLA_Args"} |
| %log.9 = f32[] log(%get-tuple-element.8), metadata={op_type="Log" op_name="cond/Log"} |
| ROOT %tuple.10 = (f32[]) tuple(%log.9), metadata={op_name="XLA_Retvals"} |
| } |
| |
| %else_branch0 { |
| %arg_tuple.12 = (f32[]) parameter(0), metadata={op_name="XLA_Args"} |
| %get-tuple-element.13 = f32[] get-tuple-element(%arg_tuple.12), index=0, metadata={op_name="XLA_Args"} |
| %exponential.14 = f32[] exponential(%get-tuple-element.13), metadata={op_type="Exp" op_name="cond/Exp"} |
| ROOT %tuple.15 = (f32[]) tuple(%exponential.14), metadata={op_name="XLA_Retvals"} |
| } |
| |
| // CHECK-LABEL: func @main |
| // CHECK-SAME: ([[A0:%.+]]: tensor<f32>) |
| ENTRY %tfcompile.0 { |
| %arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"} |
| |
| // CHECK: [[C0:%.+]] = mhlo.constant |
| %constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"} |
| |
| // CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]]) |
| %compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"} |
| |
| %tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"} |
| |
| // CHECK: [[R2:%.+]] = "mhlo.if"([[R1]]) ({ |
| // CHECK: [[R3:%.+]] = "mhlo.log"([[A0]]) |
| // CHECK: "mhlo.return"([[R3]]) |
| // CHECK: }, { |
| // CHECK: [[R3:%.+]] = "mhlo.exponential"([[A0]]) |
| // CHECK: "mhlo.return"([[R3]]) |
| // CHECK: }) |
| %conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch0, false_computation=%else_branch0, metadata={op_type="If" op_name="cond/Merge_if"} |
| |
| %get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"} |
| // CHECK: [[R4:%.+]] = "mhlo.tuple"([[R2]]) |
| // CHECK: return [[R4]] |
| ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"} |
| } |
| |
| // ---------------------- |
| // Test importing If op with nested tuple block-arguments of different types. |
| // ---------------------- |
| |
| %then_branch1 { |
| %arg_tuple = (f32[], f32[]) parameter(0) |
| |
| %get-tuple-element.0 = f32[] get-tuple-element(%arg_tuple), index=0 |
| %get-tuple-element.1 = f32[] get-tuple-element(%arg_tuple), index=1 |
| |
| %log.0 = f32[] log(%get-tuple-element.0) |
| %log.1 = f32[] log(%get-tuple-element.1) |
| %tuple.0 = (f32[], f32[]) tuple(%log.0, %log.1) |
| ROOT %tuple.1 = ((f32[], f32[]), f32[]) tuple(%tuple.0, %get-tuple-element.0) |
| } |
| |
| %else_branch1 { |
| %arg_tuple = ((f32[], f32[]), f32[]) parameter(0) |
| |
| %get-tuple-element.0 = (f32[], f32[]) get-tuple-element(%arg_tuple), index=0 |
| %get-tuple-element.1 = f32[] get-tuple-element(%arg_tuple), index=1 |
| %get-tuple-element.2 = f32[] get-tuple-element(%get-tuple-element.0), index=0 |
| %get-tuple-element.3 = f32[] get-tuple-element(%get-tuple-element.0), index=1 |
| |
| %exponential.0 = f32[] exponential(%get-tuple-element.1) |
| %exponential.1 = f32[] exponential(%get-tuple-element.2) |
| %exponential.2 = f32[] exponential(%get-tuple-element.3) |
| %tuple.0 = (f32[], f32[]) tuple(%exponential.1, %exponential.2) |
| ROOT %tuple.1 = ((f32[], f32[]), f32[]) tuple(%tuple.0, %exponential.0) |
| } |
| |
| // CHECK-LABEL: func private @tfcompile.1 |
| // CHECK-SAME: (%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) |
| //CHECK-NEXT: %[[CST:.*]] = mhlo.constant |
| //CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[A0]], %[[CST]]) {comparison_direction = #mhlo<"comparison_direction LT">} |
| //CHECK-NEXT: %[[IF:.*]]:3 = "mhlo.if"(%[[CMP]]) ({ |
| //CHECK-NEXT: %[[LOG0:.*]] = "mhlo.log"(%[[A0]]) : (tensor<f32>) -> tensor<f32> |
| //CHECK-NEXT: %[[LOG1:.*]] = "mhlo.log"(%[[A1]]) : (tensor<f32>) -> tensor<f32> |
| //CHECK-NEXT: "mhlo.return"(%[[LOG0]], %[[LOG1]], %[[A0]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> () |
| //CHECK-NEXT: }, { |
| //CHECK-NEXT: %[[EXP0:.*]] = "mhlo.exponential"(%[[A0]]) : (tensor<f32>) -> tensor<f32> |
| //CHECK-NEXT: %[[EXP1:.*]] = "mhlo.exponential"(%[[A1]]) : (tensor<f32>) -> tensor<f32> |
| //CHECK-NEXT: %[[EXP2:.*]] = "mhlo.exponential"(%[[A2]]) : (tensor<f32>) -> tensor<f32> |
| //CHECK-NEXT: "mhlo.return"(%[[EXP0]], %[[EXP1]], %[[EXP2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> () |
| //CHECK-NEXT: }) : (tensor<i1>) -> (tensor<f32>, tensor<f32>, tensor<f32>) |
| //CHECK-NEXT: %[[R0:.*]] = "mhlo.tuple"(%[[IF]]#0, %[[IF]]#1) |
| //CHECK-NEXT: %[[R1:.*]] = "mhlo.tuple"(%[[R0]], %[[IF]]#2) |
| //CHECK-NEXT: return %[[R1]] : tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> |
| |
| %tfcompile.1 { |
| %arg0 = f32[] parameter(0), parameter_replication={false} |
| %arg1 = f32[] parameter(1), parameter_replication={false} |
| %arg2 = f32[] parameter(2), parameter_replication={false} |
| |
| %constant.3 = f32[] constant(10) |
| %compare.4 = pred[] compare(%arg0, %constant.3), direction=LT |
| %tuple.5 = (f32[],f32[]) tuple(%arg0, %arg1) |
| %tuple.6 = ((f32[],f32[]), f32[]) tuple(%tuple.5, %arg2) |
| |
| ROOT %conditional.16 = ((f32[],f32[]), f32[]) conditional(%compare.4, %tuple.5, %tuple.6), true_computation=%then_branch1, false_computation=%else_branch1 |
| } |
| |
| // ---------------------- |
| // Test importing If op with non-tuple block-arguments. |
| // ---------------------- |
| |
| %then_branch2 { |
| %arg_tuple = f32[] parameter(0) |
| |
| ROOT %log.0 = f32[] log(%arg_tuple) |
| } |
| |
| %else_branch2 { |
| %arg_tuple = f32[] parameter(0) |
| |
| ROOT %exponential.0 = f32[] exponential(%arg_tuple) |
| } |
| |
| // CHECK-LABEL: func private @tfcompile.2 |
| // CHECK-SAME: (%[[A0:.*]]: tensor<f32>) |
| // CHECK-NEXT: %[[CST:.*]] = mhlo.constant |
| // CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[A0]], %[[CST]]) {comparison_direction = #mhlo<"comparison_direction LT">} |
| // CHECK-NEXT: %[[IF:.*]] = "mhlo.if"(%[[CMP]]) ({ |
| // CHECK-NEXT: %[[LOG0:.*]] = "mhlo.log"(%[[A0]]) : (tensor<f32>) -> tensor<f32> |
| // CHECK-NEXT: "mhlo.return"(%[[LOG0]]) : (tensor<f32>) -> () |
| // CHECK-NEXT: }, { |
| // CHECK-NEXT: %[[EXP0:.*]] = "mhlo.exponential"(%[[A0]]) : (tensor<f32>) -> tensor<f32> |
| // CHECK-NEXT: "mhlo.return"(%[[EXP0]]) : (tensor<f32>) -> () |
| // CHECK-NEXT: }) : (tensor<i1>) -> tensor<f32> |
| // CHECK-NEXT: return %[[IF]] : tensor<f32> |
| |
| %tfcompile.2 { |
| %arg0 = f32[] parameter(0), parameter_replication={false} |
| |
| %constant.3 = f32[] constant(10) |
| %compare.4 = pred[] compare(%arg0, %constant.3), direction=LT |
| |
| ROOT %conditional.16 = f32[] conditional(%compare.4, %arg0, %arg0), true_computation=%then_branch2, false_computation=%else_branch2 |
| } |
| |
| // ---------------------- |
| // Test importing nest If op with zero xla parameters. |
| // ---------------------- |
| |
| %region_1.7 (Arg_.8: f32[]) -> f32[] { |
| ROOT %Arg_.8 = f32[] parameter(0) |
| } |
| |
| %region_2.9 (arg_empty_tuple.10: ()) -> f32[] { |
| %arg_empty_tuple.10 = () parameter(0) |
| ROOT %constant.11 = f32[] constant(10) |
| } |
| |
| %region_0.12 (arg_empty_tuple.13: ()) -> (f32[], f32[]) { |
| %arg_empty_tuple.13 = () parameter(0) |
| %constant.14 = pred[] constant(false) |
| %constant.15 = f32[] constant(10) |
| %tuple.16 = () tuple() |
| %conditional.17 = f32[] conditional(pred[] %constant.14, f32[] %constant.15, () %tuple.16), true_computation=%region_1.7, false_computation=%region_2.9 |
| ROOT %tuple.18 = (f32[], f32[]) tuple(f32[] %conditional.17, f32[] %conditional.17) |
| } |
| |
| %region_3.19 (arg_tuple.20: (f32[], f32[])) -> (f32[], f32[]) { |
| %arg_tuple.20 = (f32[], f32[]) parameter(0) |
| %get-tuple-element.21 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.20), index=0 |
| %get-tuple-element.22 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.20), index=1 |
| ROOT %tuple.23 = (f32[], f32[]) tuple(f32[] %get-tuple-element.21, f32[] %get-tuple-element.22) |
| } |
| |
| // CHECK-LABEL: func private @tfcompile.3 |
| // CHECK-SAME: (%[[A0:.+]]: tensor<i1>, %[[A1:.+]]: tensor<f32>, %[[A2:.+]]: tensor<f32>) |
| // CHECK-NEXT: %[[CST1:.*]] = mhlo.constant |
| // CHECK-NEXT: %[[OUTER_IF:.+]]:2 = "mhlo.if"(%[[A0]]) ({ |
| // CHECK-NEXT: %[[PRED1:.*]] = mhlo.constant |
| // CHECK-NEXT: %[[CST2:.*]] = mhlo.constant |
| // CHECK-NEXT: %[[INNER_IF:.+]] = "mhlo.if"(%[[PRED1]]) ({ |
| // CHECK-NEXT: "mhlo.return"(%[[CST2]]) : (tensor<f32>) -> () |
| // CHECK-NEXT: }, { |
| // CHECK-NEXT: %[[CST3:.*]] = mhlo.constant |
| // CHECK-NEXT: "mhlo.return"(%[[CST3]]) : (tensor<f32>) -> () |
| // CHECK-NEXT: }) : (tensor<i1>) -> tensor<f32> |
| // CHECK-NEXT: "mhlo.return"(%[[INNER_IF]], %[[INNER_IF]]) : (tensor<f32>, tensor<f32>) -> () |
| // CHECK-NEXT: }, { |
| // CHECK-NEXT: "mhlo.return"(%[[A1]], %[[A2]]) : (tensor<f32>, tensor<f32>) -> () |
| // CHECK-NEXT: }) : (tensor<i1>) -> (tensor<f32>, tensor<f32>) |
| // CHECK-NEXT: return %[[OUTER_IF]]#1 : tensor<f32> |
| |
| %tfcompile.3 (Arg_0.1: pred[], Arg_1.2: f32[], Arg_2.3: f32[]) -> f32[] { |
| %constant.4 = f32[] constant(10) |
| %Arg_0.1 = pred[] parameter(0) |
| %tuple.5 = () tuple() |
| %Arg_1.2 = f32[] parameter(1) |
| %Arg_2.3 = f32[] parameter(2) |
| %tuple.6 = (f32[], f32[]) tuple(f32[] %Arg_1.2, f32[] %Arg_2.3) |
| %conditional.24 = (f32[], f32[]) conditional(pred[] %Arg_0.1, () %tuple.5, (f32[], f32[]) %tuple.6), true_computation=%region_0.12, false_computation=%region_3.19 |
| %get-tuple-element.25 = f32[] get-tuple-element((f32[], f32[]) %conditional.24), index=0 |
| ROOT %get-tuple-element.26 = f32[] get-tuple-element((f32[], f32[]) %conditional.24), index=1 |
| } |