blob: 987a6e4e94f79d07db120154a0b4cc2bff65a91e [file] [log] [blame]
// RUN: mlir-hlo-opt -tensor-bufferize -hlo-legalize-to-memref -cse --split-input-file %s | FileCheck %s
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
// CHECK-LABEL: func @dyn_broadcast
func.func @dyn_broadcast(%operand: tensor<?x?xf32>) -> tensor<?x?x?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
%c1 = arith.constant 1 : i64
%shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xi64>
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
func.return %result : tensor<?x?x?xf32>
}
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[OPERAND:.*]] = bufferization.to_memref %[[ARG]]
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EXPAND_1:.*]] = arith.cmpi slt, %[[OPER_DIM_0]], %[[C1]] : index
// CHECK: %[[STRIDE_1:.*]] = arith.select %[[EXPAND_1]], %[[C0]], %[[OPER_DIM_1]] : index
// CHECK: %[[EXPAND_2:.*]] = arith.cmpi slt, %[[OPER_DIM_1]], %[[C1]] : index
// CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[C1]], %[[C1]], %[[C1]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
// CHECK: return %[[RESULT]]
// -----
// CHECK-LABEL: func @dyn_broadcast_unsigned
func.func @dyn_broadcast_unsigned(%operand: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xi32>, %[[SHAPE:.*]]: tensor<3xi64>
%c1 = arith.constant 1 : i64
%result = "mhlo.dynamic_broadcast_in_dim"(%operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
func.return %result : tensor<?x?x?xi32>
}
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[OPERAND:.*]] = bufferization.to_memref %[[ARG]]
// CHECK: %[[OPER_DIM_1:.*]] = memref.dim %[[OPERAND]], %[[C1]] : memref<?x?xi32>
// CHECK: %[[OPER_DIM_0:.*]] = memref.dim %[[OPERAND]], %[[C0]] : memref<?x?xi32>
// CHECK: %[[EL0:.*]] = tensor.extract %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = arith.index_cast %[[EL0]] : i64 to index
// CHECK: %[[EL1:.*]] = tensor.extract %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = arith.index_cast %[[EL1]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = arith.cmpi slt, %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = arith.select %[[EXPAND_1]], %[[C0]], %[[OPER_DIM_1]] : index
// CHECK: %[[EL2:.*]] = tensor.extract %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = arith.index_cast %[[EL2]] : i64 to index
// CHECK: %[[EXPAND_2:.*]] = arith.cmpi slt, %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = arith.select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref.reinterpret_cast %[[OPERAND]] to offset: [0], sizes: [%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: [%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]] : memref<?x?xi32> to memref<?x?x?xi32, #map>
// CHECK: %[[RESULT:.*]] = bufferization.to_tensor %[[TRANSFORMED_MEMREF]]
// CHECK: return %[[RESULT]]
// -----
// CHECK-LABEL: func @dyn_reshape_unsigned
func.func @dyn_reshape_unsigned(%operand: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xi32>, %[[SHAPE:.*]]: tensor<3xi64>
%c1 = arith.constant 1 : i64
%result = "mhlo.dynamic_reshape"(%operand, %shape) : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
func.return %result : tensor<?x?x?xi32>
}
// CHECK-DAG: %[[OPERAND:.*]] = bufferization.to_memref %[[ARG]]
// CHECK-DAG: %[[BSHAPE:.*]] = bufferization.to_memref %[[SHAPE]]
// CHECK: %[[RESHAPED:.*]] = memref.reshape %[[OPERAND]](%[[BSHAPE]]) : (memref<?x?xi32>, memref<3xi64>) -> memref<?x?x?xi32>
// CHECK: %[[TRESULT:.*]] = bufferization.to_tensor %[[RESHAPED]] : memref<?x?x?xi32>
// CHECK: return %[[TRESULT]]
// -----
// CHECK-LABEL: func @dyn_reshape_of_extract_slice
func.func @dyn_reshape_of_extract_slice(%operand: tensor<?xi32>, %shape: tensor<3xi64>, %offset: index) -> tensor<?x?x?xi32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xi32>, %[[SHAPE:.*]]: tensor<3xi64>, %[[OFFSET:.*]]: index
%0 = tensor.extract_slice %operand[%offset][11][1] : tensor<?xi32> to tensor<11xi32>
%result = "mhlo.dynamic_reshape"(%0, %shape) : (tensor<11xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
func.return %result : tensor<?x?x?xi32>
}
// CHECK-DAG: %[[OPERAND:.*]] = bufferization.to_memref %[[ARG]]
// CHECK-DAG: %[[BSHAPE:.*]] = bufferization.to_memref %[[SHAPE]]
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[OPERAND]][%[[OFFSET]]] [11] [1]
// CHECK: %[[ALLOC:.*]] = memref.alloc() {alignment = 128 : i64}
// CHECK: memref.copy %[[SUBVIEW]], %[[ALLOC]]
// CHECK: %[[RESHAPED:.*]] = memref.reshape %[[ALLOC]](%[[BSHAPE]]) : (memref<11xi32>, memref<3xi64>) -> memref<?x?x?xi32>
// CHECK: %[[TRESULT:.*]] = bufferization.to_tensor %[[RESHAPED]] : memref<?x?x?xi32>
// CHECK: return %[[TRESULT]]
// -----
// CHECK-LABEL: func @reshape_unsigned
func.func @reshape_unsigned(%operand: tensor<*xi32>) -> tensor<4x3xi32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<*xi32>
%result = "mhlo.reshape"(%operand) : (tensor<*xi32>) -> tensor<4x3xi32>
func.return %result : tensor<4x3xi32>
}
// CHECK: %[[OPERAND:.*]] = bufferization.to_memref %[[ARG]]
// CHECK: %[[RESHAPED:.*]] = memref.cast %[[OPERAND]] : memref<*xi32> to memref<4x3xi32>
// CHECK: %[[TRESULT:.*]] = bufferization.to_tensor %[[RESHAPED]] : memref<4x3xi32>
// CHECK: return %[[TRESULT]]