blob: d8048f5c8981dd97668c2ce6263f8f5f32af0606 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/SmallVector.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace mhlo {
namespace {
// Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
// 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
// a static broadcast.
Value broadcastToFeatureDim(Location loc, RankedTensorType resultType,
Value value1d, Value shapeValue, int64_t featureDim,
PatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
auto dimsType = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dimsType, {featureDim});
if (shapeValue) {
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
loc, resultType, value1d, shapeValue, dims);
}
assert(resultType.hasStaticShape());
return rewriter.create<mhlo::BroadcastInDimOp>(loc, resultType, value1d,
dims);
}
// Calculate the shape value of operand, assuming it is a dynamic shape with
// static rank.
Value calculateShapeValue(Location loc, Value operand,
PatternRewriter& rewriter) { // NOLINT
RankedTensorType resultType = operand.getType().dyn_cast<RankedTensorType>();
llvm::SmallVector<Value, 4> shapeValues;
int64_t rank = resultType.getRank();
shapeValues.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
shapeValues.push_back(
rewriter.create<mlir::tensor::DimOp>(loc, operand, i));
}
return rewriter.create<tensor::FromElementsOp>(loc, shapeValues);
}
Value materializeEpsilon(Operation* op, FloatAttr epsilonAttr, FloatType fpType,
Value broadcastTo, RankedTensorType broadcastToType,
PatternRewriter& rewriter) { // NOLINT
Builder b(rewriter.getContext());
if (epsilonAttr.getType() != fpType) {
// Need to convert.
bool losesInfo;
APFloat epsilonFloat = epsilonAttr.getValue();
auto status = epsilonFloat.convert(
fpType.getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo);
if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
"type: opStatus = "
<< static_cast<int>(status);
return nullptr;
}
if (losesInfo) {
op->emitWarning("Conversion of epsilon loses precision");
}
epsilonAttr = b.getFloatAttr(fpType, epsilonFloat);
}
auto scalarType = RankedTensorType::get({}, fpType);
auto epsilonTensorAttr =
DenseElementsAttr::get(scalarType, {epsilonAttr.cast<Attribute>()});
Value epsilon =
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilonTensorAttr);
auto dimsType = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
if (broadcastToType.hasStaticShape()) {
return rewriter.create<mhlo::BroadcastInDimOp>(
op->getLoc(), broadcastToType, epsilon, /*broadcast_dims=*/dims);
}
Value shapeValue = calculateShapeValue(op->getLoc(), broadcastTo, rewriter);
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcastToType, epsilon, shapeValue,
/*broadcast_dims=*/dims);
}
class UnfuseBatchNormInferencePattern
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
public:
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bnOp,
PatternRewriter& rewriter) const override {
// Enforce type invariants.
// Note that we deduce the actual element type from the variance,
// which should not be subject to quantization at a higher level.
auto inputType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
auto varianceType = bnOp.variance().getType().dyn_cast<RankedTensorType>();
if (!inputType || !varianceType) {
return failure();
}
auto fpType = varianceType.getElementType().dyn_cast<FloatType>();
if (!fpType) {
return failure();
}
int64_t featureDim = bnOp.feature_index();
// Add epsilon to the variance and sqrt to get stddev:
// stddev = sqrt(variance + epsilon)
auto epsilon =
materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
bnOp.variance(), varianceType, rewriter);
if (!epsilon) {
return failure();
}
Value stddev =
rewriter.create<mhlo::AddOp>(bnOp.getLoc(), bnOp.variance(), epsilon);
stddev = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), stddev);
// Broadcast all terms.
Value shapeValue;
if (!inputType.hasStaticShape()) {
shapeValue = calculateShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
}
auto broadcastScale =
broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.scale(),
shapeValue, featureDim, rewriter);
auto broadcastOffset =
broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.offset(),
shapeValue, featureDim, rewriter);
auto broadcastMean =
broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.mean(), shapeValue,
featureDim, rewriter);
auto broadcastStddev = broadcastToFeatureDim(
bnOp.getLoc(), inputType, stddev, shapeValue, featureDim, rewriter);
// Compute:
// scale * (input - mean) / stddev + offset
Value result = rewriter.create<mhlo::SubOp>(bnOp.getLoc(), bnOp.operand(),
broadcastMean);
result =
rewriter.create<mhlo::MulOp>(bnOp.getLoc(), result, broadcastScale);
result =
rewriter.create<mhlo::DivOp>(bnOp.getLoc(), result, broadcastStddev);
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bnOp, result, broadcastOffset);
return success();
}
};
// Create "mhlo.reduce", "operand" is reduce input and "zero" is init value,
// reduce sum from operand to operand[feature_index].
Value createReduce(Location loc, Value operand, Value zero,
SmallVector<int64_t>& reduceDims, int64_t featureIndex,
PatternRewriter& rewriter) {
auto operandType = operand.getType().cast<RankedTensorType>();
Type reduceResultType = RankedTensorType::get(
{operandType.getDimSize(featureIndex)}, operandType.getElementType());
mhlo::ReduceOp reduce =
rewriter.create<mhlo::ReduceOp>(loc, reduceResultType, operand, zero,
rewriter.getI64TensorAttr(reduceDims));
// setup "mhlo.reduce"'s body
Region& region = reduce.body();
Block& block = region.emplaceBlock();
RankedTensorType blockArgumentType =
RankedTensorType::get({}, operandType.getElementType());
block.addArgument(blockArgumentType, loc);
block.addArgument(blockArgumentType, loc);
auto* firstArgument = block.args_begin();
auto secondArgument = block.args_rbegin();
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&block);
Value addResult =
rewriter.create<mhlo::AddOp>(loc, *firstArgument, *secondArgument);
rewriter.create<mhlo::ReturnOp>(loc, addResult);
}
return reduce.getResult(0);
}
// Calculate total reduce size, assuming it is a dynamic shape with static rank.
// Reduce from operand to operand[feature_index]
Value calculateReduceSize(Operation* op, Value operand,
RankedTensorType operandType,
RankedTensorType scaleType, int64_t featureIndex,
PatternRewriter& rewriter) {
Location loc = op->getLoc();
if (!operandType.hasStaticShape()) {
// the "operand" has dynamic shape with static rank
llvm::SmallVector<Value, 4> reduceValues;
for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
if (i != featureIndex) {
reduceValues.push_back(rewriter.create<tensor::DimOp>(loc, operand, i));
}
}
assert(!reduceValues.empty());
Value reduceSize = reduceValues[0];
for (size_t i = 1, e = reduceValues.size(); i < e; i++) {
reduceSize =
rewriter.create<arith::MulIOp>(loc, reduceSize, reduceValues[i]);
}
reduceSize = rewriter.create<arith::IndexCastOp>(loc, rewriter.getI64Type(),
reduceSize);
reduceSize = rewriter.create<tensor::FromElementsOp>(loc, reduceSize);
reduceSize = rewriter.create<mhlo::ConvertOp>(
loc, RankedTensorType::get({1}, operandType.getElementType()),
reduceSize);
reduceSize = rewriter.create<mhlo::ReshapeOp>(
loc, RankedTensorType::get({}, operandType.getElementType()),
reduceSize);
Value featureSize =
rewriter.create<tensor::DimOp>(loc, operand, featureIndex);
featureSize = rewriter.create<tensor::FromElementsOp>(loc, featureSize);
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
loc, scaleType, reduceSize, featureSize, rewriter.getI64TensorAttr({}));
}
// the "operand" has static shape
int64_t reduceDimsSize = 1;
for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
if (i != featureIndex) {
reduceDimsSize *= operandType.getDimSize(i);
}
}
llvm::APFloat floatValue(static_cast<double>(reduceDimsSize));
bool losesInfo;
floatValue.convert(
scaleType.getElementType().cast<FloatType>().getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
if (losesInfo) {
op->emitWarning("Conversion of reduce_dims_size loses precision");
}
Value reduceSize = rewriter.create<mhlo::ConstOp>(
loc, DenseFPElementsAttr::get(scaleType, floatValue));
return reduceSize;
}
// BatchNormTraining(X, scale, offset) =
// ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
class UnfuseBatchNormTrainingPattern
: public OpRewritePattern<mhlo::BatchNormTrainingOp> {
public:
using OpRewritePattern<mhlo::BatchNormTrainingOp>::OpRewritePattern;
LogicalResult matchAndRewrite(mhlo::BatchNormTrainingOp bnOp,
PatternRewriter& rewriter) const override {
auto operandType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
auto scaleType = bnOp.scale().getType().dyn_cast<RankedTensorType>();
if (!operandType || !scaleType) {
return failure();
}
auto fpType = operandType.getElementType().dyn_cast<FloatType>();
if (!fpType) {
return failure();
}
int64_t featureIndex = bnOp.feature_index();
SmallVector<int64_t> dimensionsWithoutFeature;
for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
if (i != featureIndex) {
dimensionsWithoutFeature.push_back(i);
}
}
// zero constant
Value constZero = rewriter.create<mhlo::ConstOp>(
bnOp.getLoc(),
DenseFPElementsAttr::get(RankedTensorType::get({}, fpType),
APFloat::getZero(fpType.getFloatSemantics())));
// epsilon
auto epsilon =
materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
bnOp.scale(), scaleType, rewriter);
if (!epsilon) {
return failure();
}
// reduce size constant
Value reduceSize =
calculateReduceSize(bnOp.getOperation(), bnOp.operand(), operandType,
scaleType, featureIndex, rewriter);
if (!reduceSize) {
return failure();
}
// Sum[X]
Value sum = createReduce(bnOp.getLoc(), bnOp.operand(), constZero,
dimensionsWithoutFeature, featureIndex, rewriter);
// X^2
Value operandSquare = rewriter.create<mhlo::MulOp>(
bnOp.getLoc(), bnOp.operand(), bnOp.operand());
// Sum[X^2]
Value squareSum =
createReduce(bnOp.getLoc(), operandSquare, constZero,
dimensionsWithoutFeature, featureIndex, rewriter);
// E[X]
Value mean = rewriter.create<mhlo::DivOp>(bnOp.getLoc(), sum, reduceSize);
// E[X^2]
Value squareMean =
rewriter.create<mhlo::DivOp>(bnOp.getLoc(), squareSum, reduceSize);
// E^2[X]
Value meanSquare = rewriter.create<mhlo::MulOp>(bnOp.getLoc(), mean, mean);
// Var[X]
Value var =
rewriter.create<mhlo::SubOp>(bnOp.getLoc(), squareMean, meanSquare);
// Var[X] + epsilon
Value varAddEpsilon =
rewriter.create<mhlo::AddOp>(bnOp.getLoc(), var, epsilon);
// Sqrt(Var[X] + epsilon)
Value sqrtVar = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), varAddEpsilon);
Value shapeValue;
if (!operandType.hasStaticShape()) {
shapeValue = calculateShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
}
// X - E[X]
Value meanBroadcast = broadcastToFeatureDim(
bnOp.getLoc(), operandType, mean, shapeValue, featureIndex, rewriter);
Value operandMinusMean = rewriter.create<mhlo::SubOp>(
bnOp.getLoc(), bnOp.operand(), meanBroadcast);
// (X - E[X]) / Sqrt(Var[X] + epsilon)
Value sqrtVarBroadcast =
broadcastToFeatureDim(bnOp.getLoc(), operandType, sqrtVar, shapeValue,
featureIndex, rewriter);
Value normalized = rewriter.create<mhlo::DivOp>(
bnOp.getLoc(), operandMinusMean, sqrtVarBroadcast);
// ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale
Value scaleBroadcast =
broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.scale(),
shapeValue, featureIndex, rewriter);
Value scaledNormalized =
rewriter.create<mhlo::MulOp>(bnOp.getLoc(), normalized, scaleBroadcast);
// ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
Value offsetBroadcast =
broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.offset(),
shapeValue, featureIndex, rewriter);
Value shiftedNormalized = rewriter.create<mhlo::AddOp>(
bnOp.getLoc(), scaledNormalized, offsetBroadcast);
// results
SmallVector<Value> results = {shiftedNormalized, mean, var};
rewriter.replaceOp(bnOp, results);
return success();
}
};
} // namespace
// Populates conversion patterns to unfuse batch normalization operations.
// In combination with marking such ops as illegal, this allows backends that
// do not have special support for fused batchnorm to use simpler arithmetic
// primitives.
void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
RewritePatternSet* patterns) {
patterns->add<UnfuseBatchNormInferencePattern>(context);
patterns->add<UnfuseBatchNormTrainingPattern>(context);
}
} // namespace mhlo
} // namespace mlir