blob: d626f5208244c74acc598d922bffb62d039fcdfe [file] [log] [blame]
// RUN: mlir-hlo-opt --mhlo-test-infer-shaped-type-methods --allow-unregistered-dialect --split-input-file %s | FileCheck %s
// -----
// CHECK-LABEL: @select
// CHECK-SAME: (%[[PRED:.*]]: tensor<2x?xi1>,
func @select(%pred : tensor<2x?xi1>, %a : tensor<2x?xf32>, %b : tensor<2x?xf32>)
-> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[PRED]], %[[C1]] : tensor<2x?xi1>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.select"(%pred, %a, %b)
: (tensor<2x?xi1>, tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xf32>) -> tensor<2xi64>
return %1 : tensor<2xi64>
}
// -----
// CHECK-LABEL: @compare
// CHECK-SAME: (%[[A:.*]]: tensor<2x?xf32>,
func @compare(%a : tensor<2x?xf32>, %b : tensor<2x?xf32>) -> tensor<2xi64> {
// CHECK: %[[C2:.*]] = constant 2 : i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM_AS_INDEX:.*]] = dim %[[A]], %[[C1]] : tensor<2x?xf32>
// CHECK: %[[DIM:.*]] = index_cast %[[DIM_AS_INDEX]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[C2]], %[[DIM]] : tensor<2xi64>
// CHECK: return %[[SHAPE]] : tensor<2xi64>
%0 = "mhlo.compare"(%a, %b) { comparison_direction = "NE" }
: (tensor<2x?xf32>, tensor<2x?xf32>) -> tensor<2x?xi1>
%1 = "mhlo_test.reify_return_type_shapes"(%0)
: (tensor<2x?xi1>) -> tensor<2xi64>
return %1 : tensor<2xi64>
}