blob: ae063abaad104e20d6b0f32d24ca091040729156 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H
#define MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_BASE_H
#include <algorithm>
#include "llvm/ADT/Sequence.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
// Include order below matters.
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_enums.h.inc"
#define GET_ATTRDEF_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h.inc"
namespace mlir {
namespace mhlo {
// Forward declaration for a function declared in hlo_ops.h.
bool isCompatibleForMhloTypeInference(Type tp1, Type tp2);
namespace OpTrait {
template <typename ConcreteType>
class BroadcastingElementwise
: public mlir::OpTrait::TraitBase<ConcreteType, BroadcastingElementwise> {};
template <typename ConcreteType>
class PairwiseSameOperandAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
PairwiseSameOperandAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
const int numOperands = op->getNumOperands();
const int numResults = op->getNumResults();
if (numOperands != numResults) {
return op->emitOpError()
<< "requires the same number of operands and results";
}
for (int idx : llvm::seq<int>(0, numOperands)) {
if (op->getOperand(idx).getType() != op->getResult(idx).getType()) {
return op->emitOpError()
<< "requires the same type for operand and result at index "
<< idx;
}
}
return success();
}
};
template <typename ConcreteType>
class CompatibleOperandsAndResultType
: public mlir::OpTrait::TraitBase<ConcreteType,
CompatibleOperandsAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
Type expected;
if (op->getNumResults() != 0) expected = op->getResult(0).getType();
if (op->getNumOperands() != 0) expected = op->getOperand(0).getType();
if (!expected) return failure();
auto typeMatch = [&](Type actual) {
return isCompatibleForMhloTypeInference(actual, expected);
};
auto allMatch = llvm::all_of(op->getOperandTypes(), typeMatch) &&
llvm::all_of(op->getResultTypes(), typeMatch);
if (!allMatch) {
return op->emitOpError(
"requires compatible types for all operands and results");
}
return success(allMatch);
}
static LogicalResult inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr /*attributes*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
// TODO(b/231358795): Review the use of InferTypeOpInterface for ops that
// support quantization or sparsity.
if (operands.empty())
return emitOptionalError(
location,
"Expected non-empty operands for [CompatibleOperandsAndResultType]");
if (failed(inferMostSpecificType(context, location, operands.getTypes(),
inferredReturnTypes)))
return failure();
return success();
}
// This function is not going to be called automatically.
// It needs to be paired with INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS
// (see examples in hlo_ops.cc).
static LogicalResult inferReturnTypeComponentsFromOperands(
MLIRContext *context, Optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
SmallVector<Type> inferredReturnTypes;
if (failed(inferReturnTypes(context, location, operands.getValues(),
attributes, regions, inferredReturnTypes)))
return failure();
auto inferredReturnType = inferredReturnTypes[0].cast<ShapedType>();
inferredReturnShapes.push_back(inferredReturnType);
return success();
}
private:
// Cases of infer return shape with bounds (lhs and rhs are commutative):
// Dim of lhs Dim of rhs Infer
// c0: 3 3 3
// c1: 3 ? 3
// c2: 3 ?, bound=4 3
// c3: 3 ?, bound=2 Error out
// c4: ? ? ?
// c5: ? ?, bound=3 ?, bound=3
// c6: ?, bound=3 ?, bound=3 ?, bound=3
// c7: ?, bound=3 ?, bound=4 ?, bound=3
// This method generalizes it to multiple inputs: 1) get the static input dims
// (if any) as infer dim, and 2) get min of input bounds as infer bound
static LogicalResult inferMostSpecificType(
MLIRContext *context, Optional<Location> location,
ValueTypeRange<ValueRange> inputTypes,
SmallVectorImpl<Type> &inferredReturnTypes) {
// TODO(zhouxin) remove this part and find a way to infer sparsity encoding.
if (inputTypes.size() == 1) {
inferredReturnTypes.push_back(inputTypes[0]);
return success();
}
SmallVector<RankedTensorType> rankedTypes;
for (auto inputType : inputTypes)
if (auto rankedType = inputType.dyn_cast<RankedTensorType>())
rankedTypes.push_back(rankedType);
if (rankedTypes.empty()) {
inferredReturnTypes.push_back(inputTypes[0]);
return success();
}
auto rank = rankedTypes[0].getRank();
SmallVector<int64_t> inferredDimSizes(rank, ShapedType::kDynamicSize);
SmallVector<int64_t> inferredBounds(rank, ShapedType::kDynamicSize);
for (auto rankedType : rankedTypes) {
SmallVector<int64_t> bounds;
if (auto encoding =
rankedType.getEncoding().dyn_cast_or_null<TypeExtensionsAttr>())
bounds = llvm::to_vector<4>(encoding.getBounds());
for (int dim = 0; dim < rank; ++dim) {
// Dimensions
auto dimSize = rankedType.getShape()[dim];
if (inferredDimSizes[dim] != ShapedType::kDynamicSize &&
dimSize != ShapedType::kDynamicSize &&
inferredDimSizes[dim] != dimSize)
return emitOptionalError(location, "Mismatch dimension size ",
inferredDimSizes[dim], " and ", dimSize,
" in dimension ", dim);
if (inferredDimSizes[dim] == ShapedType::kDynamicSize)
inferredDimSizes[dim] = dimSize;
// Bounds
if (!bounds.empty() && bounds[dim] != ShapedType::kDynamicSize) {
if (inferredBounds[dim] == ShapedType::kDynamicSize) {
inferredBounds[dim] = bounds[dim];
} else {
inferredBounds[dim] = std::min(inferredBounds[dim], bounds[dim]);
}
}
// Error out case that the inferred bound is smaller than inferred dim
if (inferredBounds[dim] != ShapedType::kDynamicSize &&
inferredBounds[dim] < inferredDimSizes[dim])
return emitOptionalError(location,
"bound must not be less than static "
"dimension size but has bound ",
inferredBounds[dim], " vs static size ",
inferredDimSizes[dim], " in dimension ",
dim);
if (inferredDimSizes[dim] != ShapedType::kDynamicSize)
inferredBounds[dim] = ShapedType::kDynamicSize;
}
}
Attribute encoding = nullptr;
if (llvm::any_of(inferredBounds,
[](auto el) { return el != ShapedType::kDynamicSize; }))
encoding = TypeExtensionsAttr::get(context, inferredBounds);
inferredReturnTypes.push_back(RankedTensorType::get(
inferredDimSizes, rankedTypes[0].getElementType(), encoding));
return success();
}
};
} // namespace OpTrait
} // namespace mhlo
} // namespace mlir
#endif