blob: 0e66b80a4c77a3e1a5fc7803a101432e8e17aa13 [file] [log] [blame]
// RUN: mlir-hlo-opt -hlo-legalize-sort %s | FileCheck %s
func.func @sort(%arg0 : tensor<2xi32>, %arg1 : tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
%result:2 = "mhlo.sort"(%arg0, %arg1) ({
^bb0(%00: tensor<i32>, %01: tensor<i32>, %10: tensor<i32>, %11: tensor<i32>):
%50 = tensor.extract %00[] : tensor<i32>
%51 = tensor.extract %01[] : tensor<i32>
%52 = arith.cmpi sgt, %50, %51 : i32
%cmp_result = tensor.from_elements %52 : tensor<i1>
"mhlo.return"(%cmp_result) : (tensor<i1>) -> ()
}) {dimension = 0 : i64, is_stable = true} : (tensor<2xi32>, tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>)
func.return %result#0, %result#1 : tensor<2xi32>, tensor<2xi32>
}
// CHECK-LABEL: func @sort(
// CHECK-SAME: %[[ARG0:.*]]: tensor<2xi32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<2xi32>) -> (tensor<2xi32>, tensor<2xi32>) {
// Iterate through dimension 0
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[ARG0]], %[[C0_1]] : tensor<2xi32>
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
// Iterate through dimension 1
// CHECK: %[[VAL_6:.*]]:2 = scf.for %[[VAL_7:.*]] = %[[C0_0]] to %[[VAL_4]] step %[[C1_0]] iter_args(%[[VAL_8:.*]] = %[[ARG0]], %[[VAL_9:.*]] = %[[ARG1]]) -> (tensor<2xi32>, tensor<2xi32>) {
// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1
// CHECK-DAG: %[[C2_0:.*]] = arith.constant 2
// CHECK: %[[VAL_13:.*]] = arith.subi %[[C2_0]], %[[C1_1]] : index
// Iterate through sorted dimension
// CHECK: %[[VAL_14:.*]]:2 = scf.for %[[VAL_15:.*]] = %[[C0_2]] to %[[VAL_13]] step %[[C1_1]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_9]]) -> (tensor<2xi32>, tensor<2xi32>) {
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
// Extract each value twice because we are comparing both directions and haven't run CSE yet
// CHECK: %[[VAL_19:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
// CHECK: %[[VAL_20:.*]] = tensor.from_elements %[[VAL_19]] : tensor<i32>
// CHECK: %[[VAL_21:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_21]] : tensor<i32>
// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
// CHECK: %[[VAL_24:.*]] = tensor.from_elements %[[VAL_23]] : tensor<i32>
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_18]]] : tensor<2xi32>
// CHECK: %[[VAL_26:.*]] = tensor.from_elements %[[VAL_25]] : tensor<i32>
// CHECK: %[[VAL_27:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_29:.*]] = arith.cmpi sgt, %[[VAL_27]], %[[VAL_28]] : i32
// CHECK: %[[VAL_30:.*]] = tensor.from_elements %[[VAL_29]] : tensor<i1>
// CHECK: %[[VAL_31:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_33:.*]] = arith.cmpi sgt, %[[VAL_31]], %[[VAL_32]] : i32
// CHECK: %[[VAL_34:.*]] = tensor.from_elements %[[VAL_33]] : tensor<i1>
// Extract comparison results that were packed back into tensors by mhlo
// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_34]][] : tensor<i1>
// CHECK: %[[VAL_36:.*]] = tensor.extract %[[VAL_30]][] : tensor<i1>
// Determine if swapping should occur which happens only if NOT(CMP(A,B)) && CMP(B,A)
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[VAL_38:.*]] = arith.xori %[[VAL_35]], %[[TRUE]] : i1
// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_38]], %[[VAL_36]] : i1
// CHECK: %[[VAL_40:.*]]:2 = scf.if %[[VAL_39]] -> (tensor<2xi32>, tensor<2xi32>) {
// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
// Swap first pair of values
// CHECK: %[[VAL_42:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
// CHECK: %[[VAL_44:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_45:.*]] = tensor.insert %[[VAL_44]] into %[[VAL_43]]{{\[}}%[[VAL_41]]] : tensor<2xi32>
// Swap second pair of values
// CHECK: %[[VAL_46:.*]] = tensor.extract %[[VAL_26]][] : tensor<i32>
// CHECK: %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<2xi32>
// CHECK: %[[VAL_48:.*]] = tensor.extract %[[VAL_24]][] : tensor<i32>
// CHECK: %[[VAL_49:.*]] = tensor.insert %[[VAL_48]] into %[[VAL_47]]{{\[}}%[[VAL_41]]] : tensor<2xi32>
// CHECK: scf.yield %[[VAL_45]], %[[VAL_49]] : tensor<2xi32>, tensor<2xi32>
// CHECK: } else {
// Don't swap
// CHECK: scf.yield %[[VAL_16]], %[[VAL_17]] : tensor<2xi32>, tensor<2xi32>
// CHECK: }
// Propagate values back through the loops
// CHECK: scf.yield %[[VAL_40:.*]]#0, %[[VAL_40]]#1 : tensor<2xi32>, tensor<2xi32>
// CHECK: }
// CHECK: scf.yield %[[VAL_14:.*]]#0, %[[VAL_14]]#1 : tensor<2xi32>, tensor<2xi32>
// CHECK: }
// CHECK: return %[[VAL_6:.*]]#0, %[[VAL_6]]#1 : tensor<2xi32>, tensor<2xi32>
// CHECK: }
func.func @dyn_sort(%arg0 : tensor<?xi32>, %arg1 : tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>) {
%result:2 = "mhlo.sort"(%arg0, %arg1) ({
^bb0(%00: tensor<i32>, %01: tensor<i32>, %10: tensor<i32>, %11: tensor<i32>):
%50 = tensor.extract %00[] : tensor<i32>
%51 = tensor.extract %01[] : tensor<i32>
%52 = arith.cmpi sgt, %50, %51 : i32
%cmp_result = tensor.from_elements %52 : tensor<i1>
"mhlo.return"(%cmp_result) : (tensor<i1>) -> ()
}) {dimension = 0 : i64, is_stable = true} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
func.return %result#0, %result#1 : tensor<?xi32>, tensor<?xi32>
}
// CHECK-LABEL: func @dyn_sort(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>) {
// Iterate through dimension 0
// CHECK-DAG: %[[C0_0:.*]] = arith.constant 0
// CHECK-DAG: %[[C0_1:.*]] = arith.constant 0
// CHECK: %[[VAL_4:.*]] = tensor.dim %[[ARG0]], %[[C0_1]] : tensor<?xi32>
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
// Iterate through dimension 1
// CHECK: %[[VAL_6:.*]]:2 = scf.for %[[VAL_7:.*]] = %[[C0_0]] to %[[VAL_4]] step %[[C1_0]] iter_args(%[[VAL_8:.*]] = %[[ARG0]], %[[VAL_9:.*]] = %[[ARG1]]) -> (tensor<?xi32>, tensor<?xi32>) {
// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1
// CHECK-DAG: %[[C0_3:.*]] = arith.constant 0
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C0_3]] : tensor<?xi32>
// CHECK: %[[VAL_13:.*]] = arith.subi %[[DIM]], %[[C1_1]] : index
// Iterate through sorted dimension
// CHECK: %[[VAL_14:.*]]:2 = scf.for %[[VAL_15:.*]] = %[[C0_2]] to %[[VAL_13]] step %[[C1_1]] iter_args(%[[VAL_16:.*]] = %[[VAL_8]], %[[VAL_17:.*]] = %[[VAL_9]]) -> (tensor<?xi32>, tensor<?xi32>) {
// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
// Extract each value twice because we are comparing both directions and haven't run CSE yet
// CHECK: %[[VAL_19:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
// CHECK: %[[VAL_20:.*]] = tensor.from_elements %[[VAL_19]] : tensor<i32>
// CHECK: %[[VAL_21:.*]] = tensor.extract %[[VAL_16]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
// CHECK: %[[VAL_22:.*]] = tensor.from_elements %[[VAL_21]] : tensor<i32>
// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
// CHECK: %[[VAL_24:.*]] = tensor.from_elements %[[VAL_23]] : tensor<i32>
// CHECK: %[[VAL_25:.*]] = tensor.extract %[[VAL_17]]{{\[}}%[[VAL_18]]] : tensor<?xi32>
// CHECK: %[[VAL_26:.*]] = tensor.from_elements %[[VAL_25]] : tensor<i32>
// CHECK: %[[VAL_27:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_28:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_29:.*]] = arith.cmpi sgt, %[[VAL_27]], %[[VAL_28]] : i32
// CHECK: %[[VAL_30:.*]] = tensor.from_elements %[[VAL_29]] : tensor<i1>
// CHECK: %[[VAL_31:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_32:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_33:.*]] = arith.cmpi sgt, %[[VAL_31]], %[[VAL_32]] : i32
// CHECK: %[[VAL_34:.*]] = tensor.from_elements %[[VAL_33]] : tensor<i1>
// Extract comparison results that were packed back into tensors by mhlo
// CHECK: %[[VAL_35:.*]] = tensor.extract %[[VAL_34]][] : tensor<i1>
// CHECK: %[[VAL_36:.*]] = tensor.extract %[[VAL_30]][] : tensor<i1>
// Determine if swapping should occur which happens only if NOT(CMP(A,B)) && CMP(B,A)
// CHECK: %[[TRUE:.*]] = arith.constant true
// CHECK: %[[VAL_38:.*]] = arith.xori %[[VAL_35]], %[[TRUE]] : i1
// CHECK: %[[VAL_39:.*]] = arith.andi %[[VAL_38]], %[[VAL_36]] : i1
// CHECK: %[[VAL_40:.*]]:2 = scf.if %[[VAL_39]] -> (tensor<?xi32>, tensor<?xi32>) {
// CHECK: %[[VAL_41:.*]] = arith.addi %[[VAL_15]], %[[C1_1]] : index
// Swap first pair of values
// CHECK: %[[VAL_42:.*]] = tensor.extract %[[VAL_22]][] : tensor<i32>
// CHECK: %[[VAL_43:.*]] = tensor.insert %[[VAL_42]] into %[[VAL_16]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
// CHECK: %[[VAL_44:.*]] = tensor.extract %[[VAL_20]][] : tensor<i32>
// CHECK: %[[VAL_45:.*]] = tensor.insert %[[VAL_44]] into %[[VAL_43]]{{\[}}%[[VAL_41]]] : tensor<?xi32>
// Swap second pair of values
// CHECK: %[[VAL_46:.*]] = tensor.extract %[[VAL_26]][] : tensor<i32>
// CHECK: %[[VAL_47:.*]] = tensor.insert %[[VAL_46]] into %[[VAL_17]]{{\[}}%[[VAL_15]]] : tensor<?xi32>
// CHECK: %[[VAL_48:.*]] = tensor.extract %[[VAL_24]][] : tensor<i32>
// CHECK: %[[VAL_49:.*]] = tensor.insert %[[VAL_48]] into %[[VAL_47]]{{\[}}%[[VAL_41]]] : tensor<?xi32>
// CHECK: scf.yield %[[VAL_45]], %[[VAL_49]] : tensor<?xi32>, tensor<?xi32>
// CHECK: } else {
// Don't swap
// CHECK: scf.yield %[[VAL_16]], %[[VAL_17]] : tensor<?xi32>, tensor<?xi32>
// CHECK: }
// Propagate values back through the loops
// CHECK: scf.yield %[[VAL_40:.*]]#0, %[[VAL_40]]#1 : tensor<?xi32>, tensor<?xi32>
// CHECK: }
// CHECK: scf.yield %[[VAL_14:.*]]#0, %[[VAL_14]]#1 : tensor<?xi32>, tensor<?xi32>
// CHECK: }
// CHECK: return %[[VAL_6:.*]]#0, %[[VAL_6]]#1 : tensor<?xi32>, tensor<?xi32>
// CHECK: }