blob: 008dc2a79f281456803f552b0245157c7477df23 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <cstdint>
#include <iterator>
#include <memory>
#include <numeric>
#include <utility>
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Analysis/shape_component_analysis.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Transforms/PassDetail.h"
#include "mlir-hlo/Transforms/passes.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
namespace mlir {
using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo;
using Symbol = ShapeComponentAnalysis::Symbol;
using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr;
namespace {
// Temporary data structure to hold a single dimension of the symbolic result of
// `shape.broadcast`.
struct SymbolicBroadcastDimension {
size_t operand_index;
size_t operand_dim;
SymbolicExpr expr;
};
// Replace shape.broadcast with a shape if it's statically known.
struct SimplifyBroadcasts : public mlir::OpRewritePattern<shape::BroadcastOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
shape::BroadcastOp op, mlir::PatternRewriter &rewriter) const override {
// Require successful shape analysis.
ShapeComponentAnalysis shape_analysis;
llvm::SmallVector<ArrayRef<SymbolicExpr>> shapes_info;
auto shapes = op.getShapes();
shapes_info.reserve(shapes.size());
for (Value s : shapes) {
auto s_info = shape_analysis.GetValueInfo(s);
if (!s_info) return failure();
shapes_info.push_back(*s_info);
}
// Find the result rank.
size_t rank = 0;
for (const auto &s_info : shapes_info) rank = std::max(rank, s_info.size());
// Compute broadcast symbolically.
SmallVector<Optional<SymbolicBroadcastDimension>> sym_result(rank,
llvm::None);
for (const auto &s_info : llvm::enumerate(shapes_info)) {
size_t dim_offset = rank - s_info.value().size();
for (const auto &sym_expr : llvm::enumerate(s_info.value())) {
// Unit dimensions are neutral to the final result.
if (sym_expr.value().isConstant(1)) continue;
// Use unique expression.
size_t i = dim_offset + sym_expr.index();
if (!sym_result[i]) {
sym_result[i] = {s_info.index(), sym_expr.index(), sym_expr.value()};
continue;
}
// Bail if the dimensions are neither equal nor 1.
if (sym_result[i]->expr != sym_expr.value()) return failure();
}
}
// Materialize broadcast result.
auto loc = op.getLoc();
DenseMap<int64_t, Value> constants;
auto find_or_create_constant = [&](int64_t c) {
auto it = constants.find(c);
if (it != constants.end()) return it->second;
Value newly_created = rewriter.create<arith::ConstantIndexOp>(loc, c);
constants[c] = newly_created;
return newly_created;
};
auto elements = llvm::to_vector<8>(
llvm::map_range(sym_result, [&](const auto &sym_result_dim) {
// If we know the dimension statically, use a constant.
if (!sym_result_dim) return find_or_create_constant(1);
if (auto cexpr = sym_result_dim->expr.expr
.template dyn_cast<AffineConstantExpr>()) {
return find_or_create_constant(cexpr.getValue());
}
// Othwerise, extract the dimension from the unique operand.
Value operand = shapes[sym_result_dim->operand_index];
Value operand_dim =
find_or_create_constant(sym_result_dim->operand_dim);
return rewriter.create<tensor::ExtractOp>(loc, operand, operand_dim)
.getResult();
}));
Type index_ty = rewriter.getIndexType();
Type concrete_result_ty = RankedTensorType::get(
{static_cast<int64_t>(elements.size())}, index_ty);
Value result = rewriter.create<tensor::FromElementsOp>(
loc, concrete_result_ty, elements);
// Insert cast, if needed.
Type expected_ty = op.getResult().getType();
if (result.getType() != expected_ty) {
result = rewriter.create<tensor::CastOp>(loc, expected_ty, result);
}
rewriter.replaceOp(op, result);
return success();
}
};
LogicalResult AnalyzeDynamicBroadcastInDimExpandingBehavior(
ShapeComponentAnalysis &analysis, Value value, Value shape,
llvm::SmallSetVector<int64_t, 4> *known_expanding_dims,
llvm::SmallSetVector<int64_t, 4> *known_nonexpanding_dims) {
// Require successful analysis of shapes.
auto shape_in = analysis.GetShapeInfo(value);
auto shape_out = analysis.GetValueInfo(shape);
if (!shape_in || !shape_out) return failure();
// Analyze per argument dimension.
size_t rank_in = shape_in->size();
size_t rank_out = shape_out->size();
assert(rank_in <= rank_out);
size_t dim_out_offset = rank_out - rank_in;
for (size_t i = 0; i < rank_in; ++i) {
SymbolicExpr dim_in = (*shape_in)[i];
SymbolicExpr dim_out = (*shape_out)[dim_out_offset + i];
if (dim_in.isConstant(1) && dim_out.isKnownNotOne())
known_expanding_dims->insert(i);
if (dim_in == dim_out || dim_out.isConstant(1))
known_nonexpanding_dims->insert(i);
}
return success();
}
// Analyze `mhlo.dynamic_broadcast_in_dim` op and populate attributes for
// statically known expanding and non-expanding dimensions.
struct AnnotateExpandingDimensionsInDynamicBroadcastInDim
: public mlir::OpRewritePattern<mhlo::DynamicBroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(
mhlo::DynamicBroadcastInDimOp op,
mlir::PatternRewriter &rewriter) const override {
// Analyze shapes and identify expanding and non-expanding dims.
ShapeComponentAnalysis analysis;
llvm::SmallSetVector<int64_t, 4> known_expanding_dims,
known_nonexpanding_dims;
if (failed(AnalyzeDynamicBroadcastInDimExpandingBehavior(
analysis, op.operand(), op.output_dimensions(),
&known_expanding_dims, &known_nonexpanding_dims))) {
return failure();
}
// Collect possibly already annotated info.
auto insert_all = [](llvm::SmallSetVector<int64_t, 4> &dst,
Optional<DenseIntElementsAttr> src) {
if (!src) return;
for (auto it : *src) dst.insert(it.getLimitedValue());
};
insert_all(known_expanding_dims, op.known_expanding_dimensions());
insert_all(known_nonexpanding_dims, op.known_nonexpanding_dimensions());
// Fail pattern application if there is nothing new to annotate.
auto is_equal = [](llvm::SmallSetVector<int64_t, 4> &set,
DenseIntElementsAttr attr) {
return set.size() == attr.size() && llvm::all_of(attr, [&](auto it) {
return set.count(it.getLimitedValue());
});
};
if (op.known_expanding_dimensions() && op.known_nonexpanding_dimensions() &&
is_equal(known_expanding_dims, *op.known_expanding_dimensions()) &&
is_equal(known_nonexpanding_dims,
*op.known_nonexpanding_dimensions())) {
return failure();
}
// Annotate op in place.
rewriter.startRootUpdate(op);
op.known_expanding_dimensionsAttr(
rewriter.getI64TensorAttr(known_expanding_dims.takeVector()));
op.known_nonexpanding_dimensionsAttr(
rewriter.getI64TensorAttr(known_nonexpanding_dims.takeVector()));
rewriter.finalizeRootUpdate(op);
return success();
}
};
// Remove compute_reshape_shape if we can prove that the dynamic shape does not
// contain a `-1` dimension.
struct RemoveComputeReshapeShape final
: public OpRewritePattern<mhlo::ComputeReshapeShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::ComputeReshapeShapeOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shapeComponentAnalysis;
auto dynamic_shape =
shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynamic_shape) return failure();
if (llvm::any_of(*dynamic_shape, [](const auto &dim) {
return !dim.isKnownNotNegativeOne();
})) {
return failure();
}
rewriter.replaceOp(op, op.dynamic_shape());
return success();
}
};
bool IsProduct(AffineExpr expr,
llvm::function_ref<void(AffineConstantExpr)> cbkConstantFactor,
llvm::function_ref<void(AffineSymbolExpr)> cbkSymbolicFactor) {
auto binExpr = expr.dyn_cast<AffineBinaryOpExpr>();
if (binExpr && binExpr.getKind() == AffineExprKind::Mul) {
return IsProduct(binExpr.getLHS(), cbkConstantFactor, cbkSymbolicFactor) &&
IsProduct(binExpr.getRHS(), cbkConstantFactor, cbkSymbolicFactor);
}
if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
cbkSymbolicFactor(symExpr);
return true;
}
if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
cbkConstantFactor(constExpr);
return true;
}
return false;
}
bool IsSymbolicProduct(const SymbolicExpr &symbolicExpr,
llvm::function_ref<void(int64_t)> cbkConstantFactor,
llvm::function_ref<void(Symbol)> cbkSymbolicFactor) {
return IsProduct(
symbolicExpr.expr,
[&](AffineConstantExpr cexpr) { cbkConstantFactor(cexpr.getValue()); },
[&](AffineSymbolExpr sexpr) {
cbkSymbolicFactor(symbolicExpr.symbols[sexpr.getPosition()]);
});
}
// Represents a product of symbolic and concrete factors. This will allow us to
// prove product equalities symbolically.
struct SymbolicProduct {
// Product of all concrete factors.
int64_t concrete = 1;
// List all symbolic factors as they can not be aggregated.
llvm::SmallVector<Symbol> symbolic;
bool empty() { return concrete == 1 && symbolic.empty(); }
};
bool IsSymbolicProduct(const SymbolicExpr &symbolicExpr,
SymbolicProduct *product) {
return IsSymbolicProduct(
symbolicExpr, [&](int64_t c) { product->concrete *= c; },
[&](Symbol s) { product->symbolic.push_back(s); });
}
struct RemoveRedundantCstrReshapable final
: public OpRewritePattern<mhlo::CstrReshapableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::CstrReshapableOp op,
PatternRewriter &rewriter) const override {
// Get shape analysis info for the number of elements.
ShapeComponentAnalysis shapeComponentAnalysis;
auto numElementsInfo =
shapeComponentAnalysis.GetValueInfo(op.num_elements());
if (!numElementsInfo) return failure();
assert(numElementsInfo->size() == 1 && "expect one value for a scalar");
auto numElements = numElementsInfo->front();
// Get shape analysis info for the dynamic shape.
auto dynShapeDims = shapeComponentAnalysis.GetValueInfo(op.dynamic_shape());
if (!dynShapeDims) return failure();
// We can handle two cases:
// - there is exactly one -1 in the dynamic shape, i.e. a unique wildcard
// dimension, or
// - there is no -1 in the dynamic shape, i.e. no wildcard dimension.
bool unique_wildcard_dimension = false;
for (const auto &d : *dynShapeDims) {
if (d.isConstant(-1)) {
if (unique_wildcard_dimension) return failure();
unique_wildcard_dimension = true;
} else if (!d.isKnownNotNegativeOne()) {
return failure();
}
}
// We can only handle simple products with constants and symbols. Find all
// the factors based on the number of elements.
SymbolicProduct numElementsRemainingFactors;
if (!IsSymbolicProduct(numElements, &numElementsRemainingFactors)) {
return failure();
}
assert(numElementsRemainingFactors.concrete >= 1 &&
"number of elements cannot entail negative or zero factors");
// Find all factors based on the dynamic shape.
// - Accumulate the conrete product to later compare it against its
// equivalent based on the number of elements.
// - Remove symbolic factors from the list and fail if we find an unknown
// factor, i.e. if the symbolic factors based on the dynamic shape are
// not a subset of the factors based on the number of elements.
int64_t concreteProductDynShape = 1;
for (const auto &dim : *dynShapeDims) {
SmallVector<Symbol> partialSymbolicFactorsDynShape;
if (!IsSymbolicProduct(
dim,
[&](int64_t c) {
if (c != ShapedType::kDynamicSize) concreteProductDynShape *= c;
},
[&](Symbol s) { partialSymbolicFactorsDynShape.push_back(s); })) {
return failure();
}
for (const Symbol &symDynShape : partialSymbolicFactorsDynShape) {
auto *it =
llvm::find(numElementsRemainingFactors.symbolic, symDynShape);
if (it == numElementsRemainingFactors.symbolic.end()) return failure();
numElementsRemainingFactors.symbolic.erase(it);
}
}
assert(concreteProductDynShape >= 1 &&
"concrete product must not aggregate negative or zero factors");
// A wildcard dimension can subsume the remaining symbolic factors and
// potentially also a concrete factor.
if (unique_wildcard_dimension) {
if (numElementsRemainingFactors.concrete % concreteProductDynShape != 0)
return failure();
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
// W/o a wildcard, the symbolic and concrete products must be equal.
bool isReshapable =
numElementsRemainingFactors.symbolic.empty() &&
numElementsRemainingFactors.concrete == concreteProductDynShape;
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, isReshapable);
return success();
}
};
LogicalResult MaterializeReshapeAsScalarExpand(RankedTensorType operand_ty,
RankedTensorType result_ty,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(operand_ty.getRank() == 0 && "expect scalar operand");
auto loc = op.getLoc();
SmallVector<int64_t> unit_dims(result_ty.getRank(), 1);
auto expanded_ty =
RankedTensorType::get(unit_dims, result_ty.getElementType());
Value expanded_scalar = rewriter.create<tensor::ExpandShapeOp>(
loc, expanded_ty, op.operand(), ArrayRef<ReassociationIndices>{});
if (expanded_scalar.getType() != result_ty) {
expanded_scalar =
rewriter.create<tensor::CastOp>(loc, result_ty, expanded_scalar);
}
rewriter.replaceOp(op, expanded_scalar);
return success();
}
LogicalResult MaterializeReshapeAsScalarCollapse(RankedTensorType operand_ty,
RankedTensorType result_ty,
mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
assert(result_ty.getRank() == 0 && "expect scalar result");
auto loc = op.getLoc();
Value operand = op.operand();
SmallVector<int64_t> unit_dims(operand_ty.getRank(), 1);
auto casted_operand_ty =
RankedTensorType::get(unit_dims, operand_ty.getElementType());
if (operand.getType() != casted_operand_ty) {
operand = rewriter.create<tensor::CastOp>(loc, casted_operand_ty, operand);
}
Value collapsed_scalar = rewriter.create<tensor::CollapseShapeOp>(
loc, operand, ArrayRef<ReassociationIndices>{});
rewriter.replaceOp(op, collapsed_scalar);
return success();
}
enum class DimensionGroupKind {
kNone,
kExpanding,
kCollapsing,
};
struct DimensionGroup {
int64_t size = 0;
DimensionGroupKind kind = DimensionGroupKind::kNone;
};
SymbolicProduct EliminateCommonFactors(SymbolicProduct &a, SymbolicProduct &b) {
SymbolicProduct gcd;
// Eliminate common concrete factors.
gcd.concrete = llvm::GreatestCommonDivisor64(a.concrete, b.concrete);
a.concrete /= gcd.concrete;
b.concrete /= gcd.concrete;
// Eliminate common symbolic factors.
int64_t i = 0;
while (i < a.symbolic.size()) {
auto it = llvm::find(b.symbolic, a.symbolic[i]);
if (it != b.symbolic.end()) {
gcd.symbolic.push_back(*it);
std::swap(a.symbolic[i], a.symbolic.back());
a.symbolic.pop_back();
b.symbolic.erase(it);
} else {
i++;
}
}
return gcd;
}
bool IsUnpairedUnitDimension(
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator it,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator end,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator other_it,
ArrayRef<ShapeComponentAnalysis::SymbolicExpr>::iterator other_end) {
return it != end && it->isConstant(1) &&
!(other_it != other_end && other_it->isConstant(1));
}
int64_t GetShapedTypyDimSize(const SymbolicProduct &sym_product) {
return sym_product.symbolic.empty() ? sym_product.concrete
: ShapedType::kDynamicSize;
}
// Iterate over the operand's and the result's shape dimensions and find
// dimension groups that are collapsing, expanding, or untouched:
// - Collapsing: Multiple dimensions of the operand shape can be collapsed
// into a single dimension of the result shape. We must prove that the
// product of the operand shape's dimensions is equal to the corresponding
// result dimension.
// - Expanding: A single dimension of the operand shape can be expanded into
// multiple dimensions of the result shape. We must prove that the product
// of the result shape's dimensions is equal to the corresponding operand
// dimension. This case is limited to at most one dynamic dimension per
// expansion group as otherwise not supported by the `expand_shape` op.
// - Untouched: There is a 1:1 correspondance between an operand and a result
// shape dimension.
//
// We can determine the optimal dimension groups greedily by consuming operand
// and result dimensions from left to right. If the leading operand dimension is
// a strict divisor of the leading result dimension, collapsing is required. In
// this case, we keep consuming the operand dimensions until the products are
// equal. If the leading result dimension is a strict divisor of the leading
// operand dimension, expanding is required. In this case, we keep consuming the
// result dimensions until the products are equal. Trailing unit dimensions may
// be inlcuded in the dimension group. This is useful iff they are "unpaired",
// in which case they would only limit us in the subsequent iteration.
//
LogicalResult FindExpandingAndCollapsingDimensionGroups(
ArrayRef<SymbolicExpr> operand_shape_info,
ArrayRef<SymbolicExpr> result_shape_info,
SmallVector<DimensionGroup> *dimension_groups,
SmallVector<int64_t> *expanded_interm_shape) {
auto operand_shape_it = operand_shape_info.begin();
auto operand_shape_end = operand_shape_info.end();
auto result_shape_it = result_shape_info.begin();
auto result_shape_end = result_shape_info.end();
// Crucial iteration state.
SymbolicProduct remaining_operand_shape_factors;
SymbolicProduct remaining_result_shape_factors;
auto any_remaining_factors = [&]() {
return !remaining_operand_shape_factors.empty() ||
!remaining_result_shape_factors.empty();
};
while (operand_shape_it != operand_shape_end &&
result_shape_it != result_shape_end) {
assert(!any_remaining_factors() &&
"expect no remaining factors from previous iteration");
DimensionGroup &dim_group = dimension_groups->emplace_back();
// Consume at least one operand and result dimension.
{
if (!IsSymbolicProduct(*operand_shape_it++,
&remaining_operand_shape_factors) ||
!IsSymbolicProduct(*result_shape_it++,
&remaining_result_shape_factors)) {
return failure();
}
dim_group.size++;
SymbolicProduct gcd = EliminateCommonFactors(
remaining_operand_shape_factors, remaining_result_shape_factors);
expanded_interm_shape->push_back(GetShapedTypyDimSize(gcd));
}
// Fail if there are unresolvable, contradicting factors remaining.
if (!remaining_operand_shape_factors.empty() &&
!remaining_result_shape_factors.empty()) {
return failure();
}
// Collapsing: Create a collapsing dimension group.
bool requires_collapsing =
remaining_operand_shape_factors.empty() &&
(!remaining_result_shape_factors.empty() ||
IsUnpairedUnitDimension(operand_shape_it, operand_shape_end,
result_shape_it, result_shape_end));
if (requires_collapsing) {
dim_group.kind = DimensionGroupKind::kCollapsing;
// Consume operand shape dimensions until their product matches the
// corresponding result dimension (or fail if unresolvable/contradicting
// factors are found).
while (operand_shape_it != operand_shape_end &&
remaining_operand_shape_factors.empty() &&
!remaining_result_shape_factors.empty()) {
if (!IsSymbolicProduct(*operand_shape_it++,
&remaining_operand_shape_factors)) {
return failure();
}
dim_group.size++;
SymbolicProduct gcd = EliminateCommonFactors(
remaining_operand_shape_factors, remaining_result_shape_factors);
expanded_interm_shape->push_back(GetShapedTypyDimSize(gcd));
}
if (any_remaining_factors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (IsUnpairedUnitDimension(operand_shape_it, operand_shape_end,
result_shape_it, result_shape_end)) {
operand_shape_it++;
dim_group.size++;
expanded_interm_shape->push_back(1);
}
continue;
}
// Expanding: Create an expanding dimension group.
bool requires_expanding =
remaining_result_shape_factors.empty() &&
(!remaining_operand_shape_factors.empty() ||
IsUnpairedUnitDimension(result_shape_it, result_shape_end,
operand_shape_it, operand_shape_end));
if (requires_expanding) {
dim_group.kind = DimensionGroupKind::kExpanding;
int64_t num_dynamic_dims = 0;
// Consume result shape dimensions until their product matches the
// corresponding operand dimension (or fail if unresolvable/contradicting
// factors are found).
while (result_shape_it != result_shape_end &&
remaining_result_shape_factors.empty() &&
!remaining_operand_shape_factors.empty()) {
if (!IsSymbolicProduct(*result_shape_it++,
&remaining_result_shape_factors)) {
return failure();
}
dim_group.size++;
SymbolicProduct gcd = EliminateCommonFactors(
remaining_operand_shape_factors, remaining_result_shape_factors);
int64_t ty_dim_size = GetShapedTypyDimSize(gcd);
// Allow no more than one dynamic dimension per expansion group.
if (ty_dim_size == ShapedType::kDynamicSize) {
num_dynamic_dims++;
if (num_dynamic_dims > 1) return failure();
}
expanded_interm_shape->push_back(ty_dim_size);
}
if (any_remaining_factors()) return failure();
// Consume trailing, unpaired unit dimensions.
while (IsUnpairedUnitDimension(result_shape_it, result_shape_end,
operand_shape_it, operand_shape_end)) {
result_shape_it++;
dim_group.size++;
expanded_interm_shape->push_back(1);
}
continue;
}
// Untouched: 1:1 mapping between operand and result shape dimension. This
// is neither expanding nor collapsing.
assert(!requires_collapsing && !requires_expanding && "expect id case");
assert(dim_group.size == 1 && dim_group.kind == DimensionGroupKind::kNone &&
"expect simple dimension group");
}
// Fail if there are remaining dimensions that could not be consumed.
assert(!any_remaining_factors() && "expect no remaining factors");
if (operand_shape_it != operand_shape_end ||
result_shape_it != result_shape_end) {
return failure();
}
return success();
}
SmallVector<int64_t> ConcretizeOperandShape(
ArrayRef<int64_t> operand_shape,
ArrayRef<SymbolicExpr> operand_shape_info) {
SmallVector<int64_t> result;
for (auto it : llvm::zip(operand_shape, operand_shape_info)) {
auto dim_size = std::get<0>(it);
auto s_expr = std::get<1>(it);
if (auto cexpr = s_expr.expr.dyn_cast<AffineConstantExpr>()) {
int64_t also_dim_size = cexpr.getValue();
assert((ShapedType::isDynamic(dim_size) || dim_size == also_dim_size) &&
"expect shape analysis result to be compatible with type");
result.push_back(also_dim_size);
continue;
}
result.push_back(dim_size);
}
return result;
}
llvm::Optional<SmallVector<ReassociationIndices>> RequiresReassociationOfKind(
DimensionGroupKind kind, const SmallVector<DimensionGroup> &dim_groups) {
SmallVector<ReassociationIndices> reassociation;
reassociation.reserve(dim_groups.size());
bool is_strictly_reassociating = false;
int64_t i = 0;
for (const DimensionGroup &g : dim_groups) {
if (g.kind == kind) {
is_strictly_reassociating = true;
reassociation.push_back(
llvm::to_vector(llvm::seq<int64_t>(i, i + g.size)));
i += g.size;
continue;
}
for (int64_t j = 0; j < g.size; j++) reassociation.push_back({i++});
}
// Return the reassociation if expansion is required.
if (is_strictly_reassociating) return reassociation;
return llvm::None;
}
LogicalResult MaterializeReshapeAsExpandAndCollapse(
ShapeComponentAnalysis &shape_analysis, RankedTensorType operand_ty,
RankedTensorType result_ty, mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) {
// Require sucessful shape analysis for operand and result shape.
auto operand_shape_info = shape_analysis.GetShapeInfo(op.operand());
if (!operand_shape_info) return failure();
auto result_shape_info = shape_analysis.GetValueInfo(op.output_shape());
if (!result_shape_info) return failure();
// Identify dimension groups and the intermediate expanded type.
SmallVector<DimensionGroup> dimension_groups;
SmallVector<int64_t> expanded_interm_shape;
if (failed(FindExpandingAndCollapsingDimensionGroups(
*operand_shape_info, *result_shape_info, &dimension_groups,
&expanded_interm_shape))) {
return failure();
}
// Materialize cast, expand, collapse, and cast, as needed.
auto loc = op.getLoc();
Value interm = op.operand();
auto casted_operand_ty = RankedTensorType::get(
ConcretizeOperandShape(operand_ty.getShape(), *operand_shape_info),
operand_ty.getElementType());
if (operand_ty != casted_operand_ty) {
interm = rewriter.create<tensor::CastOp>(loc, casted_operand_ty, interm);
}
if (auto reassociation = RequiresReassociationOfKind(
DimensionGroupKind::kExpanding, dimension_groups)) {
interm = rewriter.create<tensor::ExpandShapeOp>(
loc,
RankedTensorType::get(expanded_interm_shape,
operand_ty.getElementType()),
interm, *reassociation);
}
if (auto reassociation = RequiresReassociationOfKind(
DimensionGroupKind::kCollapsing, dimension_groups)) {
interm =
rewriter.create<tensor::CollapseShapeOp>(loc, interm, *reassociation);
}
if (interm.getType() != result_ty) {
interm = rewriter.create<tensor::CastOp>(loc, result_ty, interm);
}
rewriter.replaceOp(op, interm);
return success();
}
// Tries to express `dynamic_reshape` ops through `expand_shape` and
// `collapse_shape` ops.
struct DynamicReshapeToExpandAndCollapseShape final
: public OpRewritePattern<mhlo::DynamicReshapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::DynamicReshapeOp op,
PatternRewriter &rewriter) const override {
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_ty) return failure();
auto result_ty = op.getType().dyn_cast<RankedTensorType>();
if (!result_ty) return failure();
// Handle degenerate scalar expand case.
if (operand_ty.getRank() == 0) {
return MaterializeReshapeAsScalarExpand(operand_ty, result_ty, op,
rewriter);
}
// Handle degenerate scalar collapse case.
if (result_ty.getRank() == 0) {
return MaterializeReshapeAsScalarCollapse(operand_ty, result_ty, op,
rewriter);
}
ShapeComponentAnalysis shape_analysis;
return MaterializeReshapeAsExpandAndCollapse(shape_analysis, operand_ty,
result_ty, op, rewriter);
}
};
// Returns true if all of bcasted_shapes can be broadcasted with output_shape.
bool IsKnownBroadcastable(ShapeComponentAnalysis &analysis,
ValueRange bcasted_shapes, Value output_shape) {
auto output_shape_dims = analysis.GetValueInfo(output_shape);
if (!output_shape_dims) return false;
for (Value shape : bcasted_shapes) {
auto shape_dims = analysis.GetValueInfo(shape);
if (!shape_dims) return false;
// Iterate backwards over the smallest input shape.
for (auto zip : llvm::zip(llvm::reverse(*output_shape_dims),
llvm::reverse(*shape_dims))) {
const auto &first = std::get<0>(zip);
const auto &second = std::get<1>(zip);
// TODO(ezhulenev): What to do with dimensions statically known to be
// zero?
// Numpy can only broadcast [0] with [1], however Tensorflow can broadcast
// [0] with any dimension size, and produces dimension of size [0].
// Currently we'll conservatively return failure and will not proceed with
// a rewrite.
if (first.isConstant(0) || second.isConstant(0)) return false;
// If either shape has a static one dimension the broadcast will always
// succeed.
if (first.isConstant(1) || second.isConstant(1)) continue;
// Otherwise dims have to be equal.
if (first != second) return false;
}
}
return true;
}
// Rewrite `shape.cstr_broadcastable` with constant witness if can prove that
// shapes are broadcastable from a symbolic analysis.
struct CstrBroadcastableOpLowering
: public OpRewritePattern<shape::CstrBroadcastableOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op,
PatternRewriter &rewriter) const override {
ShapeComponentAnalysis shape_component_analysis;
if (!IsKnownBroadcastable(shape_component_analysis, op.getShapes(),
op.getShapes().front())) {
return failure();
}
rewriter.replaceOpWithNewOp<shape::ConstWitnessOp>(op, true);
return success();
}
};
class SymbolicShapeOptimizationPass final
: public SymbolicShapeOptimizationBase<SymbolicShapeOptimizationPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}
void runOnOperation() override {
MLIRContext *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);
// clang-format off
patterns.insert<
AnnotateExpandingDimensionsInDynamicBroadcastInDim,
CstrBroadcastableOpLowering,
DynamicReshapeToExpandAndCollapseShape,
RemoveComputeReshapeShape,
RemoveRedundantCstrReshapable,
SimplifyBroadcasts>(ctx);
// clang-format on
shape::AssumingOp::getCanonicalizationPatterns(patterns, ctx);
if (failed(mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
signalPassFailure();
}
}
};
} // end namespace
std::unique_ptr<OperationPass<FuncOp>> createSymbolicShapeOptimizationPass() {
return std::make_unique<SymbolicShapeOptimizationPass>();
}
} // end namespace mlir