blob: e6ba7e5123e9108756e711922634e2bc340fc615 [file] [log] [blame]
// RUN: mlir-hlo-opt %s --split-input-file --symbolic-shape-optimization | \
// RUN: FileCheck %s
// CHECK-LABEL: func @reshape_expand_front
func.func @reshape_expand_front(%arg0: tensor<?x?xf32>) -> tensor<1x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%shape = tensor.from_elements %c1, %d0, %d1 : tensor<3xindex>
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<?x?xf32>, tensor<3xindex>) -> tensor<1x?x?xf32>
// CHECK: tensor.expand_shape %arg0 [
// CHECK-SAME: [0, 1], [2]] : tensor<?x?xf32> into tensor<1x?x?xf32>
func.return %reshape : tensor<1x?x?xf32>
// CHECK-LABEL: func @reshape_expand_front_static
func.func @reshape_expand_front_static(%arg0: tensor<2x?xf32>) -> tensor<1x2x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg0, %c0 : tensor<2x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<2x?xf32>
%shape = tensor.from_elements %c1, %d0, %d1 : tensor<3xindex>
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<2x?xf32>, tensor<3xindex>) -> tensor<1x2x?xf32>
// CHECK: tensor.expand_shape %arg0 [
// CHECK-SAME: [0, 1], [2]] : tensor<2x?xf32> into tensor<1x2x?xf32>
func.return %reshape : tensor<1x2x?xf32>
// -----
// CHECK-LABEL: func @reshape_expand_back
func.func @reshape_expand_back(%arg0: tensor<?x?xf32>) -> tensor<?x?x1x1xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%shape = tensor.from_elements %d0, %d1, %c1, %c1 : tensor<4xindex>
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<?x?xf32>, tensor<4xindex>) -> tensor<?x?x1x1xf32>
// CHECK: tensor.expand_shape %arg0 [
// CHECK-SAME: [0], [1, 2, 3]] : tensor<?x?xf32> into tensor<?x?x1x1xf32>
func.return %reshape : tensor<?x?x1x1xf32>
// -----
// CHECK-LABEL: @reshape_expand_scalar
// CHECK-SAME: %[[ARG:.*]]: tensor<f32>
func.func @reshape_expand_scalar(%arg0: tensor<f32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[EXPAND:.*]] = tensor.expand_shape %[[ARG]] [] : tensor<f32> into tensor<1x1xf32>
// CHECK-DAG: %[[RES:.*]] = tensor.cast %[[EXPAND]] : tensor<1x1xf32> to tensor<?x?xf32>
// CHECK: return %[[RES]]
%shape = mhlo.constant dense<1> : tensor<2xi32>
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<f32>, tensor<2xi32>) -> tensor<?x?xf32>
func.return %reshape : tensor<?x?xf32>
// -----
// CHECK-LABEL: @reshape_collapse_scalar
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?xf32>
func.func @reshape_collapse_scalar(%arg0 : tensor<?x?xf32>) -> tensor<f32> {
%shape = mhlo.constant dense<1> : tensor<0xi32>
// CHECK-DAG: %[[CASTED_ARG:.*]] = tensor.cast %[[ARG]] : tensor<?x?xf32> to tensor<1x1xf32>
// CHECK-DAG: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[CASTED_ARG]] [] : tensor<1x1xf32> into tensor<f32>
// CHECK: return %[[COLLAPSED]]
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<?x?xf32>, tensor<0xi32>) -> tensor<f32>
func.return %reshape : tensor<f32>
// -----
// CHECK-LABEL: func @reshape_undefined
func.func @reshape_undefined(%arg0: tensor<?xf32>) -> tensor<1x1x1xf32> {
// CHECK: mhlo.dynamic_reshape
%c1 = arith.constant 1 : index
%shape = tensor.from_elements %c1, %c1, %c1 : tensor<3xindex>
%reshape = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<?xf32>, tensor<3xindex>) -> tensor<1x1x1xf32>
func.return %reshape : tensor<1x1x1xf32>
// -----
// CHECK-LABEL: @shape_expansion
// CHECK-SAME: %[[ARG:.*]]: tensor<?x1xi64>
func.func @shape_expansion(%arg : tensor<?x1xi64>) -> tensor<?x1x1xi64> {
// CHECK-DAG: %[[RES:.*]] = tensor.expand_shape %[[ARG]] {{\[}}[0], [1, 2]{{\]}} : tensor<?x1xi64> into tensor<?x1x1xi64>
// CHECK: return %[[RES]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%d0 = tensor.dim %arg, %c0 : tensor<?x1xi64>
%shape = tensor.from_elements %d0, %c1, %c1 : tensor<3xindex>
%result = "mhlo.dynamic_reshape"(%arg, %shape)
: (tensor<?x1xi64>, tensor<3xindex>) -> tensor<?x1x1xi64>
func.return %result : tensor<?x1x1xi64>
// -----
// CHECK-LABEL: @shape_collapse_and_expansion
// CHECK-SAME: %[[ARG:.*]]: tensor<3x?x1xi64>
func.func @shape_collapse_and_expansion(%arg : tensor<3x?x1xi64>)
-> tensor<?x1x1xi64> {
// CHECK: %[[EED:.*]] = tensor.expand_shape %[[ARG]] {{\[}}[0], [1], [2, 3]{{\]}} : tensor<3x?x1xi64> into tensor<3x?x1x1xi64>
// CHECK: %[[CED:.*]] = tensor.collapse_shape %[[EED]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<3x?x1x1xi64> into tensor<?x1x1xi64>
// CHECK: return %[[CED]]
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%d1 = tensor.dim %arg, %c1 : tensor<3x?x1xi64>
%three_d1 = arith.muli %c3, %d1 : index
%15 = tensor.from_elements %three_d1, %c1, %c1 : tensor<3xindex>
%16 = "mhlo.dynamic_reshape"(%arg, %15)
: (tensor<3x?x1xi64>, tensor<3xindex>) -> tensor<?x1x1xi64>
func.return %16 : tensor<?x1x1xi64>
// -----
// CHECK-LABEL: @shape_collapse_and_expansion_w_cast
// CHECK-SAME: %[[ARG:.*]]: tensor<16x8x?x?xf32>
func.func @shape_collapse_and_expansion_w_cast(%arg0: tensor<16x8x?x?xf32>) -> tensor<16x4x?x?xf32> {
// CHECK-DAG: %[[EED:.*]] = tensor.expand_shape %[[ARG]] {{\[}}[0], [1, 2], [3], [4]{{\]}} : tensor<16x8x?x?xf32> into tensor<16x4x2x?x?xf32>
// CHECK-DAG: %[[CED:.*]] = tensor.collapse_shape %[[EED]] {{\[}}[0], [1], [2], [3, 4]{{\]}} : tensor<16x4x2x?x?xf32> into tensor<16x4x2x?xf32>
// CHECK-DAG: %[[RES:.*]] = tensor.cast %[[CED]]
// CHECK: return %[[RES]]
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c16 = arith.constant 16 : index
%1 = tensor.dim %arg0, %c2 : tensor<16x8x?x?xf32>
%2 = tensor.dim %arg0, %c3 : tensor<16x8x?x?xf32>
%4 = arith.muli %1, %2 : index
%5 = tensor.from_elements %c16, %c4, %c2, %4 : tensor<4xindex>
%6 = "mhlo.dynamic_reshape"(%arg0, %5) : (tensor<16x8x?x?xf32>, tensor<4xindex>) -> tensor<16x4x?x?xf32>
func.return %6 : tensor<16x4x?x?xf32>
// -----
// CHECK-LABEL: @dynamic_reshape_to_collapse_shape
// CHECK-SAME: %[[ARG:.*]]: tensor<1x4x?x64x?x8x1x1xf32>
func.func @dynamic_reshape_to_collapse_shape(%arg0 : tensor<1x4x?x64x?x8x1x1xf32>)
-> tensor<?x?x8xf32> {
// CHECK: %[[RESULT:.*]] = tensor.collapse_shape %[[ARG]] {{\[}}[0, 1, 2], [3, 4], [5, 6, 7]{{\]}}
// CHECK: return %[[RESULT]]
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%c4_i32 = arith.constant 4 : i32
%c8_i32 = arith.constant 8 : i32
%c64_i32 = arith.constant 64 : i32
%d2 = tensor.dim %arg0, %c2 : tensor<1x4x?x64x?x8x1x1xf32>
%d4 = tensor.dim %arg0, %c4 : tensor<1x4x?x64x?x8x1x1xf32>
%d2_i32 = arith.index_cast %d2 : index to i32
%d4_i32 = arith.index_cast %d4 : index to i32
%s0 = arith.muli %c4_i32, %d2_i32 : i32
%s1 = arith.muli %c64_i32, %d4_i32 : i32
%shape = tensor.from_elements %s0, %s1, %c8_i32 : tensor<3xi32>
%result = "mhlo.dynamic_reshape"(%arg0, %shape)
: (tensor<1x4x?x64x?x8x1x1xf32>, tensor<3xi32>) -> tensor<?x?x8xf32>
func.return %result : tensor<?x?x8xf32>
// -----
// CHECK-LABEL: @expansion_unit_dims
// CHECK-SAME: %[[ARG:.*]]: tensor<1x?x1xi64>
func.func @expansion_unit_dims(%arg0: tensor<1x?x1xi64>) -> tensor<1x1x?x1xi64> {
// CHECK-DAG: %[[RES:.*]] = tensor.expand_shape %[[ARG]] {{\[}}[0, 1], [2], [3]{{\]}} : tensor<1x?x1xi64> into tensor<1x1x?x1xi64>
// CHECK: return %[[RES]]
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%0 = tensor.dim %arg0, %c1 : tensor<1x?x1xi64>
%1 = tensor.from_elements %c1, %c1, %0, %c1 : tensor<4xindex>
%2 = "mhlo.dynamic_reshape"(%arg0, %1)
: (tensor<1x?x1xi64>, tensor<4xindex>) -> tensor<1x1x?x1xi64>
func.return %2 : tensor<1x1x?x1xi64>
// -----
// CHECK-LABEL: @multiple_reductions_and_reshape
// CHECK-SAME: %[[ARG:.*]]: tensor<?x?x?x?xi64>
func.func @multiple_reductions_and_reshape(%arg0: tensor<?x?x?x?xi64>) -> tensor<1x1x1x1xi64> {
// CHECK: %[[RED0:.*]] = mhlo.reduce(%[[ARG]]
// CHECK: %[[RED0_:.*]] = tensor.expand_shape %[[RED0]] {{\[}}[0], [1], [2, 3]{{\]}} : tensor<?x?x?xi64> into tensor<?x?x?x1xi64>
// CHECK: %[[RED1:.*]] = mhlo.reduce(%[[RED0_]]
// CHECK: %[[RED1_:.*]] = tensor.expand_shape %[[RED1]] {{\[}}[0, 1, 2], [3]{{\]}} : tensor<?x1xi64> into tensor<1x1x?x1xi64>
// CHECK: %[[RED2:.*]] = mhlo.reduce(%[[RED1_]]
// TODO(b/225204462): This should also become a shape expansion.
// CHECK: %[[RED2_:.*]] = "mhlo.reshape"(%[[RED2]]) : (tensor<1xi64>) -> tensor<1x1x1x1xi64>
// CHECK: return %[[RED2_]]
%0 = mhlo.constant dense<9223372036854775807> : tensor<i64>
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%1 = mhlo.constant dense<1> : tensor<i64>
%2 = mhlo.reduce(%arg0 init: %0)
applies mhlo.minimum across dimensions = [3]
: (tensor<?x?x?x?xi64>, tensor<i64>) -> tensor<?x?x?xi64>
%3 = tensor.dim %2, %c0 : tensor<?x?x?xi64>
%4 = tensor.dim %2, %c1 : tensor<?x?x?xi64>
%5 = tensor.dim %2, %c2 : tensor<?x?x?xi64>
%6 = tensor.from_elements %3, %4, %5, %c1 : tensor<4xindex>
%7 = "mhlo.dynamic_reshape"(%2, %6)
: (tensor<?x?x?xi64>, tensor<4xindex>) -> tensor<?x?x?x1xi64>
%8 = mhlo.reduce(%7 init: %0)
applies mhlo.minimum across dimensions = [0, 1]
: (tensor<?x?x?x1xi64>, tensor<i64>) -> tensor<?x1xi64>
%9 = tensor.dim %8, %c0 : tensor<?x1xi64>
%10 = tensor.from_elements %c1, %9, %c1 : tensor<3xindex>
%11 = "mhlo.dynamic_reshape"(%8, %10)
: (tensor<?x1xi64>, tensor<3xindex>) -> tensor<1x?x1xi64>
%12 = tensor.dim %11, %c1 : tensor<1x?x1xi64>
%13 = tensor.from_elements %c1, %c1, %12, %c1 : tensor<4xindex>
%14 = "mhlo.dynamic_reshape"(%8, %13)
: (tensor<?x1xi64>, tensor<4xindex>) -> tensor<1x1x?x1xi64>
%15 = mhlo.reduce(%14 init: %1)
applies mhlo.multiply across dimensions = [0, 1, 2]
: (tensor<1x1x?x1xi64>, tensor<i64>) -> tensor<1xi64>
%16 = "mhlo.reshape"(%15) : (tensor<1xi64>) -> tensor<1x1x1x1xi64>
func.return %16 : tensor<1x1x1x1xi64>
// -----
// CHECK-LABEL: func @compute_reshape_shape
func.func @compute_reshape_shape(%arg0: tensor<?x?xf32>, %arg1: index)
-> tensor<2xi32> {
%shape = shape.shape_of %arg0: tensor<?x?xf32> -> tensor<2xindex>
%casted = arith.index_cast %shape : tensor<2xindex> to tensor<2xi32>
%mul = mhlo.multiply %casted, %casted : tensor<2xi32>
// CHECK: %[[MUL:.*]] = mhlo.multiply
%crs = mhlo.compute_reshape_shape %arg1, %mul
: index, tensor<2xi32> -> tensor<2xi32>
func.return %crs : tensor<2xi32>
// CHECK: return %[[MUL]] : tensor<2xi32>
// -----
// CHECK-LABEL: func @compute_reshape_shape
func.func @compute_reshape_shape(%arg0: tensor<2xi32>, %arg1: index)
-> tensor<2xi32> {
%mul = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
%crs = mhlo.compute_reshape_shape %arg1, %mul
: index, tensor<2xi32> -> tensor<2xi32>
// CHECK: mhlo.compute_reshape_shape
func.return %crs : tensor<2xi32>
// -----
// CHECK-LABEL: @redundant_cstr_reshapable
func.func @redundant_cstr_reshapable(%arg0 : tensor<?x8x?x64xf32>)
-> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.const_witness true
// CHECK: return %[[WITNESS]] : !shape.witness
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c512 = arith.constant 512 : index
%s0 = shape.shape_of %arg0 : tensor<?x8x?x64xf32> -> tensor<4xindex>
%n0 = shape.num_elements %s0 : tensor<4xindex> -> index
%dim00 = tensor.dim %arg0, %c0 : tensor<?x8x?x64xf32>
%dim02 = tensor.dim %arg0, %c2 : tensor<?x8x?x64xf32>
%s1_ = tensor.from_elements %dim02, %dim00, %c512 : tensor<3xindex>
%s1 = arith.index_cast %s1_ : tensor<3xindex> to tensor<3xi32>
%w = mhlo.cstr_reshapable %n0, %s1 : index, tensor<3xi32>
func.return %w : !shape.witness
// -----
// CHECK-LABEL: @redundant_cstr_reshapable_less_obvious
func.func @redundant_cstr_reshapable_less_obvious(%arg0 : tensor<?x4x?x64xf32>)
-> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.const_witness true
// CHECK: return %[[WITNESS]] : !shape.witness
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : i32
%s0 = shape.shape_of %arg0 : tensor<?x4x?x64xf32> -> tensor<4xindex>
%n0 = shape.num_elements %s0 : tensor<4xindex> -> index
%dim00 = tensor.dim %arg0, %c0 : tensor<?x4x?x64xf32>
%dim00_twice = arith.muli %c2, %dim00 : index
%dim00_twice_ = arith.index_cast %dim00_twice : index to i32
%dim02 = tensor.dim %arg0, %c2 : tensor<?x4x?x64xf32>
%dim02_ = arith.index_cast %dim02 : index to i32
%s1 = tensor.from_elements %dim02_, %dim00_twice_, %c128 : tensor<3xi32>
%w = mhlo.cstr_reshapable %n0, %s1 : index, tensor<3xi32>
func.return %w : !shape.witness
// -----
// CHECK-LABEL: @redundant_cstr_reshapable
func.func @redundant_cstr_reshapable(%arg0 : tensor<?x8x?x64xf32>)
-> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.const_witness true
// CHECK: return %[[WITNESS]] : !shape.witness
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c128 = arith.constant 128 : i32
%cminus1 = arith.constant -1 : i32
%s0 = shape.shape_of %arg0 : tensor<?x8x?x64xf32> -> tensor<4xindex>
%n0 = shape.num_elements %s0 : tensor<4xindex> -> index
%dim02 = tensor.dim %arg0, %c2 : tensor<?x8x?x64xf32>
%dim02_ = arith.index_cast %dim02 : index to i32
%s1 = tensor.from_elements %dim02_, %cminus1, %c128 : tensor<3xi32>
%w = mhlo.cstr_reshapable %n0, %s1 : index, tensor<3xi32>
func.return %w : !shape.witness
// -----
// CHECK-LABEL: @nonredundant_cstr_reshapable
func.func @nonredundant_cstr_reshapable(%arg0 : tensor<?x8x?x64xf32>)
-> !shape.witness {
// CHECK: %[[WITNESS:.*]] = mhlo.cstr_reshapable %{{.*}}, %{{.*}}
// CHECK: return %[[WITNESS]] : !shape.witness
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c42 = arith.constant 42 : i32
%cminus1 = arith.constant -1 : i32
%s0 = shape.shape_of %arg0 : tensor<?x8x?x64xf32> -> tensor<4xindex>
%n0 = shape.num_elements %s0 : tensor<4xindex> -> index
%dim02 = tensor.dim %arg0, %c2 : tensor<?x8x?x64xf32>
%dim02_ = arith.index_cast %dim02 : index to i32
%s1 = tensor.from_elements %dim02_, %cminus1, %c42 : tensor<3xi32>
%w = mhlo.cstr_reshapable %n0, %s1 : index, tensor<3xi32>
func.return %w : !shape.witness
// -----
// CHECK-LABEL: @redundant_cstr_reshapable
func.func @redundant_cstr_reshapable(%arg0 : tensor<?x8x?x64xf32>)
-> !shape.witness {
// CHECK: %[[WITNESS:.*]] = shape.const_witness true
// CHECK: return %[[WITNESS]] : !shape.witness
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c64 = arith.constant 64 : i32
%cminus1 = arith.constant -1 : i32
%s0 = shape.shape_of %arg0 : tensor<?x8x?x64xf32> -> tensor<4xindex>
%n0 = shape.num_elements %s0 : tensor<4xindex> -> index
%dim00 = tensor.dim %arg0, %c0 : tensor<?x8x?x64xf32>
%dim02 = tensor.dim %arg0, %c2 : tensor<?x8x?x64xf32>
%dim00_ = arith.index_cast %dim00 : index to i32
%dim02_ = arith.index_cast %dim02 : index to i32
%s1 = tensor.from_elements %dim02_, %dim00_ , %c64, %cminus1 : tensor<4xi32>
%w = mhlo.cstr_reshapable %n0, %s1 : index, tensor<4xi32>
func.return %w : !shape.witness
// -----
// CHECK-LABEL: func @reshape_integration(
// CHECK-SAME: %arg0: tensor<512x512xf32>,
// CHECK-SAME: %arg1: tensor<?x8x?x64xf32>,
// CHECK-SAME: %[[DYN_SHAPE:.*]]: tensor<4xi32>,
// CHECK-SAME: %arg3: tensor<512xf32>,
// CHECK-SAME: %arg4: tensor<?x?x512xf32>,
// CHECK-SAME: %arg5: tensor<512xf32>,
// CHECK-SAME: %arg6: tensor<512xf32>,
// CHECK-SAME: %arg7: tensor<512x2048xf32>,
// CHECK-SAME: %arg8: tensor<2048xf32>,
// CHECK-SAME: %arg9: tensor<2048x512xf32>,
// CHECK-SAME: %arg10: tensor<512xf32>,
// CHECK-SAME: %arg11: tensor<512xf32>,
// CHECK-SAME: %arg12: tensor<512xf32>)
func.func @reshape_integration(%arg0: tensor<512x512xf32>,
%arg1: tensor<?x8x?x64xf32>, %arg2: tensor<4xi32>, %arg3: tensor<512xf32>,
%arg4: tensor<?x?x512xf32>, %arg5: tensor<512xf32>, %arg6: tensor<512xf32>,
%arg7: tensor<512x2048xf32>, %arg8: tensor<2048xf32>,
%arg9: tensor<2048x512xf32>, %arg10: tensor<512xf32>,
%arg11: tensor<512xf32>, %arg12: tensor<512xf32>) -> tensor<?x512xf32> {
%0 = mhlo.constant dense<512> : tensor<1xi32>
%1 = shape.shape_of %arg1 : tensor<?x8x?x64xf32> -> tensor<4xindex>
%2 = shape.num_elements %1 : tensor<4xindex> -> index
// CHECK: %[[W:.*]] = mhlo.cstr_reshapable
%3 = mhlo.cstr_reshapable %2, %arg2 : index, tensor<4xi32>
// CHECK: shape.assuming %[[W]]
%4 = shape.assuming %3 -> (tensor<?x8x?x64xf32>) {
// CHECK: %[[SHAPE:.*]] = mhlo.compute_reshape_shape %{{.*}}, %[[DYN_SHAPE]]
%20 = mhlo.compute_reshape_shape %2, %arg2
: index, tensor<4xi32> -> tensor<4xi32>
// CHECK: "mhlo.dynamic_reshape"(%arg1, %[[SHAPE]])
%21 = "mhlo.dynamic_reshape"(%arg1, %20)
: (tensor<?x8x?x64xf32>, tensor<4xi32>) -> tensor<?x8x?x64xf32>
// CHECK: shape.assuming_yield
shape.assuming_yield %21 : tensor<?x8x?x64xf32>
%5 = "mhlo.transpose"(%4) {permutation = dense<[0, 2, 1, 3]> : tensor<4xi64>}
: (tensor<?x8x?x64xf32>) -> tensor<?x?x8x64xf32>
%6 = "mhlo.transpose"(%5) {permutation = dense<[0, 1, 3, 2]>
: tensor<4xi64>} : (tensor<?x?x8x64xf32>) -> tensor<?x?x64x8xf32>
%7 = shape.shape_of %6 : tensor<?x?x64x8xf32> -> tensor<4xindex>
%8 = arith.index_cast %7 : tensor<4xindex> to tensor<4xi32>
%9 = "mhlo.slice"(%8) {limit_indices = dense<1> : tensor<1xi64>,
start_indices = dense<0> : tensor<1xi64>,
strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32>
%10 = "mhlo.reshape"(%9) : (tensor<1xi32>) -> tensor<i32>
%11 = "mhlo.slice"(%8) {limit_indices = dense<2> : tensor<1xi64>,
start_indices = dense<1> : tensor<1xi64>,
strides = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<1xi32>
%12 = "mhlo.reshape"(%11) : (tensor<1xi32>) -> tensor<i32>
%13 = mhlo.multiply %10, %12 : tensor<i32>
%14 = "mhlo.reshape"(%13) : (tensor<i32>) -> tensor<1xi32>
%15 = "mhlo.concatenate"(%14, %0) {dimension = 0 : i64}
: (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
%16 = shape.shape_of %6 : tensor<?x?x64x8xf32> -> tensor<4xindex>
%17 = shape.num_elements %16 : tensor<4xindex> -> index
// CHECK-NOT: cstr_reshapable
%18 = mhlo.cstr_reshapable %17, %15 : index, tensor<2xi32>
// CHECK-NOT: assuming
%19 = shape.assuming %18 -> (tensor<?x512xf32>) {
// CHECK-NOT: compute_reshape_shape
%20 = mhlo.compute_reshape_shape %17, %15
: index, tensor<2xi32> -> tensor<2xi32>
// CHECK: tensor.collapse_shape
%21 = "mhlo.dynamic_reshape"(%6, %20)
: (tensor<?x?x64x8xf32>, tensor<2xi32>) -> tensor<?x512xf32>
// CHECK-NOT: assuming_yield
shape.assuming_yield %21 : tensor<?x512xf32>
func.return %19 : tensor<?x512xf32>
// -----
// CHECK-LABEL: @optimize_1dx1d_constraint
func.func @optimize_1dx1d_constraint(
%arg0: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>},
%arg1: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>}
) -> !shape.witness {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<1xindex>
// CHECK: shape.const_witness true
%2 = shape.cstr_broadcastable %0, %1 : tensor<1xindex>, tensor<1xindex>
func.return %2: !shape.witness
// -----
// CHECK-LABEL: @optimize_1dx1d_constraint_with_static_shape
func.func @optimize_1dx1d_constraint_with_static_shape(
%arg0: tensor<?xf32>
{jitrt.symbolic_shape = dense<[10]> : tensor<1xi64>},
%arg1: tensor<10xf32>
) -> !shape.witness {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<10xf32> -> tensor<1xindex>
// CHECK: shape.const_witness true
%2 = shape.cstr_broadcastable %0, %1 : tensor<1xindex>, tensor<1xindex>
func.return %2: !shape.witness
// -----
// CHECK-LABEL: @optimize_1dx1d_constraint_with_const_shape
func.func @optimize_1dx1d_constraint_with_const_shape(
%arg0: tensor<512xf32>,
%arg1: tensor<?x512xf32>
{jitrt.symbolic_shape = dense<[-2,512]> : tensor<2xi64>}
) -> !shape.witness {
%0 = shape.const_shape [512] : tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?x512xf32> -> tensor<2xindex>
// CHECK: shape.const_witness true
%2 = shape.cstr_broadcastable %0, %1 : tensor<1xindex>, tensor<2xindex>
func.return %2: !shape.witness
// -----
// CHECK-LABEL: @optimize_1dx1d_bcast
func.func @optimize_1dx1d_bcast(
%arg0: tensor<?xf32> {jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>},
%arg1: tensor<?xf32> {jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>})
-> tensor<?xf32> {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<1xindex>
%2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex>
-> tensor<1xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<0>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
func.return %3: tensor<?xf32>
// -----
// CHECK-LABEL: @optimize_1dx2d_bcast_const_shape
func.func @optimize_1dx2d_bcast_const_shape(
%arg0: tensor<512xf32>,
%arg1: tensor<?x512xf32>
{jitrt.symbolic_shape = dense<[-2, 512]> : tensor<2xi64>})
-> tensor<?x512xf32> {
%0 = shape.const_shape [512] : tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?x512xf32> -> tensor<2xindex>
%2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<2xindex>
-> tensor<2xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<0>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
{broadcast_dimensions = dense<[1]> : tensor<1xi64>}
: (tensor<512xf32>, tensor<2xindex>) -> tensor<?x512xf32>
func.return %3: tensor<?x512xf32>
// -----
// CHECK-LABEL: @optimize_1dx1dx1d_bcast
func.func @optimize_1dx1dx1d_bcast(
%arg0: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>},
%arg1: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>},
%arg2: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>}) -> tensor<?xf32> {
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<1xindex>
%2 = shape.shape_of %arg2 : tensor<?xf32> -> tensor<1xindex>
%3 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex>
-> tensor<1xindex>
%4 = shape.broadcast %3, %2 : tensor<1xindex>, tensor<1xindex>
-> tensor<1xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<0>
%5 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
func.return %5: tensor<?xf32>
// -----
// CHECK-LABEL: @optimize_2dx1d_bcast
func.func @optimize_2dx1d_bcast(
%arg0: tensor<10x?xf32>
{jitrt.symbolic_shape = dense<[10, -2]> : tensor<2xi64>},
%arg1: tensor<?xf32>
{jitrt.symbolic_shape = dense<[-2]> : tensor<1xi64>})
-> (tensor<10x?xf32>, tensor<10x?xf32>) {
%0 = shape.shape_of %arg0 : tensor<10x?xf32> -> tensor<2xindex>
%1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<1xindex>
%2 = shape.broadcast %0, %1 : tensor<2xindex>, tensor<1xindex>
-> tensor<2xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 1]>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
{broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
: (tensor<10x?xf32>, tensor<2xindex>) -> tensor<10x?xf32>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<0>
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2)
{broadcast_dimensions = dense<[1]> : tensor<1xi64>}
: (tensor<?xf32>, tensor<2xindex>) -> tensor<10x?xf32>
func.return %3, %4: tensor<10x?xf32>, tensor<10x?xf32>
// -----
// CHECK-LABEL: @optimize_3dx3d_bcast
func.func @optimize_3dx3d_bcast(
%arg0: tensor<?x1x?xf32>
{jitrt.symbolic_shape = dense<[-2, 1, -3]> : tensor<3xi64>},
%arg1: tensor<1x?x1xf32>
{jitrt.symbolic_shape = dense<[1, -4, 1]> : tensor<3xi64>})
-> (tensor<?x?x?xf32>, tensor<?x?x?xf32>) {
%0 = shape.shape_of %arg0 : tensor<?x1x?xf32> -> tensor<3xindex>
%1 = shape.shape_of %arg1 : tensor<1x?x1xf32> -> tensor<3xindex>
%2 = shape.broadcast %0, %1 : tensor<3xindex>, tensor<3xindex>
-> tensor<3xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 2]>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
: (tensor<?x1x?xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<>
// CHECK-SAME: known_nonexpanding_dimensions = dense<1>
%4 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %2)
{broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>}
: (tensor<1x?x1xf32>, tensor<3xindex>) -> tensor<?x?x?xf32>
func.return %3, %4: tensor<?x?x?xf32>, tensor<?x?x?xf32>
// -----
// CHECK-LABEL: @optimize_10d_all_cases
func.func @optimize_10d_all_cases(
%arg0: tensor<1x1x1x8x8x8x?x?x?x?xf32>
{jitrt.symbolic_shape = dense<[1, 1, 1, 8, 8, 8, -2, -3, -4, -5]>
: tensor<10xi64>},
%arg1: tensor<1x8x?x1x8x?x1x8x?x?xf32>
{jitrt.symbolic_shape = dense<[1, 8, -6, 1, 8, -7, 1, 8, -8, -5]>
: tensor<10xi64>}) -> tensor<?x?x?x?x?x?x?x?x?x?xf32> {
%0 = shape.shape_of %arg0 : tensor<1x1x1x8x8x8x?x?x?x?xf32>
-> tensor<10xindex>
%1 = shape.shape_of %arg1 : tensor<1x8x?x1x8x?x1x8x?x?xf32>
-> tensor<10xindex>
%2 = shape.broadcast %0, %1 : tensor<10xindex>, tensor<10xindex>
-> tensor<10xindex>
// CHECK: mhlo.dynamic_broadcast_in_dim
// CHECK-SAME: known_expanding_dimensions = dense<1>
// CHECK-SAME: known_nonexpanding_dimensions = dense<[0, 3, 4, 5, 6, 9]>
%3 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %2)
{broadcast_dimensions = dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]>
: tensor<10xi64>}
: (tensor<1x1x1x8x8x8x?x?x?x?xf32>, tensor<10xindex>)
-> tensor<?x?x?x?x?x?x?x?x?x?xf32>
func.return %3: tensor<?x?x?x?x?x?x?x?x?x?xf32>
// -----
// CHECK-LABEL: @empty_bcast
// CHECK-SAME: %[[ARG0:.*]]: tensor<f32>, %[[ARG1:.*]]: tensor<f32>
func.func @empty_bcast(%arg0 : tensor<f32>, %arg1 : tensor<f32>) -> tensor<0xindex> {
// CHECK-DAG: %[[SHAPE:.*]] = arith.constant dense<> : tensor<0xindex>
// CHECK: return %[[SHAPE]]
%0 = shape.shape_of %arg0 : tensor<f32> -> tensor<0xindex>
%1 = shape.shape_of %arg1 : tensor<f32> -> tensor<0xindex>
%2 = shape.broadcast %0, %1 : tensor<0xindex>, tensor<0xindex>
-> tensor<0xindex>
func.return %2 : tensor<0xindex>
// -----
// CHECK-LABEL: @simplifiable_bcast
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x1x1x4x?x?x1xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x8x1x?x1x?xf32>
func.func @simplifiable_bcast(
%arg0 : tensor<?x1x1x4x?x?x1xf32>
{jitrt.symbolic_shape = dense<[-2, 1, 1, 4, -2, -3, 1]> : tensor<7xi64>},
%arg1 : tensor<1x8x1x?x1x?xf32>
{jitrt.symbolic_shape = dense<[ 1, 8, 1, -2, 1, -4]> : tensor<6xi64>})
-> tensor<7xindex> {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-DAG: %[[C4:.*]] = arith.constant 4
// CHECK-DAG: %[[C5:.*]] = arith.constant 5
// CHECK-DAG: %[[C8:.*]] = arith.constant 8
// CHECK-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[S0D0:.*]] = tensor.extract %[[S0]][%[[C0]]]
// CHECK-DAG: %[[S0D4:.*]] = tensor.extract %[[S0]][%[[C4]]]
// CHECK-DAG: %[[S0D5:.*]] = tensor.extract %[[S0]][%[[C5]]]
// CHECK-DAG: %[[S1D5:.*]] = tensor.extract %[[S1]][%[[C5]]]
// CHECK-DAG: %[[RES:.*]] = tensor.from_elements %[[S0D0]], %[[C1]], %[[C8]], %[[C4]], %[[S0D4]], %[[S0D5]], %[[S1D5]]
// CHECK: return %[[RES]]
%0 = shape.shape_of %arg0 : tensor<?x1x1x4x?x?x1xf32> -> tensor<7xindex>
%1 = shape.shape_of %arg1 : tensor<1x8x1x?x1x?xf32> -> tensor<6xindex>
%2 = shape.broadcast %0, %1 : tensor<7xindex>, tensor<6xindex>
-> tensor<7xindex>
func.return %2 : tensor<7xindex>
// -----
// CHECK-LABEL: @very_dynamic_bcast
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xf32>, %[[ARG1:.*]]: tensor<?xf32>
func.func @very_dynamic_bcast(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>)
-> tensor<1xindex> {
// CHECK-DAG: %[[S0:.*]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[S1:.*]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCASTED:.*]] = shape.broadcast %[[S0]], %[[S1]]
// CHECK: return %[[BCASTED]]
%0 = shape.shape_of %arg0 : tensor<?xf32> -> tensor<1xindex>
%1 = shape.shape_of %arg1 : tensor<?xf32> -> tensor<1xindex>
%2 = shape.broadcast %0, %1 : tensor<1xindex>, tensor<1xindex>
-> tensor<1xindex>
func.return %2 : tensor<1xindex>
// -----
// CHECK-LABEL: @broadcast_w_dyn_ty
// CHECK-SAME: %[[ARG:.*]]: tensor<1xindex>
func.func @broadcast_w_dyn_ty(%arg0: tensor<1xindex>) -> tensor<?xindex>{
// CHECK: %[[C0:.*]] = arith.constant 0
// CHECK: %[[D0:.*]] = tensor.extract %[[ARG]][%[[C0]]]
// CHECK: %[[UNCAST:.*]] = tensor.from_elements %[[D0]]
// CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<1xindex> to tensor<?xindex>
// CHECK: return %[[CAST]]
%0 = shape.broadcast %arg0, %arg0
: tensor<1xindex>, tensor<1xindex> -> tensor<?xindex>
return %0 : tensor<?xindex>
// -----
// CHECK-LABEL: @broadcast_scalar_w_dyn_ty
// CHECK-SAME: %[[ARG:.*]]: tensor<0xindex>
func.func @broadcast_scalar_w_dyn_ty(%arg0: tensor<0xindex>) -> tensor<?xindex>{
// CHECK: %[[UNCAST:.*]] = arith.constant dense<> : tensor<0xindex>
// CHECK: %[[CAST:.*]] = tensor.cast %[[UNCAST]] : tensor<0xindex> to tensor<?xindex>
// CHECK: return %[[CAST]]
%0 = shape.broadcast %arg0, %arg0
: tensor<0xindex>, tensor<0xindex> -> tensor<?xindex>
return %0 : tensor<?xindex>