blob: 18c99bc264c447d3b98c9909b29e876c8c2a3e9e [file] [log] [blame]
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface_impl.h"
#include <tuple>
#include "llvm/ADT/TypeSwitch.h"
#include "mlir-hlo/Dialect/gml_st/IR/gml_st_ops.h"
#include "mlir-hlo/Dialect/gml_st/transforms/fusion_interface.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/transforms/map_mhlo_to_scalar_op.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/Utils/Utils.h"
namespace mlir {
namespace gml_st {
namespace {
bool isElementwiseOrTranspose(linalg::GenericOp genericOp) {
if (!genericOp.hasTensorSemantics()) return false;
if (genericOp.outputs().size() != 1) return false;
if (!llvm::all_of(genericOp.iterator_types(), [](Attribute attr) {
return mlir::isParallelIterator(attr);
})) {
return false;
}
if (!llvm::all_of(genericOp.indexing_maps(), [](Attribute attr) {
return attr.cast<AffineMapAttr>().getAffineMap().isPermutation();
})) {
return false;
}
return true;
}
struct LingalgGenericFusionInterface
: public FusionInterface::ExternalModel<LingalgGenericFusionInterface,
linalg::GenericOp> {
Value fuse(Operation* op, Location loc, Value subset,
OpBuilder& builder) const {
auto genericOp = llvm::cast<linalg::GenericOp>(op);
// Supports only tile subsets.
auto tileTy = subset.getType().dyn_cast<TileType>();
if (!tileTy) return {};
// Supports only element-wise `linalg.generic` ops.
if (!isElementwiseOrTranspose(genericOp)) return {};
// Create tiled op.
Value output = genericOp.outputs().front();
auto outputTy = output.getType().cast<RankedTensorType>();
auto subResultTy =
RankedTensorType::get(tileTy.getShape(), outputTy.getElementType());
SmallVector<Value> subOperands;
subOperands.reserve(genericOp.getNumInputs());
for (const auto& inputAndMap :
llvm::zip(genericOp.inputs(), genericOp.getIndexingMaps())) {
Value input;
AffineMap map;
std::tie(input, map) = inputAndMap;
if (map.isIdentity()) {
subOperands.push_back(
builder.create<MaterializeOp>(loc, input, subset));
continue;
}
assert(map.isPermutation());
// Materialize a space for the input.
auto inputTy = input.getType().cast<RankedTensorType>();
auto spaceTy = builder.getType<TileType>(inputTy.getShape());
auto dynamicDims = tensor::createDynamicDimValues(builder, loc, input);
auto staticDims = builder.getI64ArrayAttr(inputTy.getShape());
Value inputSpace =
builder.create<SpaceOp>(loc, spaceTy, dynamicDims, staticDims);
// Create a new tile with permutated operands.
SmallVector<Value> inputTileOffsets, inputTileSizes, inputTileStrides;
SmallVector<int64_t> inputStaticOffsets, inputStaticSizes;
// Use original tileOp where possible.
auto tileOp = subset.getDefiningOp<TileOp>();
for (unsigned int r = 0, e = map.getNumResults(); r < e; ++r) {
auto permutedDim = map.getPermutedPosition(r);
auto permutedDimConstant =
builder.create<arith::ConstantIndexOp>(loc, permutedDim);
// TODO(unknown): With a canonicalizer, we could always use values.
if (!tileOp || tileOp.isDynamicOffset(permutedDim)) {
inputTileOffsets.push_back(
builder.createOrFold<OffsetOp>(loc, subset, permutedDimConstant));
inputStaticOffsets.push_back(ShapedType::kDynamicStrideOrOffset);
} else {
inputStaticOffsets.push_back(tileOp.getStaticOffset(permutedDim));
}
if (!tileOp || tileOp.isDynamicSize(permutedDim)) {
inputTileSizes.push_back(
builder.createOrFold<SizeOp>(loc, subset, permutedDimConstant));
inputStaticSizes.push_back(ShapedType::kDynamicSize);
} else {
inputStaticSizes.push_back(tileOp.getStaticSize(permutedDim));
}
}
auto inputStaticStrides = builder.getI64ArrayAttr(
SmallVector<int64_t>(inputStaticSizes.size(), 1));
auto operandTileTy =
TileType::get(subResultTy.getContext(), inputStaticSizes);
auto permutedSubset = builder.create<TileOp>(
loc, operandTileTy, inputSpace, inputTileOffsets, inputTileSizes,
ValueRange{}, builder.getI64ArrayAttr(inputStaticOffsets),
builder.getI64ArrayAttr(inputStaticSizes), inputStaticStrides);
subOperands.push_back(
builder.create<MaterializeOp>(loc, input, permutedSubset));
}
// Materialize the tiled output.
subOperands.push_back(builder.create<MaterializeOp>(loc, output, subset));
linalg::LinalgOp linalgOp = genericOp;
Operation* tiledOp = linalgOp.clone(builder, loc, subResultTy, subOperands);
return tiledOp->getResults().front();
}
};
template <typename OpTy>
struct ElementwiseFusionInterface
: public FusionInterface::ExternalModel<ElementwiseFusionInterface<OpTy>,
OpTy> {
Value fuse(Operation* op, Location loc, Value subset,
OpBuilder& builder) const {
// Supports tile and point subsets.
Type subsetTy = subset.getType();
if (!subsetTy.isa<PointType, TileType>()) return {};
// Expect ranked element-wise op.
auto cwiseOp = llvm::cast<OpTy>(op);
auto rankedTy = cwiseOp.getType().template dyn_cast<RankedTensorType>();
if (!rankedTy) return {};
// Materialize subsets for all arguments.
auto subsetArgs = llvm::to_vector(
llvm::map_range(cwiseOp->getOperands(), [&](const auto& arg) -> Value {
return builder.create<MaterializeOp>(loc, arg, subset);
}));
// Materialize elementwise op for subset.
return llvm::TypeSwitch<Type, Value>(subsetTy)
.Case([&](TileType) -> Value {
return builder.create<OpTy>(loc, subsetArgs);
})
.Case([&](PointType) -> Value {
return mhlo::MhloOpToStdScalarOp::mapOp(
cwiseOp, rankedTy.getElementType(), subsetArgs, &builder);
})
.Default([](Type) -> Value { return {}; });
}
};
} // namespace
void registerFusionInterfaceExternalModels(DialectRegistry& registry) {
registry.insert<linalg::LinalgDialect>();
registry.addExtension(+[](MLIRContext* ctx, linalg::LinalgDialect*) {
linalg::GenericOp::attachInterface<LingalgGenericFusionInterface>(*ctx);
});
// TODO(frgossen): Update tests and remove these in favor of
// `linalg.generic`-based fusions.
registry.insert<mhlo::MhloDialect>();
registry.addExtension(+[](MLIRContext* ctx, mhlo::MhloDialect*) {
mhlo::AddOp::attachInterface<ElementwiseFusionInterface<mhlo::AddOp>>(*ctx);
mhlo::SubOp::attachInterface<ElementwiseFusionInterface<mhlo::SubOp>>(*ctx);
mhlo::CosOp::attachInterface<ElementwiseFusionInterface<mhlo::CosOp>>(*ctx);
mhlo::TanhOp::attachInterface<ElementwiseFusionInterface<mhlo::TanhOp>>(
*ctx);
});
}
} // namespace gml_st
} // namespace mlir