| /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ |
| #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ |
| |
| #include "llvm/ADT/ArrayRef.h" |
| #include "llvm/ADT/StringRef.h" |
| #include "llvm/ADT/StringSwitch.h" |
| #include "llvm/ADT/iterator_range.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" |
| #include "mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" |
| #include "mlir/Dialect/SCF/SCF.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/TypeUtilities.h" |
| |
| namespace mlir { |
| namespace lmhlo { |
| namespace impl { |
| |
| // A struct to map LhloBinaryOpTy type to the corresponding floating-point and |
| // integer scalar operation types. |
| template <typename LhloBinaryOpTy> |
| struct LhloToScalarOp; |
| |
| template <> |
| struct LhloToScalarOp<lmhlo::AddOp> { |
| using FOp = ::mlir::AddFOp; |
| using IOp = ::mlir::AddIOp; |
| using COp = ::mlir::AddCFOp; |
| }; |
| template <> |
| struct LhloToScalarOp<lmhlo::CompareOp> { |
| using FOp = ::mlir::CmpFOp; |
| using IOp = ::mlir::CmpIOp; |
| }; |
| template <> |
| struct LhloToScalarOp<lmhlo::DivOp> { |
| using FOp = ::mlir::DivFOp; |
| using IOp = ::mlir::SignedDivIOp; |
| }; |
| template <> |
| struct LhloToScalarOp<lmhlo::MulOp> { |
| using FOp = ::mlir::MulFOp; |
| using IOp = ::mlir::MulIOp; |
| }; |
| template <> |
| struct LhloToScalarOp<lmhlo::RemOp> { |
| using FOp = ::mlir::RemFOp; |
| using IOp = ::mlir::SignedRemIOp; |
| }; |
| template <> |
| struct LhloToScalarOp<lmhlo::SubOp> { |
| using FOp = ::mlir::SubFOp; |
| using IOp = ::mlir::SubIOp; |
| using COp = ::mlir::SubCFOp; |
| }; |
| |
| // Alias for the map from LHLO binary op type to STD floating-point op type. |
| template <typename LhloOp> |
| using ScalarFOp = typename LhloToScalarOp<LhloOp>::FOp; |
| // Alias for the map from LHLO binary op type to STD integer op type. |
| template <typename LhloOp> |
| using ScalarIOp = typename LhloToScalarOp<LhloOp>::IOp; |
| // Alias for the map from LHLO binary op type to STD complex op type. |
| template <typename LhloOp> |
| using ScalarCOp = typename LhloToScalarOp<LhloOp>::COp; |
| |
| template <typename... Args> |
| struct MapLhloOpToStdScalarOpImpl { |
| Value operator()(Location loc, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| return nullptr; |
| } |
| }; |
| |
| template <typename StdScalarOp> |
| struct MapLhloOpToStdScalarOpImpl<StdScalarOp> { |
| Value operator()(Location loc, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| return b->template create<StdScalarOp>(loc, result_types, args, mlir::None); |
| } |
| }; |
| |
| template <typename SupportedType, typename StdScalarOp, typename... Args> |
| struct MapLhloOpToStdScalarOpImpl<SupportedType, StdScalarOp, Args...> { |
| Value operator()(Location loc, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (element_type.isa<SupportedType>()) { |
| return b->template create<StdScalarOp>(loc, result_types, args, |
| mlir::None); |
| } |
| return MapLhloOpToStdScalarOpImpl<Args...>{}(loc, result_types, args, b); |
| } |
| }; |
| |
| // Inserts the computation that corresponds to the body of the loop for lowered |
| // LHLO unary/binary op. Returns the value for the result. |
| template <typename LhloOpTy> |
| inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<LhloOpTy>, FloatType, |
| ScalarFOp<LhloOpTy>>{}(loc, result_types, |
| args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (element_type.isa<FloatType>()) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}( |
| loc, result_types, args, b); |
| } |
| if (element_type.isa<IntegerType>()) { |
| // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) |
| Value lhs = args[0]; |
| auto integer_type = element_type.dyn_cast<IntegerType>(); |
| |
| Value zero_intval = |
| b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); |
| if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { |
| zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); |
| } |
| auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge, |
| lhs, zero_intval); |
| auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); |
| return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); |
| } |
| return nullptr; |
| } |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::AddOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::AddOp>, |
| FloatType, ScalarFOp<lmhlo::AddOp>, |
| ComplexType, ScalarCOp<lmhlo::AddOp>>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::Atan2Op>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::Atan2Op>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <typename PredicateType> |
| inline Optional<PredicateType> getCmpPredicate(StringRef comparison_direction) { |
| return llvm::None; |
| } |
| |
| template <> |
| inline Optional<CmpFPredicate> getCmpPredicate<CmpFPredicate>( |
| StringRef comparison_direction) { |
| return llvm::StringSwitch<Optional<CmpFPredicate>>(comparison_direction) |
| .Case("EQ", CmpFPredicate::OEQ) |
| .Case("NE", CmpFPredicate::UNE) |
| .Case("GE", CmpFPredicate::OGE) |
| .Case("GT", CmpFPredicate::OGT) |
| .Case("LE", CmpFPredicate::OLE) |
| .Case("LT", CmpFPredicate::OLT) |
| .Default(llvm::None); |
| } |
| |
| template <> |
| inline Optional<CmpIPredicate> getCmpPredicate<CmpIPredicate>( |
| StringRef comparison_direction) { |
| return llvm::StringSwitch<Optional<CmpIPredicate>>(comparison_direction) |
| .Case("EQ", CmpIPredicate::eq) |
| .Case("NE", CmpIPredicate::ne) |
| .Case("GE", CmpIPredicate::sge) |
| .Case("GT", CmpIPredicate::sgt) |
| .Case("LE", CmpIPredicate::sle) |
| .Case("LT", CmpIPredicate::slt) |
| .Default(llvm::None); |
| } |
| |
| template <typename CompareOpTy> |
| inline Value MapCompareOpToStdScalarOp(Location loc, |
| StringRef comparison_direction, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| const auto& lhs = args[0]; |
| const auto& rhs = args[1]; |
| Type element_type = getElementTypeOrSelf(lhs.getType()); |
| if (element_type.isSignlessInteger()) { |
| Optional<CmpIPredicate> predicate = |
| getCmpPredicate<CmpIPredicate>(comparison_direction); |
| assert(predicate.hasValue() && "expected valid comparison direction"); |
| return b->create<ScalarIOp<CompareOpTy>>(loc, predicate.getValue(), lhs, |
| rhs); |
| } |
| if (element_type.isa<FloatType>()) { |
| Optional<CmpFPredicate> predicate = |
| getCmpPredicate<CmpFPredicate>(comparison_direction); |
| assert(predicate.hasValue() && "expected valid comparison direction"); |
| return b->create<ScalarFOp<CompareOpTy>>(loc, predicate.getValue(), lhs, |
| rhs); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return args.front(); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args, |
| b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type sourceType = getElementTypeOrSelf(args.front().getType()); |
| Type targetType = getElementTypeOrSelf(result_types.front()); |
| |
| if (mlir::SIToFPOp::areCastCompatible(sourceType, targetType)) { |
| return b->create<mlir::SIToFPOp>(loc, result_types, args, mlir::None); |
| } else if (sourceType.isa<FloatType>() && targetType.isa<FloatType>()) { |
| FloatType src = sourceType.cast<FloatType>(); |
| FloatType res = targetType.cast<FloatType>(); |
| if (src.getWidth() > res.getWidth()) { |
| return b->create<mlir::FPTruncOp>(loc, result_types, args, mlir::None); |
| } else if (src.getWidth() < res.getWidth()) { |
| return b->create<mlir::FPExtOp>(loc, result_types, args, mlir::None); |
| } |
| // No conversion is needed for the same width floats |
| return args.front(); |
| } |
| if (sourceType.isSignlessInteger() && targetType.isSignlessInteger()) { |
| IntegerType src = sourceType.cast<IntegerType>(); |
| IntegerType res = targetType.cast<IntegerType>(); |
| if (src.getWidth() > res.getWidth()) { |
| return b->create<mlir::TruncateIOp>(loc, result_types, args, mlir::None); |
| } else if (src.getWidth() < res.getWidth()) { |
| return b->create<mlir::ZeroExtendIOp>(loc, result_types, args, |
| mlir::None); |
| } |
| // No conversion is needed for the same width integers |
| return args.front(); |
| } |
| if (mlir::FPToSIOp::areCastCompatible(sourceType, targetType)) { |
| return b->create<mlir::FPToSIOp>(loc, result_types, args, mlir::None); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| // Dot Op converter from lhlo to affine only accepts float and integer types. |
| const auto& lhs = args[0]; |
| const auto& rhs = args[1]; |
| const auto& result = args[2]; |
| Type element_type = lhs.getType(); |
| if (element_type.isa<FloatType>()) { |
| Value float_mul = MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::MulFOp>{}( |
| loc, result_types, {lhs, rhs}, b); |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AddFOp>{}( |
| loc, result_types, {float_mul, result}, b); |
| } |
| if (element_type.isa<IntegerType>()) { |
| Value int_mul = MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::MulIOp>{}( |
| loc, result_types, {lhs, rhs}, b); |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AddIOp>{}( |
| loc, result_types, {int_mul, result}, b); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::FloorOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::FloorFOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::IsFiniteOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| if (args[0].getType().isa<FloatType>()) { |
| auto pos_inf = APFloat::getInf( |
| args[0].getType().cast<FloatType>().getFloatSemantics()); |
| auto const_pos_inf = |
| b->create<ConstantOp>(loc, b->getFloatAttr(args[0].getType(), pos_inf)); |
| Value abs_x = b->create<::mlir::AbsFOp>(loc, args[0]); |
| return b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, abs_x, |
| const_pos_inf); |
| } |
| return nullptr; |
| } |
| |
| /// Implements the conversion of HLO op to scalar op (to use within region of a |
| /// linalg.generic op) for compare-select style operations like min/max. |
| template <typename... Args> |
| struct CompareSelectOpToStdScalarOp { |
| static Value map(Location loc, StringRef comparison_direction, |
| ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return nullptr; |
| } |
| }; |
| |
| /// Specialization which allows converting to a comparison operation in standard |
| /// dialect with a given predicate based on the element type of the operand. |
| template <typename SupportedType, typename StdCompareOp, typename Predicate, |
| typename... Args> |
| struct CompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate, |
| Args...> { |
| static Value map(Location loc, StringRef comparison_direction, |
| ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (element_type.isa<SupportedType>()) { |
| auto predicate = getCmpPredicate<Predicate>(comparison_direction); |
| assert(predicate.hasValue() && "expected valid comparison direction"); |
| auto cmp = b->template create<StdCompareOp>(loc, predicate.getValue(), |
| args[0], args[1]); |
| return b->create<::mlir::SelectOp>(loc, cmp, args[0], args[1]); |
| } |
| return CompareSelectOpToStdScalarOp<Args...>::map(loc, comparison_direction, |
| result_types, args, b); |
| } |
| }; |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::Log1pOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| auto ty = result_types.front().cast<FloatType>(); |
| Value x = args.front(); |
| Value one = b->create<ConstantOp>(loc, b->getFloatAttr(ty, 1.0)); |
| Value x_plus_one = b->create<AddFOp>(loc, x, one); |
| return b->create<::mlir::LogOp>(loc, x_plus_one); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return CompareSelectOpToStdScalarOp< |
| IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, |
| ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types, |
| args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return CompareSelectOpToStdScalarOp< |
| IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType, |
| ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types, |
| args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ClampOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| assert(args.size() == 3 && "expected 3 arguments"); |
| Value lb = args[0]; |
| Value x = args[1]; |
| Value ub = args[2]; |
| |
| // clamp(lb, x, ub) = max(min(x, ub), lb) |
| Value min_x_ub = |
| MapLhloOpToStdScalarOp<lmhlo::MinOp>(loc, result_types, {x, ub}, b); |
| return MapLhloOpToStdScalarOp<lmhlo::MaxOp>(loc, result_types, {min_x_ub, lb}, |
| b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (element_type.isa<FloatType>()) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}( |
| loc, result_types, args, b); |
| } |
| if (element_type.isa<IntegerType>()) { |
| // lmhlo.neg(x, result) -> result = sub(0, x) |
| Value lhs = args[0]; |
| auto integer_type = element_type.dyn_cast<IntegerType>(); |
| |
| Value zero_intval = |
| b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); |
| if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { |
| zero_intval = b->create<::mlir::SplatOp>(loc, vec_type, zero_intval); |
| } |
| return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::NotOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (auto integer_type = element_type.dyn_cast<IntegerType>()) { |
| // lmhlo.not(x) -> x ^ -1 |
| Value all_ones = |
| b->create<::mlir::ConstantIntOp>(loc, -1, integer_type.getWidth()); |
| if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { |
| all_ones = b->create<::mlir::SplatOp>(loc, vec_type, all_ones); |
| } |
| return b->create<::mlir::XOrOp>(loc, all_ones, args[0]); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::OrOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::OrOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::PowOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| lmhlo::PowOp::Adaptor adaptor(args); |
| // Floating point can use std::powf |
| auto result_type = result_types.front(); |
| if (result_type.isa<::mlir::FloatType>()) |
| return MapLhloOpToStdScalarOpImpl<::mlir::PowFOp>{}(loc, result_types, args, |
| b); |
| |
| assert(result_type.isa<::mlir::IntegerType>() && |
| "only float and integer `pow` is supported right now"); |
| |
| // There is no powi, so lower to a simple product. |
| Value neg_one = |
| b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, -1)); |
| Value zero = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 0)); |
| Value one = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 1)); |
| Value two = b->create<ConstantOp>(loc, b->getIntegerAttr(result_type, 2)); |
| |
| Value lowerBound = b->create<ConstantIndexOp>(loc, 0); |
| Value upperBound = |
| b->create<IndexCastOp>(loc, adaptor.rhs(), b->getIndexType()); |
| Value step = b->create<ConstantIndexOp>(loc, 1); |
| Value for_result = |
| b->create<scf::ForOp>( |
| loc, lowerBound, upperBound, step, llvm::makeArrayRef(one), |
| [&](OpBuilder& b, Location l, Value v, ValueRange iters) { |
| Value prod = |
| b.create<::mlir::MulIOp>(l, adaptor.lhs(), iters.front()); |
| b.create<scf::YieldOp>(l, prod); |
| }) |
| .getResult(0); |
| |
| Value rhs_is_even = |
| b->create<CmpIOp>(loc, CmpIPredicate::eq, |
| b->create<SignedRemIOp>(loc, adaptor.rhs(), two), zero); |
| Value rhs_is_negative = |
| b->create<CmpIOp>(loc, CmpIPredicate::slt, adaptor.rhs(), zero); |
| Value lhs_is_one = |
| b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), one); |
| Value lhs_is_neg_one = |
| b->create<CmpIOp>(loc, CmpIPredicate::eq, adaptor.lhs(), neg_one); |
| |
| // The for_result is correct when the rhs is non-negative. When rhs is |
| // negative, we return 0 for integer, with the exception of lhs values of 1 |
| // and -1 which have integer results for negative exponents. Specifically, the |
| // calulation is the following: |
| // |
| // - Return for_result if the rhs is not negative. |
| // - Return 1 or -1 depending on the parity of rhs when the lhs is -1. |
| // - Return 1 if lhs is 1. |
| // - Else return 0. |
| Value if_lhs_is_one = b->create<::mlir::SelectOp>(loc, lhs_is_one, one, zero); |
| Value if_lhs_is_neg_one = b->create<::mlir::SelectOp>( |
| loc, lhs_is_neg_one, |
| b->create<::mlir::SelectOp>(loc, rhs_is_even, one, neg_one), |
| if_lhs_is_one); |
| return b->create<::mlir::SelectOp>(loc, rhs_is_negative, if_lhs_is_neg_one, |
| for_result); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, |
| b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftLeftOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::ShiftLeftOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightArithmeticOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::SignedShiftRightOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::ShiftRightLogicalOp>( |
| Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, mlir::UnsignedShiftRightOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| Type element_type = getElementTypeOrSelf(args.front().getType()); |
| if (auto float_type = element_type.dyn_cast<FloatType>()) { |
| bool ignored; |
| APFloat zero_apfloat(0.0f); |
| zero_apfloat.convert(float_type.getFloatSemantics(), |
| APFloat::rmNearestTiesToEven, &ignored); |
| Value zero = |
| b->create<mlir::ConstantFloatOp>(loc, zero_apfloat, float_type); |
| if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { |
| zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); |
| } |
| Value ne0_i1 = |
| b->create<::mlir::CmpFOp>(loc, CmpFPredicate::ONE, args[0], zero); |
| Value ne0_float = b->create<::mlir::UIToFPOp>(loc, ne0_i1, zero.getType()); |
| Value copy_sign = |
| b->create<::mlir::CopySignOp>(loc, result_types, ne0_float, args[0]); |
| auto is_nan = |
| b->create<::mlir::CmpFOp>(loc, CmpFPredicate::UNO, args[0], args[0]); |
| return b->create<::mlir::SelectOp>(loc, is_nan, args[0], copy_sign); |
| } else if (auto integer_type = element_type.dyn_cast<IntegerType>()) { |
| // sign(x) = x == 0 ? 0 : ((x s>> 31) | 1) |
| Value zero = |
| b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); |
| Value bitwidth_minus_one = b->create<::mlir::ConstantIntOp>( |
| loc, integer_type.getWidth() - 1, integer_type.getWidth()); |
| Value one = |
| b->create<::mlir::ConstantIntOp>(loc, 1, integer_type.getWidth()); |
| if (VectorType vec_type = args.front().getType().dyn_cast<VectorType>()) { |
| zero = b->create<::mlir::SplatOp>(loc, vec_type, zero); |
| bitwidth_minus_one = |
| b->create<::mlir::SplatOp>(loc, vec_type, bitwidth_minus_one); |
| one = b->create<::mlir::SplatOp>(loc, vec_type, one); |
| } |
| Value cmp = |
| b->create<::mlir::CmpIOp>(loc, CmpIPredicate::eq, args[0], zero); |
| Value ashr = |
| b->create<::mlir::SignedShiftRightOp>(loc, args[0], bitwidth_minus_one); |
| Value or_op = b->create<::mlir::OrOp>(loc, ashr, one); |
| return b->create<::mlir::SelectOp>(loc, cmp, zero, or_op); |
| } |
| return nullptr; |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::SubOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ScalarIOp<lmhlo::SubOp>, |
| FloatType, ScalarFOp<lmhlo::SubOp>, |
| ComplexType, ScalarCOp<lmhlo::SubOp>>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| template <> |
| inline Value MapLhloOpToStdScalarOp<lmhlo::XorOp>(Location loc, |
| ArrayRef<Type> result_types, |
| ArrayRef<Value> args, |
| OpBuilder* b) { |
| return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::XOrOp>{}( |
| loc, result_types, args, b); |
| } |
| |
| } // namespace impl |
| |
| struct HloOpToStdScalarOp { |
| // Implementation for LHLO ops except lmhlo::CompareOp. |
| template <typename HloOpTy, typename LhloOpTy = HloOpTy, |
| typename = std::enable_if_t< |
| !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && |
| std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, |
| std::false_type>::value>> |
| static Value map(HloOpTy op, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { |
| return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, |
| args, b); |
| } |
| |
| // Implementation for HLO ops except mhlo::CompareOp. |
| template <typename HloOpTy, typename LhloOpTy = mhlo::HloToLhloOp<HloOpTy>, |
| typename = std::enable_if_t< |
| !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && |
| !std::is_same<LhloOpTy, std::false_type>::value>> |
| static Value map(HloOpTy op, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b, int i = 0) { |
| return impl::MapLhloOpToStdScalarOp<LhloOpTy>(op.getLoc(), result_types, |
| args, b); |
| } |
| |
| // Implementation for lmhlo::CompareOp. |
| template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< |
| LhloOpTy, lmhlo::CompareOp>::value>> |
| static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| auto comparison_direction = op.comparison_direction(); |
| return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( |
| op.getLoc(), comparison_direction, result_types, args, b); |
| } |
| |
| // Implementation for mhlo::CompareOp. |
| template <typename HloOpTy, |
| typename = |
| std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>> |
| static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b) { |
| auto comparison_direction = op.comparison_direction(); |
| return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( |
| op.getLoc(), comparison_direction, result_types, args, b); |
| } |
| |
| // Implementation for LHLO ops except lmhlo::CompareOp. |
| template <typename LhloOpTy, |
| typename = std::enable_if_t< |
| !std::is_same<LhloOpTy, lmhlo::CompareOp>::value && |
| std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>, |
| std::false_type>::value>> |
| static Value map(Location loc, ArrayRef<Type> result_types, |
| ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { |
| return impl::MapLhloOpToStdScalarOp<LhloOpTy>(loc, result_types, args, b); |
| } |
| |
| // Implementation for lmhlo::CompareOp. |
| template <typename LhloOpTy, typename = std::enable_if_t<std::is_same< |
| LhloOpTy, lmhlo::CompareOp>::value>> |
| static Value map(Location loc, StringRef comparison_direction, |
| ArrayRef<Type> result_types, ArrayRef<Value> args, |
| OpBuilder* b) { |
| return impl::MapCompareOpToStdScalarOp<lmhlo::CompareOp>( |
| loc, comparison_direction, result_types, args, b); |
| } |
| }; |
| |
| } // namespace lmhlo |
| } // namespace mlir |
| |
| #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_LMHLO_TO_SCALAR_OP_H_ |