blob: 9bfd9950d8957c2e006d7d76df7d7c0bd041342d [file] [log] [blame]
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo -hlo-import-all-computations %s -o - | FileCheck %s
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -check-prefix=NO_DEAD_FUNCTION
// NO_DEAD_FUNCTION-NOT: @test
// CHECK: module @foobar
HloModule foobar
// CHECK-LABEL: func @main(%arg0: tensor<f32>) -> tensor<f32> {
ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
ROOT %Arg_0.1 = f32[] parameter(0)
}
// CHECK-LABEL: func private @test_simple
%test_simple (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] {
%Arg_0.1 = f32[4]{0} parameter(0)
%Arg_1.2 = f32[4]{0} parameter(1)
// CHECK-NEXT: mhlo.add %arg0, %arg1 : tensor<4xf32>
%add.42 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "mhlo.dot"(%0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.42, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
// CHECK-LABEL: func private @test_after_all
// CHECK-SAME: ([[VAL_0:%.*]]: !mhlo.token, [[VAL_1:%.*]]: !mhlo.token) -> !mhlo.token
%test_after_all (token0: token[], token1: token[] ) -> token[] {
token0 = token[] parameter(0)
token1 = token[] parameter(1)
// CHECK-NEXT: mhlo.after_all [[VAL_0]], [[VAL_1]] {xla_shape = {{.*}}} : (!mhlo.token, !mhlo.token) -> !mhlo.token
ROOT after-all = token[] after-all(token0, token1)
}
// CHECK-LABEL: func private @test_after_all_no_inputs
// CHECK-SAME: () -> !mhlo.token
%test_after_all_no_inputs () -> token[] {
// CHECK-NEXT: mhlo.create_token {xla_shape = {{.*}}} : !mhlo.token
ROOT after-all = token[] after-all()
}
// CHECK-LABEL: func private @test_all_gather
// CHECK-SAME: ([[INPUT:%.*]]: tensor<128x32xf32>)
%test_all_gather {
input = f32[128,32] parameter(0)
// CHECK-NEXT: "mhlo.all_gather"([[INPUT]])
// CHECK-SAME: all_gather_dim = 1 : i64
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
ROOT ag = f32[128,128] all-gather(input), channel_id=1, replica_groups={{0, 2, 4, 6}, {1, 3, 5, 7}}, dimensions={1}
}
// Test all-to-all
// CHECK-LABEL: func private @test_all_to_all
// CHECK-SAME: ([[ARG:%.*]]: tensor<2x2xi32>)
%test_all_to_all {
%parameter = s32[2,2]{1,0} parameter(0)
// CHECK-NEXT: "mhlo.all_to_all"([[ARG]]) {
// CHECK-SAME: concat_dimension = 1 : i64,
// CHECK-SAME{LITERAL}: replica_groups = dense<[[1, 2], [3, 0]]> : tensor<2x2xi64>,
// CHECK-SAME: split_count = 2 : i64,
// CHECK-SAME: split_dimension = 1 : i64
// CHECK-SAME: } : (tensor<2x2xi32>) -> tensor<2x2xi32>
ROOT %all-to-all = s32[2,2]{1,0} all-to-all(s32[2,2]{1,0} %parameter), replica_groups={{1,2}, {3,0}}, dimensions={1}
}
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func private @test_all_reduce
// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>)
%test_all_reduce {
input = f32[8] parameter(0)
// CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-NOT: use_global_device_ids
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
// CHECK-NOT: use_global_device_ids
// CHECK-SAME: :
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, to_apply=add
}
// CHECK-LABEL: func private @test_all_reduce_global
// CHECK-SAME: ([[INPUT:%.*]]: tensor<8xf32>)
%test_all_reduce_global {
input = f32[8] parameter(0)
// CHECK-NEXT: "mhlo.all_reduce"([[INPUT]])
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 8]]> : tensor<2x4xi64>
// CHECK-SAME: use_global_device_ids
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {5,6,7,8}}, use_global_device_ids=true, to_apply=add
}
// CHECK-LABEL: func private @test_and
%test_and (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK-NEXT: mhlo.and %arg0, %arg1
ROOT %and.3 = pred[4] and(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_atan2
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_atan2 (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: mhlo.atan2 [[VAL_0]], [[VAL_1]]
ROOT %atan2 = s32[4] atan2(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_broadcast_in_dim
%test_broadcast_in_dim {
%Arg_0.1 = f32[1, 2] parameter(0)
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
// CHECK-NEXT: "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
}
// CHECK-LABEL: func private @test_batch_norm_grad
%test_batch_norm_grad (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
%input = f32[2,2,2,2] parameter(0)
%scale = f32[2] parameter(1)
%mean = f32[2] parameter(2)
%variance = f32[2] parameter(3)
%grad_output = f32[2,2,2,2] parameter(4)
// CHECK: %[[GRAD_OPERAND:.+]], %[[GRAD_SCALE:.+]], %[[GRAD_OFFSET:.+]] = "mhlo.batch_norm_grad"
// CHECK-SAME: epsilon = 1.000000e-03 : f32
// CHECK-SAME: feature_index = 1 : i64
// CHECK: %[[TUPLE:.+]] = "mhlo.tuple"
// CHECK: return %[[TUPLE]]
ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1
}
// CHECK-LABEL: func private @test_batch_norm_train
%test_batch_norm_train (input: f32[2,2,2,2], scale: f32[2], offset: f32[2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
%input = f32[2,2,2,2] parameter(0)
%scale = f32[2] parameter(1)
%offset = f32[2] parameter(2)
// CHECK: %[[OUT:.+]], %[[MEAN:.+]], %[[VAR:.+]] = "mhlo.batch_norm_training"
// CHECK-SAME: epsilon = 1.000000e-03 : f32
// CHECK-SAME: feature_index = 1 : i64
// CHECK: %[[TUPLE:.+]] = "mhlo.tuple"
// CHECK: return %[[TUPLE]]
ROOT %batch-norm-train = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-training(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %offset), epsilon=0.001, feature_index=1
}
// CHECK-LABEL: func private @call(%arg0: tensor<i64>) -> tensor<i64>
%call (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func private @test_call
%test_call (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: call @call(%arg0) : (tensor<i64>) -> tensor<i64>
ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call
}
// CHECK-LABEL: func private @test_cholesky
// CHECK-SAME: ([[ARG:%.*]]: tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
%test_cholesky (a: f32[1,291,291]) -> f32[1,291,291] {
%a = f32[1,291,291] parameter(0)
// CHECK-NEXT: "mhlo.cholesky"([[ARG]]) {lower = true} : (tensor<1x291x291xf32>) -> tensor<1x291x291xf32>
ROOT %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true
}
// CHECK-LABEL: func private @test_clamp(
%test_clamp (Arg_0.1: f32[], Arg_1.2: f32[4], Arg_1.3: f32[]) -> f32[4] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: mhlo.clamp %arg0, %arg1, %arg2 : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
ROOT %clamp.3 = f32[4] clamp(f32[] %Arg_0.1, f32[4] %Arg_1.2, f32[] %Arg_2.3)
}
// CHECK-LABEL: func private @test_collective_permute
// CHECK-SAME: ([[ARG:%.*]]: tensor<128x32xf32>) -> tensor<128x32xf32>
%test_collective_permute (input: f32[128,32]) -> f32[128,32] {
%input = f32[128,32]{1,0} parameter(0)
// CHECK-NEXT: "mhlo.collective_permute"([[ARG]]) {source_target_pairs = dense<{{\[\[}}0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>} : (tensor<128x32xf32>) -> tensor<128x32xf32>
ROOT root = f32[128,32]{1,0} collective-permute(%input), source_target_pairs={{0,1},{1,2},{2,3}}
}
// CHECK-LABEL: func private @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1>
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] {
%Arg_0.1 = f32[3] parameter(0)
%Arg_1.2 = f32[3] parameter(1)
%Arg_2.3 = f32[3] parameter(2)
// CHECK-NEXT: mhlo.compare EQ, %arg0, %arg1 : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ, type=FLOAT
// CHECK-NEXT: mhlo.compare LE, %arg0, %arg1, TOTALORDER : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE, type=TOTALORDER
// Requires broadcast of compatible tensors.
// CHECK-NEXT: mhlo.compare GT, %arg0, %arg2 : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
}
// CHECK-LABEL: func private @test_complex
%test_complex (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> c64[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.complex(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_concat(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32>
%test_concat (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[4, 2] parameter(1)
// CHECK-NEXT: "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
}
// CHECK-LABEL: func private @test_constant
%test_constant {
// Scalar/0D tensor constant
// CHECK-NEXT: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<i64>
%constant.0 = s64[] constant(1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[}}[3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
%constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<[1, 2, 4, 8]> : tensor<4xui64>
%constant.2 = u64[4] constant({ 1, 2, 4, 8 })
// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<4xbf16>
%constant.3 = bf16[4] constant({1, 2, 3, 4})
// CHECK: %[[VAL_4:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f32>>
%constant.4 = c64[] constant((1, 0))
// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<(1.000000e+00,0.000000e+00)> : tensor<complex<f64>>
%constant.5 = c128[] constant((1, 0))
// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<[1.000000e+00, -4.000000e+00, -6.550400e+04, 1.562500e-02]> : tensor<4xf16>
ROOT %constant.6 = f16[4] constant({1, -4, -65504, 0.015625})
}
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
// implementations with attributes, etc.
// CHECK-LABEL: func private @test_conv(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<256x32x32x1xf32>) -> tuple<tensor<256x30x30x1xf32>> {
%test_conv {
%arg0.1 = f32[256,32,32,1]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %[[VAL_1:.*]] = mhlo.copy %[[VAL_0]] {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
%copy.1 = f32[256,32,32,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="HLO_Args"}
// CHECK-NEXT: %[[VAL_2:.*]] = mhlo.reshape %[[VAL_1]] {xla_shape = "f32[256,32,32,1]{2,1,3,0}"} : (tensor<256x32x32x1xf32>) -> tensor<256x32x32x1xf32>
%reshape.2 = f32[256,32,32,1]{2,1,3,0} reshape(%copy.1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: %[[VAL_3:.*]] = mhlo.constant dense<{{\[\[}}{{\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[}}[3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %[[VAL_4:.*]] = mhlo.convolution(%[[VAL_2]], %[[VAL_3]])
// CHECK-SAME: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[f, 0, 1, b]
// CHECK-SAME{LITERAL}: window = {stride = [4, 5], pad = [[44, 45], [60, 60]], lhs_dilate = [1, 1], rhs_dilate = [2, 3], reverse = [0, 0]}
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
// CHECK-SAME: (tensor<256x32x32x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x30x30x256xf32>
%convolution.4 = f32[1,30,30,256]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=3x3 stride=4x5 pad=44_45x60_60 rhs_dilate=2x3}, dim_labels=b01f_01io->f01b, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %[[VAL_5:.*]] = mhlo.reshape %[[VAL_4]] : (tensor<1x30x30x256xf32>) -> tensor<256x30x30x1xf32>
%reshape.5 = f32[256,30,30,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="HLO_Retvals"}
// CHECK-NEXT: %[[VAL_6:.*]] = "mhlo.tuple"(%[[VAL_5]]) {xla_shape = {{.*}}} : (tensor<256x30x30x1xf32>) -> tuple<tensor<256x30x30x1xf32>>
ROOT %tuple.6 = (f32[256,30,30,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="HLO_Retvals"}
}
// Test for padding attribute shape in convolution
// CHECK-LABEL: func private @test_convolve1D_padding
%test_convolve1D_padding (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] {
%input = f32[1,2,1] parameter(0)
%filter = f32[1,1,1] parameter(1)
// CHECK: mhlo.convolution
// CHECK-SAME{LITERAL}: pad = [[1, 2]]
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
}
// Test for window_reversal attribute in convolution
// CHECK-LABEL: func private @test_convolve1D_reversal
%test_convolve1D_reversal (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,5,1] {
%input = f32[1,2,1] parameter(0)
%filter = f32[1,1,1] parameter(1)
// CHECK: mhlo.convolution
// CHECK-SAME{LITERAL}: reverse = [1]
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1 rhs_reversal=1}
}
// CHECK-LABEL: func private @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64>
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = mhlo.convert(%arg0) : (tensor<4xf32>) -> tensor<4xf64>
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
// CHECK-NEXT: %1 = mhlo.convert(%arg1) : (tensor<4xf32>) -> tensor<4xf64>
%convert.4 = f64[4] convert(f32[4] %Arg_1.2)
// CHECK-NEXT: mhlo.add %0, %1
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
}
// CHECK-LABEL: func private @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: mhlo.cosine %arg0 : tensor<1x16x16x3xf32>
ROOT %cosine.3 = f32[1,16,16,3]{3,2,1,0} cosine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
}
// CHECK-LABEL: func private @test_custom_call
// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32>
%test_custom_call (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
// CHECK-SAME: api_version = 1 : i32
// CHECK-SAME: backend_config = "bar"
// CHECK-SAME: call_target_name = "foo"
// CHECK-SAME: has_side_effect = true
// CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
}
// CHECK-LABEL: func private @test_custom_call_layout
// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>, [[ARG_2:%.*]]: !mhlo.token, [[ARG_3:%.*]]: tensor<i32>) -> tensor<1x2x3xf32>
%test_custom_call_layout (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
%arg3 = token[] parameter(2)
%arg4 = s32[] parameter(3)
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_3]]) {
// CHECK-SAME: api_version = 1 : i32
// CHECK-SAME: backend_config = "bar"
// CHECK-SAME: call_target_name = "foo"
// CHECK-SAME: has_side_effect = true
// CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>]
// CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>]
// CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>, !mhlo.token, tensor<i32>) -> tensor<1x2x3xf32>
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2, token[] %arg3, s32[] %arg4), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}, token[], s32[]}
}
// CHECK-LABEL: func private @test_custom_call_tuple_output
// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
%test_custom_call_tuple_output (arg1: f32[2,3], arg2: f32[5,5]) -> (f32[1,2,3], s32[3,7,9]) {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
// CHECK-SAME: api_version = 1 : i32
// CHECK-SAME: backend_config = "bar"
// CHECK-SAME: call_target_name = "foo"
// CHECK-SAME: has_side_effect = true
// CHECK-SAME: operand_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>]
// CHECK-SAME: result_layouts = [dense<[0, 2, 1]> : tensor<3xindex>, dense<[2, 0, 1]> : tensor<3xindex>]
// CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tuple<tensor<1x2x3xf32>, tensor<3x7x9xi32>>
ROOT %custom-call = (f32[1,2,3]{0,2,1}, s32[3,7,9]{2,0,1}) custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true, operand_layout_constraints={f32[2,3]{0,1}, f32[5,5]{1,0}}
}
// CHECK-LABEL: func private @custom_call_computation_0
%custom_call_computation_0 (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func private @custom_call_computation_1
%custom_call_computation_1 (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func private @test_custom_call_with_computations
// CHECK-SAME: [[ARG_0:%.*]]: tensor<2x3xf32>, [[ARG_1:%.*]]: tensor<5x5xf32>) -> tensor<1x2x3xf32>
%test_custom_call_with_computations (arg1: f32[2,3], arg2: f32[5,5]) -> f32[1,2,3] {
%arg1 = f32[2,3] parameter(0)
%arg2 = f32[5,5] parameter(1)
// CHECK: "mhlo.custom_call"([[ARG_0]], [[ARG_1]]) {
// CHECK-SAME: api_version = 1 : i32
// CHECK-SAME: call_target_name = "foo"
// CHECK-SAME: called_computations = [@custom_call_computation_0, @custom_call_computation_1]
// CHECK-SAME: : (tensor<2x3xf32>, tensor<5x5xf32>) -> tensor<1x2x3xf32>
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", called_computations={%custom_call_computation_0, %custom_call_computation_1}
}
// CHECK-LABEL: func private @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.divide %arg0, %arg1 : tensor<4xf32>
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_dot(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<1x1xf32>
%test_dot (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1, 1] {
%Arg_0.1 = f32[1, 4] parameter(0)
%Arg_1.2 = f32[4, 1] parameter(1)
// CHECK-NEXT: %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision HIGH>, #mhlo<precision HIGHEST>]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32>
dot.3 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
// CHECK-NEXT: %1 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision HIGHEST>, #mhlo<precision DEFAULT>]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32>
dot.4 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
// CHECK-NEXT: %2 = "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32>
%dot.5 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "mhlo.dot"(%arg0, %arg1) {precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<1x1xf32>
ROOT %dot.6 = f32[1, 1] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
// CHECK-LABEL: @test_dot_general
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
%test_dot_general (Arg_0.1: f32[4, 1], Arg_1.2: f32[1, 4]) -> f32[] {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1)
// CHECK-NEXT: [[R0:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]])
// CHECK-NOT: lhs_batching_dimensions
// CHECK-NOT: rhs_batching_dimensions
// CHECK-SAME: lhs_contracting_dimensions = [0]
// CHECK-SAME: rhs_contracting_dimensions = [1]
// CHECK-SAME: precision_config = [#mhlo<precision HIGH>, #mhlo<precision HIGHEST>]
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={high,highest}
// CHECK-NEXT: [[R1:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]])
// CHECK-SAME: precision_config = [#mhlo<precision HIGHEST>, #mhlo<precision DEFAULT>]
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={highest,default}
// CHECK-NEXT: [[R2:%.+]] = "mhlo.dot_general"([[ARG0]], [[ARG1]])
// CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}, operand_precision={default,default}
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: "mhlo.dot_general"([[ARG0]], [[ARG1]])
// CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
%dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}
// CHECK: [[rehape:%[a-zA-Z0-9]+]] = mhlo.reshape
// CHECK-NEXT: "mhlo.dot_general"([[rehape]], [[ARG1]])
// CHECK-SAME: lhs_contracting_dimensions = [0]
// CHECK-SAME: rhs_contracting_dimensions = [1]
// CHECK-SAME: precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]}
reshape.0 = f32[4]{0} reshape(f32[4, 1] Arg_0.1)
ROOT %dot.7 = f32[] dot(%reshape.0, Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={1}
}
// CHECK-LABEL: func private @test_dynamic_slice
// CHECK-SAME: [[OPERAND:%.*]]: tensor<2x2x258xi32>, [[START_IDX_1:%.*]]: tensor<i32>, [[START_IDX_2:%.*]]: tensor<i32>, [[START_IDX_3:%.*]]: tensor<i32>
%test_dynamic_slice (operand: s32[2,2,258], start_indices: s32[3]) -> s32[1,1,32] {
%operand = s32[2,2,258] parameter(0)
%start_idx_1 = s32[] parameter(1)
%start_idx_2 = s32[] parameter(2)
%start_idx_3 = s32[] parameter(3)
// CHECK: "mhlo.dynamic_slice"([[OPERAND]], [[START_IDX_1]], [[START_IDX_2]], [[START_IDX_3]])
// CHECK-SAME: slice_sizes = dense<[1, 1, 32]> : tensor<3xi64>
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
}
// CHECK-LABEL: func private @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32>
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1)
%Arg_2.3 = s32[] parameter(2)
%Arg_3.4 = s32[] parameter(3)
// CHECK-NEXT: mhlo.dynamic_update_slice %arg0, %arg1, %arg2, %arg3 : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<i32>, tensor<i32>) -> tensor<4x4xf32>
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
}
// CHECK-LABEL: func private @test_dynamic_update_slice_2(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<i32>) -> tensor<4xf32>
%test_dynamic_update_slice_2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[2] parameter(1)
%Arg_2.3 = s32[] parameter(2)
// CHECK-NEXT: mhlo.dynamic_update_slice %arg0, %arg1, %arg2 : (tensor<4xf32>, tensor<2xf32>, tensor<i32>) -> tensor<4xf32>
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
}
// CHECK-LABEL: func private @test_exponential(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_exponential (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: mhlo.exponential %arg0 : tensor<16xf32>
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_expm1(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: mhlo.exponential_minus_one %arg0 : tensor<16xf32>
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
%test_fft {
%arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: "mhlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = #mhlo<fft_type RFFT>
ROOT %fft.2 = c64[3,5]{1,0} fft(%arg0.1), fft_type=RFFT, fft_length={9}, metadata={op_type="RFFT" op_name="rfft"}
}
// CHECK-LABEL: func private @test_floor(
// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
%test_floor (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: mhlo.floor [[A0]] : tensor<16xf32>
ROOT %floor.2 = f32[16] floor(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_gather(
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32>
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
%arg.0 = f32[200,100,300] parameter(0)
%arg.1 = s32[10,2] parameter(1)
// CHECK: "mhlo.gather"([[ARG0]], [[ARG1]])
// CHECK-SAME: dimension_numbers
// CHECK-SAME: offset_dims = [1]
// CHECK-SAME: collapsed_slice_dims = [0, 1]
// CHECK-SAME: start_index_map = [0, 1]
// CHECK-SAME: index_vector_dim = 1
// CHECK-SAME: indices_are_sorted = true
// CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64>
ROOT gather = f32[10,300] gather(f32[200,100,300] %arg.0, s32[10,2] %arg.1),
collapsed_slice_dims={0,1},
index_vector_dim=1,
offset_dims={1},
start_index_map={0,1},
indices_are_sorted=true,
slice_sizes={1,1,300}
}
// CHECK-LABEL: func private @test_get_dimension_size
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x2xf32>)
%test_get_dimension_size (Arg_0.1: f32[4,2]) -> s32[] {
%Arg_0.1 = f32[4,2] parameter(0)
// CHECK-NEXT: "mhlo.get_dimension_size"([[ARG]]) {dimension = 1 : i64} : (tensor<4x2xf32>) -> tensor<i32>
ROOT %get-dimension-size.2 = s32[] get-dimension-size(f32[4,2] %Arg_0.1), dimensions={1}
}
// CHECK-LABEL: func private @test_imag
%test_imag (Arg_0.1: c64[4]) -> f32[4] {
%Arg_0.1 = c64[4] parameter(0)
// CHECK-NEXT: mhlo.imag(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
ROOT %imag.3 = f32[4] imag(c64[4] %Arg_0.1)
}
// CHECK-LABEL: func private @test_infeed
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> tuple<tensor<3x3x3xi32>, !mhlo.token>
%test_infeed (token0: token[]) -> (s32[3, 3, 3], token[]) {
%token0 = token[] parameter(0)
// CHECK-NEXT: "mhlo.infeed"([[TOKEN]])
// CHECK-SAME: infeed_config = "foobar"
// CHECK-SAME: layout = {{\[\[2, 0, 1]]}}
ROOT %infeed = (s32[3, 3, 3]{2, 0, 1}, token[]) infeed(token[] %token0), infeed_config="foobar"
}
// CHECK-LABEL: func private @test_infeed_with_empty_tuple_data
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token
// CHECK-NEXT: [[INFEED:%.*]] = "mhlo.infeed"([[TOKEN]])
// CHECK-SAME: infeed_config = "foobar"
// CHECK-SAME: layout = []
// CHECK: return [[INFEED]] : !mhlo.token
%test_infeed_with_empty_tuple_data (Arg_0.1: token[]) -> token[] {
%Arg_0.1 = token[] parameter(0)
%infeed.2 = ((), token[]) infeed(token[] %Arg_0.1), infeed_config="foobar"
%get-tuple-element.3 = () get-tuple-element(((), token[]) %infeed.2), index=0
ROOT %get-tuple-element.4 = token[] get-tuple-element(((), token[]) %infeed.2), index=1
}
// CHECK-LABEL: func private @test_infeed_layout
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> tuple<tuple<tensor<3x4xi32>, tensor<5x6xi32>, tensor<7x8xi32>, tensor<9x10xi32>>, !mhlo.token>
%test_infeed_layout (token0: token[]) -> ( ( s32[3, 4]{1, 0}, s32[5, 6]{1, 0}, s32[7, 8]{1, 0}, s32[9,10]{1, 0}), token[]) {
%token0 = token[] parameter(0)
// CHECK-NEXT: [[INFEED:%.*]]:5 = "mhlo.infeed"([[TOKEN]])
// CHECK-SAME: infeed_config = "foobar"
// CHECK-SAME: layout = {{\[\[}}1, 0], [1, 0], [1, 0], [1, 0]]
// CHECK-NEXT: [[T1:%.*]] = "mhlo.tuple"([[INFEED]]#0, [[INFEED]]#1, [[INFEED]]#2, [[INFEED]]#3)
// CHECK-NEXT: [[T2:%.*]] = "mhlo.tuple"([[T1]], [[INFEED]]#4)
// CHECK-NEXT: return [[T2]]
ROOT %infeed = ( ( s32[3, 4]{1, 0}, s32[5, 6]{1, 0}, s32[7, 8]{1, 0}, s32[9,10]{1, 0}), token[]) infeed(token[] %token0), infeed_config="foobar"
}
// CHECK-LABEL: func private @test_iota_1() -> tensor<4xf32>
%test_iota_1 () -> f32[4] {
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
}
// CHECK-LABEL: func private @test_iota_2() -> tensor<4x5xf32>
%test_iota_2 () -> f32[4, 5] {
// CHECK-NEXT: "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
}
// CHECK-LABEL: func private @test_log(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_log (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: mhlo.log %arg0 : tensor<16xf32>
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_log1p(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_log1p (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: mhlo.log_plus_one %arg0 : tensor<16xf32>
ROOT %log1p.2 = f32[16] log-plus-one(f32[16] %arg0.1)
}
// Test mhlo.map
%map_computation {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func private @test_map
// CHECK-SAME: [[ARG_0:%.*]]: tensor<4xf32>, [[ARG_1:%.*]]: tensor<4xf32>) -> tensor<4xf32>
%test_map {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({
// CHECK: ^bb0([[ARG_2:%.*]]: tensor<f32>, [[ARG_3:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation
}
// Test mhlo.map
%map_computation_returning_tuple {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
add = f32[] add(lhs, rhs)
ROOT tple = (f32[]) tuple(add)
}
// CHECK-LABEL: func private @test_map_with_reducer_returning_tuple
// CHECK-SAME: [[ARG_0:%.*]]: tensor<4xf32>, [[ARG_1:%.*]]: tensor<4xf32>) -> tensor<4xf32>
%test_map_with_reducer_returning_tuple {
param0 = f32[4]{0} parameter(0)
param1 = f32[4]{0} parameter(1)
// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({
// CHECK: ^bb0([[ARG_2:%.*]]: tensor<f32>, [[ARG_3:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG_2]], [[ARG_3]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=%map_computation_returning_tuple
}
// Test mhlo.map with heterogeneous inputs
%map_computation_take_left {
ROOT %Arg_0.4 = f32[] parameter(0)
%Arg_1.5 = s32[] parameter(1)
}
// CHECK-LABEL: func private @map_heterogeneous
// CHECK-SAME: [[ARG_0:%.*]]: tensor<4xf32>, [[ARG_1:%.*]]: tensor<4xi32>) -> tensor<4xf32>
%map_heterogeneous (Arg_0.1: f32[4], Arg_1.2: s32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
ROOT %map.6 = f32[4] map(f32[4] %Arg_0.1, s32[4] %Arg_1.2), dimensions={0}, to_apply=%map_computation_take_left
}
// CHECK: "mhlo.map"([[ARG_0]], [[ARG_1]]) ({
// CHECK: ^bb0([[ARG_2:%.*]]: tensor<f32>, [[ARG_3:%.*]]: tensor<i32>):
// CHECK: mhlo.return [[ARG_2]] : tensor<f32>
// CHECK: }) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<4xi32>) -> tensor<4xf32>
// CHECK-LABEL: func private @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_maximum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.maximum %arg0, %arg1 : tensor<4xf32>
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_minimum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.minimum %arg0, %arg1 : tensor<4xf32>
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_multiply(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_multiply (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = mhlo.multiply %arg0, %arg1 : tensor<4xf32>
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_negate(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_negate (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: mhlo.negate %arg0 : tensor<16xf32>
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_not(%arg0: tensor<16xi1>) -> tensor<16xi1>
%test_not (arg0.1: pred[16]) -> pred[16] {
%arg0.1 = pred[16] parameter(0)
// CHECK: mhlo.not %arg0 : tensor<16xi1>
ROOT %not.2 = pred[16] not(pred[16] %arg0.1)
}
// CHECK-LABEL: func private @test_or
%test_or (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK-NEXT: mhlo.or %arg0, %arg1
ROOT %or.3 = pred[4] or(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_logistic(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_logistic (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK-NEXT: mhlo.logistic %arg0 : tensor<16xf32>
ROOT %logistic.2 = f32[16] logistic(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_outfeed
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token
%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
%Arg_0.1 = s32[3] parameter(0)
%Arg_1.2 = token[] parameter(1)
// CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]])
// CHECK-SAME: outfeed_config = "foobar"
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar"
}
// CHECK-LABEL: func private @test_outfeed_with_sharding
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !mhlo.token) -> (!mhlo.token {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"})
%test_outfeed_with_sharding (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
%Arg_0.1 = s32[3] parameter(0)
%Arg_1.2 = token[] parameter(1)
// CHECK-NEXT: "mhlo.outfeed"([[DATA]], [[TOKEN]])
// CHECK-SAME: mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"
// CHECK-SAME: outfeed_config = "foobar"
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar", sharding={devices=[2,1]0,1}
}
// CHECK-LABEL: func private @test_outfeed_with_empty_data
// CHECK-SAME: ([[TOKEN:%.*]]: !mhlo.token) -> !mhlo.token
%test_outfeed_with_empty_data (Arg_0.1: token[]) -> token[] {
%tuple.2 = () tuple()
%Arg_0.1 = token[] parameter(0)
// CHECK-NEXT: "mhlo.outfeed"([[TOKEN]])
// CHECK-SAME: outfeed_config = "foobar"
// CHECK-SAME: xla_shape = "token[]"
// CHECK-SAME: : (!mhlo.token) -> !mhlo.token
ROOT %outfeed.3 = token[] outfeed(() %tuple.2, token[] %Arg_0.1), outfeed_shape=(), outfeed_config="foobar"
}
// CHECK-LABEL: func private @test_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32>
%test_pad (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
}
// CHECK-LABEL: func private @test_pad_edge(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32>
%test_pad_edge (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] {
%Arg_0.1 = f32[4, 4, 4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
}
// CHECK-LABEL: func private @test_pad_interior(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<10xf32>
%test_pad_interior (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[10] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: "mhlo.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<2> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<10xf32>
ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2
}
// CHECK-LABEL: func private @test_popcnt(%arg0: tensor<16xi32>) -> tensor<16xi32>
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
%arg0.1 = s32[16] parameter(0)
// CHECK: mhlo.popcnt %arg0 : tensor<16xi32>
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_pow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_pow (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.power %arg0, %arg1 : tensor<4xf32>
ROOT %power.3 = f32[4] power(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_rng_normal
// CHECK-SAME: ([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>) -> tensor<2x3x5xf32>
%test_rng_normal (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution<NORMAL>}
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_normal
}
// CHECK-LABEL: func private @test_rng_uniform
// CHECK-SAME: ([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>) -> tensor<2x3x5xf32>
%test_rng_uniform (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[2,3,5] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: [[CST:%.*]] = mhlo.constant dense<[2, 3, 5]> : tensor<3xi64>
// CHECK: "mhlo.rng"([[ARG0]], [[ARG1]], [[CST]]) {rng_distribution = #mhlo.rng_distribution<UNIFORM>}
ROOT %rng.4 = f32[2,3,5] rng(f32[] %Arg_0.1, f32[] %Arg_1.2), distribution=rng_uniform
}
// CHECK-LABEL: func private @test_real
%test_real (Arg_0.1: c64[4]) -> f32[4] {
%Arg_0.1 = c64[4] parameter(0)
// CHECK-NEXT: mhlo.real(%arg0) : (tensor<4xcomplex<f32>>) -> tensor<4xf32>
ROOT %real.3 = f32[4] real(c64[4] %Arg_0.1)
}
// Test reduce
%reduce_helper.1 (Arg_0.1: f32[], Arg_1.2: f32[], Arg_2.3: f32[], Arg_3.4: f32[]) -> (f32[], f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
%add.4 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_2.3)
%add.5 = f32[] add(f32[] %Arg_1.2, f32[] %Arg_3.4)
ROOT %tuple.6 = (f32[], f32[]) tuple(%add.4, %add.5)
}
%reduce_helper.2 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
ROOT %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
%reduce_helper.3 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
ROOT %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
}
// CHECK-LABEL: func private @test_reduce
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x4xf32>, [[ARG1:%.*]]: tensor<4xf32>, [[ARG2:%.*]]: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>>
%test_reduce (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK: mhlo.reduce([[ARG0]] init: [[ARG2]]), ([[ARG0]] init: [[ARG2]])
// CHECK-SAME: dimensions = [0, 1]
// CHECK: %[[A:.*]] = mhlo.add{{.*}} : tensor<f32>
// CHECK: %[[B:.*]] = mhlo.add{{.*}} : tensor<f32>
// CHECK: mhlo.return %[[A]], %[[B]] : tensor<f32>, tensor<f32>
// CHECK: "mhlo.tuple"(%0#0, %0#1) {xla_shape = {{.*}}} : (tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
%reduce.1 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.1
// CHECK: [[VAL2:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG2]])
// CHECK: mhlo.add{{.*}} : tensor<f32>
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
// CHECK: [[VAL3:%.*]] = mhlo.reduce([[ARG0]] init: [[ARG1]])
// CHECK-
// CHECK: mhlo.add{{.*}} : tensor<4xf32>
%reduce.2 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.2
// CHECK: [[VAL4:%.*]] = mhlo.reduce([[VAL3]] init: [[ARG2]])
// CHECK-SAME: dimensions = [0]
// CHECK: mhlo.add{{.*}} : tensor<f32>
%reduce.4 = f32[] reduce(%reduce.2, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.3
// CHECK: %5 = mhlo.subtract [[VAL2]], [[VAL4]] : tensor<f32>
%sub.5 = f32[] subtract(%reduce.3, %reduce.4)
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.1, %sub.5)
}
// Test reduce-scatter
%reduce_helper_add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
%reduce_helper_add_returning_tuple {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
add = f32[] add(lhs, rhs)
ROOT tple = (f32[]) tuple(add)
}
// CHECK-LABEL: func private @test_reduce_scatter
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
%test_reduce_scatter {
input = f32[4,8] parameter(0)
// CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) ({
// CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor<f32>, [[BARG1:%.*]]: tensor<f32>):
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor<f32>
// CHECK-NEXT: mhlo.return [[ADD]]
// CHECK-NEXT: }) {
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// CHECK-SAME: scatter_dimension = 1 : i64
// CHECK-SAME: } : (tensor<4x8xf32>) -> tensor<4x4xf32>
ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add
}
// CHECK-LABEL: func private @test_reduce_scatter_with_region_returning_tuple
%test_reduce_scatter_with_region_returning_tuple {
input = f32[4,8] parameter(0)
// CHECK-NEXT: "mhlo.reduce_scatter"
// CHECK-NEXT: ^bb0
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add
// CHECK-NEXT: mhlo.return [[ADD]]
ROOT ars = f32[4,4] reduce-scatter(input), replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add_returning_tuple
}
// CHECK-LABEL: func private @test_reduce_scatter_with_channel
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x8xf32>)
%test_reduce_scatter_with_channel {
input = f32[4,8] parameter(0)
// CHECK-NEXT: "mhlo.reduce_scatter"([[ARG0]]) ({
// CHECK-NEXT: ^bb0([[BARG0:%.*]]: tensor<f32>, [[BARG1:%.*]]: tensor<f32>):
// CHECK-NEXT: [[ADD:%.*]] = mhlo.add [[BARG0]], [[BARG1]] : tensor<f32>
// CHECK-NEXT: mhlo.return [[ADD]]
// CHECK-NEXT: }) {
// CHECK-SAME: channel_handle = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// CHECK-SAME: scatter_dimension = 1 : i64
// CHECK-SAME: } : (tensor<4x8xf32>) -> tensor<4x4xf32>
ROOT ars = f32[4,4] reduce-scatter(input), channel_id=1, replica_groups={{0,1}}, dimensions={1}, to_apply=reduce_helper_add
}
// CHECK-LABEL: func private @test_reduce_window
// CHECK-SAME: ([[ARG0:%.*]]: tensor<2x17x31x7xf32>, [[ARG1:%.*]]: tensor<f32>)
%test_reduce_window (Arg_0.1: f32[2,17,31,7], Arg_1.2: f32[]) -> f32[2,5,8,7] {
%Arg_0.1 = f32[2,17,31,7] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK: "mhlo.reduce_window"([[ARG0]], [[ARG1]]) ({
// CHECK: mhlo.add {{.*}} : tensor<f32>
// CHECK: }) {
// CHECK-SAME: base_dilations = dense<1> : tensor<4xi64>
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>
// CHECK-SAME: }
ROOT %reduce-window.1 = f32[2,5,8,7] reduce-window(f32[2,17,31,7] %Arg_0.1, f32[] %Arg_1.2), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 rhs_dilate=1x2x2x1}, to_apply=%reduce_helper.3
}
// Test reduce-window with multiple outputs
%reducer_window_helper (Arg0: f32[], Arg1: f32[], Arg2: f32[], Arg3: f32[]) -> (f32[], f32[]) {
%Arg0 = f32[] parameter(0)
%Arg1 = f32[] parameter(1)
%Arg2 = f32[] parameter(2)
%Arg3 = f32[] parameter(3)
%compare.11 = pred[] compare(f32[] %Arg0, f32[] %Arg2), direction=GE
%select.12 = f32[] select(pred[] %compare.11, f32[] %Arg0, f32[] %Arg2)
%select.13 = f32[] select(pred[] %compare.11, f32[] %Arg1, f32[] %Arg3)
ROOT %tuple.14 = (f32[], f32[]) tuple(f32[] %select.12, f32[] %select.13)
}
// CHECK-LABEL: func private @test_reduce_window_multiple_outputs
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x6xf32>, [[ARG1:%.*]]: tensor<4x6xf32>)
%test_reduce_window_multiple_outputs (Arg0: f32[4,6], Arg1: f32[4,6]) -> f32[4,3] {
%Arg0 = f32[4,6] parameter(0)
%Arg1 = f32[4,6] parameter(1)
// CHECK-DAG: [[CONST_FALSE:%.*]] = mhlo.constant dense<false>
// CHECK-DAG: [[CONST_INF:%.*]] = mhlo.constant dense<0xFF800000>
// CHECK-DAG: [[CONST_ZERO:%.*]] = mhlo.constant dense<0.000000e+00>
%constant.3 = pred[] constant(false)
%constant.4 = f32[] constant(-inf)
%constant.5 = f32[] constant(0)
// CHECK: [[REDUCE_WINDOW:%.*]]:2 = "mhlo.reduce_window"([[ARG1]], [[ARG0]], [[CONST_INF]], [[CONST_ZERO]]) ({
// CHECK: ^bb0([[BARG0:%.*]]: tensor<f32>, [[BARG1:%.*]]: tensor<f32>, [[BARG2:%.*]]: tensor<f32>, [[BARG3:%.*]]: tensor<f32>)
// CHECK: [[COMPARE:%.*]] = mhlo.compare GE, [[BARG0]], [[BARG2]]
// CHECK: [[SELECT_0:%.*]] = "mhlo.select"([[COMPARE]], [[BARG0]], [[BARG2]])
// CHECK: [[SELECT_1:%.*]] = "mhlo.select"([[COMPARE]], [[BARG1]], [[BARG3]])
// CHECK: }) {
// CHECK-SAME: base_dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: padding = dense<0> : tensor<2x2xi64>
// CHECK-SAME: window_dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: window_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: window_strides = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: }
// CHECK-SAME: : (tensor<4x6xf32>, tensor<4x6xf32>, tensor<f32>, tensor<f32>) -> (tensor<4x3xf32>, tensor<4x3xf32>)
%reduce-window.15 = (f32[4,3], f32[4,3]) reduce-window(f32[4,6] %Arg1, f32[4,6] %Arg0, f32[] %constant.4, f32[] %constant.5), window={size=1x2 stride=1x2}, to_apply=%reducer_window_helper
// CHECK: return [[REDUCE_WINDOW]]#1 : tensor<4x3xf32>
ROOT %get-tuple-element.16 = f32[4,3] get-tuple-element((f32[4,3], f32[4,3]) %reduce-window.15), index=1
}
// CHECK-LABEL: func private @test_remainder
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xf32>, [[VAL_1:%.*]]: tensor<4xf32>)
%test_remainder (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK: mhlo.remainder [[VAL_0]], [[VAL_1]]
ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_reverse_1d(%arg0: tensor<4xf32>) -> tensor<4xf32>
%test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
}
// CHECK-LABEL: func private @test_reverse_2d(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32
%test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
// CHECK-NEXT: "mhlo.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
}
// CHECK-LABEL: func private @test_rsqrt(
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
// CHECK: mhlo.rsqrt [[ARG0]] : tensor<16xf32>
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
}
// CHECK-LABEL: func private @test_scalar(%arg0: tensor<f32>) -> tensor<f32>
%test_scalar (Arg_0.1: f32[]) -> f32[] {
// CHECK-NEXT: return %arg0 : tensor<f32>
ROOT %Arg_0.1 = f32[] parameter(0)
}
// Test scatter
%update_computation {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %sum = f32[] add(f32[] %lhs, f32[] %rhs)
}
%test_scatter {
%input_tensor = f32[200,100,300] parameter(0)
%scatter_indices = s64[10,2] parameter(1)
%updates = f32[10,300] parameter(2)
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation
}
// CHECK-LABEL: func private @test_scatter
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ({
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers =
// CHECK-SAME: update_window_dims = [1]
// CHECK-SAME: inserted_window_dims = [0, 1]
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
// CHECK-SAME: index_vector_dim = 1
// CHECK-SAME: unique_indices = false
%wide_update_computation {
%lhs = f32[] parameter(0)
%lupd = f32[] parameter(1)
%rhs = f32[] parameter(2)
%rupd = f32[] parameter(3)
ROOT %tuple = (f32[], f32[]) tuple(f32[] %lhs, f32[] %rhs)
}
%test_variadic_scatter {
%input_tensor = f32[200,100,300] parameter(0)
%scatter_indices = s64[10,2] parameter(1)
%updates = f32[10,300] parameter(2)
ROOT %scatter = (f32[200,100,300], f32[200,100,300]) scatter(%input_tensor, %input_tensor, s64[10,2] %scatter_indices, %updates, %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%wide_update_computation
}
// CHECK-LABEL: func.func private @test_variadic_scatter
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tuple<tensor<200x100x300xf32>, tensor<200x100x300xf32>>
// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_0]], [[ARG_1]], [[ARG_2]], [[ARG_2]]) ({
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[UPD:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>, [[UPP:%.*]]: tensor<f32>):
// CHECK: mhlo.return [[LHS]], [[RHS]] : tensor<f32>, tensor<f32>
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers =
// CHECK-SAME: update_window_dims = [1]
// CHECK-SAME: inserted_window_dims = [0, 1]
// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1]
// CHECK-SAME: index_vector_dim = 1
// CHECK-SAME: unique_indices = false
%update_computation_returning_tuple {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
%sum = f32[] add(f32[] %lhs, f32[] %rhs)
ROOT %tuple = (f32[]) tuple(%sum)
}
%test_scatter_with_reducer_returning_tuple {
%input_tensor = f32[200,100,300] parameter(0)
%scatter_indices = s64[10,2] parameter(1)
%updates = f32[10,300] parameter(2)
ROOT %scatter = f32[200,100,300] scatter(f32[200,100,300] %input_tensor, s64[10,2] %scatter_indices, f32[10,300] %updates), update_window_dims={1}, inserted_window_dims={0,1}, scatter_dims_to_operand_dims={0,1}, index_vector_dim=1, to_apply=%update_computation_returning_tuple
}
// CHECK-LABEL: func private @test_scatter_with_reducer_returning_tuple
// CHECK-SAME: [[ARG_0:%.*]]: tensor<200x100x300xf32>, [[ARG_1:%.*]]: tensor<10x2xi64>, [[ARG_2:%.*]]: tensor<10x300xf32>) -> tensor<200x100x300xf32>
// CHECK: "mhlo.scatter"([[ARG_0]], [[ARG_1]], [[ARG_2]]) ({
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: })
// CHECK-LABEL: func private @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32>
%test_select {
%Arg_0.1 = pred[2,3] parameter(0)
%Arg_1.2 = s32[2,3] parameter(1)
%Arg_2.3 = s32[2,3] parameter(2)
// CHECK: "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
}
// Test SelectAndScatter
%ge_select (lhs: f32[], rhs: f32[]) -> pred[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
}
%add_gather (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
}
// CHECK-LABEL: func private @test_select_and_scatter
// CHECK-SAME: [[INPUT:%.*]]: tensor<4x5xf32>, [[SOURCE:%.*]]: tensor<2x2xf32>, [[INIT_VAL:%.*]]: tensor<f32>
%test_select_and_scatter {
%input = f32[4,5] parameter(0)
%source = f32[2,2] parameter(1)
%init_value = f32[] parameter(2)
ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select, scatter=%add_gather
}
// CHECK: [[RESULT:%.*]] = "mhlo.select_and_scatter"([[INPUT]], [[SOURCE]], [[INIT_VAL]]) ({
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = mhlo.compare GE, [[LHS]], [[RHS]]
// CHECK: mhlo.return [[CMP]] : tensor<i1>
// CHECK: }, {
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK: }) {
// CHECK-SAME: padding = dense<{{\[\[}}0, 0], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: window_dimensions = dense<[2, 3]> : tensor<2xi64>
// CHECK-SAME: window_strides = dense<[2, 3]> : tensor<2xi64>
// CHECK-SAME: }
// CHECK: return [[RESULT:%.*]] : tensor<4x5xf32>
// Test SelectAndScatter with tuple returns from computations.
%ge_select_returning_tuple (lhs: f32[], rhs: f32[]) -> (pred[]) {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
%greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
ROOT tple = (pred[]) tuple(%greater-than-or-equal-to)
}
%add_gather_returning_tuple (lhs.1: f32[], rhs.1: f32[]) -> (f32[]) {
%lhs = f32[] parameter(0)
%rhs = f32[] parameter(1)
%add = f32[] add(f32[] %lhs, f32[] %rhs)
ROOT tple = (f32[]) tuple(%add)
}
// CHECK-LABEL: func private @test_select_and_scatter_with_regions_returning_tuple
%test_select_and_scatter_with_regions_returning_tuple {
%input = f32[4,5] parameter(0)
%source = f32[2,2] parameter(1)
%init_value = f32[] parameter(2)
ROOT %select-and-scatter = f32[4,5] select-and-scatter(f32[4,5] %input, f32[2,2] %source, f32[] %init_value), window={size=2x3 stride=2x3 pad=0_0x0_1}, select=%ge_select_returning_tuple, scatter=%add_gather_returning_tuple
}
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = mhlo.compare GE, [[LHS]], [[RHS]]
// CHECK: mhlo.return [[CMP]] : tensor<i1>
// CHECK: ^bb0([[LHS:%.*]]: tensor<f32>, [[RHS:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[LHS]], [[RHS]]
// CHECK: mhlo.return [[ADD]] : tensor<f32>
// CHECK-LABEL: func private @test_set_dimension_size
// CHECK-SAME: ([[ARG:%.*]]: tensor<4x4xf32>, [[SIZE:%.*]]: tensor<i32>)
%test_set_dimension_size (Arg_0.1: f32[4,4], Arg_1.2: s32[]) -> f32[4,<=4] {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = s32[] parameter(1)
// CHECK-NEXT: "mhlo.set_dimension_size"([[ARG]], [[SIZE]]) {dimension = 1 : i64} : (tensor<4x4xf32>, tensor<i32>)
// CHECK-SAME: tensor<4x?xf32, #mhlo.type_extensions<bounds = [-1, 4]>>
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
}
// CHECK-LABEL: func private @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: mhlo.sine %arg0 : tensor<1x16x16x3xf32>
ROOT %sine.3 = f32[1,16,16,3]{3,2,1,0} sine(f32[1,16,16,3]{3,2,1,0} %arg0.1)
}
// Test sort
%compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
%test_sort {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
}
// CHECK-LABEL: func private @test_sort
// CHECK-SAME: [[ARG:%.*]]: tensor<1024xf32>) -> tensor<1024xf32>
// CHECK: "mhlo.sort"([[ARG]]) ({
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[CMP:%.*]] = mhlo.compare LT, [[ARG0]], [[ARG1]] : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: mhlo.return [[CMP]] : tensor<i1>
// CHECK: }) {dimension = 0 : i64, is_stable = true} : (tensor<1024xf32>) -> tensor<1024xf32>
// CHECK-LABEL: func private @test_subtract
%test_subtract (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: mhlo.subtract %arg0, %arg1 : tensor<4xf32>
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// Test sort with comparator returing tuple
%compare_returning_tuple {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
ROOT tple = (pred[]) tuple(lt)
}
%test_sort_with_comp_returning_tuple {
x = f32[1024]{0} parameter(0)
ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare_returning_tuple
}
// CHECK-LABEL: func private @test_sort_with_comp_returning_tuple
// CHECK: "mhlo.sort"
// CHECK: ^bb0
// CHECK: [[CMP:%.*]] = mhlo.compare
// CHECK: mhlo.return [[CMP]] : tensor<i1>
// CHECK-LABEL: func private @test_tanh(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: mhlo.tanh %arg0 : tensor<1x16x16x3xf32>
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
}
// CHECK-LABEL: func private @test_transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
%test_transpose {
%Arg_0.1 = s32[1,2,3,4] parameter(0)
// CHECK: "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
}
// CHECK-LABEL: func private @test_triangular_solve
// CHECK-SAME: ([[ARG_A:%.*]]: tensor<4x4xf32>, [[ARG_B:%.*]]: tensor<4x3xf32>) -> tensor<4x3xf32>
%test_triangular_solve (Arg_0.1: f32[4,4], Arg_1.2: f32[4,3]) -> f32[4,3] {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = f32[4,3] parameter(1)
// CHECK-NEXT: "mhlo.triangular_solve"([[ARG_A]], [[ARG_B]])
// CHECK-SAME: left_side = true
// CHECK-SAME: lower = true
// CHECK-SAME: transpose_a = #mhlo<transpose NO_TRANSPOSE>
// CHECK-SAME: unit_diagonal = true
ROOT %triangular-solve.3 = f32[4,3] triangular-solve(f32[4,4] %Arg_0.1, f32[4,3] %Arg_1.2), left_side=true, lower=true, transpose_a=NO_TRANSPOSE, unit_diagonal=true
}
// CHECK-LABEL: func private @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
%test_tuple(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
%Arg_0.1 = s32[1] parameter(0)
%Arg_1.2 = f32[1, 2] parameter(1)
// CHECK: "mhlo.tuple"(%arg0, %arg1) {xla_shape = {{.*}}} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
}
// Test while op
// CHECK-LABEL: func private @cond
%cond (arg_1: s64[]) -> pred[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func private @loop
%loop (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func private @test_while(%arg0: tensor<i64>) -> tensor<i64>
%test_while (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: mhlo.while(%[[ITER_ARG:.*]] = %arg0)
// CHECK: [[CMP:%.*]] = mhlo.compare LT, %[[ITER_ARG]], %[[ITER_ARG]] : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: mhlo.return [[CMP]] : tensor<i1>
// CHECK: [[ADD:%.*]] = mhlo.add %[[ITER_ARG]], %[[ITER_ARG]] : tensor<i64>
// CHECK-NEXT: mhlo.return [[ADD]] : tensor<i64>
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
}
// CHECK-LABEL: func private @test_xor
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi1>, [[VAL_1:%.*]]: tensor<4xi1>) -> tensor<4xi1>
%test_xor (Arg_0.1: pred[4], Arg_1.2: pred[4]) -> pred[4] {
%Arg_0.1 = pred[4] parameter(0)
%Arg_1.2 = pred[4] parameter(1)
// CHECK: mhlo.xor [[VAL_0]], [[VAL_1]]
ROOT %xor.3 = pred[4] xor(pred[4] %Arg_0.1, pred[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_shiftleft
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftleft (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: mhlo.shift_left [[VAL_0]], [[VAL_1]]
ROOT %shiftleft = s32[4] shift-left(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_shiftright_arithmetic
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftright_arithmetic (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: mhlo.shift_right_arithmetic [[VAL_0]], [[VAL_1]]
ROOT %shiftright.arithmetic = s32[4] shift-right-arithmetic(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @test_shiftright_logical
// CHECK-SAME: ([[VAL_0:%.*]]: tensor<4xi32>, [[VAL_1:%.*]]: tensor<4xi32>) -> tensor<4xi32>
%test_shiftright_logical (Arg_0.1: s32[4], Arg_1.2: s32[4]) -> s32[4] {
%Arg_0.1 = s32[4] parameter(0)
%Arg_1.2 = s32[4] parameter(1)
// CHECK: mhlo.shift_right_logical [[VAL_0]], [[VAL_1]]
ROOT %shiftright.logical = s32[4] shift-right-logical(s32[4] %Arg_0.1, s32[4] %Arg_1.2)
}
// CHECK-LABEL: func private @complex_type
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xcomplex<f32>>, %[[ARG1:.*]]: tensor<2xcomplex<f64>>) -> tuple<tensor<2xf32>, tensor<2xf64>>
%complex_type (Arg_0.1: c64[2], Arg_1.2: c128[2]) -> (f32[2], f64[2]) {
%Arg_0.1 = c64[2] parameter(0)
%abs.3 = f32[2] abs(c64[2] %Arg_0.1)
%Arg_1.2 = c128[2] parameter(1)
%abs.4 = f64[2] abs(c128[2] %Arg_1.2)
// CHECK: mhlo.abs(%[[ARG0]]) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
// CHECK: mhlo.abs(%[[ARG1]]) : (tensor<2xcomplex<f64>>) -> tensor<2xf64>
ROOT %tuple.5 = (f32[2], f64[2]) tuple(f32[2] %abs.3, f64[2] %abs.4)
}
// CHECK-LABEL: func private @unsigned_int
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4xui16>)
%unsigned_int(Arg_0.1: u16[4]) -> u16[4] {
%Arg_0.1 = u16[4] parameter(0)
// CHECK: mhlo.not %[[ARG0]] : tensor<4xui16>
ROOT %not.2 = u16[4] not(u16[4] %Arg_0.1)
}
// CHECK-LABEL: func private @rngbitgen
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3xui64>)
%rngbitgen (Arg_0.1: u64[3]) -> (u64[3], u32[2,2]) {
%Arg_0.1 = u64[3] parameter(0)
// CHECK: %[[RNG0:.+]], %[[RNG1:.+]] = "mhlo.rng_bit_generator"(%[[ARG0]]) {rng_algorithm = #mhlo.rng_algorithm<PHILOX>} : (tensor<3xui64>) -> (tensor<3xui64>, tensor<2x2xui32>)
// CHECK: %[[TUPLE:.+]] = "mhlo.tuple"(%[[RNG0]], %[[RNG1]]) {xla_shape = "(u64[3]{0}, u32[2,2]{1,0})"} : (tensor<3xui64>, tensor<2x2xui32>) -> tuple<tensor<3xui64>, tensor<2x2xui32>>
// CHECK: return %[[TUPLE]]
ROOT %rng-bit-generator.2 = (u64[3], u32[2,2]) rng-bit-generator(u64[3] %Arg_0.1), algorithm=rng_philox
}
// CHECK-LABEL: func private @cbrt
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
%cbrt (Arg_0.1: f32[3,4]) -> f32[3,4] {
%Arg_0.1 = f32[3,4] parameter(0)
// CHECK: mhlo.cbrt %[[ARG0]] : tensor<3x4xf32>
ROOT %cbrt = f32[3,4] cbrt(f32[3,4] %Arg_0.1)
}
// CHECK-LABEL: func private @bitcast
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>) -> tensor<3x4x1xf32>
%bitcast (Arg_0.1: f32[3,4]) -> f32[3,4,1]{2,0,1} {
%Arg_0.1 = f32[3,4] parameter(0)
// CHECK: mhlo.bitcast %[[ARG0]] {result_layout = dense<[2, 0, 1]> : tensor<3xindex>, source_layout = dense<[1, 0]> : tensor<2xindex>, xla_shape = "f32[3,4,1]{2,0,1}"} : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
ROOT %bitcast = f32[3,4,1]{2,0,1} bitcast(f32[3,4] %Arg_0.1)
}
// CHECK-LABEL: func private @reduce_precision
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x4xf32>)
%reduce_precision (Arg_0.1: f32[3,4]) -> f32[3,4] {
%Arg_0.1 = f32[3,4] parameter(0)
// CHECK: "mhlo.reduce_precision"(%[[ARG0]]) {exponent_bits = 8 : i32, mantissa_bits = 10 : i32} : (tensor<3x4xf32>) -> tensor<3x4xf32>
ROOT %reduce_precision = f32[3,4] reduce-precision(f32[3,4] %Arg_0.1), exponent_bits=8, mantissa_bits=10
}
// CHECK-LABEL: func private @optimization_barrier
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x4xf32>, %[[ARG1:.*]]: tensor<3x4xf32>)
%optimization_barrier (Arg_0.1: f32[4,4], Arg_1.2: f32[3,4]) -> (f32[4,4], f32[3,4]) {
%Arg_0.1 = f32[4,4] parameter(0)
%Arg_1.2 = f32[3,4] parameter(1)
%args = (f32[4,4], f32[3,4]) tuple(f32[4,4] %Arg_0.1, f32[3,4] %Arg_1.2)
// CHECK: mhlo.optimization_barrier %[[ARG0]], %[[ARG1]] : (tensor<4x4xf32>, tensor<3x4xf32>) -> (tensor<4x4xf32>, tensor<3x4xf32>)
ROOT %opt-barrier = (f32[4,4], f32[3,4]) opt-barrier((f32[4,4], f32[3,4]) %args)
}
// CHECK-LABEL : func private parition_id
%partition_id {
// CHECK: mhlo.partition_id : tensor<ui32>
ROOT %pid = u32[] partition-id()
}
// CHECK-LABEL: func private @round_nearest_even
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xf32>)
%round_nearest_even (Arg_0.1: f32[2]) -> f32[2] {
%Arg_0.1 = f32[2] parameter(0)
ROOT %round-nearest-even.2 = f32[2] round-nearest-even(f32[2] %Arg_0.1)
// CHECK: mhlo.round_nearest_even %[[ARG0]] : tensor<2xf32>
}
// CHECK-LABEL : func private domain
%domain (Arg_0.1: u32[]) -> u32[] {
// CHECK: "mhlo.domain"(
// CHECK-SAME: {entry_metadata = "\08\01\1A\01\01\22\01\01", exit_metadata = "\08\02", kind = #mhlo<kind sharding>}
%Arg_0.1 = u32[] parameter(0)
ROOT %domain.2 = u32[] domain(u32[] %Arg_0.1), domain={kind="sharding", entry={maximal device=1}, exit={}}
}
// CHECK-LABEL: func private @add_dependency
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x4xf32>)
%add_dependency (Arg_0.1: f32[4,4]) -> f32[4,4] {
%Arg_0.1 = f32[4,4] parameter(0)
%after-all = token[] after-all()
// CHECK: %[[TOK:.*]] = mhlo.create_token
// CHECK: mhlo.add_dependency %[[ARG0]], %[[TOK]] : (tensor<4x4xf32>, !mhlo.token) -> tensor<4x4xf32>
ROOT %add-dep = f32[4,4] add-dependency(f32[4,4] %Arg_0.1, token[] %after-all)
}
// CHECK-LABEL: func private @test_args_and_result_with_sharding
// CHECK-SAME: ([[Arg:%.*]]: tensor<4xi32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"}) -> (tensor<4xi32> {mhlo.sharding = "\08\03\1A\02\02\01\22\02\00\01"})
%test_args_and_result_with_sharding (Arg_0.1: s32[4]) -> s32[4] {
%arg0.1 = s32[4] parameter(0), sharding={devices=[2,1]0,1}
ROOT %copy.1 = s32[4] copy(%arg0.1), sharding={devices=[2,1]0,1}
}