blob: e4dc4d5e211bb0cbd191dcd27ba0e2d8b89991ea [file] [log] [blame]
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// This test is more thorough than those of the the other binary ops to test
// their shared functionality.
HloModule main.5
%reduce_helper.1 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
ROOT %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
%reduce_helper.2 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
ROOT %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
}
%reduce_helper.3 (Arg_0.1: f32[], Arg_1.2: f32[], Arg_2.3: f32[], Arg_3.4: f32[]) -> (f32[], f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
%add.4 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_2.3)
%add.5 = f32[] add(f32[] %Arg_1.2, f32[] %Arg_3.4)
ROOT %tuple.6 = (f32[], f32[]) tuple(%add.4, %add.5)
}
// CHECK-LABEL: func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> {
ENTRY %foo.5 (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: %0 = "xla_hlo.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
%reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1
// CHECK-NEXT: %1 = "xla_hlo.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
%reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2
// CHECK-NEXT: %2 = "xla_hlo.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2
// CHECK-NEXT: %3 = "xla_hlo.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
%reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
// CHECK-NEXT: %4 = "xla_hlo.sub"(%2, %3) {name = "sub.5"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%sub.5 = f32[] subtract(%reduce.2, %reduce.3)
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5)
}