blob: d1aa03e2c3ae1e87735dcea940544125a71bb192 [file] [log] [blame]
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
// ----------------------
// Test importing If op with tuple args; element-size of tuple == 1
// ----------------------
HloModule tfcompile.0
%then_branch0 {
%arg_tuple.7 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
%get-tuple-element.8 = f32[] get-tuple-element(%arg_tuple.7), index=0, metadata={op_name="XLA_Args"}
%log.9 = f32[] log(%get-tuple-element.8), metadata={op_type="Log" op_name="cond/Log"}
ROOT %tuple.10 = (f32[]) tuple(%log.9), metadata={op_name="XLA_Retvals"}
}
%else_branch0 {
%arg_tuple.12 = (f32[]) parameter(0), metadata={op_name="XLA_Args"}
%get-tuple-element.13 = f32[] get-tuple-element(%arg_tuple.12), index=0, metadata={op_name="XLA_Args"}
%exponential.14 = f32[] exponential(%get-tuple-element.13), metadata={op_type="Exp" op_name="cond/Exp"}
ROOT %tuple.15 = (f32[]) tuple(%exponential.14), metadata={op_name="XLA_Retvals"}
}
// CHECK-LABEL: func @main
// CHECK-SAME: ([[A0:%.+]]: tensor<f32>)
ENTRY %tfcompile.0 {
%arg0.1 = f32[] parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: [[C0:%.+]] = mhlo.constant
%constant.3 = f32[] constant(10), metadata={op_type="Less" op_name="Less"}
// CHECK: [[R1:%.+]] = "mhlo.compare"([[A0]], [[C0]])
%compare.4 = pred[] compare(%arg0.1, %constant.3), direction=LT, metadata={op_type="Less" op_name="Less"}
%tuple.5 = (f32[]) tuple(%arg0.1), metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R2:%.+]] = "mhlo.if"([[R1]]) ({
// CHECK: [[R3:%.+]] = "mhlo.log"([[A0]])
// CHECK: "mhlo.return"([[R3]])
// CHECK: }, {
// CHECK: [[R3:%.+]] = "mhlo.exponential"([[A0]])
// CHECK: "mhlo.return"([[R3]])
// CHECK: })
%conditional.16 = (f32[]) conditional(%compare.4, %tuple.5, %tuple.5), true_computation=%then_branch0, false_computation=%else_branch0, metadata={op_type="If" op_name="cond/Merge_if"}
%get-tuple-element.17 = f32[] get-tuple-element(%conditional.16), index=0, metadata={op_type="If" op_name="cond/Merge_if"}
// CHECK: [[R4:%.+]] = "mhlo.tuple"([[R2]])
// CHECK: return [[R4]]
ROOT %tuple.19 = (f32[]) tuple(%get-tuple-element.17), metadata={op_name="XLA_Retvals"}
}
// ----------------------
// Test importing If op with nested tuple block-arguments of different types.
// ----------------------
%then_branch1 {
%arg_tuple = (f32[], f32[]) parameter(0)
%get-tuple-element.0 = f32[] get-tuple-element(%arg_tuple), index=0
%get-tuple-element.1 = f32[] get-tuple-element(%arg_tuple), index=1
%log.0 = f32[] log(%get-tuple-element.0)
%log.1 = f32[] log(%get-tuple-element.1)
%tuple.0 = (f32[], f32[]) tuple(%log.0, %log.1)
ROOT %tuple.1 = ((f32[], f32[]), f32[]) tuple(%tuple.0, %get-tuple-element.0)
}
%else_branch1 {
%arg_tuple = ((f32[], f32[]), f32[]) parameter(0)
%get-tuple-element.0 = (f32[], f32[]) get-tuple-element(%arg_tuple), index=0
%get-tuple-element.1 = f32[] get-tuple-element(%arg_tuple), index=1
%get-tuple-element.2 = f32[] get-tuple-element(%get-tuple-element.0), index=0
%get-tuple-element.3 = f32[] get-tuple-element(%get-tuple-element.0), index=1
%exponential.0 = f32[] exponential(%get-tuple-element.1)
%exponential.1 = f32[] exponential(%get-tuple-element.2)
%exponential.2 = f32[] exponential(%get-tuple-element.3)
%tuple.0 = (f32[], f32[]) tuple(%exponential.1, %exponential.2)
ROOT %tuple.1 = ((f32[], f32[]), f32[]) tuple(%tuple.0, %exponential.0)
}
// CHECK-LABEL: func private @tfcompile.1
// CHECK-SAME: (%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>)
//CHECK-NEXT: %[[CST:.*]] = mhlo.constant
//CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[A0]], %[[CST]]) {comparison_direction = #mhlo<"comparison_direction LT">}
//CHECK-NEXT: %[[IF:.*]]:3 = "mhlo.if"(%[[CMP]]) ({
//CHECK-NEXT: %[[LOG0:.*]] = "mhlo.log"(%[[A0]]) : (tensor<f32>) -> tensor<f32>
//CHECK-NEXT: %[[LOG1:.*]] = "mhlo.log"(%[[A1]]) : (tensor<f32>) -> tensor<f32>
//CHECK-NEXT: "mhlo.return"(%[[LOG0]], %[[LOG1]], %[[A0]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> ()
//CHECK-NEXT: }, {
//CHECK-NEXT: %[[EXP0:.*]] = "mhlo.exponential"(%[[A0]]) : (tensor<f32>) -> tensor<f32>
//CHECK-NEXT: %[[EXP1:.*]] = "mhlo.exponential"(%[[A1]]) : (tensor<f32>) -> tensor<f32>
//CHECK-NEXT: %[[EXP2:.*]] = "mhlo.exponential"(%[[A2]]) : (tensor<f32>) -> tensor<f32>
//CHECK-NEXT: "mhlo.return"(%[[EXP0]], %[[EXP1]], %[[EXP2]]) : (tensor<f32>, tensor<f32>, tensor<f32>) -> ()
//CHECK-NEXT: }) : (tensor<i1>) -> (tensor<f32>, tensor<f32>, tensor<f32>)
//CHECK-NEXT: %[[R0:.*]] = "mhlo.tuple"(%[[IF]]#0, %[[IF]]#1)
//CHECK-NEXT: %[[R1:.*]] = "mhlo.tuple"(%[[R0]], %[[IF]]#2)
//CHECK-NEXT: return %[[R1]] : tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>>
%tfcompile.1 {
%arg0 = f32[] parameter(0), parameter_replication={false}
%arg1 = f32[] parameter(1), parameter_replication={false}
%arg2 = f32[] parameter(2), parameter_replication={false}
%constant.3 = f32[] constant(10)
%compare.4 = pred[] compare(%arg0, %constant.3), direction=LT
%tuple.5 = (f32[],f32[]) tuple(%arg0, %arg1)
%tuple.6 = ((f32[],f32[]), f32[]) tuple(%tuple.5, %arg2)
ROOT %conditional.16 = ((f32[],f32[]), f32[]) conditional(%compare.4, %tuple.5, %tuple.6), true_computation=%then_branch1, false_computation=%else_branch1
}
// ----------------------
// Test importing If op with non-tuple block-arguments.
// ----------------------
%then_branch2 {
%arg_tuple = f32[] parameter(0)
ROOT %log.0 = f32[] log(%arg_tuple)
}
%else_branch2 {
%arg_tuple = f32[] parameter(0)
ROOT %exponential.0 = f32[] exponential(%arg_tuple)
}
// CHECK-LABEL: func private @tfcompile.2
// CHECK-SAME: (%[[A0:.*]]: tensor<f32>)
// CHECK-NEXT: %[[CST:.*]] = mhlo.constant
// CHECK-NEXT: %[[CMP:.*]] = "mhlo.compare"(%[[A0]], %[[CST]]) {comparison_direction = #mhlo<"comparison_direction LT">}
// CHECK-NEXT: %[[IF:.*]] = "mhlo.if"(%[[CMP]]) ({
// CHECK-NEXT: %[[LOG0:.*]] = "mhlo.log"(%[[A0]]) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: "mhlo.return"(%[[LOG0]]) : (tensor<f32>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: %[[EXP0:.*]] = "mhlo.exponential"(%[[A0]]) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: "mhlo.return"(%[[EXP0]]) : (tensor<f32>) -> ()
// CHECK-NEXT: }) : (tensor<i1>) -> tensor<f32>
// CHECK-NEXT: return %[[IF]] : tensor<f32>
%tfcompile.2 {
%arg0 = f32[] parameter(0), parameter_replication={false}
%constant.3 = f32[] constant(10)
%compare.4 = pred[] compare(%arg0, %constant.3), direction=LT
ROOT %conditional.16 = f32[] conditional(%compare.4, %arg0, %arg0), true_computation=%then_branch2, false_computation=%else_branch2
}
// ----------------------
// Test importing nest If op with zero xla parameters.
// ----------------------
%region_1.7 (Arg_.8: f32[]) -> f32[] {
ROOT %Arg_.8 = f32[] parameter(0)
}
%region_2.9 (arg_empty_tuple.10: ()) -> f32[] {
%arg_empty_tuple.10 = () parameter(0)
ROOT %constant.11 = f32[] constant(10)
}
%region_0.12 (arg_empty_tuple.13: ()) -> (f32[], f32[]) {
%arg_empty_tuple.13 = () parameter(0)
%constant.14 = pred[] constant(false)
%constant.15 = f32[] constant(10)
%tuple.16 = () tuple()
%conditional.17 = f32[] conditional(pred[] %constant.14, f32[] %constant.15, () %tuple.16), true_computation=%region_1.7, false_computation=%region_2.9
ROOT %tuple.18 = (f32[], f32[]) tuple(f32[] %conditional.17, f32[] %conditional.17)
}
%region_3.19 (arg_tuple.20: (f32[], f32[])) -> (f32[], f32[]) {
%arg_tuple.20 = (f32[], f32[]) parameter(0)
%get-tuple-element.21 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.20), index=0
%get-tuple-element.22 = f32[] get-tuple-element((f32[], f32[]) %arg_tuple.20), index=1
ROOT %tuple.23 = (f32[], f32[]) tuple(f32[] %get-tuple-element.21, f32[] %get-tuple-element.22)
}
// CHECK-LABEL: func private @tfcompile.3
// CHECK-SAME: (%[[A0:.+]]: tensor<i1>, %[[A1:.+]]: tensor<f32>, %[[A2:.+]]: tensor<f32>)
// CHECK-NEXT: %[[CST1:.*]] = mhlo.constant
// CHECK-NEXT: %[[OUTER_IF:.+]]:2 = "mhlo.if"(%[[A0]]) ({
// CHECK-NEXT: %[[PRED1:.*]] = mhlo.constant
// CHECK-NEXT: %[[CST2:.*]] = mhlo.constant
// CHECK-NEXT: %[[INNER_IF:.+]] = "mhlo.if"(%[[PRED1]]) ({
// CHECK-NEXT: "mhlo.return"(%[[CST2]]) : (tensor<f32>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: %[[CST3:.*]] = mhlo.constant
// CHECK-NEXT: "mhlo.return"(%[[CST3]]) : (tensor<f32>) -> ()
// CHECK-NEXT: }) : (tensor<i1>) -> tensor<f32>
// CHECK-NEXT: "mhlo.return"(%[[INNER_IF]], %[[INNER_IF]]) : (tensor<f32>, tensor<f32>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: "mhlo.return"(%[[A1]], %[[A2]]) : (tensor<f32>, tensor<f32>) -> ()
// CHECK-NEXT: }) : (tensor<i1>) -> (tensor<f32>, tensor<f32>)
// CHECK-NEXT: return %[[OUTER_IF]]#1 : tensor<f32>
%tfcompile.3 (Arg_0.1: pred[], Arg_1.2: f32[], Arg_2.3: f32[]) -> f32[] {
%constant.4 = f32[] constant(10)
%Arg_0.1 = pred[] parameter(0)
%tuple.5 = () tuple()
%Arg_1.2 = f32[] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%tuple.6 = (f32[], f32[]) tuple(f32[] %Arg_1.2, f32[] %Arg_2.3)
%conditional.24 = (f32[], f32[]) conditional(pred[] %Arg_0.1, () %tuple.5, (f32[], f32[]) %tuple.6), true_computation=%region_0.12, false_computation=%region_3.19
%get-tuple-element.25 = f32[] get-tuple-element((f32[], f32[]) %conditional.24), index=0
ROOT %get-tuple-element.26 = f32[] get-tuple-element((f32[], f32[]) %conditional.24), index=1
}