blob: 1900bb4007b2e28cfddeff4fb65e19194c97b521 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <memory>
#include <numeric>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Analysis/shape_component_analysis.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Transforms/PassDetail.h"
#include "mlir-hlo/Transforms/passes.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
using Symbol = ShapeComponentAnalysis::Symbol;
using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
namespace {
// Temporary data structure to hold a single dimension of the symbolic result of
// `shape.broadcast`.
struct SymbolicBroadcastDimension {
size_t operandIndex;
size_t operandDim;
SymbolicExpr expr;
};
// Replace shape.broadcast with a shape if it's statically known.
struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override {
// Require successful shape analysis.
ShapeComponentAnalysis shapeAnalysis;
llvm::SmallVector<ArrayRef<SymbolicExpr>> shapesInfo;
auto shapes = op.getShapes();
shapesInfo.reserve(shapes.size());
for (Value s : shapes) {
auto sInfo = shapeAnalysis.GetValueInfo(s);
if (!sInfo) return failure();
shapesInfo.push_back(*sInfo);
}
// Find the result rank.
size_t rank = 0;
for (const auto &sInfo : shapesInfo) rank = std::max(rank, sInfo.size());
// Compute broadcast symbolically.
SmallVector<Optional<SymbolicBroadcastDimension>> symResult(rank,
llvm::None);
for (const auto &sInfo : llvm::enumerate(shapesInfo)) {
size_t dimOffset = rank - sInfo.value().size();
for (const auto &symExpr : llvm::enumerate(sInfo.value())) {
// Unit dimensions are neutral to the final result.
if (symExpr.value().isConstant(1)) continue;
// Use unique expression.
size_t i = dimOffset + symExpr.index();
if (!symResult[i]) {
symResult[i] = {sInfo.index(), symExpr.index(), symExpr.value()};
continue;
}
// Bail if the dimensions are neither equal nor 1.
if (symResult[i]->expr != symExpr.value()) return failure();
}
}
// Materialize broadcast result.
auto loc = op.getLoc();
DenseMap<int64_t, Value> constants;
auto findOrCreateConstant = [&](int64_t c) {
auto it = constants.find(c);
if (it != constants.end()) return it->second;
Value newlyCreated = rewriter.create<arith::ConstantIndexOp>(loc, c);
constants[c] = newlyCreated;
return newlyCreated;
};
auto elements = llvm::to_vector<8>(
llvm::map_range(symResult, [&](const auto &symResultDim) {
// If we know the dimension statically, use a constant.
if (!symResultDim) return findOrCreateConstant(1);
if (auto cexpr = symResultDim->expr.expr
.template dyn_cast<AffineConstantExpr>()) {
return findOrCreateConstant(cexpr.getValue());
}
// Othwerise, extract the dimension from the unique operand.
Value operand = shapes[symResultDim->operandIndex];
Value operandDim = findOrCreateConstant(symResultDim->operandDim);
return rewriter.create<tensor::ExtractOp>(loc, operand, operandDim)
.getResult();
}));
Type indexTy = rewriter.getIndexType();
Type concreteResultTy =
RankedTensorType::get({static_cast<int64_t>(elements.size())}, indexTy);
Value result = rewriter.create<tensor::FromElementsOp>(
loc, concreteResultTy, elements);
// Insert cast, if needed.
Type expectedTy = op.getResult().getType();
if (result.getType() != expectedTy) {
result = rewriter.create<tensor::CastOp>(loc, expectedTy, result);
}
rewriter.replaceOp(op, result);
return success();
}
};
LogicalResult analyzeDynamicBroadcastInDimExpandingBehavior(
ShapeComponentAnalysis &analysis, Value value, Value shape,
llvm::SmallSetVector<int64_t, 4> *knownExpandingDims,
llvm::SmallSetVector<int64_t, 4> *knownNonexpandingDims) {
// Require successful analysis of shapes.
auto shapeIn = analysis.GetShapeInfo(value);
auto shapeOut = analysis.GetValueInfo(shape);
if (!shapeIn || !shapeOut) return failure();
// Analyze per argument dimension.
size_t rankIn = shapeIn->size();
size_t rankOut = shapeOut->size();
assert(rankIn <= rankOut);
size_t dimOutOffset = rankOut - rankIn;
for (size_t i = 0; i < rankIn; ++i) {
SymbolicExpr dimIn = (*shapeIn)[i];
SymbolicExpr dimOut = (*shapeOut)[dimOutOffset + i];
if (dimIn.isConstant(1) && dimOut.isKnownNotOne())
knownExpandingDims->insert(i);
if (dimIn == dimOut || dimOut.isConstant(1))
knownNonexpandingDims->insert(i);
}
return success();
}
// Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for
// statically known expanding and non-expanding dimensions.
struct AnnotateExpandingDimensionsInDynamicBroadcastInDim
: public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op,
mlir::PatternRewriter &rewriter) const override {
// Analyze shapes and identify expanding and non-expanding dims.
ShapeComponentAnalysis analysis;
llvm::SmallSetVector<int64_t, 4> knownExpandingDims, knownNonexpandingDims;
if (failed(analyzeDynamicBroadcastInDimExpandingBehavior(
analysis, op.operand(), op.output_dimensions(), &knownExpandingDims,
&knownNonexpandingDims))) {
return failure();
}
// Collect possibly already annotated info.
auto insertAll = [](llvm::SmallSetVector<int64_t, 4> &dst,
Optional<DenseIntElementsAttr> src) {
if (!src) return;
for (auto it : *src) dst.insert(it.getLimitedValue());
};
insertAll(knownExpandingDims, op.known_expanding_dimensions());
insertAll(knownNonexpandingDims, op.known_nonexpanding_dimensions());
// Fail pattern application if there is nothing new to annotate.
auto isEqual = [](llvm::SmallSetVector<int64_t, 4> &set,
DenseIntElementsAttr attr) {
return set.size() == attr.size() && llvm::all_of(attr, [&](auto it) {
return set.count(it.getLimitedValue());
});
};
if (op.known_expanding_dimensions() && op.known_nonexpanding_dimensions() &&
isEqual(knownExpandingDims, *op.known_expanding_dimensions()) &&
isEqual(knownNonexpandingDims, *op.known_nonexpanding_dimensions())) {
return failure();
}
// Annotate op in place.
rewriter.startRootUpdate(op);
op.known_expanding_dimensionsAttr(
rewriter.getI64TensorAttr(knownExpandingDims.takeVector()));
op.known_nonexpanding_dimensionsAttr(
rewriter.getI64TensorAttr(knownNonexpandingDims.takeVector()));
rewriter.finalizeRootUpdate(op);
return success();
}
};
// Remove compute_reshape_shape if we can prove that the dynamic shape does not
// contain a `-1` dimension.
struct RemoveComputeReshapeShape final
: public OpRewritePattern<mhlo::ComputeReshapeShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ComputeReshapeShapeOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shapeComponentAnalysis;
auto dynamicShape = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynamicShape) return failure();
if (llvm::any_of(*dynamicShape, [](const auto &dim) {
return !dim.isKnownNotNegativeOne();
})) {
return failure();
}
rewriter.replaceOp(op, op.dynamic_shape());
return success();
}
};
bool isProduct(AffineExpr expr,
llvm::function_ref<void(AffineConstantExpr)> cbkConstantFactor,
llvm::function_ref<void(AffineSymbolExpr)> cbkSymbolicFactor) {
auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
if (binExpr && binExpr.getKind() == AffineExprKind::Mul) {
return isProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) &&
isProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor);
}
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
cbkSymbolicFactor(symExpr);
return true;
}
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
cbkConstantFactor(constExpr);
return true;
}
return false;
}
bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
llvm::function_ref<void(int64_t)> cbkConstantFactor,
llvm::function_ref<void(Symbol)> cbkSymbolicFactor) {
return isProduct(
symbolicExpr.expr,
[&](AffineConstantExpr cexpr) { cbkConstantFactor(cexpr.getValue()); },
[&](AffineSymbolExpr sexpr) {
cbkSymbolicFactor(symbolicExpr.symbols[sexpr.getPosition()]);
});
}
// Represents a product of symbolic and concrete factors. This will allow us to
// prove product equalities symbolically.
struct SymbolicProduct {
// Product of all concrete factors.
int64_t concrete = 1;
// List all symbolic factors as they can not be aggregated.
llvm::SmallVector<Symbol> symbolic;
bool empty() { return concrete == 1 && symbolic.empty(); }
};
bool isSymbolicProduct(const SymbolicExpr &symbolicExpr,
SymbolicProduct *product) {
return isSymbolicProduct(
symbolicExpr, [&](int64_t c) { product->concrete *= c; },
[&](Symbol s) { product->symbolic.push_back(s); });
}
struct RemoveRedundantCstrReshapable final
: public OpRewritePattern<mhlo::CstrReshapableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
PatternRewriter &rewriter) const override {
// Get shape analysis info for the number of elements.
ShapeComponentAnalysis shapeComponentAnalysis;
auto numElementsInfo =
shapeComponentAnalysis.GetValueInfo(op.num_elements());
if (!numElementsInfo) return failure();
assert(numElementsInfo->size() == 1 && "expect one value for a scalar");
auto numElements = numElementsInfo->front();
// Get shape analysis info for the dynamic shape.
auto dynShapeDims = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynShapeDims) return failure();
// We can handle two cases:
// - there is exactly one -1 in the dynamic shape, i.e. a unique wildcard
// dimension, or
// - there is no -1 in the dynamic shape, i.e. no wildcard dimension.
bool uniqueWildcardDimension = false;
for (const auto &d : *dynShapeDims) {
if (d.isConstant(-1)) {
if (uniqueWildcardDimension) return failure();
uniqueWildcardDimension = true;
} else if (!d.isKnownNotNegativeOne()) {
return failure();
}
}
// We can only handle simple products with constants and symbols. Find all
// the factors based on the number of elements.
SymbolicProduct numElementsRemainingFactors;
if (!isSymbolicProduct(numElements, &numElementsRemainingFactors)) {
return failure();
}
assert(numElementsRemainingFactors.concrete >= 1 &&
"number of elements cannot entail negative or zero factors");
// Find all factors based on the dynamic shape.
// - Accumulate the conrete product to later compare it against its
// equivalent based on the number of elements.
// - Remove symbolic factors from the list and fail if we find an unknown
// factor, i.e. if the symbolic factors based on the dynamic shape are
// not a subset of the factors based on the number of elements.
int64_t concreteProductDynShape = 1;
for (const auto &dim : *dynShapeDims) {
SmallVector<Symbol> partialSymbolicFactorsDynShape;
if (!isSymbolicProduct(
dim,
[&](int64_t c) {
if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c;
},
[&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) {
return failure();
}
for (const Symbol &symDynShape : partialSymbolicFactorsDynShape) {
auto *it =
llvm::find(numElementsRemainingFactors.symbolic, symDynShape);
if (it == numElementsRemainingFactors.symbolic.end()) return failure();
numElementsRemainingFactors.symbolic.erase(it);
}
}
assert(concreteProductDynShape >= 1 &&
"concrete product must not aggregate negative or zero factors");
// A wildcard dimension can subsume the remaining symbolic factors and
// potentially also a concrete factor.
if (uniqueWildcardDimension) {
if (numElementsRemainingFactors.concrete % concreteProductDynShape != 0)
return failure();
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
// W/o a wildcard, the symbolic and concrete products must be equal.
bool isReshapable =
numElementsRemainingFactors.symbolic.empty() &&
numElementsRemainingFactors.concrete == concreteProductDynShape;
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, isReshapable);
return success();
}
};
LogicalResult materializeReshapeAsScalarExpand(RankedTensorType operandTy,
RankedTensorType resultTy,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(operandTy.getRank() == 0 && "expect scalar operand");
auto loc = op.getLoc();
SmallVector<int64_t> unitDims(resultTy.getRank(), 1);
auto expandedTy = RankedTensorType::get(unitDims, resultTy.getElementType());
Value expandedScalar = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedTy, op.operand(), ArrayRef<ReassociationIndices>{});
if (expandedScalar.getType() != resultTy) {
expandedScalar =
rewriter.create<tensor::CastOp>(loc, resultTy, expandedScalar);
}
rewriter.replaceOp(op, expandedScalar);
return success();
}
LogicalResult materializeReshapeAsScalarCollapse(RankedTensorType operandTy,
RankedTensorType resultTy,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(resultTy.getRank() == 0 && "expect scalar result");
auto loc = op.getLoc();
Value operand = op.operand();
SmallVector<int64_t> unitDims(operandTy.getRank(), 1);
auto castedOperandTy =
RankedTensorType::get(unitDims, operandTy.getElementType());
if (operand.getType() != castedOperandTy) {
operand = rewriter.create<tensor::CastOp>(loc, castedOperandTy, operand);
}
Value collapsedScalar = rewriter.create<tensor::CollapseShapeOp>(
loc, operand, ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, collapsedScalar);
return success();
}
enum class DimensionGroupKind {
kNone,
kExpanding,
kCollapsing,
};
struct DimensionGroup {
int64_t size = 0;
DimensionGroupKind kind = DimensionGroupKind::kNone;
};
SymbolicProduct eliminateCommonFactors(SymbolicProduct &a, SymbolicProduct &b) {
SymbolicProduct gcd;
// Eliminate common concrete factors.
gcd.concrete = llvm::GreatestCommonDivisor64(a.concrete, b.concrete);
a.concrete /= gcd.concrete;
b.concrete /= gcd.concrete;
// Eliminate common symbolic factors.
int64_t i = 0;
while (i < a.symbolic.size()) {
auto *it = llvm::find(b.symbolic, a.symbolic[i]);
if (it != b.symbolic.end()) {
gcd.symbolic.push_back(*it);
std::swap(a.symbolic[i], a.symbolic.back());
a.symbolic.pop_back();
b.symbolic.erase(it);
} else {
i++;
}
}
return gcd;
}
bool isUnpairedUnitDimension(
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherIt,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator otherEnd) {
return it != end && it->isConstant(1) &&
!(otherIt != otherEnd && otherIt->isConstant(1));
}
int64_t getShapedTypyDimSize(const SymbolicProduct &symProduct) {
return symProduct.symbolic.empty() ? symProduct.concrete
: ShapedType::kDynamicSize;
}
// Iterate over the operand's and the result's shape dimensions and find
// dimension groups that are collapsing, expanding, or untouched:
// - Collapsing: Multiple dimensions of the operand shape can be collapsed
// into a single dimension of the result shape. We must prove that the
// product of the operand shape's dimensions is equal to the corresponding
// result dimension.
// - Expanding: A single dimension of the operand shape can be expanded into
// multiple dimensions of the result shape. We must prove that the product
// of the result shape's dimensions is equal to the corresponding operand
// dimension. This case is limited to at most one dynamic dimension per
// expansion group as otherwise not supported by the `expand_shape` op.
// - Untouched: There is a 1:1 correspondance between an operand and a result
// shape dimension.
//
// We can determine the optimal dimension groups greedily by consuming operand
// and result dimensions from left to right. If the leading operand dimension is
// a strict divisor of the leading result dimension, collapsing is required. In
// this case, we keep consuming the operand dimensions until the products are
// equal. If the leading result dimension is a strict divisor of the leading
// operand dimension, expanding is required. In this case, we keep consuming the
// result dimensions until the products are equal. Trailing unit dimensions may
// be inlcuded in the dimension group. This is useful iff they are "unpaired",
// in which case they would only limit us in the subsequent iteration.
//
LogicalResult findExpandingAndCollapsingDimensionGroups(
ArrayRef<SymbolicExpr> operandShapeInfo,
ArrayRef<SymbolicExpr> resultShapeInfo,
SmallVector<DimensionGroup> *dimensionGroups,
SmallVector<int64_t> *expandedIntermShape) {
const auto *operandShapeIt = operandShapeInfo.begin();
const auto *operandShapeEnd = operandShapeInfo.end();
const auto *resultShapeIt = resultShapeInfo.begin();
const auto *resultShapeEnd = resultShapeInfo.end();
// Crucial iteration state.
SymbolicProduct remainingOperandShapeFactors;
SymbolicProduct remainingResultShapeFactors;
auto anyRemainingFactors = [&]() {
return !remainingOperandShapeFactors.empty() ||
!remainingResultShapeFactors.empty();
};
while (operandShapeIt != operandShapeEnd && resultShapeIt != resultShapeEnd) {
assert(!anyRemainingFactors() &&
"expect no remaining factors from previous iteration");
DimensionGroup &dimGroup = dimensionGroups->emplace_back();
// Consume at least one operand and result dimension.
{
if (!isSymbolicProduct(*operandShapeIt++,
&remainingOperandShapeFactors) ||
!isSymbolicProduct(*resultShapeIt++, &remainingResultShapeFactors)) {
return failure();
}
dimGroup.size++;
SymbolicProduct gcd = eliminateCommonFactors(remainingOperandShapeFactors,
remainingResultShapeFactors);
expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
}
// Fail if there are unresolvable, contradicting factors remaining.
if (!remainingOperandShapeFactors.empty() &&
!remainingResultShapeFactors.empty()) {
return failure();
}
// Collapsing: Create a collapsing dimension group.
bool requiresCollapsing =
remainingOperandShapeFactors.empty() &&
(!remainingResultShapeFactors.empty() ||
isUnpairedUnitDimension(operandShapeIt, operandShapeEnd, resultShapeIt,
resultShapeEnd));
if (requiresCollapsing) {
dimGroup.kind = DimensionGroupKind::kCollapsing;
// Consume operand shape dimensions until their product matches the
// corresponding result dimension (or fail if unresolvable/contradicting
// factors are found).
while (operandShapeIt != operandShapeEnd &&
remainingOperandShapeFactors.empty() &&
!remainingResultShapeFactors.empty()) {
if (!isSymbolicProduct(*operandShapeIt++,
&remainingOperandShapeFactors)) {
return failure();
}
dimGroup.size++;
SymbolicProduct gcd = eliminateCommonFactors(
remainingOperandShapeFactors, remainingResultShapeFactors);
expandedIntermShape->push_back(getShapedTypyDimSize(gcd));
}
if (anyRemainingFactors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (isUnpairedUnitDimension(operandShapeIt, operandShapeEnd,
resultShapeIt, resultShapeEnd)) {
operandShapeIt++;
dimGroup.size++;
expandedIntermShape->push_back(1);
}
continue;
}
// Expanding: Create an expanding dimension group.
bool requiresExpanding =
remainingResultShapeFactors.empty() &&
(!remainingOperandShapeFactors.empty() ||
isUnpairedUnitDimension(resultShapeIt, resultShapeEnd, operandShapeIt,
operandShapeEnd));
if (requiresExpanding) {
dimGroup.kind = DimensionGroupKind::kExpanding;
int64_t numDynamicDims = 0;
// Consume result shape dimensions until their product matches the
// corresponding operand dimension (or fail if unresolvable/contradicting
// factors are found).
while (resultShapeIt != resultShapeEnd &&
remainingResultShapeFactors.empty() &&
!remainingOperandShapeFactors.empty()) {
if (!isSymbolicProduct(*resultShapeIt++,
&remainingResultShapeFactors)) {
return failure();
}
dimGroup.size++;
SymbolicProduct gcd = eliminateCommonFactors(
remainingOperandShapeFactors, remainingResultShapeFactors);
int64_t tyDimSize = getShapedTypyDimSize(gcd);
// Allow no more than one dynamic dimension per expansion group.
if (tyDimSize == ShapedType::kDynamicSize) {
numDynamicDims++;
if (numDynamicDims > 1) return failure();
}
expandedIntermShape->push_back(tyDimSize);
}
if (anyRemainingFactors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (isUnpairedUnitDimension(resultShapeIt, resultShapeEnd,
operandShapeIt, operandShapeEnd)) {
resultShapeIt++;
dimGroup.size++;
expandedIntermShape->push_back(1);
}
continue;
}
// Untouched: 1:1 mapping between operand and result shape dimension. This
// is neither expanding nor collapsing.
assert(!requiresCollapsing && !requiresExpanding && "expect id case");
assert(dimGroup.size == 1 && dimGroup.kind == DimensionGroupKind::kNone &&
"expect simple dimension group");
}
// Fail if there are remaining dimensions that could not be consumed.
assert(!anyRemainingFactors() && "expect no remaining factors");
if (operandShapeIt != operandShapeEnd || resultShapeIt != resultShapeEnd) {
return failure();
}
return success();
}
SmallVector<int64_t> concretizeOperandShape(
ArrayRef<int64_t> operandShape, ArrayRef<SymbolicExpr> operandShapeInfo) {
SmallVector<int64_t> result;
for (auto it : llvm::zip(operandShape, operandShapeInfo)) {
auto dimSize = std::get<0>(it);
auto sExpr = std::get<1>(it);
if (auto cexpr = sExpr.expr.dyn_cast<AffineConstantExpr>()) {
int64_t alsoDimSize = cexpr.getValue();
assert((ShapedType::isDynamic(dimSize) || dimSize == alsoDimSize) &&
"expect shape analysis result to be compatible with type");
result.push_back(alsoDimSize);
continue;
}
result.push_back(dimSize);
}
return result;
}
llvm::Optional<SmallVector<ReassociationIndices>> requiresReassociationOfKind(
DimensionGroupKind kind, const SmallVector<DimensionGroup> &dimGroups) {
SmallVector<ReassociationIndices> reassociation;
reassociation.reserve(dimGroups.size());
bool isStrictlyReassociating = false;
int64_t i = 0;
for (const DimensionGroup &g : dimGroups) {
if (g.kind == kind) {
isStrictlyReassociating = true;
reassociation.push_back(
llvm::to_vector(llvm::seq<int64_t>(i, i + g.size)));
i += g.size;
continue;
}
for (int64_t j = 0; j < g.size; j++) reassociation.push_back({i++});
}
// Return the reassociation if expansion is required.
if (isStrictlyReassociating) return reassociation;
return llvm::None;
}
LogicalResult materializeReshapeAsExpandAndCollapse(
ShapeComponentAnalysis &shapeAnalysis, RankedTensorType operandTy,
RankedTensorType resultTy, mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
// Require sucessful shape analysis for operand and result shape.
auto operandShapeInfo = shapeAnalysis.GetShapeInfo(op.operand());
if (!operandShapeInfo) return failure();
auto resultShapeInfo = shapeAnalysis.GetValueInfo(op.output_shape());
if (!resultShapeInfo) return failure();
// Identify dimension groups and the intermediate expanded type.
SmallVector<DimensionGroup> dimensionGroups;
SmallVector<int64_t> expandedIntermShape;
if (failed(findExpandingAndCollapsingDimensionGroups(
*operandShapeInfo, *resultShapeInfo, &dimensionGroups,
&expandedIntermShape))) {
return failure();
}
// Materialize cast, expand, collapse, and cast, as needed.
auto loc = op.getLoc();
Value interm = op.operand();
auto castedOperandTy = RankedTensorType::get(
concretizeOperandShape(operandTy.getShape(), *operandShapeInfo),
operandTy.getElementType());
if (operandTy != castedOperandTy) {
interm = rewriter.create<tensor::CastOp>(loc, castedOperandTy, interm);
}
if (auto reassociation = requiresReassociationOfKind(
DimensionGroupKind::kExpanding, dimensionGroups)) {
interm = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(expandedIntermShape, operandTy.getElementType()),
interm, *reassociation);
}
if (auto reassociation = requiresReassociationOfKind(
DimensionGroupKind::kCollapsing, dimensionGroups)) {
interm =
rewriter.create<tensor::CollapseShapeOp>(loc, interm, *reassociation);
}
if (interm.getType() != resultTy) {
interm = rewriter.create<tensor::CastOp>(loc, resultTy, interm);
}
rewriter.replaceOp(op, interm);
return success();
}
// Tries to express `dynamic_reshape` ops through `expand_shape` and
// `collapse_shape` ops.
struct DynamicReshapeToExpandAndCollapseShape final
: public OpRewritePattern<mhlo::DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) const override {
auto operandTy = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operandTy) return failure();
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
if (!resultTy) return failure();
// Handle degenerate scalar expand case.
if (operandTy.getRank() == 0) {
return materializeReshapeAsScalarExpand(operandTy, resultTy, op,
rewriter);
}
// Handle degenerate scalar collapse case.
if (resultTy.getRank() == 0) {
return materializeReshapeAsScalarCollapse(operandTy, resultTy, op,
rewriter);
}
ShapeComponentAnalysis shapeAnalysis;
return materializeReshapeAsExpandAndCollapse(shapeAnalysis, operandTy,
resultTy, op, rewriter);
}
};
// Returns true if all of bcasted_shapes can be broadcasted with output_shape.
bool isKnownBroadcastable(ShapeComponentAnalysis &analysis,
ValueRange bcastedShapes, Value outputShape) {
auto outputShapeDims = analysis.GetValueInfo(outputShape);
if (!outputShapeDims) return false;
for (Value shape : bcastedShapes) {
auto shapeDims = analysis.GetValueInfo(shape);
if (!shapeDims) return false;
// Iterate backwards over the smallest input shape.
for (auto zip : llvm::zip(llvm::reverse(*outputShapeDims),
llvm::reverse(*shapeDims))) {
const auto &first = std::get<0>(zip);
const auto &second = std::get<1>(zip);
// TODO(ezhulenev): What to do with dimensions statically known to be
// zero?
// Numpy can only broadcast [0] with [1], however Tensorflow can broadcast
// [0] with any dimension size, and produces dimension of size [0].
// Currently we'll conservatively return failure and will not proceed with
// a rewrite.
if (first.isConstant(0) || second.isConstant(0)) return false;
// If either shape has a static one dimension the broadcast will always
// succeed.
if (first.isConstant(1) || second.isConstant(1)) continue;
// Otherwise dims have to be equal.
if (first != second) return false;
}
}
return true;
}
// Rewrite `shape.cstr_broadcastable` with constant witness if can prove that
// shapes are broadcastable from a symbolic analysis.
struct CstrBroadcastableOpLowering
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shapeComponentAnalysis;
if (!isKnownBroadcastable(shapeComponentAnalysis, op.getShapes(),
op.getShapes().front())) {
return failure();
}
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
class SymbolicShapeOptimizationPass final
: public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
// clang-format off
patterns.insert<
AnnotateExpandingDimensionsInDynamicBroadcastInDim,
CstrBroadcastableOpLowering,
DynamicReshapeToExpandAndCollapseShape,
RemoveComputeReshapeShape,
RemoveRedundantCstrReshapable,
SimplifyBroadcasts>(ctx);
// clang-format on
shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
}
}
};
} // end namespace
std::unique_ptr<OperationPass<func::FuncOp>>
createSymbolicShapeOptimizationPass() {
return std::make_unique<SymbolicShapeOptimizationPass>();
}
} // end namespace mlir