blob: fe6ab0e1860da2ac7ff40bdbe0ef3bd7e1e4ac9f [file] [log] [blame]
// RUN: mlir-hlo-opt %s -hlo-legalize-shape-computations -split-input-file | FileCheck %s
// CHECK-LABEL: func @get_dimension_size
func.func @get_dimension_size(%arg0: tensor<?x?xf32>) -> (tensor<i32>) {
%1 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<?x?xf32>) -> tensor<i32>
func.return %1 : tensor<i32>
}
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
// CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[DIM]]
// CHECK-DAG: %[[FROM:.+]] = tensor.from_elements %[[IDX]]
// CHECK: return %[[FROM]] : tensor<i32>
// -----
// CHECK-LABEL: func @reshape_dimension_size
func.func @reshape_dimension_size(%arg0: tensor<?x?xf32>) -> (tensor<1xi32>) {
%0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<?x?xf32>) -> tensor<i32>
%1 = "mhlo.reshape"(%0) : (tensor<i32>) -> tensor<1xi32>
func.return %1 : tensor<1xi32>
}
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
// CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[DIM]]
// CHECK-DAG: %[[FROM:.+]] = tensor.from_elements %[[IDX]]
// CHECK: return %[[FROM]] : tensor<1xi32>
// -----
// CHECK-LABEL: func @multiply_dimension_size
func.func @multiply_dimension_size(%arg0: tensor<?x?xf32>) -> (tensor<i32>) {
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<?x?xf32>) -> tensor<i32>
%2 = "mhlo.multiply"(%0, %1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
func.return %2 : tensor<i32>
}
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[C2:.+]] = arith.constant 2
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
// CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[DIM]]
// CHECK-DAG: %[[MUL:.+]] = arith.muli %[[IDX]], %[[C2]]
// CHECK-DAG: %[[RES:.+]] = tensor.from_elements %[[MUL]]
// CHECK: return %[[RES]]
// -----
// CHECK-LABEL: func @concat_dimension_size
func.func @concat_dimension_size(%arg0: tensor<?x?xf32>) -> (tensor<2xi32>) {
%0 = "mhlo.get_dimension_size"(%arg0) {dimension = 1 : i64} : (tensor<?x?xf32>) -> tensor<i32>
%1 = "mhlo.reshape"(%0) : (tensor<i32>) -> tensor<1xi32>
%2 = mhlo.constant dense<2> : tensor<1xi32>
%3 = "mhlo.concatenate"(%1, %2) {dimension = 0 : i64} : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
func.return %3 : tensor<2xi32>
}
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[C2:.+]] = arith.constant 2
// CHECK-DAG: %[[DIM:.+]] = tensor.dim %arg0, %[[C1]]
// CHECK-DAG: %[[IDX:.+]] = arith.index_cast %[[DIM]]
// CHECK-DAG: %[[RES:.+]] = tensor.from_elements %[[IDX]], %[[C2]]
// CHECK: return %[[RES]]