blob: 998918bdfa074416d9ae8a9bc60e19d7730d76d9 [file] [log] [blame]
// RUN: mlir-hlo-opt -split-input-file -shape-simplification %s | FileCheck %s
// Incompatible shapes. No folding.
// CHECK-LABEL: func @f
func.func @f() -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [2] : !shape.shape
%1 = shape.const_shape [7] : !shape.shape
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
func.return %2 : !shape.shape
// -----
// Broadcast of partially dynamic shapes yields a static shape.
// CHECK-LABEL: func @f
func.func @f(%arg0 : tensor<42x?x42x?xf32>, %arg1 : tensor<42x?x?xf32>) -> !shape.shape {
// CHECK: %[[CST:.*]] = shape.const_shape [42, 42, 42, 256] : !shape.shape
// CHECK: return %[[CST]]
%0 = shape.const_shape [256] : !shape.shape
%1 = shape.shape_of %arg0 : tensor<42x?x42x?xf32> -> !shape.shape
%2 = shape.shape_of %arg1 : tensor<42x?x?xf32> -> !shape.shape
%3 = shape.broadcast %0, %1, %2 : !shape.shape, !shape.shape, !shape.shape -> !shape.shape
func.return %3 : !shape.shape
// -----
// Remove operands that don't contribute to the result.
// CHECK-LABEL: func @f
func.func @f(%arg0 : tensor<?x?x42x42xf32>, %arg1 : tensor<42x42xf32>) -> tensor<?xindex> {
// CHECK: %[[SHAPE0:.*]] = shape.shape_of %arg0 : tensor<?x?x42x42xf32> -> tensor<?xindex>
// CHECK: return %[[SHAPE0]]
%0 = shape.const_shape [42, 1] : tensor<2xindex>
%1 = shape.shape_of %arg0 : tensor<?x?x42x42xf32> -> tensor<?xindex>
%2 = shape.shape_of %arg1 : tensor<42x42xf32> -> tensor<2xindex>
%3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<?xindex>, tensor<2xindex> -> tensor<?xindex>
func.return %3 : tensor<?xindex>
// -----
// The constant shape needs to stay alive or the result will be smaller.
// CHECK-LABEL: func @f
func.func @f(%arg0 : tensor<?xf32>) -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [1, 1] : !shape.shape
%1 = shape.shape_of %arg0 : tensor<?xf32> -> !shape.shape
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
func.return %2 : !shape.shape
// -----
// [256] is the only contributor of that constant, keep it.
// CHECK-LABEL: func @f
func.func @f(%arg0 : tensor<?x?xf32>) -> !shape.shape {
// CHECK: shape.broadcast
%0 = shape.const_shape [256] : !shape.shape
%1 = shape.shape_of %arg0 : tensor<?x?xf32> -> !shape.shape
%2 = shape.broadcast %0, %1 : !shape.shape, !shape.shape -> !shape.shape
func.return %2 : !shape.shape
// -----
// CHECK-LABEL: func @static_non1_succeeds
// CHECK-NEXT: %[[C2:.*]] = arith.constant 2
// CHECK-NEXT: return %[[C2]]
func.func @static_non1_succeeds(%arg0 : tensor<?x?xf64>, %arg1 : tensor<?x1xf64>,
%arg2: tensor<?x2xf64>) -> index {
%c1 = arith.constant 1 : index
%1 = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
%2 = shape.shape_of %arg1 : tensor<?x1xf64> -> tensor<2xindex>
%3 = shape.shape_of %arg2 : tensor<?x2xf64> -> tensor<2xindex>
%4 = shape.broadcast %1, %2, %3 : tensor<2xindex>, tensor<2xindex>,
tensor<2xindex> -> tensor<2xindex>
%result = tensor.extract %4[%c1] : tensor<2xindex>
func.return %result : index
// -----
// CHECK-LABEL: func @all_static_1s_succeeds
// CHECK-NEXT: %[[C1:.*]] = arith.constant 1
// CHECK-NEXT: return %[[C1]]
func.func @all_static_1s_succeeds(%arg0 : tensor<?x1xf64>, %arg1 : tensor<?x1xf64>)
-> index {
%c1 = arith.constant 1 : index
%1 = shape.shape_of %arg0 : tensor<?x1xf64> -> tensor<2xindex>
%2 = shape.shape_of %arg1 : tensor<?x1xf64> -> tensor<2xindex>
%3 = shape.broadcast %1, %2 : tensor<2xindex>, tensor<2xindex>
-> tensor<2xindex>
%result = tensor.extract %3[%c1] : tensor<2xindex>
func.return %result : index
// -----
// CHECK-LABEL: func @single_non_static_1_succeeds
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?xf64>
// CHECK: %[[C1:.*]] = arith.constant 1
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: return %[[DIM]]
func.func @single_non_static_1_succeeds(%arg0 : tensor<?x?xf64>,
%arg1 : tensor<?x1xf64>) -> index {
%c0 = arith.constant 1 : index
%1 = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
%2 = shape.shape_of %arg1 : tensor<?x1xf64> -> tensor<2xindex>
%3 = shape.broadcast %1, %2 : tensor<2xindex>, tensor<2xindex>
-> tensor<2xindex>
%result = tensor.extract %3[%c0] : tensor<2xindex>
func.return %result : index
// -----
// CHECK-LABEL: func @multiple_non_static_1_fails
// CHECK-NEXT: constant 0
// CHECK-NEXT: shape.shape_of
// CHECK-NEXT: shape.shape_of
// CHECK-NEXT: shape.broadcast
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract
// CHECK-NEXT: return %[[RESULT]]
func.func @multiple_non_static_1_fails(%arg0 : tensor<?x?xf64>,
%arg1 : tensor<?x1xf64>) -> index {
%c0 = arith.constant 0 : index
%1 = shape.shape_of %arg0 : tensor<?x?xf64> -> tensor<2xindex>
%2 = shape.shape_of %arg1 : tensor<?x1xf64> -> tensor<2xindex>
%3 = shape.broadcast %1, %2 : tensor<2xindex>, tensor<2xindex>
-> tensor<2xindex>
%result = tensor.extract %3[%c0] : tensor<2xindex>
func.return %result : index
// -----
// CHECK-LABEL: func @extract_no_crash
// CHECK-NEXT: tensor.extract
func.func @extract_no_crash(%arg0 : tensor<index>) -> index {
%result = tensor.extract %arg0[] : tensor<index>
func.return %result : index