#include <algorithm>
#include <cstdint>
#include <iterator>
#include <memory>
#include <numeric>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Analysis/shape_component_analysis.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Transforms/PassDetail.h"
#include "mlir-hlo/Transforms/passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
using Symbol = ShapeComponentAnalysis::Symbol;
using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
namespace {
// Temporary data structure to hold a single dimension of the symbolic result of
// `shape.broadcast`.
struct SymbolicBroadcastDimension {
size_t operandIndex;
size_t operandDim;
SymbolicExpr expr;
// Replace shape.broadcast with a shape if it's statically known.
struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override {
// Require successful shape analysis.
ShapeComponentAnalysis shapeAnalysis;
llvm::SmallVector<ArrayRef<SymbolicExpr>> shapesInfo;
auto shapes = op.getShapes();
for (Value s : shapes) {
auto sInfo = shapeAnalysis.GetValueInfo(s);
if (!sInfo) return failure();
// Find the result rank.
size_t rank = 0;
for (const auto &sInfo : shapesInfo) rank = std::max(rank, sInfo.size());
// Compute broadcast symbolically.
SmallVector<Optional<SymbolicBroadcastDimension>> symResult(rank,
for (const auto &sInfo : llvm::enumerate(shapesInfo)) {
size_t dimOffset = rank - sInfo.value().size();
for (const auto &symExpr : llvm::enumerate(sInfo.value())) {
// Unit dimensions are neutral to the final result.
if (symExpr.value().isConstant(1)) continue;
// Use unique expression.
size_t i = dimOffset + symExpr.index();
if (!symResult[i]) {
symResult[i] = {sInfo.index(), symExpr.index(), symExpr.value()};
// Bail if the dimensions are neither equal nor 1.
if (symResult[i]->expr != symExpr.value()) return failure();
// Materialize broadcast result.
auto loc = op.getLoc();
DenseMap<int64_t, Value> constants;
auto findOrCreateConstant = [&](int64_t c) {
auto it = constants.find(c);
if (it != constants.end()) return it->second;
Value newlyCreated = rewriter.create<arith::ConstantIndexOp>(loc, c);
constants[c] = newlyCreated;
return newlyCreated;
auto elements = llvm::to_vector<8>(
llvm::map_range(symResult, [&](const auto &symResultDim) {
// If we know the dimension statically, use a constant.
if (!symResultDim) return findOrCreateConstant(1);
if (auto cexpr = symResultDim->expr.expr
.template dyn_cast<AffineConstantExpr>()) {
return findOrCreateConstant(cexpr.getValue());
// Othwerise, extract the dimension from the unique operand.
Value operand = shapes[symResultDim->operandIndex];
Value operandDim = findOrCreateConstant(symResultDim->operandDim);
return rewriter.create<tensor::ExtractOp>(loc, operand, operandDim)
Type indexTy = rewriter.getIndexType();
Type concreteResultTy =
RankedTensorType::get({static_cast<int64_t>(elements.size())}, indexTy);
Value result = rewriter.create<tensor::FromElementsOp>(
loc, concreteResultTy, elements);
// Insert cast, if needed.
Type expectedTy = op.getResult().getType();
if (result.getType() != expectedTy) {
result = rewriter.create<tensor::CastOp>(loc, expectedTy, result);
rewriter.replaceOp(op, result);
return success();
LogicalResult analyzeDynamicBroadcastInDimExpandingBehavior(
ShapeComponentAnalysis &analysis, Value value, Value shape,
llvm::SmallSetVector<int64_t, 4> *knownExpandingDims,
llvm::SmallSetVector<int64_t, 4> *knownNonexpandingDims) {
// Require successful analysis of shapes.
auto shapeIn = analysis.GetShapeInfo(value);
auto shapeOut = analysis.GetValueInfo(shape);
if (!shapeIn || !shapeOut) return failure();
// Analyze per argument dimension.
size_t rankIn = shapeIn->size();
size_t rankOut = shapeOut->size();
assert(rankIn <= rankOut);
size_t dimOutOffset = rankOut - rankIn;
for (size_t i = 0; i < rankIn; ++i) {
SymbolicExpr dimIn = (*shapeIn)[i];
SymbolicExpr dimOut = (*shapeOut)[dimOutOffset + i];
if (dimIn.isConstant(1) && dimOut.isKnownNotOne())
if (dimIn == dimOut || dimOut.isConstant(1))
return success();
// Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for
// statically known expanding and non-expanding dimensions.
struct AnnotateExpandingDimensionsInDynamicBroadcastInDim
: public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op,
mlir::PatternRewriter &rewriter) const override {
// Analyze shapes and identify expanding and non-expanding dims.
ShapeComponentAnalysis analysis;
llvm::SmallSetVector<int64_t, 4> knownExpandingDims, knownNonexpandingDims;
if (failed(analyzeDynamicBroadcastInDimExpandingBehavior(
analysis, op.operand(), op.output_dimensions(), &knownExpandingDims,
&knownNonexpandingDims))) {
return failure();
// Collect possibly already annotated info.
auto insertAll = [](llvm::SmallSetVector<int64_t, 4> &dst,
Optional<DenseIntElementsAttr> src) {
if (!src) return;
for (auto it : *src) dst.insert(it.getLimitedValue());
insertAll(knownExpandingDims, op.known_expanding_dimensions());
insertAll(knownNonexpandingDims, op.known_nonexpanding_dimensions());
// Fail pattern application if there is nothing new to annotate.
auto isEqual = [](llvm::SmallSetVector<int64_t, 4> &set,
DenseIntElementsAttr attr) {
return set.size() == attr.size() && llvm::all_of(attr, [&](auto it) {
return set.count(it.getLimitedValue());
if (op.known_expanding_dimensions() && op.known_nonexpanding_dimensions() &&
isEqual(knownExpandingDims, *op.known_expanding_dimensions()) &&
isEqual(knownNonexpandingDims, *op.known_nonexpanding_dimensions())) {
return failure();
// Annotate op in place.
return success();
// Remove compute_reshape_shape if we can prove that the dynamic shape does not
// contain a `-1` dimension.
struct RemoveComputeReshapeShape final
: public OpRewritePattern<mhlo::ComputeReshapeShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ComputeReshapeShapeOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shapeComponentAnalysis;
auto dynamicShape = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynamicShape) return failure();
if (llvm::any_of(*dynamicShape, [](const auto &dim) {
return !dim.isKnownNotNegativeOne();
})) {
return failure();
rewriter.replaceOp(op, op.dynamic_shape());
return success();
bool isProduct(AffineExpr expr,
llvm::function_ref<void(AffineConstantExpr)> cbkConstantFactor,
llvm::function_ref<void(AffineSymbolExpr)> cbkSymbolicFactor) {
auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
if (binExpr && binExpr.getKind() == AffineExprKind::Mul) {
return isProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) &&
isProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor);
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
return true;
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
return true;
return false;
bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
llvm::function_ref<void(int64_t)> cbkConstantFactor,
llvm::function_ref<void(Symbol)> cbkSymbolicFactor) {
return isProduct(
[&](AffineConstantExpr cexpr) { cbkConstantFactor(cexpr.getValue()); },
[&](AffineSymbolExpr sexpr) {
// Represents a product of symbolic and concrete factors. This will allow us to
// prove product equalities symbolically.
struct SymbolicProduct {
// Product of all concrete factors.
int64_t concrete = 1;
// List all symbolic factors as they can not be aggregated.
llvm::SmallVector<Symbol> symbolic;
bool empty() { return concrete == 1 && symbolic.empty(); }
bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
SymbolicProduct *product) {
return isSymbolicProduct(
symbolicExpr, [&](int64_t c) { product->concrete *= c; },
[&](Symbol s) { product->symbolic.push_back(s); });
struct RemoveRedundantCstrReshapable final
: public OpRewritePattern<mhlo::CstrReshapableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
PatternRewriter &rewriter) const override {
// Get shape analysis info for the number of elements.
ShapeComponentAnalysis shapeComponentAnalysis;
auto numElementsInfo =
if (!numElementsInfo) return failure();
assert(numElementsInfo->size() == 1 && "expect one value for a scalar");
auto numElements = numElementsInfo->front();
// Get shape analysis info for the dynamic shape.
auto dynShapeDims = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynShapeDims) return failure();
// We can handle two cases:
// - there is exactly one -1 in the dynamic shape, i.e. a unique wildcard
// dimension, or
// - there is no -1 in the dynamic shape, i.e. no wildcard dimension.
bool uniqueWildcardDimension = false;
for (const auto &d : *dynShapeDims) {
if (d.isConstant(-1)) {
if (uniqueWildcardDimension) return failure();
uniqueWildcardDimension = true;
} else if (!d.isKnownNotNegativeOne()) {
return failure();
// We can only handle simple products with constants and symbols. Find all
// the factors based on the number of elements.
SymbolicProduct numElementsRemainingFactors;
if (!isSymbolicProduct(numElements, &numElementsRemainingFactors)) {
return failure();
assert(numElementsRemainingFactors.concrete >= 1 &&
"number of elements cannot entail negative or zero factors");
// Find all factors based on the dynamic shape.
// - Accumulate the conrete product to later compare it against its
// equivalent based on the number of elements.
// - Remove symbolic factors from the list and fail if we find an unknown
// factor, i.e. if the symbolic factors based on the dynamic shape are
// not a subset of the factors based on the number of elements.
int64_t concreteProductDynShape = 1;
for (const auto &dim : *dynShapeDims) {
SmallVector<Symbol> partialSymbolicFactorsDynShape;
if (!isSymbolicProduct(
[&](int64_t c) {
if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c;
[&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) {
return failure();
for (const Symbol &symDynShape : partialSymbolicFactorsDynShape) {
auto *it =
llvm::find(numElementsRemainingFactors.symbolic, symDynShape);
if (it == numElementsRemainingFactors.symbolic.end()) return failure();
assert(concreteProductDynShape >= 1 &&
"concrete product must not aggregate negative or zero factors");
// A wildcard dimension can subsume the remaining symbolic factors and
// potentially also a concrete factor.
if (uniqueWildcardDimension) {
if (numElementsRemainingFactors.concrete % concreteProductDynShape != 0)
return failure();
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
// W/o a wildcard, the symbolic and concrete products must be equal.
bool isReshapable =
numElementsRemainingFactors.symbolic.empty() &&
numElementsRemainingFactors.concrete == concreteProductDynShape;
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, isReshapable);
return success();
LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy,
RankedTensorType resultTy,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(operandTy.getRank() == 0 && "expect scalar operand");
auto loc = op.getLoc();
SmallVector<int64_t> unitDims(resultTy.getRank(), 1);
auto expandedTy = RankedTensorType::get(unitDims, resultTy.getElementType());
Value expandedScalar = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedTy, op.operand(), ArrayRef<ReassociationIndices>{});
if (expandedScalar.getType() != resultTy) {
expandedScalar =
rewriter.create<tensor::CastOp>(loc, resultTy, expandedScalar);
rewriter.replaceOp(op, expandedScalar);
return success();
LogicalResult materializeReshapeAsScalarCollapse(RankedTensorType operandTy,
RankedTensorType resultTy,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(resultTy.getRank() == 0 && "expect scalar result");
auto loc = op.getLoc();
Value operand = op.operand();
SmallVector<int64_t> unitDims(operandTy.getRank(), 1);
auto castedOperandTy =
RankedTensorType::get(unitDims, operandTy.getElementType());
if (operand.getType() != castedOperandTy) {
operand = rewriter.create<tensor::CastOp>(loc, castedOperandTy, operand);
Value collapsedScalar = rewriter.create<tensor::CollapseShapeOp>(
loc, operand, ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, collapsedScalar);
return success();
enum class DimensionGroupKind {
struct DimensionGroup {
int64_t size = 0;
DimensionGroupKind kind = DimensionGroupKind::kNone;
SymbolicProduct eliminateCommonFactors(SymbolicProduct &a, SymbolicProduct &b) {
SymbolicProduct gcd;
// Eliminate common concrete factors.
gcd.concrete = llvm::GreatestCommonDivisor64(a.concrete, b.concrete);
a.concrete /= gcd.concrete;
b.concrete /= gcd.concrete;
// Eliminate common symbolic factors.
int64_t i = 0;
while (i < a.symbolic.size()) {
auto *it = llvm::find(b.symbolic, a.symbolic[i]);
if (it != b.symbolic.end()) {
std::swap(a.symbolic[i], a.symbolic.back());
} else {
return gcd;
bool isUnpairedUnitDimension(
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd) {
return it != end && it->isConstant(1) &&
(otherIt == otherEnd || !otherIt->isConstant(1));
int64_t getShapedTypyDimSize(const SymbolicProduct &symProduct) {
return symProduct.symbolic.empty() ? symProduct.concrete
: ShapedType::kDynamicSize;
// Iterate over the operand's and the result's shape dimensions and find
// dimension groups that are collapsing, expanding, or untouched:
// - Collapsing: Multiple dimensions of the operand shape can be collapsed
// into a single dimension of the result shape. We must prove that the
// product of the operand shape's dimensions is equal to the corresponding
// result dimension.
// - Expanding: A single dimension of the operand shape can be expanded into
// multiple dimensions of the result shape. We must prove that the product
// of the result shape's dimensions is equal to the corresponding operand
// dimension. This case is limited to at most one dynamic dimension per
// expansion group as otherwise not supported by the `expand_shape` op.
// - Untouched: There is a 1:1 correspondance between an operand and a result
// shape dimension.
// We can determine the optimal dimension groups greedily by consuming operand
// and result dimensions from left to right. If the leading operand dimension is
// a strict divisor of the leading result dimension, collapsing is required. In
// this case, we keep consuming the operand dimensions until the products are
// equal. If the leading result dimension is a strict divisor of the leading
// operand dimension, expanding is required. In this case, we keep consuming the
// result dimensions until the products are equal. Trailing unit dimensions may
// be inlcuded in the dimension group. This is useful iff they are "unpaired",
// in which case they would only limit us in the subsequent iteration.
LogicalResult findExpandingAndCollapsingDimensionGroups(
ArrayRef<SymbolicExpr> operandShapeInfo,
ArrayRef<SymbolicExpr> resultShapeInfo,
SmallVector<DimensionGroup> *dimensionGroups,
SmallVector<int64_t> *expandedIntermShape) {
const auto *operandShapeIt = operandShapeInfo.begin();
const auto *operandShapeEnd = operandShapeInfo.end();
const auto *resultShapeIt = resultShapeInfo.begin();
const auto *resultShapeEnd = resultShapeInfo.end();
// Crucial iteration state.
SymbolicProduct remainingOperandShapeFactors;
SymbolicProduct remainingResultShapeFactors;
auto anyRemainingFactors = [&]() {
return !remainingOperandShapeFactors.empty() ||
while (operandShapeIt != operandShapeEnd && resultShapeIt != resultShapeEnd) {
assert(!anyRemainingFactors() &&
"expect no remaining factors from previous iteration");
DimensionGroup &dimGroup = dimensionGroups->emplace_back();
// Consume at least one operand and result dimension.
if (!isSymbolicProduct(*operandShapeIt++,
&remainingOperandShapeFactors) ||
!isSymbolicProduct(*resultShapeIt++, &remainingResultShapeFactors)) {
return failure();
SymbolicProduct gcd = eliminateCommonFactors(remainingOperandShapeFactors,
// Fail if there are unresolvable, contradicting factors remaining.
if (!remainingOperandShapeFactors.empty() &&
!remainingResultShapeFactors.empty()) {
return failure();
// Collapsing: Create a collapsing dimension group.
bool requiresCollapsing =
remainingOperandShapeFactors.empty() &&
(!remainingResultShapeFactors.empty() ||
isUnpairedUnitDimension(operandShapeIt, operandShapeEnd, resultShapeIt,
if (requiresCollapsing) {
dimGroup.kind = DimensionGroupKind::kCollapsing;
// Consume operand shape dimensions until their product matches the
// corresponding result dimension (or fail if unresolvable/contradicting
// factors are found).
while (operandShapeIt != operandShapeEnd &&
remainingOperandShapeFactors.empty() &&
!remainingResultShapeFactors.empty()) {
if (!isSymbolicProduct(*operandShapeIt++,
&remainingOperandShapeFactors)) {
return failure();
SymbolicProduct gcd = eliminateCommonFactors(
remainingOperandShapeFactors, remainingResultShapeFactors);
if (anyRemainingFactors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (isUnpairedUnitDimension(operandShapeIt, operandShapeEnd,
resultShapeIt, resultShapeEnd)) {
// Expanding: Create an expanding dimension group.
bool requiresExpanding =
remainingResultShapeFactors.empty() &&
(!remainingOperandShapeFactors.empty() ||
isUnpairedUnitDimension(resultShapeIt, resultShapeEnd, operandShapeIt,
if (requiresExpanding) {
dimGroup.kind = DimensionGroupKind::kExpanding;
int64_t numDynamicDims = 0;
// Consume result shape dimensions until their product matches the
// corresponding operand dimension (or fail if unresolvable/contradicting
// factors are found).
while (resultShapeIt != resultShapeEnd &&
remainingResultShapeFactors.empty() &&
!remainingOperandShapeFactors.empty()) {
if (!isSymbolicProduct(*resultShapeIt++,
&remainingResultShapeFactors)) {
return failure();
SymbolicProduct gcd = eliminateCommonFactors(
remainingOperandShapeFactors, remainingResultShapeFactors);
int64_t tyDimSize = getShapedTypyDimSize(gcd);
// Allow no more than one dynamic dimension per expansion group.
if (tyDimSize == ShapedType::kDynamicSize) {
if (numDynamicDims > 1) return failure();
if (anyRemainingFactors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (isUnpairedUnitDimension(resultShapeIt, resultShapeEnd,
operandShapeIt, operandShapeEnd)) {
// Untouched: 1:1 mapping between operand and result shape dimension. This
// is neither expanding nor collapsing.
assert(!requiresCollapsing && !requiresExpanding && "expect id case");
assert(dimGroup.size == 1 && dimGroup.kind == DimensionGroupKind::kNone &&
"expect simple dimension group");
// Fail if there are remaining dimensions that could not be consumed.
assert(!anyRemainingFactors() && "expect no remaining factors");
if (operandShapeIt != operandShapeEnd || resultShapeIt != resultShapeEnd) {
return failure();
return success();
SmallVector<int64_t> concretizeOperandShape(
ArrayRef<int64_t> operandShape, ArrayRef<SymbolicExpr> operandShapeInfo) {
SmallVector<int64_t> result;
for (auto it : llvm::zip(operandShape, operandShapeInfo)) {
auto dimSize = std::get<0>(it);
auto sExpr = std::get<1>(it);
if (auto cexpr = sExpr.expr.dyn_cast<AffineConstantExpr>()) {
int64_t alsoDimSize = cexpr.getValue();
assert((ShapedType::isDynamic(dimSize) || dimSize == alsoDimSize) &&
"expect shape analysis result to be compatible with type");
return result;
llvm::Optional<SmallVector<ReassociationIndices>> requiresReassociationOfKind(
DimensionGroupKind kind, const SmallVector<DimensionGroup> &dimGroups) {
SmallVector<ReassociationIndices> reassociation;
bool isStrictlyReassociating = false;
int64_t i = 0;
for (const DimensionGroup &g : dimGroups) {
if (g.kind == kind) {
isStrictlyReassociating = true;
llvm::to_vector(llvm::seq<int64_t>(i, i + g.size)));
i += g.size;
for (int64_t j = 0; j < g.size; j++) reassociation.push_back({i++});
// Return the reassociation if expansion is required.
if (isStrictlyReassociating) return reassociation;
return llvm::None;
LogicalResult materializeReshapeAsExpandAndCollapse(
ShapeComponentAnalysis &shapeAnalysis, RankedTensorType operandTy,
RankedTensorType resultTy, mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
// Require sucessful shape analysis for operand and result shape.
auto operandShapeInfo = shapeAnalysis.GetShapeInfo(op.operand());
if (!operandShapeInfo) return failure();
auto resultShapeInfo = shapeAnalysis.GetValueInfo(op.output_shape());
if (!resultShapeInfo) return failure();
// Identify dimension groups and the intermediate expanded type.
SmallVector<DimensionGroup> dimensionGroups;
SmallVector<int64_t> expandedIntermShape;
if (failed(findExpandingAndCollapsingDimensionGroups(
*operandShapeInfo, *resultShapeInfo, &dimensionGroups,
&expandedIntermShape))) {
return failure();
// Materialize cast, expand, collapse, and cast, as needed.
auto loc = op.getLoc();
Value interm = op.operand();
auto castedOperandTy = RankedTensorType::get(
concretizeOperandShape(operandTy.getShape(), *operandShapeInfo),
if (operandTy != castedOperandTy) {
interm = rewriter.create<tensor::CastOp>(loc, castedOperandTy, interm);
if (auto reassociation = requiresReassociationOfKind(
DimensionGroupKind::kExpanding, dimensionGroups)) {
interm = rewriter.create<tensor::ExpandShapeOp>(
RankedTensorType::get(expandedIntermShape, operandTy.getElementType()),
interm, *reassociation);
if (auto reassociation = requiresReassociationOfKind(
DimensionGroupKind::kCollapsing, dimensionGroups)) {
interm =
rewriter.create<tensor::CollapseShapeOp>(loc, interm, *reassociation);
if (interm.getType() != resultTy) {
interm = rewriter.create<tensor::CastOp>(loc, resultTy, interm);
rewriter.replaceOp(op, interm);
return success();
// Tries to express `dynamic_reshape` ops through `expand_shape` and
// `collapse_shape` ops.
struct DynamicReshapeToExpandAndCollapseShape final
: public OpRewritePattern<mhlo::DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) const override {
auto operandTy = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operandTy) return failure();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
if (!resultTy) return failure();
// Handle degenerate scalar expand case.
if (operandTy.getRank() == 0) {
return materializeReshapeAsScalarExpand(operandTy, resultTy, op,
// Handle degenerate scalar collapse case.
if (resultTy.getRank() == 0) {
return materializeReshapeAsScalarCollapse(operandTy, resultTy, op,
ShapeComponentAnalysis shapeAnalysis;
return materializeReshapeAsExpandAndCollapse(shapeAnalysis, operandTy,
resultTy, op, rewriter);
// Returns true if all of bcasted_shapes can be broadcasted with output_shape.
bool isKnownBroadcastable(ShapeComponentAnalysis &analysis,
ValueRange bcastedShapes, Value outputShape) {
auto outputShapeDims = analysis.GetValueInfo(outputShape);
if (!outputShapeDims) return false;
for (Value shape : bcastedShapes) {
auto shapeDims = analysis.GetValueInfo(shape);
if (!shapeDims) return false;
// Iterate backwards over the smallest input shape.
for (auto zip : llvm::zip(llvm::reverse(*outputShapeDims),
llvm::reverse(*shapeDims))) {
const auto &first = std::get<0>(zip);
const auto &second = std::get<1>(zip);
// TODO(ezhulenev): What to do with dimensions statically known to be
// zero?
// Numpy can only broadcast [0] with [1], however Tensorflow can broadcast
// [0] with any dimension size, and produces dimension of size [0].
// Currently we'll conservatively return failure and will not proceed with
// a rewrite.
if (first.isConstant(0) || second.isConstant(0)) return false;
// If either shape has a static one dimension the broadcast will always
// succeed.
if (first.isConstant(1) || second.isConstant(1)) continue;
// Otherwise dims have to be equal.
if (first != second) return false;
return true;
// Rewrite `shape.cstr_broadcastable` with constant witness if can prove that
// shapes are broadcastable from a symbolic analysis.
struct CstrBroadcastableOpLowering
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shapeComponentAnalysis;
if (!isKnownBroadcastable(shapeComponentAnalysis, op.getShapes(),
op.getShapes().front())) {
return failure();
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
class SymbolicShapeOptimizationPass final
: public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
void getDependentDialects(DialectRegistry &registry) const override {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
// clang-format off
// clang-format on
shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
} // end namespace
createSymbolicShapeOptimizationPass() {
return std::make_unique<SymbolicShapeOptimizationPass>();
} // end namespace mlir