blob: e6c351bd105b46890e70c217b0f18cc23fb12ebe [file] [log] [blame]
//===- LowerUniformRealMath.cpp ------------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "UniformKernelUtils.h"
#include "mlir/Dialect/FxpMathOps/FxpMathOps.h"
#include "mlir/Dialect/FxpMathOps/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/StandardOps/Ops.h"
using namespace mlir;
using namespace mlir::fxpmath;
using namespace mlir::fxpmath::detail;
using namespace mlir::quant;
namespace {
struct LowerUniformRealMathPass
: public FunctionPass<LowerUniformRealMathPass> {
void runOnFunction() override;
};
struct LowerUniformCastsPass : public FunctionPass<LowerUniformCastsPass> {
void runOnFunction() override;
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Dequantize
//===----------------------------------------------------------------------===//
static Value *emitUniformPerLayerDequantize(Location loc, Value *input,
UniformQuantizedType elementType,
PatternRewriter &rewriter) {
// Pre-conditions.
if (!elementType.isSigned()) {
// TODO: Support unsigned storage type.
emitWarning(loc, "unimplemented: dequantize signed uniform");
return nullptr;
}
Type storageType = elementType.castToStorageType(input->getType());
Type realType = elementType.castToExpressedType(input->getType());
Type intermediateType =
castElementType(storageType, IntegerType::get(32, rewriter.getContext()));
assert(storageType && "cannot cast to storage type");
assert(realType && "cannot cast to expressed type");
// Cast to storage type.
input = rewriter.create<StorageCastOp>(loc, storageType, input);
// Promote to intermediate type.
input = rewriter.create<ConvertISOp>(loc, intermediateType, input);
// Apply zero-point offset.
if (elementType.getZeroPoint() != 0) {
Value *negZeroPointConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstIntValue(intermediateType,
-elementType.getZeroPoint()));
input = rewriter.create<AddIOp>(loc, input, negZeroPointConst);
}
// Convert to float.
input = rewriter.create<ConvertISToFOp>(loc, realType, input);
// Mul by scale.
Value *scaleConst = rewriter.create<ConstantOp>(
loc, broadcastScalarConstFloatValue(realType,
APFloat(elementType.getScale())));
return rewriter.create<MulFOp>(loc, input, scaleConst);
}
static Value *
emitUniformPerAxisDequantize(Location loc, Value *input,
UniformQuantizedPerAxisType elementType,
PatternRewriter &rewriter) {
// TODO: Support per-axis dequantize.
rewriter.getContext()->getDiagEngine().emit(loc, DiagnosticSeverity::Warning)
<< "unimplemented: per-axis uniform dequantization";
return nullptr;
}
static Value *emitDequantize(Location loc, Value *input,
PatternRewriter &rewriter) {
Type inputType = input->getType();
QuantizedType qElementType =
QuantizedType::getQuantizedElementType(inputType);
if (auto uperLayerElementType =
qElementType.dyn_cast_or_null<UniformQuantizedType>()) {
return emitUniformPerLayerDequantize(loc, input, uperLayerElementType,
rewriter);
} else if (auto uperAxisElementType =
qElementType.dyn_cast_or_null<UniformQuantizedPerAxisType>()) {
return emitUniformPerAxisDequantize(loc, input, uperAxisElementType,
rewriter);
} else {
return nullptr;
}
}
namespace {
struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
Type inputType = op.arg()->getType();
Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
Type expressedOutputType = inputElementType.castToExpressedType(inputType);
if (expressedOutputType != outputType) {
// Not a valid uniform cast.
return matchFailure();
}
Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
rewriter.replaceOp(op, dequantizedValue);
return matchSuccess();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Elementwise add
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineAddEwIsomorphicSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned() || info.lhsType != info.resultType ||
info.rhsType != info.resultType) {
return failure();
}
// Choose a byte aligned intermediate width big enough to perform the
// calculation without overflow.
// TODO: This should probably be made just big enough to avoid overflow and
// leave the downstream tooling to decide how to align that to machine
// word sizes.
unsigned intermediateWidth =
info.resultType.getStorageTypeIntegralWidth() <= 8 ? 16 : 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Add.
Value *resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Zero point offset adjustment.
// result = (lhs - zp) + (rhs - zp) + zp
// zpOffset = -zp
int zpOffset = -1 * info.resultType.getZeroPoint();
if (zpOffset != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType, zpOffset));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
//===----------------------------------------------------------------------===//
// Elementwise mul
//===----------------------------------------------------------------------===//
static LogicalResult
tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info,
PatternRewriter &rewriter) {
if (!info.resultType.isSigned()) {
return failure();
}
double outputMultiplierReal = info.lhsType.getScale() *
info.rhsType.getScale() /
info.resultType.getScale();
if (outputMultiplierReal > 1.0) {
info.op->emitWarning("unimplemented: cannot multiply with multipler > 1.0");
return failure();
}
// TODO: Choose an appropriate intermediate width for muls > 8 bits to
// avoid overflow.
unsigned intermediateWidth = 32;
IntegerType intermediateElementType =
IntegerType::get(intermediateWidth, rewriter.getContext());
Type intermediateType =
castElementType(info.resultStorageType, intermediateElementType);
// Cast operands to storage type.
Value *lhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.lhsStorageType, info.lhs)
.getResult();
Value *rhsValue = rewriter
.create<StorageCastOp>(info.op->getLoc(),
info.rhsStorageType, info.rhs)
.getResult();
// Cast to the intermediate sized type.
lhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
lhsValue);
rhsValue = rewriter.create<ConvertISOp>(info.op->getLoc(), intermediateType,
rhsValue);
// Apply argument zeroPoints.
if (info.lhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.lhsType.getZeroPoint()));
lhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), lhsValue, zpOffsetConst);
}
if (info.rhsType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(), broadcastScalarConstIntValue(
intermediateType, -info.rhsType.getZeroPoint()));
rhsValue =
rewriter.create<AddIOp>(info.op->getLoc(), rhsValue, zpOffsetConst);
}
// Mul.
Value *resultValue =
rewriter.create<MulIOp>(info.op->getLoc(), lhsValue, rhsValue);
// Scale output.
QuantizedMultiplierSmallerThanOneExp outputMultiplier(outputMultiplierReal);
resultValue = rewriter.create<VecScalarSaturatingRoundingDoublingHighMulISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, outputMultiplier.multiplier));
resultValue = rewriter.create<RoundingDivideByPotISOp>(
info.op->getLoc(), resultValue,
IntegerAttr::get(intermediateElementType, -outputMultiplier.exponent));
// Zero point offset adjustment.
if (info.resultType.getZeroPoint() != 0) {
Value *zpOffsetConst = rewriter.create<ConstantOp>(
info.op->getLoc(),
broadcastScalarConstIntValue(intermediateType,
info.resultType.getZeroPoint()));
resultValue =
rewriter.create<AddIOp>(info.op->getLoc(), resultValue, zpOffsetConst);
}
// Clamp.
auto clampMinMax = info.getClampMinMax(intermediateElementType);
resultValue = rewriter.create<ClampISOp>(
info.op->getLoc(), resultValue, clampMinMax.first, clampMinMax.second);
// Convert back to original type.
resultValue = rewriter.create<ConvertISOp>(
info.op->getLoc(), info.resultStorageType, resultValue);
// Cast back for new result.
rewriter.replaceOpWithNewOp<StorageCastOp>(
info.op, info.getQuantizedResultType(), resultValue);
return success();
}
namespace {
struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineAddEwIsomorphicSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
// Try all of the permutations we support.
if (succeeded(tryRewriteAffineMulEwSigned(info, rewriter))) {
return matchSuccess();
}
return matchFailure();
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// LowerUniformRealMath pass
//===----------------------------------------------------------------------===//
void LowerUniformRealMathPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformRealAddEwPattern, UniformRealMulEwPattern>(context);
applyPatternsGreedily(fn, patterns);
}
FunctionPassBase *mlir::fxpmath::createLowerUniformRealMathPass() {
return new LowerUniformRealMathPass();
}
static PassRegistration<LowerUniformRealMathPass> lowerUniformRealMathPass(
"fxpmath-lower-uniform-real-math",
"Lowers uniform-quantized real math ops to integer arithmetic.");
//===----------------------------------------------------------------------===//
// LowerUniformCasts pass
//===----------------------------------------------------------------------===//
void LowerUniformCastsPass::runOnFunction() {
auto fn = getFunction();
OwningRewritePatternList patterns;
auto *context = &getContext();
patterns.insert<UniformDequantizePattern>(context);
applyPatternsGreedily(fn, patterns);
}
FunctionPassBase *mlir::fxpmath::createLowerUniformCastsPass() {
return new LowerUniformCastsPass();
}
static PassRegistration<LowerUniformCastsPass>
lowerUniformCastsPass("fxpmath-lower-uniform-casts",
"Lowers uniform-quantized casts.");