| /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "mlir-hlo/Analysis/shape_component_analysis.h" |
| |
| #include <algorithm> |
| #include <vector> |
| |
| #include "llvm/ADT/STLExtras.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Shape/IR/Shape.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/AffineExpr.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Matchers.h" |
| |
| using namespace mlir; |
| |
| using SymbolicShapeConstraintsMap = |
| ShapeComponentAnalysis::SymbolicShapeConstraintsMap; |
| using ShapeOrValueInfo = ShapeComponentAnalysis::ShapeOrValueInfo; |
| using Symbol = ShapeComponentAnalysis::Symbol; |
| using SymbolicExpr = ShapeComponentAnalysis::SymbolicExpr; |
| using SymbolicExprsMap = ShapeComponentAnalysis::SymbolicExprsMap; |
| |
| namespace { |
| // Shape visitor. This implements a symbolic interpreter for MHLO with some |
| // shape and tensor dialect ops mixed in. We are interested in shapes (e.g., the |
| // dimensions of a tensor) and values (e.g, the elements of a shape tensor). The |
| // goal is to assign every component of a shape or value either a symbol, a |
| // constant, or a symbolic expression. We propagate these symbolic expressions |
| // through the various operations. Later optimization passes can use this |
| // information for optimizations, e.g., exploiting the equality of dimensions. |
| // |
| // The visitation happens in two phases: |
| // 1. Find the sources of a value's shape or value. This climbs up the |
| // operations from a given value until an unknown op or a function argument |
| // is found. These sources are assigned the initial symbols for each of |
| // their components. |
| // 2. Propagate the initial symbols downwards. This builds symbolic |
| // expressions so users of the analysis can pattern match things like |
| // "two dimensions are multiplied". |
| // |
| // Conceptually, this is defined recursively. For each op, we compute the |
| // required shape or value information for the operands and then derive the |
| // resulting symbolic expression. |
| struct ShapeVisitor { |
| ShapeVisitor(SymbolicExprsMap *symbolicExprsMap, |
| SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap) |
| : symbolicExprsMap(symbolicExprsMap), |
| symbolicShapeConstraintsMap(symbolicShapeConstraintsMap) {} |
| |
| void visit(ShapeOrValueInfo requestedInfo) { |
| backwards_worklist.push_back(requestedInfo); |
| |
| // First, we climb up the operations so we get the set of all ops taking |
| // part in this shape or value computation. An alternative would be |
| // analyzing everything eagerly. This backwards pass allows us to be lazy. |
| while (!backwards_worklist.empty()) { |
| // Skip if already processed. |
| ShapeOrValueInfo transitivelyRequestedInfo = |
| backwards_worklist.pop_back_val(); |
| if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue; |
| |
| // Skip irrelevant cases early. |
| Value value = transitivelyRequestedInfo.value(); |
| Type ty = value.getType(); |
| if (!ty.isIntOrIndexOrFloat() && !ty.isa<RankedTensorType>()) continue; |
| |
| // Handle shapes. |
| if (transitivelyRequestedInfo.isShapeInfo()) { |
| if (value.getDefiningOp<shape::AssumingOp>()) { |
| backwardAssumingShape(value); |
| } else if (auto bcast = |
| value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) { |
| backwardDynamicBroadcastInDimShape(bcast); |
| } else if (auto reshape = |
| value.getDefiningOp<mhlo::DynamicReshapeOp>()) { |
| backwardDynamicReshapeShape(reshape); |
| } else if (value.getDefiningOp<mhlo::ReduceOp>()) { |
| backwardReduceShape(value); |
| } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) { |
| backwardTransposeShape(transpose); |
| } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) { |
| backwardSelectShape(select); |
| } else if (auto arg = value.dyn_cast<BlockArgument>()) { |
| backwardBlockArgumentShape(arg); |
| } else if (value.getDefiningOp() && |
| value.getDefiningOp() |
| ->hasTrait<OpTrait::SameOperandsAndResultShape>()) { |
| backwardSameOperandsAndResultShape(value); |
| } else { |
| backwardUnknownShape(value); |
| } |
| continue; |
| } |
| |
| // Skip irrelevant cases early. |
| auto ranked_ty = ty.dyn_cast<RankedTensorType>(); |
| bool is_possibly_interesting_scalar = ty.isIntOrIndex(); |
| bool is_possibly_interesting_tensor = |
| ranked_ty && ranked_ty.getRank() <= 1 && ranked_ty.hasStaticShape(); |
| if (!is_possibly_interesting_scalar && !is_possibly_interesting_tensor) { |
| continue; |
| } |
| |
| // Handle values. |
| assert(transitivelyRequestedInfo.isValueInfo() && |
| "Expect value info at this point."); |
| if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) { |
| backwardShapeOf(shapeof); |
| } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) { |
| backwardBroadcast(bcast); |
| } else if (auto num_elements = |
| value.getDefiningOp<shape::NumElementsOp>()) { |
| backwardNumElements(num_elements); |
| } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) { |
| backwardDim(dim); |
| } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) { |
| backwardIndexCast(cast); |
| } else if (auto fromElements = |
| value.getDefiningOp<tensor::FromElementsOp>()) { |
| backwardTensorFromElements(fromElements); |
| } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) { |
| backwardTensorExtract(extract); |
| } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) { |
| backwardBinOp(add); |
| } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) { |
| backwardBinOp(mul); |
| } else if (auto add = value.getDefiningOp<arith::AddIOp>()) { |
| backwardBinOp(add); |
| } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) { |
| backwardBinOp(mul); |
| } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) { |
| backwardConcatenate(concat); |
| } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) { |
| backwardReshape(reshape); |
| } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) { |
| backwardSlice(slice); |
| } else if (matchPattern(value, m_Constant())) { |
| backwardConstant(value); |
| } else { |
| backwardUnknown(value); |
| } |
| } |
| |
| // Second, we walk down from the defs to the uses, building symbolic |
| // expressions for shape and value components. |
| while (!forwards_worklist.empty()) { |
| auto transitivelyRequestedInfo = forwards_worklist.pop_back_val(); |
| |
| // Skip if already processed. |
| if (symbolicExprsMap->count(transitivelyRequestedInfo)) continue; |
| |
| // Handle shapes. |
| Value value = transitivelyRequestedInfo.value(); |
| if (!transitivelyRequestedInfo.isValueInfo()) { |
| if (value.getDefiningOp<shape::AssumingOp>()) { |
| forwardAssumingShape(value); |
| } else if (auto broadcast = |
| value.getDefiningOp<mhlo::DynamicBroadcastInDimOp>()) { |
| forwardDynamicBroadcastInDimShape(broadcast); |
| } else if (auto reshape = |
| value.getDefiningOp<mhlo::DynamicReshapeOp>()) { |
| forwardDynamicReshapeShape(reshape); |
| } else if (value.getDefiningOp<mhlo::ReduceOp>()) { |
| forwardReduceShape(value); |
| } else if (auto transpose = value.getDefiningOp<mhlo::TransposeOp>()) { |
| forwardTransposeShape(transpose); |
| } else if (auto select = value.getDefiningOp<mhlo::SelectOp>()) { |
| forwardSelectShape(select); |
| } else if (value.getDefiningOp() && |
| value.getDefiningOp() |
| ->hasTrait<OpTrait::SameOperandsAndResultShape>()) { |
| forwardSameOperandsShape(value); |
| } else { |
| forwardUnknownShape(value); |
| } |
| continue; |
| } |
| |
| // Handle values. |
| assert(transitivelyRequestedInfo.isValueInfo() && |
| "Expect value info at this point."); |
| if (auto shapeof = value.getDefiningOp<shape::ShapeOfOp>()) { |
| forwardShapeOf(shapeof); |
| } else if (auto bcast = value.getDefiningOp<shape::BroadcastOp>()) { |
| forwardBroadcast(bcast); |
| } else if (auto num_elements = |
| value.getDefiningOp<shape::NumElementsOp>()) { |
| forwardNumElements(num_elements); |
| } else if (auto dim = value.getDefiningOp<tensor::DimOp>()) { |
| forwardDim(dim); |
| } else if (auto cast = value.getDefiningOp<arith::IndexCastOp>()) { |
| forwardIndexCast(cast); |
| } else if (auto fromElements = |
| value.getDefiningOp<tensor::FromElementsOp>()) { |
| forwardTensorFromElements(fromElements); |
| } else if (auto extract = value.getDefiningOp<tensor::ExtractOp>()) { |
| forwardTensorExtract(extract); |
| } else if (auto add = value.getDefiningOp<mhlo::AddOp>()) { |
| forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; }); |
| } else if (auto mul = value.getDefiningOp<mhlo::MulOp>()) { |
| forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; }); |
| } else if (auto add = value.getDefiningOp<arith::AddIOp>()) { |
| forwardBinOp(add, [](AffineExpr a, AffineExpr b) { return a + b; }); |
| } else if (auto mul = value.getDefiningOp<arith::MulIOp>()) { |
| forwardBinOp(mul, [](AffineExpr a, AffineExpr b) { return a * b; }); |
| } else if (auto concat = value.getDefiningOp<mhlo::ConcatenateOp>()) { |
| forwardConcatenate(concat); |
| } else if (auto reshape = value.getDefiningOp<mhlo::ReshapeOp>()) { |
| forwardReshape(reshape); |
| } else if (auto slice = value.getDefiningOp<mhlo::SliceOp>()) { |
| forwardSlice(slice); |
| } else if (matchPattern(value, m_Constant())) { |
| forwardConstant(value); |
| } else { |
| forwardUnknown(value); |
| } |
| } |
| } |
| |
| private: |
| // === |
| // Functions that traverse the shapes of operations. |
| // === |
| |
| void backwardAssumingShape(Value op) { |
| auto assumingOp = op.getDefiningOp<shape::AssumingOp>(); |
| auto number = op.cast<OpResult>().getResultNumber(); |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| backwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf( |
| cast<shape::AssumingYieldOp>( |
| assumingOp.getDoRegion().back().getTerminator()) |
| .getOperand(number))); |
| } |
| void forwardAssumingShape(Value op) { |
| auto assumingOp = op.getDefiningOp<shape::AssumingOp>(); |
| auto number = op.cast<OpResult>().getResultNumber(); |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| dims = lookup(ShapeOrValueInfo::getShapeInfoOf( |
| cast<shape::AssumingYieldOp>( |
| assumingOp.getDoRegion().back().getTerminator()) |
| .getOperand(number))); |
| } |
| void backwardBroadcast(shape::BroadcastOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (Value s : op.getShapes()) |
| backwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(s)); |
| } |
| void forwardBroadcast(shape::BroadcastOp op) { |
| auto *ctx = op.getContext(); |
| |
| // Get operands' info. |
| SmallVector<ArrayRef<SymbolicExpr>> args_info = |
| llvm::to_vector(llvm::map_range(op.getShapes(), [&](Value s) { |
| return lookup(ShapeOrValueInfo::getValueInfoOf(s)); |
| })); |
| |
| // Determine broadcasted rank. |
| size_t rank = 0; |
| for (auto &info : args_info) rank = std::max(rank, info.size()); |
| |
| // Evaluate broadcast per result dimension. |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (size_t i = 0; i < rank; ++i) { |
| // Init with neural element. |
| SymbolicExpr bcasted_expr; |
| bcasted_expr.expr = getAffineConstantExpr(1, ctx); |
| |
| // Consider all the operands. |
| for (auto &info : args_info) { |
| // Find corresponding symbolic expression for the ith result dimension, |
| // if the operand contributes. |
| size_t arg_rank = info.size(); |
| if (i + arg_rank < rank) continue; |
| size_t j = i + arg_rank - rank; |
| SymbolicExpr expr = info[j]; |
| |
| // One dimensions are neutral. |
| if (expr.isConstant(1)) continue; |
| |
| // If a dimension is known not to be 1, we can use this expression. |
| if (expr.isKnownNotOne()) { |
| bcasted_expr = expr; |
| break; |
| } |
| |
| // If all other dimensions were neutral, try using this expression. |
| if (bcasted_expr.isConstant(1)) { |
| bcasted_expr = expr; |
| continue; |
| } |
| |
| // If we have contradicting expressions, give up and create a new |
| // symbol. |
| if (bcasted_expr != expr) { |
| bcasted_expr.expr = getAffineSymbolExpr(0, ctx); |
| bcasted_expr.symbols = {{ShapeOrValueInfo::getValueInfoOf(op), i}}; |
| break; |
| } |
| } |
| |
| dims.push_back(bcasted_expr); |
| } |
| assert(dims.size() == rank && "expect one expression per dimension"); |
| } |
| void backwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getValueInfoOf(op.output_dimensions())); |
| } |
| void forwardDynamicBroadcastInDimShape(mhlo::DynamicBroadcastInDimOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| dims = lookup(ShapeOrValueInfo::getValueInfoOf(op.output_dimensions())); |
| } |
| void backwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getValueInfoOf(op.output_shape())); |
| } |
| void forwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) { |
| auto ranked_ty = op.getResult().getType().cast<RankedTensorType>(); |
| auto shape_dims = |
| lookup(ShapeOrValueInfo::getValueInfoOf(op.output_shape())); |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| dimsFromStaticShape(ranked_ty, shape_dims, &dims); |
| } |
| void backwardReduceShape(Value op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>(); |
| if (reduceOp.inputs().size() == 1) |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getShapeInfoOf(reduceOp.inputs().back())); |
| } |
| void forwardReduceShape(Value op) { |
| auto reduceOp = op.getDefiningOp<mhlo::ReduceOp>(); |
| if (reduceOp.inputs().size() != 1) return forwardUnknownShape(op); |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| for (const auto &dim : llvm::enumerate(lookup( |
| ShapeOrValueInfo::getShapeInfoOf(reduceOp.inputs().back())))) { |
| if (!llvm::is_contained(reduceOp.dimensions(), dim.index())) |
| dims.push_back(dim.value()); |
| } |
| } |
| void backwardTransposeShape(mhlo::TransposeOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getShapeInfoOf(op.operand())); |
| } |
| void forwardTransposeShape(mhlo::TransposeOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.operand())); |
| auto elem = op.permutation().cast<DenseIntElementsAttr>(); |
| for (const auto &val : elem) dims.push_back(in[val.getZExtValue()]); |
| } |
| void backwardSelectShape(mhlo::SelectOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getShapeInfoOf(op.on_true())); |
| } |
| void forwardSelectShape(mhlo::SelectOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); |
| // Forward the `on_true` operand, it has the same shape as the output. |
| dims = lookup(ShapeOrValueInfo::getShapeInfoOf(op.on_true())); |
| } |
| void backwardSameOperandsAndResultShape(Value v) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0))); |
| } |
| void forwardSameOperandsShape(Value v) { |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v)); |
| dims = lookup( |
| ShapeOrValueInfo::getShapeInfoOf(v.getDefiningOp()->getOperand(0))); |
| } |
| void backwardBlockArgumentShape(BlockArgument argument) { |
| // JitRT uses jitrt.symbolic_shape to describe identical dimensions. Make |
| // use of that when it exists. |
| // |
| // Example: |
| // func @compute( |
| // %arg0: tensor<?xf32> {jitrt.symbolic_shape = dense<-2> : |
| // tensor<1xi64>}, %arg1: tensor<?xf32> {jitrt.symbolic_shape = |
| // dense<-2> : tensor<1xi64>}) |
| // } { ... } |
| // |
| // Symbolic shape is a negative value smaller than `-1`. The concrete value |
| // is not known at compile time, and in this particular example it is only |
| // known that both arguments have the same shape. |
| // |
| // TODO(ezhulenev): Add symbolic shape attribute verifier to the jitrt |
| // dialect. |
| if (auto func = dyn_cast_or_null<func::FuncOp>( |
| argument.getOwner()->getParentOp())) { |
| if (auto shape = func.getArgAttrOfType<DenseIntElementsAttr>( |
| argument.getArgNumber(), "jitrt.symbolic_shape")) { |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(argument)); |
| auto id = getAffineSymbolExpr(0, argument.getContext()); |
| for (const auto &symbol : llvm::enumerate(shape.getValues<ssize_t>())) { |
| dims.emplace_back(); |
| auto &dim = dims.back(); |
| if (symbol.value() >= 0) { |
| dim.expr = |
| getAffineConstantExpr(symbol.value(), argument.getContext()); |
| } else { |
| auto it = symbolicShapeConstraintsMap->try_emplace( |
| symbol.value(), |
| Symbol{ShapeOrValueInfo::getShapeInfoOf(argument), |
| symbol.index()}); |
| dim.symbols.push_back(it.first->second); |
| dim.expr = id; |
| } |
| } |
| return; |
| } |
| } |
| forwardUnknownShape(argument); |
| } |
| void backwardUnknownShape(Value v) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v)); |
| } |
| void forwardUnknownShape(Value v) { |
| auto ranked_ty = v.getType().dyn_cast<RankedTensorType>(); |
| if (!ranked_ty) return; |
| auto id = getAffineSymbolExpr(0, v.getContext()); |
| auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v)); |
| return dimsFromStaticShape( |
| ranked_ty, |
| [&](size_t i) { |
| SymbolicExpr d; |
| d.symbols.push_back({ShapeOrValueInfo::getShapeInfoOf(v), i}); |
| d.expr = id; |
| return d; |
| }, |
| &dims); |
| } |
| |
| // === |
| // Functions that traverse values. These can be shape tensors (e.g., of type |
| // tensor<3xindex>) or interesting scalars (e.g., of type index). |
| // === |
| |
| void backwardShapeOf(shape::ShapeOfOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.getArg())); |
| } |
| void forwardShapeOf(shape::ShapeOfOp op) { |
| auto ranked_ty = op.getArg().getType().cast<RankedTensorType>(); |
| auto arg = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getArg())); |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| return dimsFromStaticShape(ranked_ty, arg, &dims); |
| } |
| void backwardNumElements(shape::NumElementsOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getValueInfoOf(op.getShape())); |
| } |
| void forwardNumElements(shape::NumElementsOp op) { |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getShape())); |
| |
| // Accumulate product symbolically and concrete where possible. |
| int64_t concrete_product = 1; |
| SymbolicExpr dim; |
| for (auto &it : in) { |
| // For constant expressions, we can accumulate a concrete product. |
| if (auto cexpr = it.expr.dyn_cast<AffineConstantExpr>()) { |
| assert(cexpr.getValue() > 0 && "shape value must be positive"); |
| concrete_product *= cexpr.getValue(); |
| continue; |
| } |
| |
| // Simply copy the first sybolic factor. |
| if (!dim.expr) { |
| dim = it; |
| continue; |
| } |
| |
| // Multiply remaining symbolic factors. |
| dim.expr = dim.expr * |
| it.expr.shiftSymbols(dim.symbols.size(), it.symbols.size()); |
| dim.symbols.append(it.symbols); |
| } |
| |
| // Combine concrete and symbolic product. |
| if (concrete_product != 1 || !dim.expr) { |
| auto cexpr = getAffineConstantExpr(concrete_product, op.getContext()); |
| if (dim.expr) |
| dim.expr = cexpr * dim.expr; |
| else |
| dim.expr = cexpr; |
| } |
| |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| dims.push_back(dim); |
| } |
| void backwardDim(tensor::DimOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.source())); |
| } |
| void forwardDim(tensor::DimOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| if (auto index = op.index().getDefiningOp<arith::ConstantOp>()) { |
| int64_t i = index.getValue().cast<IntegerAttr>().getInt(); |
| auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.source())); |
| dims.push_back({in[i].symbols, in[i].expr}); |
| } else { |
| forwardUnknown(op); |
| } |
| } |
| template <typename Op> |
| void backwardBinOp(Op op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form. |
| backwards_worklist.append( |
| {ShapeOrValueInfo::getValueInfoOf(op.getOperand(0)), |
| ShapeOrValueInfo::getValueInfoOf(op.getOperand(1))}); |
| } |
| template <typename Op, typename Combiner> |
| void forwardBinOp(Op op, Combiner &&combiner) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| // TODO(jpienaar): Switch to named accessors when MHLO uses prefixed form. |
| auto lhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(0))); |
| auto rhs = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand(1))); |
| for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) { |
| dims.emplace_back(); |
| auto &dim = dims.back(); |
| dim.symbols.append(lhs[i].symbols); |
| dim.symbols.append(rhs[i].symbols); |
| dim.expr = combiner(lhs[i].expr, |
| rhs[i].expr.shiftSymbols(rhs[i].symbols.size(), |
| lhs[i].symbols.size())); |
| } |
| } |
| void backwardIndexCast(arith::IndexCastOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.getIn())); |
| } |
| void forwardIndexCast(arith::IndexCastOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getIn())); |
| for (int64_t i = 0, e = dim0size(op.getType()); i != e; ++i) { |
| // This is intentionally not modelling the truncation/zero extension of |
| // index_cast. While it's incorrect it doesn't really matter for shape |
| // computations. |
| dims.push_back({in[i].symbols, in[i].expr}); |
| } |
| } |
| void backwardTensorFromElements(tensor::FromElementsOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (auto operand : op.getOperands()) |
| backwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand)); |
| } |
| void forwardTensorFromElements(tensor::FromElementsOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (auto operand : op.getOperands()) { |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand)); |
| assert(in.size() == 1); |
| dims.push_back({in[0].symbols, in[0].expr}); |
| } |
| } |
| void backwardTensorExtract(tensor::ExtractOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op.tensor())); |
| } |
| void forwardTensorExtract(tensor::ExtractOp op) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| assert(op.indices().size() == 1); |
| if (auto index = op.indices().front().getDefiningOp<arith::ConstantOp>()) { |
| int64_t i = index.getValue().cast<IntegerAttr>().getInt(); |
| // We asssume this is in bounds. |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.tensor())); |
| dims.push_back({in[i].symbols, in[i].expr}); |
| } else { |
| forwardUnknown(op); |
| } |
| } |
| void backwardConstant(Value v) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(v)); |
| } |
| void forwardConstant(Value v) { |
| IntegerAttr intAttr; |
| DenseIntElementsAttr denseAttr; |
| if (matchPattern(v, m_Constant(&denseAttr))) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v)); |
| for (uint64_t i = 0, e = dim0size(v.getType()); i != e; ++i) { |
| dims.emplace_back(); |
| auto &dim = dims.back(); |
| dim.expr = getAffineConstantExpr( |
| denseAttr.getValues<APInt>()[i].getSExtValue(), v.getContext()); |
| } |
| } else if (matchPattern(v, m_Constant(&intAttr))) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v)); |
| dims.emplace_back(); |
| auto &dim = dims.back(); |
| dim.expr = getAffineConstantExpr(intAttr.getInt(), v.getContext()); |
| } else { |
| forwardUnknown(v); |
| } |
| } |
| void backwardConcatenate(mhlo::ConcatenateOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (auto operand : op.getOperands()) |
| backwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(operand)); |
| } |
| void forwardConcatenate(mhlo::ConcatenateOp op) { |
| for (auto operand : op.getOperands()) { |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand)); |
| if (in.size() != 1) return forwardUnknown(op); |
| } |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| for (auto operand : op.getOperands()) { |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(operand)); |
| dims.push_back({in[0].symbols, in[0].expr}); |
| } |
| } |
| void backwardReshape(mhlo::ReshapeOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getValueInfoOf(op.operand())); |
| } |
| void forwardReshape(mhlo::ReshapeOp op) { |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand())); |
| if (in.size() != 1) return forwardUnknown(op); |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| dims.push_back({in[0].symbols, in[0].expr}); |
| } |
| void backwardSlice(mhlo::SliceOp op) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(op)); |
| backwards_worklist.push_back( |
| ShapeOrValueInfo::getValueInfoOf(op.operand())); |
| } |
| void forwardSlice(mhlo::SliceOp op) { |
| // Only handle slices equivalent to an extract. |
| if (!op.getType().hasStaticShape({1})) { |
| return forwardUnknown(op); |
| } |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); |
| auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.operand())); |
| auto elem = op.start_indices().cast<DenseIntElementsAttr>(); |
| auto i = (*elem.begin()).getZExtValue(); |
| if (i >= in.size()) { // Bounds check. |
| return forwardUnknown(op); |
| } |
| dims.push_back({in[i].symbols, in[i].expr}); |
| } |
| void backwardUnknown(Value v) { |
| forwards_worklist.push_back(ShapeOrValueInfo::getValueInfoOf(v)); |
| } |
| void forwardUnknown(Value v) { |
| auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(v)); |
| auto id = getAffineSymbolExpr(0, v.getContext()); |
| for (size_t i = 0, e = dim0size(v.getType()); i != e; ++i) { |
| dims.emplace_back(); |
| auto &dim = dims.back(); |
| dim.symbols.push_back({ShapeOrValueInfo::getValueInfoOf(v), i}); |
| dim.expr = id; |
| } |
| } |
| |
| // === |
| // Helpers |
| // === |
| |
| static void dimsFromStaticShape( |
| RankedTensorType ranked_ty, |
| llvm::function_ref<SymbolicExpr(int64_t)> fallback, |
| std::vector<SymbolicExpr> *merged_dims) { |
| auto *ctx = ranked_ty.getContext(); |
| for (int64_t i = 0, e = ranked_ty.getRank(); i != e; ++i) { |
| if (ranked_ty.isDynamicDim(i)) { |
| merged_dims->push_back(fallback(i)); |
| } else { |
| merged_dims->emplace_back(); |
| auto &d = merged_dims->back(); |
| d.expr = getAffineConstantExpr(ranked_ty.getDimSize(i), ctx); |
| } |
| } |
| } |
| |
| static void dimsFromStaticShape(RankedTensorType ranked_ty, |
| ArrayRef<SymbolicExpr> fallback, |
| std::vector<SymbolicExpr> *merged_dims) { |
| return dimsFromStaticShape( |
| ranked_ty, [&](int64_t i) { return fallback[i]; }, merged_dims); |
| } |
| |
| // Return the size of the first dimension. Returns 1 for scalars. |
| static int64_t dim0size(Type type) { |
| if (auto rankedType = type.dyn_cast<RankedTensorType>()) |
| return rankedType.getRank() == 0 ? 1 : rankedType.getDimSize(0); |
| return 1; |
| } |
| |
| // Retrieves the existing information from the cache. |
| ArrayRef<SymbolicExpr> lookup(ShapeOrValueInfo requestedInfo) { |
| auto i = symbolicExprsMap->find(requestedInfo); |
| assert(i != symbolicExprsMap->end() && "op not processed yet?"); |
| return llvm::makeArrayRef(i->second); |
| } |
| |
| // Inserts a new entry into the cache and returns a reference to its result |
| // components. |
| std::vector<SymbolicExpr> &insert(ShapeOrValueInfo requestedInfo) { |
| auto i = symbolicExprsMap->try_emplace(requestedInfo); |
| assert(i.second && "op already processed?"); |
| return i.first->second; |
| } |
| |
| SymbolicExprsMap *symbolicExprsMap; |
| SymbolicShapeConstraintsMap *symbolicShapeConstraintsMap; |
| |
| // Worklists for the forward and backward passes. |
| SmallVector<ShapeOrValueInfo> backwards_worklist; |
| SmallVector<ShapeOrValueInfo> forwards_worklist; |
| }; |
| } // namespace |
| |
| void ShapeComponentAnalysis::compute(ShapeOrValueInfo requestedInfo) { |
| ShapeVisitor(&symbolicExprsMap, &symbolicShapeConstraintsMap) |
| .visit(requestedInfo); |
| } |
| |
| Optional<ArrayRef<SymbolicExpr>> |
| ShapeComponentAnalysis::ShapeComponentAnalysis::GetShapeInfo(Value value) { |
| auto request = ShapeOrValueInfo::getShapeInfoOf(value); |
| compute(request); |
| auto found = symbolicExprsMap.find(request); |
| if (found == symbolicExprsMap.end()) return {}; |
| return llvm::makeArrayRef(found->second); |
| } |
| |
| Optional<ArrayRef<SymbolicExpr>> |
| ShapeComponentAnalysis::ShapeComponentAnalysis::GetValueInfo(Value shape) { |
| auto request = ShapeOrValueInfo::getValueInfoOf(shape); |
| compute(request); |
| auto found = symbolicExprsMap.find(request); |
| if (found == symbolicExprsMap.end()) return {}; |
| return llvm::makeArrayRef(found->second); |
| } |
| |
| void ShapeComponentAnalysis::reset() { |
| symbolicExprsMap.clear(); |
| symbolicShapeConstraintsMap.clear(); |
| } |
| |
| bool SymbolicExpr::isConstant(int64_t value) const { |
| return expr.isa<AffineConstantExpr>() && |
| expr.cast<AffineConstantExpr>().getValue() == value; |
| } |
| |
| bool SymbolicExpr::isKnownNotNegativeOne() const { |
| // If the symbol is coming from a shape it can't be a -1. Also allow results |
| // of shape_of, compute_reshape_shape, and num_elements. This is correct, not |
| // complete. |
| auto isGoodSymbol = [](const Symbol &symbol) { |
| if (symbol.source.isShapeInfo()) return true; |
| Operation *op = symbol.source.value().getDefiningOp(); |
| if (op == nullptr) return false; |
| return llvm::isa<shape::ShapeOfOp, mhlo::ComputeReshapeShapeOp, |
| shape::NumElementsOp>(op); |
| }; |
| |
| // For constants we know if it's -1 or not. Checking the sign is sufficient |
| // here and allows for reuse below. This is correct, not complete. |
| auto isGoodSymbolOrGoodConstantExpr = [&](AffineExpr expr) { |
| if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) |
| return isGoodSymbol(symbols[symExpr.getPosition()]); |
| if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) |
| return constExpr.getValue() >= 0; |
| return false; |
| }; |
| |
| if (isGoodSymbolOrGoodConstantExpr(expr)) return true; |
| |
| // Multiplying non-negative symbols and non-negative constants will always |
| // give a positive result. This is correct, not complete. |
| // TODO(kramerb): Could the analysis provide a generic interface for this? |
| if (auto bexpr = expr.dyn_cast<AffineBinaryOpExpr>()) { |
| return bexpr.getKind() == AffineExprKind::Mul && |
| isGoodSymbolOrGoodConstantExpr(bexpr.getLHS()) && |
| isGoodSymbolOrGoodConstantExpr(bexpr.getRHS()); |
| } |
| |
| return false; |
| } |
| |
| bool SymbolicExpr::isKnownNotOne() const { |
| if (auto const_expr = expr.dyn_cast<AffineConstantExpr>()) { |
| return const_expr.getValue() != 1; |
| } |
| return false; |
| } |
| |
| llvm::Optional<Symbol> SymbolicExpr::singleton() const { |
| if (expr.isa<AffineSymbolExpr>() && |
| expr.cast<AffineSymbolExpr>().getPosition() == 0) { |
| assert(symbols.size() == 1); |
| return symbols[0]; |
| } |
| return llvm::None; |
| } |
| |
| void SymbolicExpr::dump(llvm::raw_ostream &os) const { |
| expr.print(os); |
| if (!symbols.empty()) os << " with"; |
| os << "\n"; |
| if (symbols.empty()) return; |
| for (const auto &sym : llvm::enumerate(symbols)) { |
| os.indent(4); |
| os << 's' << sym.index() << " = "; |
| if (!sym.value().source.isValueInfo()) os << "shapeof("; |
| sym.value().source.value().print(os); |
| if (!sym.value().source.isValueInfo()) os << ")"; |
| os << '[' << sym.value().index << "]\n"; |
| } |
| } |