blob: a4a161e4b21ec0b34666c9271283a1d9a42538ff [file] [log] [blame]
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file -allow-unregistered-dialect | FileCheck %s
// CHECK: func @select_and_scatter
func.func @select_and_scatter(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
compare_type = #mhlo<comparison_type TOTALORDER>,
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
func.func @select_and_scatter_with_unranked_dims(
%arg0: tensor<4x5x1x1xbf16>,
%arg1: tensor<2x2x1x1xbf16>,
%arg2: tensor<bf16>) -> tensor<?x?x?x?xbf16> {
%0 = mhlo.constant dense<0> : tensor<4x2xi32>
%1 = mhlo.constant dense<[2, 2, 1, 1]> : tensor<4xi32>
%2 = mhlo.constant dense<[2, 3, 1, 1]> : tensor<4xi32>
%3 = "mhlo.select_and_scatter"(%arg0, %arg1, %arg2) ({
^bb0(%arg3: tensor<*xbf16>, %arg4: tensor<*xbf16>):
%4 = "mhlo.compare"(%arg3, %arg4) {
compare_type = #mhlo<comparison_type TOTALORDER>,
comparison_direction = #mhlo<comparison_direction GE>}
: (tensor<*xbf16>, tensor<*xbf16>) -> tensor<*xi1>
"mhlo.return"(%4) : (tensor<*xi1>) -> ()
}, {
^bb0(%arg3: tensor<*xbf16>, %arg4: tensor<*xbf16>):
%4 = "mhlo.add"(%arg3, %arg4) : (tensor<*xbf16>, tensor<*xbf16>) ->
tensor<*xbf16>
"mhlo.return"(%4) : (tensor<*xbf16>) -> ()
}) {
padding = dense<0> : tensor<4x2xi64>,
window_dimensions = dense<[2, 3, 1, 1]> : tensor<4xi64>,
window_strides = dense<[2, 2, 1, 1]> : tensor<4xi64>}
: (tensor<4x5x1x1xbf16>, tensor<2x2x1x1xbf16>, tensor<bf16>) ->
tensor<?x?x?x?xbf16>
func.return %3 : tensor<?x?x?x?xbf16>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the select-region to take 2 parameters, but takes 1}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg3) {
compare_type = #mhlo<comparison_type TOTALORDER>,
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the type of select-region's parameter at index 0 to be 'tensor<f32>', but got 'tensor<i32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<i32>, tensor<i32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the type of select-region's parameter at index 0 to be 'tensor<f32>', but got 'tensor<1xf32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xi1>
%3 = "mhlo.reshape"(%2) : (tensor<1xi1>) -> tensor<i1>
"mhlo.return"(%3) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects select-region to return single value, but got: 2}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2, %2) : (tensor<i1>, tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the return-type of select-region to be tensor<i1>, but got: 'tensor<1xi1>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%3 = "mhlo.reshape"(%2) : (tensor<i1>) -> tensor<1xi1>
"mhlo.return"(%3) : (tensor<1xi1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_select_computation(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the return-type of select-region to be tensor<i1>, but got: 'tuple<tensor<i1>>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%3 = "mhlo.tuple"(%2) : (tensor<i1>) -> tuple<tensor<i1>>
"mhlo.return"(%3) : (tuple<tensor<i1>>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects window-dimensions size == operand rank, but got window-dimensions size: 0 and operand-type: 'tensor<10x24x24x64xf32>' with rank = 4.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects window-strides to have same dimension-size as size of window dimensions (4), but got: 3}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2]> : tensor<3xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects padding-entries to have same dimension-size as size of window dimensions (4), but got: 3}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<3x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the shape of padding-attribute to be {N, 2}, but got {4, 1}.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<[[2], [2], [0], [0]]> : tensor<4x1xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the padding-entries to have even number of elements, but got 5 elements.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<[2, 2, 0, 0, 0]> : tensor<5xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects window to have positive value for 3-th window dimension, but got 0.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 0]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_attributes(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> tensor<10x24x24x64xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects window to have positive stride for 0-th window dimension, but got 0.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[0, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return %1 : tensor<10x24x24x64xf32>
}
// -----
func.func @select_and_scatter_invalid_ret_type(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x32xf32>' and 'tensor<10x24x24x64xf32>' resp.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x32xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_ret_type(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects the return-type to match the operand-type, but got 'tensor<10x24x24x64xi32>' and 'tensor<10x24x24x64xf32>' resp.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xi32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{Reduction-region must take 2 parameters, but takes 1 parameter(s)}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>):
%2 = mhlo.add %arg3, %arg3 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{The reduction-region expected to return some value(s)}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"() : () -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{Reduction-region here must produce 1 tensors, but produces 2 instead}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2, %2) : (tensor<f32>, tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{Reduction-region here must produce tensor-typed result(s), but produces 'tuple<tensor<f32>>' instead}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
%3 = "mhlo.tuple"(%2) : (tensor<f32>) -> tuple<tensor<f32>>
"mhlo.return"(%3) : (tuple<tensor<f32>>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{Reduction-region here must produce tensor-typed result(s), but produces 'f32' instead}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: f32, %arg4: f32):
%2 = "llvm.add" (%arg3, %arg4) : (f32, f32) -> f32
"mhlo.return"(%2) : (f32) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{The type of reduction-region's parameter at index 0 is different than the corresponding result type: 'tensor<f32>' vs 'tensor<i32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.constant dense<0> : tensor<i32>
"mhlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{The type of reduction-region's parameter at index 1 is different than the corresponding result type: 'tensor<i32>' vs 'tensor<f32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<i32>):
%2 = mhlo.add %arg3, %arg3 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0> : tensor<i32>
// expected-error @+1 {{The element-type of reduction-region's argument at index 1 is expected to be 'f32', but got 'tensor<i32>' as its type.}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
%2 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<i32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{The type of reduction-region's result type at index 0 differs from the op's corresponding init-value type: 'tensor<i32>' vs 'tensor<f32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
%2 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_scatter_reducer(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<1xf32>
// expected-error @+1 {{The rank of reduction-region's argument at index 1 is expected to be <= 0, got 1}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<1xf32>, %arg4: tensor<1xf32>):
%2 = mhlo.add %arg3, %arg4 : tensor<1xf32>
"mhlo.return"(%2) : (tensor<1xf32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xf32>, tensor<1xf32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_source_operand(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x32xf32>) -> () {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
// expected-error @+1 {{expects source-type to be 'tensor<10x12x12x64xf32>', but got'tensor<10x12x12x32xf32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = mhlo.add %arg3, %arg4 : tensor<f32>
"mhlo.return"(%2) : (tensor<f32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x32xf32>, tensor<f32>) ->
tensor<10x24x24x64xf32>
func.return
}
// -----
func.func @select_and_scatter_invalid_source_operand(
%arg0: tensor<10x24x24x64xf32>,
%arg1: tensor<10x12x12x64xi32>) -> () {
%0 = mhlo.constant dense<0> : tensor<i32>
// expected-error @+1 {{expects source-type to be 'tensor<10x12x12x64xf32>', but got'tensor<10x12x12x64xi32>'}}
%1 = "mhlo.select_and_scatter"(%arg0, %arg1, %0) ({
^bb0(%arg3: tensor<f32>, %arg4: tensor<f32>):
%2 = "mhlo.compare"(%arg3, %arg4) {
comparison_direction = #mhlo<comparison_direction GE>
} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%2) : (tensor<i1>) -> ()
}, {
^bb0(%arg3: tensor<i32>, %arg4: tensor<i32>):
%2 = mhlo.add %arg3, %arg4 : tensor<i32>
"mhlo.return"(%2) : (tensor<i32>) -> ()
}) {
window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>,
window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>,
padding = dense<0> : tensor<4x2xi64>
} : (tensor<10x24x24x64xf32>, tensor<10x12x12x64xi32>, tensor<i32>) ->
tensor<10x24x24x64xf32>
func.return
}