blob: e1c8e838d4f8b4efd4d908af773f0f17156c4a78 [file] [log] [blame]
// RUN: tf-mlir-translate -split-input-file -hlo-text-to-lhlo -optimize-xla-hlo=false %s | FileCheck %s
HloModule TestModule
// CHECK-LABEL: func @TestComputation
FusedComputation {
// CHECK: to_tensor {{.*}} {xla_shape = "f32[3,2]{0,1}"}
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} add(x, x)
}
ENTRY TestComputation {
x = f32[3, 2]{0,1} parameter(0)
ROOT y = f32[3, 2]{0,1} fusion(x), kind=kLoop, calls=FusedComputation
}
// -----
HloModule ScatterModule
update_s32 (lhs: s32[], rhs: s32[]) -> s32[] {
lhs = s32[] parameter(0)
ROOT rhs = s32[] parameter(1)
}
// CHECK-LABEL: func @main
// CHECK: "lmhlo.scatter"
// CHECK: ^bb0(%[[ARG5:.*]]: tensor<i32>, %[[ARG6:.*]]: tensor<i32>):
// CHECK: "mhlo.return"(%[[ARG6]])
// CHECK: indices_are_sorted = false
// CHECK: update_window_dims = [1]
// CHECK: inserted_window_dims = [0]
// CHECK: scatter_dims_to_operand_dims = [0]
// CHECK: index_vector_dim = 1
// CHECK: unique_indices = false
// CHECK: (memref<3x3xi32>, memref<2xi32>, memref<2x3xi32>, memref<3x3xi32>) -> ()
ENTRY main {
operand = s32[3,3] parameter(0)
indices = s32[2] parameter(1)
updates = s32[2,3] parameter(2)
ROOT scatter_op = s32[3,3] scatter(operand, indices, updates),
to_apply=update_s32,
update_window_dims={1},
inserted_window_dims={0},
scatter_dims_to_operand_dims={0},
index_vector_dim=1
}
// -----
HloModule SelectAndScatter
%ge_F32 (lhs.5: f32[], rhs.6: f32[]) -> pred[] {
%lhs.5 = f32[] parameter(0)
%rhs.6 = f32[] parameter(1)
ROOT %compare.7 = pred[] compare(f32[] %lhs.5, f32[] %rhs.6), direction=GE
}
%add_F32 (lhs.9: f32[], rhs.10: f32[]) -> f32[] {
%lhs.9 = f32[] parameter(0)
%rhs.10 = f32[] parameter(1)
ROOT %add.11 = f32[] add(f32[] %lhs.9, f32[] %rhs.10)
}
// CHECK-LABEL: module
// CHECK: memref.global "private" constant @[[$GLOBAL:.*]] : memref<f32> = dense<0.000000e+00>
// CHECK-LABEL: func @main
// CHECK: %[[GLOBAL_MEMREF:.*]] = memref.get_global @[[$GLOBAL]] : memref<f32>
// CHECK: "lmhlo.select_and_scatter"(%{{.*}}, %{{.*}}, %[[GLOBAL_MEMREF]], %{{.*}})
// CHECK: ^bb0(%[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>):
// CHECK: %[[COMPARE:.*]] = "mhlo.compare"(%[[ARG0]], %[[ARG1]]) {comparison_direction = #mhlo<"comparison_direction GE">}
// CHECK: "mhlo.return"(%[[COMPARE]]) : (tensor<i1>) -> ()
// CHECK: ^bb0(%[[ARG2:.*]]: tensor<f32>, %[[ARG3:.*]]: tensor<f32>):
// CHECK: %[[ADD:.*]] = mhlo.add %[[ARG2]], %[[ARG3]]
// CHECK: "mhlo.return"(%[[ADD]]) : (tensor<f32>) -> ()
// CHECK: padding = dense<0> : tensor<1xi64>
// CHECK: window_dimensions = dense<3> : tensor<1xi64>
// CHECK: window_strides = dense<3> : tensor<1xi64>
// CHECK: (memref<6xf32>, memref<2xf32>, memref<f32>, memref<6xf32>) -> ()
ENTRY main () -> f32[6] {
%operand = f32[6]{0} parameter(0)
%source = f32[2]{0} parameter(1)
%init = f32[] constant(0)
ROOT %select-and-scatter.12 = f32[6]{0} select-and-scatter(f32[6]{0} %operand, f32[2]{0} %source, f32[] %init), window={size=3 stride=3}, select=%ge_F32, scatter=%add_F32
}
// -----
HloModule SliceToDynamic
// CHECK-LABEL: func @main
// CHECK: "lmhlo.custom_call"
// CHECK: backend_config = "", call_target_name = "SliceToDynamic"
// CHECK-SAME: operand_segment_sizes = dense<[4, 1]> : vector<2xi32>
// CHECK-NOT: target_arg_mapping
// CHECK: (memref<2x2x2xi32>, memref<i32>, memref<i32>, memref<i32>, memref<2x2x2xi32>) -> ()
ENTRY main {
%param = s32[2,2,2] parameter(0)
%static = s32[] parameter(1)
%dynamic = s32[] parameter(2)
ROOT %custom-call = s32[2,<=2, 2] custom-call(s32[2,2,2] %param,
s32[] %static,
s32[] %dynamic,
s32[] %static),
custom_call_target="SliceToDynamic",
backend_config=""
}
// -----
HloModule Cholesky
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.cholesky"
// CHECK-SAME: is_lower = true
ENTRY main {
%param = f32[3,3] parameter(0)
ROOT %custom-call = (f32[3,3], f32[3], s32[]) custom-call(f32[3,3] %param),
custom_call_target="__cusolver$cholesky",
operand_layout_constraints={f32[3,3]},
backend_config="{\"lower\":true}"
}
// -----
HloModule Gemm
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm"
// CHECK-SAME: algorithm = 7 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-NOT: lhs_batching_dimensions
// CHECK-NOT: rhs_batching_dimensions
// CHECK-SAME: lhs_contracting_dimensions = [1]
// CHECK-SAME: rhs_contracting_dimensions = [0]
// CHECK: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
ENTRY main {
%A = f32[2,2]{1,0} parameter(0)
%B = f32[2,2]{1,0} parameter(1)
ROOT %sgemm = f32[2,2]{1,0} custom-call(f32[2,2]{1,0} %A, f32[2,2]{1,0} %B),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":0,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"selected_algorithm\":\"7\"}"
}
// -----
HloModule GemmBias
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm_bias"
// CHECK-SAME: algorithm = 0 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-SAME: beta = 1.000000e+00 : f64
// CHECK-NOT: lhs_batching_dimensions
// CHECK-NOT: rhs_batching_dimensions
// CHECK-SAME: lhs_contracting_dimensions = [1]
// CHECK-SAME: rhs_contracting_dimensions = [0]
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
ENTRY main {
%A = f32[1,1]{1,0} parameter(0)
%B = f32[1,4]{1,0} parameter(1)
%C = f32[1,4]{1,0} parameter(2)
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"selected_algorithm\":\"0\"}"
}
// -----
HloModule GemmBias
// CHECK-LABEL: func @main
// CHECK: "lmhlo_gpu.gemm_bias"
// CHECK-SAME: algorithm = 0 : i64
// CHECK-SAME: alpha_imag = 0.000000e+00 : f64
// CHECK-SAME: alpha_real = 1.000000e+00 : f64
// CHECK-SAME: beta = 1.000000e+00 : f64
// CHECK-NOT: lhs_batching_dimensions
// CHECK-NOT: rhs_batching_dimensions
// CHECK-SAME: lhs_contracting_dimensions = [1]
// CHECK-SAME: rhs_contracting_dimensions = [0]
// CHECK: (memref<1x1xf32>, memref<1x4xf32>, memref<1x4xf32>, memref<1x4xf32>)
ENTRY main {
%A = f32[1,1]{1,0} parameter(0)
%B = f32[1,4]{1,0} parameter(1)
%C = f32[1,4]{1,0} parameter(2)
ROOT %sgemm_add = f32[1,4]{1,0} custom-call(f32[1,1]{0,1} %A, f32[1,4]{1,0} %B, f32[1,4]{1,0} %C),
custom_call_target="__cublas$gemm",
backend_config="{\"alpha_real\":1,\"alpha_imag\":0,\"beta\":1,\"dot_dimension_numbers\":{\"lhs_contracting_dimensions\":[\"1\"],\"rhs_contracting_dimensions\":[\"0\"],\"lhs_batch_dimensions\":[],\"rhs_batch_dimensions\":[]},\"selected_algorithm\":\"0\"}"
}
// -----
HloModule AllReduce
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_all_reduce
// CHECK-SAME: ([[INPUT:%.*]]: memref<32xi8> {lmhlo.params = 0
%test_all_reduce {
input = f32[8] parameter(0)
// CHECK: [[VIEW0:%.*]] = memref.view [[INPUT]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: "lmhlo.all_reduce"([[VIEW0]], {{.*}})
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_id = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
}
// -----
HloModule AllReduceTuple
// Test all-reduce with tuples
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_all_reduce_tuple
// CHECK: "lmhlo.all_reduce"
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-NOT: channel_id
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [5, 6, 7, 4]]> : tensor<2x4xi64>
// CHECK-SAME: (memref<8xf32>, memref<5xf32>, memref<8xf32>, memref<5xf32>) -> ()
%test_all_reduce_tuple {
input = f32[8] parameter(0)
%tuple = (f32[8], f32[5]) custom-call(),custom_call_target="make-tuple"
ROOT result = (f32[8], f32[5]) all-reduce(%tuple), replica_groups={{0,1,2,3}, {5,6,7,4}}, to_apply=add
}
// -----
HloModule AllReduceNonUniformReplicaGroups
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_all_reduce
// CHECK-SAME: ([[INPUT:%.*]]: memref<32xi8> {lmhlo.params = 0
%test_all_reduce {
input = f32[8] parameter(0)
// CHECK: [[VIEW0:%.*]] = memref.view [[INPUT]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: "lmhlo.all_reduce"([[VIEW0]], {{.*}})
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_id = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, -1], [3, 4, 5, 6]]> : tensor<2x4xi64>
ROOT result = f32[8] all-reduce(input), channel_id=1, replica_groups={{0,1,2}, {3,4,5,6}}, to_apply=add
}
// -----
HloModule AsyncAllReduce
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_async_all_reduce
// CHECK-SAME: [[BUFFER:%.*]]: memref<32xi8>
%test_async_all_reduce {
param0 = f32[8] parameter(0)
// CHECK: [[INPUT:%.*]] = memref.view [[BUFFER]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: [[OUTPUT:%.*]] = memref.view [[BUFFER]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[INPUT]], [[OUTPUT]])
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_id = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
// CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]], [[OUTPUT]])
start = f32[8] all-reduce-start(param0),
channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
ROOT done = f32[8] all-reduce-done(start)
}
// -----
HloModule AsyncAllReduceTwoOperands
// Test all-reduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
// CHECK-LABEL: func @test_async_all_reduce_two_operands
// CHECK-SAME: [[BUFFER0:%.*]]: memref<32xi8>
// CHECK-SAME: [[BUFFER1:%.*]]: memref<36xi8>
%test_async_all_reduce_two_operands {
param0 = f32[8] parameter(0)
param1 = f32[9] parameter(1)
// CHECK: [[INPUT0:%.*]] = memref.view [[BUFFER0]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: [[INPUT1:%.*]] = memref.view [[BUFFER1]]{{.*}} : memref<36xi8> to memref<9xf32>
// CHECK: [[OUTPUT0:%.*]] = memref.view [[BUFFER0]]{{.*}} : memref<32xi8> to memref<8xf32>
// CHECK: [[OUTPUT1:%.*]] = memref.view [[BUFFER1]]{{.*}} : memref<36xi8> to memref<9xf32>
// CHECK: [[TOKEN:%.*]] = "lmhlo_gpu.all_reduce_start"([[INPUT0]], [[INPUT1]], [[OUTPUT0]], [[OUTPUT1]])
// CHECK: ^bb0([[ARG0:%.*]]: tensor<f32>, [[ARG1:%.*]]: tensor<f32>):
// CHECK: [[ADD:%.*]] = mhlo.add [[ARG0]], [[ARG1]]
// CHECK: "mhlo.return"([[ADD]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: channel_id = #mhlo.channel_handle<handle = 1, type = 0>
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3], [4, 5, 6, 7]]> : tensor<2x4xi64>
// CHECK: "lmhlo_gpu.all_reduce_done"([[TOKEN]], [[OUTPUT0]], [[OUTPUT1]])
start = (f32[8], f32[9]) all-reduce-start(param0, param1),
channel_id=1, replica_groups={{0,1,2,3}, {4,5,6,7}}, to_apply=add
ROOT done = (f32[8], f32[9]) all-reduce-done(start)
}
// -----
HloModule ConvForward
// CHECK-LABEL: func @main
// CHECK: lmhlo_gpu.conv_forward
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [0, 0], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [1, 1]}
// CHECK-SAME: algorithm = 2
// CHECK-SAME: tensor_ops_enabled = false
// CHECK-SAME: operand_0_layout = [3, 2, 1, 0]
// CKECK-SAME: operand_1_layout = [3, 2, 1, 0]
// CHECK-SAME: result_layout = [3, 2, 1, 0]
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: result_scale = 1.000000e+00 : f64
// CHECK: (memref<4x256x3x3xf32>, memref<256x256x2x2xf32>, memref<4x256x2x2xf32>, memref<65536xui8>)
ENTRY main {
%input = f32[4,256,3,3]{3,2,1,0} parameter(0)
%filter = f32[256,256,2,2]{3,2,1,0} parameter(1)
ROOT %custom-call.1 = (f32[4,256,2,2]{3,2, 1,0}, u8[65536]{0}) custom-call(f32[4,256,3,3]{3,2,1,0} %input, f32[256,256,2,2]{3,2,1,0} %filter),
window={size=2x2 rhs_reversal=1x1}, dim_labels=bf01_oi01->bf01,
custom_call_target="__cudnn$convForward",
backend_config="{\"algorithm\": {\"algo_id\":\"2\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}"
}
// -----
// CHECK: func @main
// CHECK: lmhlo_gpu.conv_forward_fused
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]}
// CHECK-SAME: activation_mode = #lmhlo_gpu<"activation Relu">
// CHECK-SAME: algorithm = 0
// CHECK-SAME: tensor_ops_enabled = false
// CHECK-SAME: operand_0_layout = [1, 3, 2, 0]
// CHECK-SAME: operand_1_layout = [2, 1, 0, 3]
// CHECK-SAME: result_layout = [1, 3, 2, 0]
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
// CHECK-SAME: result_scale = 1.000000e+00 : f64
// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> ()
HloModule FusedConvForward
ENTRY main {
%input = f16[1,17,9,9]{1,3,2,0} parameter(0)
%filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
%bias = f16[32]{0} parameter(2)
ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\": {\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":0}"
}
// -----
// CHECK: func @main
// CHECK: lmhlo_gpu.conv_forward_fused_with_side_input
// CHECK-SAME: dim_numbers = [b, f, 0, 1]x[0, 1, i, o]->[b, f, 0, 1]
// CHECK-SAME{LITERAL}: window = {stride = [1, 1], pad = [1, 1], lhs_dilate = [1, 1], rhs_dilate = [1, 1], reverse = [0, 0]}
// CHECK-SAME: activation_mode = #lmhlo_gpu<"activation Relu">
// CHECK-SAME: algorithm = 0
// CHECK-SAME: tensor_ops_enabled = false
// CHECK-SAME: operand_0_layout = [1, 3, 2, 0]
// CHECK-SAME: operand_1_layout = [2, 1, 0, 3]
// CHECK-SAME: result_layout = [1, 3, 2, 0]
// CHECK-SAME: batch_group_count = 1 : i64
// CHECK-SAME: feature_group_count = 1 : i64
// CHECK-SAME: precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
// CHECK-SAME: result_scale = 1.000000e+00 : f64
// CHECK-SAME: side_input_scale = 1.000000e+00
// CHECK-SAME: (memref<1x17x9x9xf16, #map{{.*}}>, memref<3x3x17x32xf16, #map{{.*}}>, memref<32xf16>, memref<1x32x9x9xf16, #{{.*}}>, memref<0xui8>) -> ()
HloModule FusedConvForwardSideInput
ENTRY main {
%input = f16[1,17,9,9]{1,3,2,0} parameter(0)
%filter = f16[3,3,17,32]{2,1,0,3} parameter(1)
%bias = f16[32]{0} parameter(2)
%side = f16[32]{0} parameter(3)
ROOT %custom-call.2 = (f16[1,32,9,9]{1,3,2,0}, u8[0]{0}) custom-call(f16[1,17,9,9]{1,3,2,0} %input, f16[3,3,17,32]{2,1,0,3} %filter, f16[32]{0} %bias, f16[32]{0} %side), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config="{\"algorithm\":{\"algo_id\":\"0\",\"math_type\":\"DEFAULT_MATH\"},\"conv_result_scale\":1,\"activation_mode\":\"2\",\"side_input_scale\":1}"
}
// -----
HloModule Infeed
// CHECK: func @main
// CHECK: "lmhlo.infeed"
// CHECK-SAME: (memref<3xf32>) -> ()
ENTRY main {
%tok = token[] parameter(0)
ROOT %infeed = (f32[3]{0}, token[]) infeed(token[] %tok)
}
// -----
HloModule Outfeed
// CHECK: func @main
// CHECK: "lmhlo.outfeed"
// CHECK-SAME: config = ""
// CHECK-SAME: (memref<3xf32>) -> ()
ENTRY main {
%source = f32[3] parameter(0)
%tok = token[] parameter(1)
ROOT %o = token[] outfeed(f32[3] %source, token[] %tok)
}
// -----
HloModule Outfeed
// CHECK: func @main
// CHECK: "lmhlo.custom_call"
// CHECK-SAME: call_target_name = "foo"
// CHECK: "lmhlo.outfeed"
// CHECK-SAME: config = ""
// CHECK-SAME: (memref<3xf32>, memref<5xf16>) -> ()
ENTRY main {
%tok = token[] parameter(0)
%tuple = (f32[3], f16[5]) custom-call(),custom_call_target="foo"
ROOT %o = token[] outfeed((f32[3], f16[5]) %tuple, token[] %tok)
}
// -----
HloModule Test
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.dot_general"(%{{.*}}, %{{.*}}) {
// CHECK-SAME: dot_dimension_numbers =
// CHECK-SAME: lhs_batching_dimensions = [0]
// CHECK-SAME: rhs_batching_dimensions = [0]
// CHECK-SAME: lhs_contracting_dimensions = [2]
// CHECK-SAME: rhs_contracting_dimensions = [1]
// CHECK-SAME: precision_config = [#mhlo<"precision DEFAULT">, #mhlo<"precision DEFAULT">]
// CHECK-SAME: : (tensor<1x3x4xf32>, tensor<1x4x5xf32>) -> tensor<1x4x5xf32>
ENTRY main {
%arg0 = f32[1,3,4]{2,1,0} parameter(0)
%arg1 = f32[1,4,5]{2,1,0} parameter(1)
ROOT %out = f32[1,4,5]{2,1,0} dot(%arg0, %arg1), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
}
// -----
HloModule Test
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.reshape"(%{{.*}}) : (tensor<2xf32>) -> tensor<1x2xf32>
ENTRY main {
%arg0 = f32[2]{0} parameter(0)
ROOT %out = f32[1,2]{1,0} reshape(%arg0)
}
// -----
HloModule Test
max {
%a = f32[] parameter(0)
%b = f32[] parameter(1)
ROOT %c = f32[] maximum(%a, %b)
}
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.reduce_window"(%{{.*}}, %{{.*}}) ({
// CHECK: ^bb0(%[[ARG6:.*]]: tensor<f32>, %[[ARG7:.*]]: tensor<f32>):
// CHECK: %[[RET:.*]] = mhlo.maximum %[[ARG6]], %[[ARG7]] : tensor<f32>
// CHECK: "mhlo.return"(%[[RET]]) : (tensor<f32>) -> ()
// CHECK: }) {
// CHECK-SAME: padding = dense<{{\[}}[0, 0], [2, 0], [0, 2], [0, 0]{{\]}}> : tensor<4x2xi64>,
// CHECK-SAME: window_dilations = dense<[1, 2, 2, 1]> : tensor<4xi64>,
// CHECK-SAME: window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
// CHECK-SAME: window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>}
// CHECK-SAME: : (tensor<2x17x31x7xf32>, tensor<f32>) -> tensor<2x5x8x7xf32>
ENTRY main {
%arg0 = f32[2,17,31,7] parameter(0)
%c = f32[] constant(0)
ROOT %out = reduce-window(%arg0, %c), window={size=1x2x2x1 stride=1x4x4x1 pad=0_0x2_0x0_2x0_0 lhs_dilate=1x1x1x1 rhs_dilate=1x2x2x1}, to_apply=max
}
// -----
HloModule Test
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.pad"(%{{.*}}, %{{.*}}) {
// CHECK-SAME: edge_padding_high = dense<[4, 5]> : tensor<2xi64>,
// CHECK-SAME: edge_padding_low = dense<[2, 3]> : tensor<2xi64>,
// CHECK-SAME: interior_padding = dense<1> : tensor<2xi64>}
// CHECK-SAME: : (tensor<4x6xf32>, tensor<f32>) -> tensor<13x19xf32>
ENTRY main {
%arg0 = f32[4,6] parameter(0)
%arg1 = f32[] parameter(1)
%out = f32[13,19] pad(%arg0, %arg1), padding=2_4_1x3_5_1
}
// -----
HloModule Test
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.transpose"(%{{.*}}) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xf32>) -> tensor<2x1x4x3xf32>
ENTRY main {
%arg0 = f32[1,2,3,4] parameter(0)
%out = f32[2,1,4,3] transpose(%arg0), dimensions={1,0,3,2}
}
// -----
HloModule Test
// CHECK: func @main
// CHECK: "lmhlo.fusion"()
// CHECK: "mhlo.broadcast_in_dim"(%{{.*}}) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<1xf32>) -> tensor<10xf32>
ENTRY main {
%arg0 = f32[1] parameter(0)
%out = f32[10] broadcast(%arg0), dimensions={0}
}
// -----
HloModule TestModule
// CHECK: func @main
// CHECK: "lmhlo.rng_get_and_update_state"(%{{.*}}) {delta = 131072 : i64} : (memref<2xui64>) -> ()
ENTRY main {
ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=131072
}
// -----
HloModule TestAllGather
// CHECK: func @main
// CHECK: "lmhlo.all_gather"
// CHECK-SAME: all_gather_dimension = 1 : i64
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>
// CHECK-SAME: use_global_device_ids = false
ENTRY main {
param0 = f32[10,20] parameter(0)
ROOT ag = f32[10,80] all-gather(param0), replica_groups={{0,1,2,3}},
dimensions={1}
}
// -----
// CHECK: func @entry
// CHECK: "lmhlo.all_to_all"
// CHECK-SAME: constrain_layout = false
// CHECK-SAME{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
HloModule TestAllToAll
ENTRY entry {
p0 = f32[128,4]{0,1} parameter(0)
p1 = f32[128,4]{1,0} parameter(1)
ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
replica_groups={{0,1}}
}
// -----
// CHECK: func @main
// CHECK: "lmhlo.collective_permute"
// CHECK-SAME: channel_id = #mhlo.channel_handle<handle = 2, type = 0>
// CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [1, 2], [2, 0]]> : tensor<3x2xi64>
HloModule TestCollectivePermute
ENTRY main {
p0 = f32[128] parameter(0)
ROOT permute = f32[128] collective-permute(p0),
source_target_pairs={{0,1}, {1,2}, {2,0}}, channel_id=2
}
// -----
HloModule TestReplicaId
// CHECK: func @main
// CHECK: "lmhlo.replica_id"
// CHECK-SAME: (memref<ui32>) -> ()
ENTRY main {
ROOT %replica_id = u32[] replica-id()
}
// -----
HloModule fft
// CHECK: func @main
// CHECK: "lmhlo.fft"
// CHECK-SAME: fft_length = dense<[8, 32]> : tensor<2xi64>
// CHECK-SAME: fft_type = #mhlo<"fft_type IFFT">
ENTRY main {
%input = c64[5,8,32] parameter(0)
ROOT %fft = c64[5,8,32] fft(c64[5,8,32] %input), fft_type=IFFT, fft_length={8,32}
}
// -----
HloModule TriangularSolve_module
// CHECK: func @main
// CHECK: "lmhlo.triangular_solve"
// CHECK-SAME: layout_a = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: layout_b = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: layout_output = dense<[1, 0]> : tensor<2xindex>
// CHECK-SAME: left_side = false
// CHECK-SAME: lower = true
// CHECK-SAME: transpose_a = #mhlo<"transpose NO_TRANSPOSE">
// CHECK-SAME: unit_diagonal = false
ENTRY main {
%a = f32[4,4]{1,0} parameter(0)
%b = f32[3,4]{1,0} parameter(1)
ROOT %triangular-solve = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a, f32[3,4]{1,0} %b), lower=true, transpose_a=NO_TRANSPOSE
}
// -----
HloModule CustomCallWithTokens
// CHECK: func @main
// CHECK: "lmhlo.custom_call"
// CHECK-SAME: num_args = 1
// CHECK-SAME: num_results = 2
// CHECK-SAME: args_to_target_args = []
// CHECK-SAME: results_to_target_results = [0]
ENTRY main {
%tok = token[] parameter(0)
ROOT %call = (f32[3], token[]) custom-call (%tok), custom_call_target="foo",
backend_config=""
}
// -----
HloModule CustomCallWithTokens
// CHECK: func @main
// CHECK: "lmhlo.custom_call"
// CHECK-SAME: num_args = 3
// CHECK-SAME: num_results = 3
// CHECK-SAME: args_to_target_args = [1]
// CHECK-SAME: results_to_target_results = [0, 2]
ENTRY main {
%tok = token[] parameter(0)
%input = f32[5,8,32] parameter(1)
ROOT %call = (f32[3]{0}, token[], f32[3]) custom-call (%tok, %input, %tok),
custom_call_target="foo",
backend_config=""
}
// -----
HloModule CustomCallWithTokens
// CHECK: func @main
// CHECK: "lmhlo.custom_call"
// CHECK-SAME: num_args = 3
// CHECK-SAME: num_results = 1
// CHECK-SAME: args_to_target_args = [1]
// CHECK-SAME: results_to_target_results = [0]
ENTRY main {
%tok = token[] parameter(0)
%input = f32[5,8,32] parameter(1)
ROOT %call = f32[3] custom-call (%tok, %input, %tok),
custom_call_target="foo",
backend_config=""
}
// -----
HloModule CustomCallWithTokens
// CHECK: func @main
// CHECK: "lmhlo.custom_call"
// CHECK-SAME: num_args = 1
// CHECK-SAME: num_results = 4
// CHECK-SAME: args_to_target_args = [0]
// CHECK-SAME: results_to_target_results = [1]
ENTRY main {
%input = f32[5,8,32] parameter(0)
ROOT %call = (token[], f32[3]{0}, token[], token[]) custom-call (%input),
custom_call_target="foo",
backend_config=""
}
// -----
// CHECK: func @main
// CHECK: "lmhlo.reduce_scatter"
// CHECK{LITERAL}: replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>
// CHECK-SAME: scatter_dimension = 0 : i64
HloModule ReduceScatter
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY main {
input = f32[8]{0} parameter(0)
ROOT ars = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}}, dimensions={0}, to_apply=add
}
// -----
// CHECK: func @main
// CHECK: "lmhlo.while"(%{{.*}}) ({
HloModule WhileConstantCondition
%body {
%parameter.5 = (f32[5]{0}) parameter(0)
%constant_8 = f32[] constant(0)
%broadcast.9 = f32[5]{0} broadcast(f32[] %constant_8), dimensions={}
ROOT %tuple.2 = (f32[5]{0}) tuple(f32[5]{0} %broadcast.9)
}
%cond {
%parameter.12 = (f32[5]{0}) parameter(0)
ROOT %constant_1 = pred[] constant(false)
}
ENTRY %main (parameter.1: f32[5]) -> (f32[5]) {
%parameter.1 = f32[5]{0} parameter(0)
%copy.3 = f32[5]{0} copy(f32[5]{0} %parameter.1)
%tuple = (f32[5]{0}) tuple(f32[5]{0} %copy.3)
ROOT %while.19 = (f32[5]{0}) while((f32[5]{0}) %tuple), condition=%cond, body=%body
}
// -----
// CHECK: func @main
// CHECK: "mhlo.copy"
// CHECK-SAME: (tensor<2xi32>) -> tensor<2xi32>
HloModule CopyTest
ENTRY main {
%parameter.1 = s32[2]{0} parameter(0)
%parameter.2 = s32[] parameter(1)
%custom-call = s32[<=2]{0} custom-call(s32[2]{0} %parameter.1, s32[] %parameter.2), custom_call_target="SliceToDynamic"
ROOT %copy = s32[<=2]{0} copy(s32[<=2]{0} %custom-call)
}