[XLA:GPU][NFC] Use constant pointers for HloInstruction and other types.
PiperOrigin-RevId: 351696339
Change-Id: I2806212d30f31052295638058131864f658cd6cd
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 2c6e06b..151c9ec 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -206,7 +206,7 @@
// instruction. If `num_operands` is valid, then only the first `num_operands`
// operands of the HLO instruction will be considered.
Status LhloDialectEmitter::CreateOperands(
- HloInstruction* instr, absl::optional<xla::int64> num_operands,
+ const HloInstruction* instr, absl::optional<xla::int64> num_operands,
llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
size_t& num_results) {
if (num_operands.value_or(0) > instr->operand_count())
@@ -222,7 +222,7 @@
}
template <typename OpType>
-OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
+OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
ValueRange operands) {
Location loc = getLocation(instr);
NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
@@ -232,7 +232,7 @@
template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
- HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+ const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
absl::optional<xla::int64> num_operands) {
llvm::SmallVector<Value, 4> operands;
TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
@@ -240,7 +240,8 @@
return CreateOpWithoutAttrs<OpType>(instr, operands);
}
-StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
+StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
+ const HloInstruction* instr) {
using xla::HloOpcode;
switch (instr->opcode()) {
case HloOpcode::kAbs:
@@ -363,11 +364,12 @@
}
}
-Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
+Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
return EmitOp(instr).status();
}
-StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
+StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
@@ -425,7 +427,7 @@
}
StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
Location loc = getLocation(instr);
auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
@@ -494,7 +496,7 @@
}
StatusOr<mhlo::ScatterDimensionNumbers>
-LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) {
+LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr) {
auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
const xla::ScatterDimensionNumbers& xla_scatter_dim =
@@ -509,7 +511,7 @@
}
StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto scatter,
CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
@@ -533,7 +535,7 @@
}
StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto select_and_scatter,
CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
@@ -566,7 +568,7 @@
}
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
if (xla::gpu::IsCustomCallToCusolver(*instr)) {
@@ -601,7 +603,7 @@
}
StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
- HloCustomCallInstruction* custom_call) {
+ const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(auto cholesky_op,
CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
@@ -611,7 +613,7 @@
}
StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
- HloCustomCallInstruction* custom_call) {
+ const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(
auto const config,
custom_call->backend_config<xla::gpu::GemmBackendConfig>());
@@ -675,7 +677,7 @@
}
StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
- HloCustomCallInstruction* custom_call) {
+ const HloCustomCallInstruction* custom_call) {
TF_ASSIGN_OR_RETURN(
auto const backend_config,
custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
@@ -796,7 +798,7 @@
}
StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
- HloCustomCallInstruction* custom_call) {
+ const HloCustomCallInstruction* custom_call) {
const xla::int64 num_operands = custom_call->operand_count();
auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
// The last 2 operands of a custom call for batch norm are the epsilon and
@@ -895,7 +897,7 @@
}
StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto reduce_op,
CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
auto* reduce = xla::Cast<xla::HloReduceInstruction>(instr);
@@ -907,7 +909,8 @@
return reduce_op;
}
-StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
+StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
auto* map = xla::Cast<xla::HloMapInstruction>(instr);
std::vector<int64_t> dimensions(map->dimensions().begin(),
@@ -919,7 +922,7 @@
}
StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto compare_op,
CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
@@ -960,7 +963,7 @@
}
StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
auto* reduce_precision = xla::Cast<xla::HloReducePrecisionInstruction>(instr);
@@ -972,7 +975,7 @@
}
StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
- HloInstruction* instr) {
+ const HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto all_reduce_op,
CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
@@ -993,8 +996,8 @@
}
StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
- HloInstruction* instr) {
- HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
+ const HloInstruction* instr) {
+ const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
// HLO Infeed instruction has a single operand of token type and a tuple
// with buffers and a token as its output. LMHLO Infeed operation does not
// need the token operand or result, so drop it.
@@ -1006,8 +1009,9 @@
}
StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
- HloInstruction* instr) {
- HloOutfeedInstruction* outfeed = xla::Cast<HloOutfeedInstruction>(instr);
+ const HloInstruction* instr) {
+ const HloOutfeedInstruction* outfeed =
+ xla::Cast<HloOutfeedInstruction>(instr);
// HLO outfeed instruction has 2 operands, the source and a token, and a
// single token output. LMHLO Outfeed does not need the token operand and
// result, do drop it.
@@ -1223,7 +1227,7 @@
module.getContext()
->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
lmhlo_gpu::LmhloGpuDialect>();
- HloComputation* computation = hlo_module.entry_computation();
+ const HloComputation* computation = hlo_module.entry_computation();
LhloDialectEmitter emitter(assignment, *computation, module);
TF_RETURN_IF_ERROR(emitter.Initialize());
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 8df68c7..72cd0c7 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -35,7 +35,7 @@
// This class will process an HloModule with the supplied BufferAssignment and
// populate the MLIR ModuleOp with the computation converted in the LHLO
// dialect.
-class LhloDialectEmitter : public xla::DfsHloVisitorWithDefault {
+class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
public:
// Initializes internal data structures. It must be called before calling any
// of the visitors.
@@ -49,45 +49,46 @@
builder_(module.getContext()),
i8_type_(builder_.getIntegerType(8)) {}
- xla::StatusOr<mlir::Operation*> EmitOp(xla::HloInstruction* instr);
+ xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
- xla::HloInstruction* instr);
+ const xla::HloInstruction* instr);
private:
- xla::StatusOr<lmhlo::SortOp> EmitSortOp(xla::HloInstruction* instr);
- xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(xla::HloInstruction* instr);
- xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
+ const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
- xla::HloInstruction* instr);
+ const xla::HloInstruction* instr);
- xla::StatusOr<Operation*> EmitCustomCallOp(xla::HloInstruction* instr);
+ xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr);
xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
- xla::HloCustomCallInstruction* custom_call);
+ const xla::HloCustomCallInstruction* custom_call);
xla::StatusOr<Operation*> EmitGemm(
- xla::HloCustomCallInstruction* custom_call);
+ const xla::HloCustomCallInstruction* custom_call);
xla::StatusOr<Operation*> EmitDnnConvolution(
- xla::HloCustomCallInstruction* custom_call);
+ const xla::HloCustomCallInstruction* custom_call);
xla::StatusOr<Operation*> EmitDnnBatchNorm(
- xla::HloCustomCallInstruction* custom_call);
+ const xla::HloCustomCallInstruction* custom_call);
- xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(xla::HloInstruction* instr);
- xla::StatusOr<GetGlobalMemrefOp> EmitConstant(xla::HloInstruction* instr) {
- return EmitConstant(static_cast<const xla::HloInstruction*>(instr));
- }
+ xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(const xla::HloInstruction* instr);
xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
const xla::HloInstruction* instr);
- xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(
+ const xla::HloInstruction* instr);
- ::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
- ::xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(::xla::HloInstruction* instr);
- ::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(
+ const xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::MapOp> EmitMapOp(const xla::HloInstruction* instr);
xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
- xla::HloInstruction* instr);
+ const xla::HloInstruction* instr);
- xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(xla::HloInstruction* instr);
+ xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
+ const xla::HloInstruction* instr);
// Create LHLO operation operands given an XLA HLO instruction. By default,
// all XLA HLO operands and results are converted to MLIR and appended to
@@ -95,14 +96,14 @@
// operands of the instruction are converted to MLIR. The function returns the
// actual number of operands and results generated for MLIR in `num_arguments`
// and `num_results`.
- xla::Status CreateOperands(xla::HloInstruction* instr,
+ xla::Status CreateOperands(const xla::HloInstruction* instr,
absl::optional<xla::int64> num_operands,
SmallVectorImpl<Value>& operands,
size_t& num_arguments, size_t& num_results);
template <typename OpType>
xla::StatusOr<OpType> CreateOpWithoutAttrs(
- xla::HloInstruction* instr,
+ const xla::HloInstruction* instr,
absl::optional<xla::int64> num_operands = absl::nullopt) {
size_t unused;
return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
@@ -110,11 +111,13 @@
template <typename OpType>
xla::StatusOr<OpType> CreateOpWithoutAttrs(
- xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+ const xla::HloInstruction* instr, size_t& num_arguments,
+ size_t& num_results,
absl::optional<xla::int64> num_operands = absl::nullopt);
template <typename OpType>
- OpType CreateOpWithoutAttrs(xla::HloInstruction* instr, ValueRange operands);
+ OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr,
+ ValueRange operands);
template <typename T>
DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
@@ -133,11 +136,11 @@
return GetI64DenseElementsAttr(elements);
}
- tensorflow::Status DefaultAction(xla::HloInstruction* instr) final;
+ tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
// Computation parameters don't need any specific handling when they are
// visited, they are already processed when we enter a new computation.
- tensorflow::Status HandleParameter(xla::HloInstruction* instr) final {
+ tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final {
return tensorflow::Status::OK();
}