[MLIR] Add LMHLO support for getting HloOpcode from MLIR
PiperOrigin-RevId: 345180174
Change-Id: I6eea22d23b5544a344ef34fbff182c9e1781c019
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 5dc89ad..ab7e215 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -193,6 +193,7 @@
deps = [
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
+ "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc
index 7660f15..51e00bc 100644
--- a/tensorflow/compiler/mlir/xla/hlo_utils.cc
+++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc
@@ -21,6 +21,7 @@
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/bfloat16.h"
#include "tensorflow/core/platform/logging.h"
@@ -200,4 +201,219 @@
builder.getContext());
}
+StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
+ using mlir::isa;
+
+ if (isa<mlir::mhlo::ConstOp, mlir::lmhlo::ConstOp>(op)) {
+ return xla::HloOpcode::kConstant;
+ } else if (isa<mlir::mhlo::IotaOp, mlir::lmhlo::IotaOp>(op)) {
+ return xla::HloOpcode::kIota;
+ } else if (isa<mlir::mhlo::ConvertOp, mlir::lmhlo::ConvertOp>(op)) {
+ return xla::HloOpcode::kConvert;
+ } else if (isa<mlir::mhlo::AddOp, mlir::lmhlo::AddOp>(op)) {
+ return xla::HloOpcode::kAdd;
+ } else if (isa<mlir::mhlo::Atan2Op, mlir::lmhlo::Atan2Op>(op)) {
+ return xla::HloOpcode::kAtan2;
+ } else if (isa<mlir::mhlo::DivOp, mlir::lmhlo::DivOp>(op)) {
+ return xla::HloOpcode::kDivide;
+ } else if (isa<mlir::mhlo::MaxOp, mlir::lmhlo::MaxOp>(op)) {
+ return xla::HloOpcode::kMaximum;
+ } else if (isa<mlir::mhlo::MinOp, mlir::lmhlo::MinOp>(op)) {
+ return xla::HloOpcode::kMinimum;
+ } else if (isa<mlir::mhlo::MulOp, mlir::lmhlo::MulOp>(op)) {
+ return xla::HloOpcode::kMultiply;
+ } else if (isa<mlir::mhlo::PowOp, mlir::lmhlo::PowOp>(op)) {
+ return xla::HloOpcode::kPower;
+ } else if (isa<mlir::mhlo::RemOp, mlir::lmhlo::RemOp>(op)) {
+ return xla::HloOpcode::kRemainder;
+ } else if (isa<mlir::mhlo::ShiftLeftOp, mlir::lmhlo::ShiftLeftOp>(op)) {
+ return xla::HloOpcode::kShiftLeft;
+ } else if (isa<mlir::mhlo::ShiftRightArithmeticOp,
+ mlir::lmhlo::ShiftRightArithmeticOp>(op)) {
+ return xla::HloOpcode::kShiftRightArithmetic;
+ } else if (isa<mlir::mhlo::ShiftRightLogicalOp,
+ mlir::lmhlo::ShiftRightLogicalOp>(op)) {
+ return xla::HloOpcode::kShiftRightLogical;
+ } else if (isa<mlir::mhlo::SubOp, mlir::lmhlo::SubOp>(op)) {
+ return xla::HloOpcode::kSubtract;
+ } else if (isa<mlir::mhlo::XorOp, mlir::lmhlo::XorOp>(op)) {
+ return xla::HloOpcode::kXor;
+ } else if (isa<mlir::mhlo::InfeedOp, mlir::lmhlo::Infeed>(op)) {
+ return xla::HloOpcode::kInfeed;
+ } else if (isa<mlir::mhlo::OutfeedOp, mlir::lmhlo::Outfeed>(op)) {
+ return xla::HloOpcode::kOutfeed;
+ } else if (isa<mlir::mhlo::SendOp>(op)) {
+ return xla::HloOpcode::kSend;
+ } else if (isa<mlir::mhlo::RecvOp>(op)) {
+ return xla::HloOpcode::kRecv;
+ } else if (isa<mlir::mhlo::ReplicaIdOp, mlir::lmhlo::ReplicaIdOp>(op)) {
+ return xla::HloOpcode::kReplicaId;
+ } else if (isa<mlir::mhlo::AfterAllOp>(op)) {
+ return xla::HloOpcode::kAfterAll;
+ } else if (isa<mlir::mhlo::AllReduceOp, mlir::lmhlo::AllReduceOp>(op)) {
+ return xla::HloOpcode::kAllReduce;
+ } else if (isa<mlir::mhlo::AllToAllOp>(op)) {
+ return xla::HloOpcode::kAllToAll;
+ } else if (isa<mlir::mhlo::TupleOp>(op)) {
+ return xla::HloOpcode::kTuple;
+ } else if (isa<mlir::mhlo::BatchNormGradOp, mlir::lmhlo::BatchNormGradOp>(
+ op)) {
+ return xla::HloOpcode::kBatchNormGrad;
+ } else if (isa<mlir::mhlo::BatchNormInferenceOp,
+ mlir::lmhlo::BatchNormInferenceOp>(op)) {
+ return xla::HloOpcode::kBatchNormInference;
+ } else if (isa<mlir::mhlo::BatchNormTrainingOp,
+ mlir::lmhlo::BatchNormTrainingOp>(op)) {
+ return xla::HloOpcode::kBatchNormTraining;
+ } else if (isa<mlir::mhlo::BitcastConvertOp, mlir::lmhlo::BitcastConvertOp>(
+ op)) {
+ return xla::HloOpcode::kBitcastConvert;
+ } else if (isa<mlir::mhlo::BroadcastOp, mlir::lmhlo::BroadcastOp>(op)) {
+ return xla::HloOpcode::kBroadcast;
+ } else if (isa<mlir::mhlo::CholeskyOp, mlir::lmhlo::CholeskyOp>(op)) {
+ return xla::HloOpcode::kCholesky;
+ } else if (isa<mlir::mhlo::ClampOp, mlir::lmhlo::ClampOp>(op)) {
+ return xla::HloOpcode::kClamp;
+ } else if (isa<mlir::mhlo::ConcatenateOp, mlir::lmhlo::ConcatenateOp>(op)) {
+ return xla::HloOpcode::kConcatenate;
+ } else if (isa<mlir::mhlo::ConvOp, mlir::lmhlo::ConvOp>(op)) {
+ return xla::HloOpcode::kConvolution;
+ } else if (isa<mlir::mhlo::SortOp, mlir::lmhlo::SortOp>(op)) {
+ return xla::HloOpcode::kSort;
+ } else if (isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
+ return xla::HloOpcode::kRngBitGenerator;
+ } else if (isa<mlir::mhlo::FusionOp, mlir::lmhlo::FusionOp>(op)) {
+ return xla::HloOpcode::kFusion;
+ } else if (isa<mlir::mhlo::BitcastOp, mlir::lmhlo::BitcastOp>(op)) {
+ return xla::HloOpcode::kBitcast;
+ } else if (isa<mlir::mhlo::AbsOp, mlir::lmhlo::AbsOp>(op)) {
+ return xla::HloOpcode::kAbs;
+ } else if (isa<mlir::mhlo::CbrtOp>(op)) {
+ return xla::HloOpcode::kCbrt;
+ } else if (isa<mlir::mhlo::CeilOp, mlir::lmhlo::CeilOp>(op)) {
+ return xla::HloOpcode::kCeil;
+ } else if (isa<mlir::mhlo::ClzOp, mlir::lmhlo::ClzOp>(op)) {
+ return xla::HloOpcode::kClz;
+ } else if (isa<mlir::mhlo::CosOp, mlir::lmhlo::CosOp>(op)) {
+ return xla::HloOpcode::kCos;
+ } else if (isa<mlir::mhlo::ExpOp, mlir::lmhlo::ExpOp>(op)) {
+ return xla::HloOpcode::kExp;
+ } else if (isa<mlir::mhlo::Expm1Op, mlir::lmhlo::Expm1Op>(op)) {
+ return xla::HloOpcode::kExpm1;
+ } else if (isa<mlir::mhlo::FloorOp, mlir::lmhlo::FloorOp>(op)) {
+ return xla::HloOpcode::kFloor;
+ } else if (isa<mlir::mhlo::ImagOp, mlir::lmhlo::ImagOp>(op)) {
+ return xla::HloOpcode::kImag;
+ } else if (isa<mlir::mhlo::IsFiniteOp, mlir::lmhlo::IsFiniteOp>(op)) {
+ return xla::HloOpcode::kIsFinite;
+ } else if (isa<mlir::mhlo::LogOp, mlir::lmhlo::LogOp>(op)) {
+ return xla::HloOpcode::kLog;
+ } else if (isa<mlir::mhlo::Log1pOp, mlir::lmhlo::Log1pOp>(op)) {
+ return xla::HloOpcode::kLog1p;
+ } else if (isa<mlir::mhlo::LogisticOp>(op)) {
+ return xla::HloOpcode::kLogistic;
+ } else if (isa<mlir::mhlo::NotOp, mlir::lmhlo::NotOp>(op)) {
+ return xla::HloOpcode::kNot;
+ } else if (isa<mlir::mhlo::NegOp, mlir::lmhlo::NegOp>(op)) {
+ return xla::HloOpcode::kNegate;
+ } else if (isa<mlir::mhlo::PopulationCountOp, mlir::lmhlo::PopulationCountOp>(
+ op)) {
+ return xla::HloOpcode::kPopulationCount;
+ } else if (isa<mlir::mhlo::RealOp, mlir::lmhlo::RealOp>(op)) {
+ return xla::HloOpcode::kReal;
+ } else if (isa<mlir::mhlo::RoundOp, mlir::lmhlo::RoundOp>(op)) {
+ return xla::HloOpcode::kRoundNearestAfz;
+ } else if (isa<mlir::mhlo::RsqrtOp, mlir::lmhlo::RsqrtOp>(op)) {
+ return xla::HloOpcode::kRsqrt;
+ } else if (isa<mlir::mhlo::SignOp, mlir::lmhlo::SignOp>(op)) {
+ return xla::HloOpcode::kSign;
+ } else if (isa<mlir::mhlo::SinOp, mlir::lmhlo::SinOp>(op)) {
+ return xla::HloOpcode::kSin;
+ } else if (isa<mlir::mhlo::SqrtOp, mlir::lmhlo::SqrtOp>(op)) {
+ return xla::HloOpcode::kSqrt;
+ } else if (isa<mlir::mhlo::TanhOp, mlir::lmhlo::TanhOp>(op)) {
+ return xla::HloOpcode::kTanh;
+ } else if (isa<mlir::mhlo::ComplexOp, mlir::lmhlo::ComplexOp>(op)) {
+ return xla::HloOpcode::kComplex;
+ } else if (isa<mlir::mhlo::AndOp, mlir::lmhlo::AndOp>(op)) {
+ return xla::HloOpcode::kAnd;
+ } else if (isa<mlir::mhlo::OrOp, mlir::lmhlo::OrOp>(op)) {
+ return xla::HloOpcode::kOr;
+ } else if (isa<mlir::mhlo::WhileOp, mlir::lmhlo::WhileOp>(op)) {
+ return xla::HloOpcode::kWhile;
+ } else if (isa<mlir::mhlo::ReduceOp, mlir::lmhlo::ReduceOp>(op)) {
+ return xla::HloOpcode::kReduce;
+ } else if (isa<mlir::mhlo::GetTupleElementOp>(op)) {
+ return xla::HloOpcode::kGetTupleElement;
+ } else if (isa<mlir::mhlo::CompareOp, mlir::lmhlo::CompareOp>(op)) {
+ return xla::HloOpcode::kCompare;
+ } else if (isa<mlir::mhlo::SliceOp, mlir::lmhlo::SliceOp>(op)) {
+ return xla::HloOpcode::kSlice;
+ } else if (isa<mlir::mhlo::DynamicSliceOp>(op)) {
+ return xla::HloOpcode::kDynamicSlice;
+ } else if (isa<mlir::mhlo::DynamicUpdateSliceOp,
+ mlir::lmhlo::DynamicUpdateSliceOp>(op)) {
+ return xla::HloOpcode::kDynamicUpdateSlice;
+ } else if (isa<mlir::mhlo::CollectivePermuteOp,
+ mlir::lmhlo::CollectivePermuteOp>(op)) {
+ return xla::HloOpcode::kCollectivePermute;
+ } else if (isa<mlir::mhlo::CopyOp, mlir::lmhlo::CopyOp>(op)) {
+ return xla::HloOpcode::kCopy;
+ } else if (isa<mlir::mhlo::CustomCallOp, mlir::lmhlo::CustomCallOp>(op)) {
+ return xla::HloOpcode::kCustomCall;
+ } else if (isa<mlir::mhlo::DotOp, mlir::lmhlo::DotOp>(op)) {
+ return xla::HloOpcode::kDot;
+ } else if (isa<mlir::mhlo::FftOp, mlir::lmhlo::FftOp>(op)) {
+ return xla::HloOpcode::kFft;
+ } else if (isa<mlir::mhlo::GatherOp, mlir::lmhlo::GatherOp>(op)) {
+ return xla::HloOpcode::kGather;
+ } else if (isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
+ return xla::HloOpcode::kGetDimensionSize;
+ } else if (isa<mlir::mhlo::MapOp, mlir::lmhlo::MapOp>(op)) {
+ return xla::HloOpcode::kMap;
+ } else if (isa<mlir::mhlo::ReshapeOp, mlir::lmhlo::ReshapeOp>(op)) {
+ return xla::HloOpcode::kReshape;
+ } else if (isa<mlir::mhlo::DynamicReshapeOp>(op)) {
+ return xla::HloOpcode::kDynamicReshape;
+ } else if (isa<mlir::mhlo::ScatterOp, mlir::lmhlo::ScatterOp>(op)) {
+ return xla::HloOpcode::kScatter;
+ } else if (isa<mlir::mhlo::SelectOp, mlir::lmhlo::SelectOp>(op)) {
+ return xla::HloOpcode::kSelect;
+ } else if (isa<mlir::mhlo::SelectAndScatterOp,
+ mlir::lmhlo::SelectAndScatterOp>(op)) {
+ return xla::HloOpcode::kSelectAndScatter;
+ } else if (isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
+ return xla::HloOpcode::kSetDimensionSize;
+ } else if (isa<mlir::mhlo::ReverseOp, mlir::lmhlo::ReverseOp>(op)) {
+ return xla::HloOpcode::kReverse;
+ } else if (isa<mlir::mhlo::PadOp, mlir::lmhlo::PadOp>(op)) {
+ return xla::HloOpcode::kPad;
+ } else if (isa<mlir::mhlo::TraceOp>(op)) {
+ return xla::HloOpcode::kTrace;
+ } else if (isa<mlir::mhlo::TransposeOp, mlir::lmhlo::TransposeOp>(op)) {
+ return xla::HloOpcode::kTranspose;
+ } else if (isa<mlir::mhlo::TriangularSolveOp, mlir::lmhlo::TriangularSolveOp>(
+ op)) {
+ return xla::HloOpcode::kTriangularSolve;
+ } else if (isa<mlir::mhlo::ReduceWindowOp, mlir::lmhlo::ReduceWindowOp>(op)) {
+ return xla::HloOpcode::kReduceWindow;
+ } else if (isa<mlir::mhlo::ReducePrecisionOp, mlir::lmhlo::ReducePrecisionOp>(
+ op)) {
+ return xla::HloOpcode::kReducePrecision;
+ } else if (isa<mlir::mhlo::DotGeneralOp>(op)) {
+ return xla::HloOpcode::kDot;
+ } else if (isa<mlir::mhlo::BroadcastInDimOp, mlir::lmhlo::BroadcastInDimOp>(
+ op)) {
+ return xla::HloOpcode::kBroadcast;
+ } else {
+ std::string s;
+ {
+ llvm::raw_string_ostream os(s);
+ op->print(os);
+ }
+ return tensorflow::errors::Unimplemented(
+ "Unimplemented MHLO -> HloOpcode: ", s);
+ }
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.h b/tensorflow/compiler/mlir/xla/hlo_utils.h
index 1b77d60..3ad39ae 100644
--- a/tensorflow/compiler/mlir/xla/hlo_utils.h
+++ b/tensorflow/compiler/mlir/xla/hlo_utils.h
@@ -82,6 +82,8 @@
return ConvertTensorShapeToType<TypeT>(shape, builder);
}
+::xla::StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_UTILS_H_
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
index 1a17d61..5c7a592 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
@@ -1733,204 +1733,4 @@
return op->getAttrOfType<mlir::DenseIntElementsAttr>("minor_to_major");
}
-StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op) {
- if (mlir::isa<mlir::mhlo::ConstOp>(op)) {
- return xla::HloOpcode::kConstant;
- } else if (mlir::isa<mlir::mhlo::IotaOp>(op)) {
- return xla::HloOpcode::kIota;
- } else if (mlir::isa<mlir::mhlo::ConvertOp>(op)) {
- return xla::HloOpcode::kConvert;
- } else if (mlir::isa<mlir::mhlo::AddOp>(op)) {
- return xla::HloOpcode::kAdd;
- } else if (mlir::isa<mlir::mhlo::Atan2Op>(op)) {
- return xla::HloOpcode::kAtan2;
- } else if (mlir::isa<mlir::mhlo::DivOp>(op)) {
- return xla::HloOpcode::kDivide;
- } else if (mlir::isa<mlir::mhlo::MaxOp>(op)) {
- return xla::HloOpcode::kMaximum;
- } else if (mlir::isa<mlir::mhlo::MinOp>(op)) {
- return xla::HloOpcode::kMinimum;
- } else if (mlir::isa<mlir::mhlo::MulOp>(op)) {
- return xla::HloOpcode::kMultiply;
- } else if (mlir::isa<mlir::mhlo::PowOp>(op)) {
- return xla::HloOpcode::kPower;
- } else if (mlir::isa<mlir::mhlo::RemOp>(op)) {
- return xla::HloOpcode::kRemainder;
- } else if (mlir::isa<mlir::mhlo::ShiftLeftOp>(op)) {
- return xla::HloOpcode::kShiftLeft;
- } else if (mlir::isa<mlir::mhlo::ShiftRightArithmeticOp>(op)) {
- return xla::HloOpcode::kShiftRightArithmetic;
- } else if (mlir::isa<mlir::mhlo::ShiftRightLogicalOp>(op)) {
- return xla::HloOpcode::kShiftRightLogical;
- } else if (mlir::isa<mlir::mhlo::SubOp>(op)) {
- return xla::HloOpcode::kSubtract;
- } else if (mlir::isa<mlir::mhlo::XorOp>(op)) {
- return xla::HloOpcode::kXor;
- } else if (mlir::isa<mlir::mhlo::InfeedOp>(op)) {
- return xla::HloOpcode::kInfeed;
- } else if (mlir::isa<mlir::mhlo::OutfeedOp>(op)) {
- return xla::HloOpcode::kOutfeed;
- } else if (mlir::isa<mlir::mhlo::SendOp>(op)) {
- return xla::HloOpcode::kSend;
- } else if (mlir::isa<mlir::mhlo::RecvOp>(op)) {
- return xla::HloOpcode::kRecv;
- } else if (mlir::isa<mlir::mhlo::ReplicaIdOp>(op)) {
- return xla::HloOpcode::kReplicaId;
- } else if (mlir::isa<mlir::mhlo::AfterAllOp>(op)) {
- return xla::HloOpcode::kAfterAll;
- } else if (mlir::isa<mlir::mhlo::AllReduceOp>(op)) {
- return xla::HloOpcode::kAllReduce;
- } else if (mlir::isa<mlir::mhlo::AllToAllOp>(op)) {
- return xla::HloOpcode::kAllToAll;
- } else if (mlir::isa<mlir::mhlo::TupleOp>(op)) {
- return xla::HloOpcode::kTuple;
- } else if (mlir::isa<mlir::mhlo::BatchNormGradOp>(op)) {
- return xla::HloOpcode::kBatchNormGrad;
- } else if (mlir::isa<mlir::mhlo::BatchNormInferenceOp>(op)) {
- return xla::HloOpcode::kBatchNormInference;
- } else if (mlir::isa<mlir::mhlo::BatchNormTrainingOp>(op)) {
- return xla::HloOpcode::kBatchNormTraining;
- } else if (mlir::isa<mlir::mhlo::BitcastConvertOp>(op)) {
- return xla::HloOpcode::kBitcastConvert;
- } else if (mlir::isa<mlir::mhlo::BroadcastOp>(op)) {
- return xla::HloOpcode::kBroadcast;
- } else if (mlir::isa<mlir::mhlo::CholeskyOp>(op)) {
- return xla::HloOpcode::kCholesky;
- } else if (mlir::isa<mlir::mhlo::ClampOp>(op)) {
- return xla::HloOpcode::kClamp;
- } else if (mlir::isa<mlir::mhlo::ConcatenateOp>(op)) {
- return xla::HloOpcode::kConcatenate;
- } else if (mlir::isa<mlir::mhlo::ConvOp>(op)) {
- return xla::HloOpcode::kConvolution;
- } else if (mlir::isa<mlir::mhlo::SortOp>(op)) {
- return xla::HloOpcode::kSort;
- } else if (mlir::isa<mlir::mhlo::RngBitGeneratorOp>(op)) {
- return xla::HloOpcode::kRngBitGenerator;
- } else if (mlir::isa<mlir::mhlo::FusionOp>(op)) {
- return xla::HloOpcode::kFusion;
- } else if (mlir::isa<mlir::mhlo::BitcastOp>(op)) {
- return xla::HloOpcode::kBitcast;
- } else if (mlir::isa<mlir::mhlo::AbsOp>(op)) {
- return xla::HloOpcode::kAbs;
- } else if (mlir::isa<mlir::mhlo::CbrtOp>(op)) {
- return xla::HloOpcode::kCbrt;
- } else if (mlir::isa<mlir::mhlo::CeilOp>(op)) {
- return xla::HloOpcode::kCeil;
- } else if (mlir::isa<mlir::mhlo::ClzOp>(op)) {
- return xla::HloOpcode::kClz;
- } else if (mlir::isa<mlir::mhlo::CosOp>(op)) {
- return xla::HloOpcode::kCos;
- } else if (mlir::isa<mlir::mhlo::ExpOp>(op)) {
- return xla::HloOpcode::kExp;
- } else if (mlir::isa<mlir::mhlo::Expm1Op>(op)) {
- return xla::HloOpcode::kExpm1;
- } else if (mlir::isa<mlir::mhlo::FloorOp>(op)) {
- return xla::HloOpcode::kFloor;
- } else if (mlir::isa<mlir::mhlo::ImagOp>(op)) {
- return xla::HloOpcode::kImag;
- } else if (mlir::isa<mlir::mhlo::IsFiniteOp>(op)) {
- return xla::HloOpcode::kIsFinite;
- } else if (mlir::isa<mlir::mhlo::LogOp>(op)) {
- return xla::HloOpcode::kLog;
- } else if (mlir::isa<mlir::mhlo::Log1pOp>(op)) {
- return xla::HloOpcode::kLog1p;
- } else if (mlir::isa<mlir::mhlo::LogisticOp>(op)) {
- return xla::HloOpcode::kLogistic;
- } else if (mlir::isa<mlir::mhlo::NotOp>(op)) {
- return xla::HloOpcode::kNot;
- } else if (mlir::isa<mlir::mhlo::NegOp>(op)) {
- return xla::HloOpcode::kNegate;
- } else if (mlir::isa<mlir::mhlo::PopulationCountOp>(op)) {
- return xla::HloOpcode::kPopulationCount;
- } else if (mlir::isa<mlir::mhlo::RealOp>(op)) {
- return xla::HloOpcode::kReal;
- } else if (mlir::isa<mlir::mhlo::RoundOp>(op)) {
- return xla::HloOpcode::kRoundNearestAfz;
- } else if (mlir::isa<mlir::mhlo::RsqrtOp>(op)) {
- return xla::HloOpcode::kRsqrt;
- } else if (mlir::isa<mlir::mhlo::SignOp>(op)) {
- return xla::HloOpcode::kSign;
- } else if (mlir::isa<mlir::mhlo::SinOp>(op)) {
- return xla::HloOpcode::kSin;
- } else if (mlir::isa<mlir::mhlo::SqrtOp>(op)) {
- return xla::HloOpcode::kSqrt;
- } else if (mlir::isa<mlir::mhlo::TanhOp>(op)) {
- return xla::HloOpcode::kTanh;
- } else if (mlir::isa<mlir::mhlo::ComplexOp>(op)) {
- return xla::HloOpcode::kComplex;
- } else if (mlir::isa<mlir::mhlo::AndOp>(op)) {
- return xla::HloOpcode::kAnd;
- } else if (mlir::isa<mlir::mhlo::OrOp>(op)) {
- return xla::HloOpcode::kOr;
- } else if (mlir::isa<mlir::mhlo::WhileOp>(op)) {
- return xla::HloOpcode::kWhile;
- } else if (mlir::isa<mlir::mhlo::ReduceOp>(op)) {
- return xla::HloOpcode::kReduce;
- } else if (mlir::isa<mlir::mhlo::GetTupleElementOp>(op)) {
- return xla::HloOpcode::kGetTupleElement;
- } else if (mlir::isa<mlir::mhlo::CompareOp>(op)) {
- return xla::HloOpcode::kCompare;
- } else if (mlir::isa<mlir::mhlo::SliceOp>(op)) {
- return xla::HloOpcode::kSlice;
- } else if (mlir::isa<mlir::mhlo::DynamicSliceOp>(op)) {
- return xla::HloOpcode::kDynamicSlice;
- } else if (mlir::isa<mlir::mhlo::DynamicUpdateSliceOp>(op)) {
- return xla::HloOpcode::kDynamicUpdateSlice;
- } else if (mlir::isa<mlir::mhlo::CollectivePermuteOp>(op)) {
- return xla::HloOpcode::kCollectivePermute;
- } else if (mlir::isa<mlir::mhlo::CopyOp>(op)) {
- return xla::HloOpcode::kCopy;
- } else if (mlir::isa<mlir::mhlo::CustomCallOp>(op)) {
- return xla::HloOpcode::kCustomCall;
- } else if (mlir::isa<mlir::mhlo::DotOp>(op)) {
- return xla::HloOpcode::kDot;
- } else if (mlir::isa<mlir::mhlo::FftOp>(op)) {
- return xla::HloOpcode::kFft;
- } else if (mlir::isa<mlir::mhlo::GatherOp>(op)) {
- return xla::HloOpcode::kGather;
- } else if (mlir::isa<mlir::mhlo::GetDimensionSizeOp>(op)) {
- return xla::HloOpcode::kGetDimensionSize;
- } else if (mlir::isa<mlir::mhlo::MapOp>(op)) {
- return xla::HloOpcode::kMap;
- } else if (mlir::isa<mlir::mhlo::ReshapeOp>(op)) {
- return xla::HloOpcode::kReshape;
- } else if (mlir::isa<mlir::mhlo::DynamicReshapeOp>(op)) {
- return xla::HloOpcode::kDynamicReshape;
- } else if (mlir::isa<mlir::mhlo::ScatterOp>(op)) {
- return xla::HloOpcode::kScatter;
- } else if (mlir::isa<mlir::mhlo::SelectOp>(op)) {
- return xla::HloOpcode::kSelect;
- } else if (mlir::isa<mlir::mhlo::SelectAndScatterOp>(op)) {
- return xla::HloOpcode::kSelectAndScatter;
- } else if (mlir::isa<mlir::mhlo::SetDimensionSizeOp>(op)) {
- return xla::HloOpcode::kSetDimensionSize;
- } else if (mlir::isa<mlir::mhlo::ReverseOp>(op)) {
- return xla::HloOpcode::kReverse;
- } else if (mlir::isa<mlir::mhlo::PadOp>(op)) {
- return xla::HloOpcode::kPad;
- } else if (mlir::isa<mlir::mhlo::TraceOp>(op)) {
- return xla::HloOpcode::kTrace;
- } else if (mlir::isa<mlir::mhlo::TransposeOp>(op)) {
- return xla::HloOpcode::kTranspose;
- } else if (mlir::isa<mlir::mhlo::TriangularSolveOp>(op)) {
- return xla::HloOpcode::kTriangularSolve;
- } else if (mlir::isa<mlir::mhlo::ReduceWindowOp>(op)) {
- return xla::HloOpcode::kReduceWindow;
- } else if (mlir::isa<mlir::mhlo::ReducePrecisionOp>(op)) {
- return xla::HloOpcode::kReducePrecision;
- } else if (mlir::isa<mlir::mhlo::DotGeneralOp>(op)) {
- return xla::HloOpcode::kDot;
- } else if (mlir::isa<mlir::mhlo::BroadcastInDimOp>(op)) {
- return xla::HloOpcode::kBroadcast;
- } else {
- std::string s;
- {
- llvm::raw_string_ostream os(s);
- op->print(os);
- }
- return tensorflow::errors::Unimplemented(
- "Unimplemented MHLO -> HloOpcode: %s", s);
- }
-}
-
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
index e601f41..a260a79 100644
--- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
+++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
@@ -66,8 +66,6 @@
mlir::DenseIntElementsAttr GetLayoutFromMlirHlo(mlir::Operation* op);
-::xla::StatusOr<::xla::HloOpcode> MhloToHloOpcode(mlir::Operation* op);
-
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1bafe90..e2d57dd 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -4303,7 +4303,7 @@
if (mlir::isa<mlir::TensorLoadOp>(op)) {
opcode = HloOpcode::kParameter;
} else {
- opcode = *mlir::MhloToHloOpcode(op);
+ opcode = *MhloToHloOpcode(op);
}
if (HloInstruction::IsOpElementwise(opcode)) {
for (mlir::Value v : op->getResults()) {