blob: 2bb7fb32fde4a01397b50a12ccc75b1134cb6785 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef MLIR_HLO_UTILS_HLO_UTILS_H
#define MLIR_HLO_UTILS_HLO_UTILS_H
#include <string>
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
namespace mlir {
namespace hlo {
// Computes the broadcast dimensions attr for an elementwise binary operator
// between two ranked tensors.
// If `allow_empty` is true, then null can be returned to mean that the
// broadcast is an "identity".
mlir::DenseIntElementsAttr getBroadcastDimensionsAttr(mlir::Builder* b,
mlir::Value x,
mlir::Value y,
bool allowEmpty = true);
// Get a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType.
template <typename T>
static ElementsAttr getSplat(Builder* b, RankedTensorType ty, T constant) {
Type elementTy = getElementTypeOrSelf(ty);
if (elementTy.isSignlessInteger())
return DenseElementsAttr::get(ty, b->getIntegerAttr(elementTy, constant));
if (elementTy.isa<FloatType>())
return DenseElementsAttr::get(ty, b->getFloatAttr(elementTy, constant));
if (auto complexTy = elementTy.dyn_cast<ComplexType>()) {
auto complexElementTy = complexTy.getElementType();
if (complexElementTy.isF32())
return DenseElementsAttr::get(ty,
static_cast<std::complex<float>>(constant));
if (complexElementTy.isF64())
return DenseElementsAttr::get(
ty, static_cast<std::complex<double>>(constant));
}
llvm_unreachable("unhandled element type");
}
template <typename T>
static ElementsAttr getSplat(Builder* b, Value val, T constant) {
return getSplat(b, val.getType().cast<RankedTensorType>(), constant);
}
// Returns DenseElementsAttr of rank zero with the given element type and the
// value.
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
DenseElementsAttr getScalarOfType(Type ty, int64_t rawValue);
// Returns DenseElementsAttr of rank zero with the given element type and the
// value which is the neutral element for additions.
// Requires `ty` to be either FloatType, IntegerType, or ComplexType.
DenseElementsAttr getScalarNegZeroOfType(Type ty);
// Enum type used to specify scalar argument to GetScalarLimitOfType.
enum ScalarLimit {
kLowest, // The scalar corresponding to numeric_limits<T>::lowest.
kInfinityLowest, // Like kLowest, but returns -infinity where available.
kMax, // The scalar corresponding to numeric_limits<T>::max.
kInfinityMax, // Like kMax, but returns infinity where available.
};
// Returns a scalar limit value for the given type.
//
// The argument 'limit' describes which scalar value to return.
//
// Requires `ty` to be either FloatType or IntegerType.
DenseElementsAttr getScalarLimitOfType(Type ty, ScalarLimit limit);
// Given `op_name` from LMHLO, returns the corresponding op name in MHLO.
// Returns empty string if no such op exists.
std::string lmhloToMhloOpName(llvm::StringRef opName,
mlir::MLIRContext* context);
// Return true if Attr has values [0, 1, ...].
bool isSequenceStartingWith0(Attribute attr);
// Returns the argument index for the giving FuncOp and its operand value.
int64_t getArgumentIndex(func::FuncOp op, Value value);
/// Computes the memory usage of the given allocations.
std::pair<size_t, size_t> computeMemory(const std::vector<Value>& allocs);
} // namespace hlo
} // namespace mlir
#endif // MLIR_HLO_UTILS_HLO_UTILS_H