[XLA:GPU][NFC] Use constant pointers for HloInstruction and other types.

PiperOrigin-RevId: 351696339
Change-Id: I2806212d30f31052295638058131864f658cd6cd
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 2c6e06b..151c9ec 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -206,7 +206,7 @@
 // instruction. If `num_operands` is valid, then only the first `num_operands`
 // operands of the HLO instruction will be considered.
 Status LhloDialectEmitter::CreateOperands(
-    HloInstruction* instr, absl::optional<xla::int64> num_operands,
+    const HloInstruction* instr, absl::optional<xla::int64> num_operands,
     llvm::SmallVectorImpl<Value>& operands, size_t& num_arguments,
     size_t& num_results) {
   if (num_operands.value_or(0) > instr->operand_count())
@@ -222,7 +222,7 @@
 }
 
 template <typename OpType>
-OpType LhloDialectEmitter::CreateOpWithoutAttrs(HloInstruction* instr,
+OpType LhloDialectEmitter::CreateOpWithoutAttrs(const HloInstruction* instr,
                                                 ValueRange operands) {
   Location loc = getLocation(instr);
   NamedAttribute attrs[] = {{Identifier::get("name", builder_.getContext()),
@@ -232,7 +232,7 @@
 
 template <typename OpType>
 StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
-    HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+    const HloInstruction* instr, size_t& num_arguments, size_t& num_results,
     absl::optional<xla::int64> num_operands) {
   llvm::SmallVector<Value, 4> operands;
   TF_RETURN_IF_ERROR(CreateOperands(instr, num_operands, operands,
@@ -240,7 +240,8 @@
   return CreateOpWithoutAttrs<OpType>(instr, operands);
 }
 
-StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(HloInstruction* instr) {
+StatusOr<mlir::Operation*> LhloDialectEmitter::EmitOp(
+    const HloInstruction* instr) {
   using xla::HloOpcode;
   switch (instr->opcode()) {
     case HloOpcode::kAbs:
@@ -363,11 +364,12 @@
   }
 }
 
-Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
+Status LhloDialectEmitter::DefaultAction(const HloInstruction* instr) {
   return EmitOp(instr).status();
 }
 
-StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
+StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
   auto* sort_instr = xla::Cast<xla::HloSortInstruction>(instr);
   sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
@@ -425,7 +427,7 @@
 }
 
 StatusOr<lmhlo::FusionOp> LhloDialectEmitter::EmitFusionOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   Location loc = getLocation(instr);
 
   auto* fusion_instr = xla::Cast<xla::HloFusionInstruction>(instr);
@@ -494,7 +496,7 @@
 }
 
 StatusOr<mhlo::ScatterDimensionNumbers>
-LhloDialectEmitter::GetScatterDimensionNumbers(HloInstruction* instr) {
+LhloDialectEmitter::GetScatterDimensionNumbers(const HloInstruction* instr) {
   auto* scatter_instr = xla::Cast<xla::HloScatterInstruction>(instr);
 
   const xla::ScatterDimensionNumbers& xla_scatter_dim =
@@ -509,7 +511,7 @@
 }
 
 StatusOr<lmhlo::ScatterOp> LhloDialectEmitter::EmitScatterOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto scatter,
                       CreateOpWithoutAttrs<lmhlo::ScatterOp>(instr));
 
@@ -533,7 +535,7 @@
 }
 
 StatusOr<lmhlo::SelectAndScatterOp> LhloDialectEmitter::EmitSelectAndScatterOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto select_and_scatter,
                       CreateOpWithoutAttrs<lmhlo::SelectAndScatterOp>(instr));
 
@@ -566,7 +568,7 @@
 }
 
 StatusOr<mlir::Operation*> LhloDialectEmitter::EmitCustomCallOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   auto* custom_call_instr = xla::Cast<xla::HloCustomCallInstruction>(instr);
 
   if (xla::gpu::IsCustomCallToCusolver(*instr)) {
@@ -601,7 +603,7 @@
 }
 
 StatusOr<lmhlo_gpu::CholeskyOp> LhloDialectEmitter::EmitCholesky(
-    HloCustomCallInstruction* custom_call) {
+    const HloCustomCallInstruction* custom_call) {
   TF_ASSIGN_OR_RETURN(auto cholesky_op,
                       CreateOpWithoutAttrs<lmhlo_gpu::CholeskyOp>(custom_call));
   TF_ASSIGN_OR_RETURN(xla::CholeskyOptions options,
@@ -611,7 +613,7 @@
 }
 
 StatusOr<Operation*> LhloDialectEmitter::EmitGemm(
-    HloCustomCallInstruction* custom_call) {
+    const HloCustomCallInstruction* custom_call) {
   TF_ASSIGN_OR_RETURN(
       auto const config,
       custom_call->backend_config<xla::gpu::GemmBackendConfig>());
@@ -675,7 +677,7 @@
 }
 
 StatusOr<Operation*> LhloDialectEmitter::EmitDnnConvolution(
-    HloCustomCallInstruction* custom_call) {
+    const HloCustomCallInstruction* custom_call) {
   TF_ASSIGN_OR_RETURN(
       auto const backend_config,
       custom_call->backend_config<xla::gpu::CudnnConvBackendConfig>());
@@ -796,7 +798,7 @@
 }
 
 StatusOr<Operation*> LhloDialectEmitter::EmitDnnBatchNorm(
-    HloCustomCallInstruction* custom_call) {
+    const HloCustomCallInstruction* custom_call) {
   const xla::int64 num_operands = custom_call->operand_count();
   auto set_batchnorm_attributes = [&](auto op) -> StatusOr<Operation*> {
     // The last 2 operands of a custom call for batch norm are the epsilon and
@@ -895,7 +897,7 @@
 }
 
 StatusOr<lmhlo::ReduceOp> LhloDialectEmitter::EmitReduceOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto reduce_op,
                       CreateOpWithoutAttrs<lmhlo::ReduceOp>(instr));
   auto* reduce = xla::Cast<xla::HloReduceInstruction>(instr);
@@ -907,7 +909,8 @@
   return reduce_op;
 }
 
-StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(HloInstruction* instr) {
+StatusOr<lmhlo::MapOp> LhloDialectEmitter::EmitMapOp(
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto map_op, CreateOpWithoutAttrs<lmhlo::MapOp>(instr));
   auto* map = xla::Cast<xla::HloMapInstruction>(instr);
   std::vector<int64_t> dimensions(map->dimensions().begin(),
@@ -919,7 +922,7 @@
 }
 
 StatusOr<lmhlo::CompareOp> LhloDialectEmitter::EmitCompareOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto compare_op,
                       CreateOpWithoutAttrs<lmhlo::CompareOp>(instr));
 
@@ -960,7 +963,7 @@
 }
 
 StatusOr<lmhlo::ReducePrecisionOp> LhloDialectEmitter::EmitReducePrecisionOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto reduce_precision_op,
                       CreateOpWithoutAttrs<lmhlo::ReducePrecisionOp>(instr));
   auto* reduce_precision = xla::Cast<xla::HloReducePrecisionInstruction>(instr);
@@ -972,7 +975,7 @@
 }
 
 StatusOr<lmhlo::AllReduceOp> LhloDialectEmitter::EmitAllReduceOp(
-    HloInstruction* instr) {
+    const HloInstruction* instr) {
   TF_ASSIGN_OR_RETURN(auto all_reduce_op,
                       CreateOpWithoutAttrs<lmhlo::AllReduceOp>(instr));
   auto* all_reduce = xla::Cast<xla::HloAllReduceInstruction>(instr);
@@ -993,8 +996,8 @@
 }
 
 StatusOr<lmhlo::InfeedOp> LhloDialectEmitter::EmitInfeedOp(
-    HloInstruction* instr) {
-  HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
+    const HloInstruction* instr) {
+  const HloInfeedInstruction* infeed = xla::Cast<HloInfeedInstruction>(instr);
   // HLO Infeed instruction has a single operand of token type and a tuple
   // with buffers and a token as its output. LMHLO Infeed operation does not
   // need the token operand or result, so drop it.
@@ -1006,8 +1009,9 @@
 }
 
 StatusOr<lmhlo::OutfeedOp> LhloDialectEmitter::EmitOutfeedOp(
-    HloInstruction* instr) {
-  HloOutfeedInstruction* outfeed = xla::Cast<HloOutfeedInstruction>(instr);
+    const HloInstruction* instr) {
+  const HloOutfeedInstruction* outfeed =
+      xla::Cast<HloOutfeedInstruction>(instr);
   // HLO outfeed instruction has 2 operands, the source and a token, and a
   // single token output. LMHLO Outfeed does not need the token operand and
   // result, do drop it.
@@ -1223,7 +1227,7 @@
   module.getContext()
       ->loadDialect<StandardOpsDialect, mhlo::MhloDialect, lmhlo::LmhloDialect,
                     lmhlo_gpu::LmhloGpuDialect>();
-  HloComputation* computation = hlo_module.entry_computation();
+  const HloComputation* computation = hlo_module.entry_computation();
 
   LhloDialectEmitter emitter(assignment, *computation, module);
   TF_RETURN_IF_ERROR(emitter.Initialize());
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
index 8df68c7..72cd0c7 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h
@@ -35,7 +35,7 @@
 // This class will process an HloModule with the supplied BufferAssignment and
 // populate the MLIR ModuleOp with the computation converted in the LHLO
 // dialect.
-class LhloDialectEmitter : public xla::DfsHloVisitorWithDefault {
+class LhloDialectEmitter : public xla::ConstDfsHloVisitorWithDefault {
  public:
   // Initializes internal data structures. It must be called before calling any
   // of the visitors.
@@ -49,45 +49,46 @@
         builder_(module.getContext()),
         i8_type_(builder_.getIntegerType(8)) {}
 
-  xla::StatusOr<mlir::Operation*> EmitOp(xla::HloInstruction* instr);
+  xla::StatusOr<mlir::Operation*> EmitOp(const xla::HloInstruction* instr);
 
   xla::StatusOr<mhlo::ScatterDimensionNumbers> GetScatterDimensionNumbers(
-      xla::HloInstruction* instr);
+      const xla::HloInstruction* instr);
 
  private:
-  xla::StatusOr<lmhlo::SortOp> EmitSortOp(xla::HloInstruction* instr);
-  xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(xla::HloInstruction* instr);
-  xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::SortOp> EmitSortOp(const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::FusionOp> EmitFusionOp(const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::ScatterOp> EmitScatterOp(
+      const xla::HloInstruction* instr);
   xla::StatusOr<lmhlo::SelectAndScatterOp> EmitSelectAndScatterOp(
-      xla::HloInstruction* instr);
+      const xla::HloInstruction* instr);
 
-  xla::StatusOr<Operation*> EmitCustomCallOp(xla::HloInstruction* instr);
+  xla::StatusOr<Operation*> EmitCustomCallOp(const xla::HloInstruction* instr);
   xla::StatusOr<lmhlo_gpu::CholeskyOp> EmitCholesky(
-      xla::HloCustomCallInstruction* custom_call);
+      const xla::HloCustomCallInstruction* custom_call);
   xla::StatusOr<Operation*> EmitGemm(
-      xla::HloCustomCallInstruction* custom_call);
+      const xla::HloCustomCallInstruction* custom_call);
   xla::StatusOr<Operation*> EmitDnnConvolution(
-      xla::HloCustomCallInstruction* custom_call);
+      const xla::HloCustomCallInstruction* custom_call);
   xla::StatusOr<Operation*> EmitDnnBatchNorm(
-      xla::HloCustomCallInstruction* custom_call);
+      const xla::HloCustomCallInstruction* custom_call);
 
-  xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(xla::HloInstruction* instr);
-  xla::StatusOr<GetGlobalMemrefOp> EmitConstant(xla::HloInstruction* instr) {
-    return EmitConstant(static_cast<const xla::HloInstruction*>(instr));
-  }
+  xla::StatusOr<lmhlo::ReduceOp> EmitReduceOp(const xla::HloInstruction* instr);
   xla::StatusOr<GetGlobalMemrefOp> EmitConstant(
       const xla::HloInstruction* instr);
 
-  xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::CompareOp> EmitCompareOp(
+      const xla::HloInstruction* instr);
 
-  ::xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(::xla::HloInstruction* instr);
-  ::xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(::xla::HloInstruction* instr);
-  ::xla::StatusOr<lmhlo::MapOp> EmitMapOp(::xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::InfeedOp> EmitInfeedOp(const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::OutfeedOp> EmitOutfeedOp(
+      const xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::MapOp> EmitMapOp(const xla::HloInstruction* instr);
 
   xla::StatusOr<lmhlo::ReducePrecisionOp> EmitReducePrecisionOp(
-      xla::HloInstruction* instr);
+      const xla::HloInstruction* instr);
 
-  xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(xla::HloInstruction* instr);
+  xla::StatusOr<lmhlo::AllReduceOp> EmitAllReduceOp(
+      const xla::HloInstruction* instr);
 
   // Create LHLO operation operands given an XLA HLO instruction. By default,
   // all XLA HLO operands and results are converted to MLIR and appended to
@@ -95,14 +96,14 @@
   // operands of the instruction are converted to MLIR. The function returns the
   // actual number of operands and results generated for MLIR in `num_arguments`
   // and `num_results`.
-  xla::Status CreateOperands(xla::HloInstruction* instr,
+  xla::Status CreateOperands(const xla::HloInstruction* instr,
                              absl::optional<xla::int64> num_operands,
                              SmallVectorImpl<Value>& operands,
                              size_t& num_arguments, size_t& num_results);
 
   template <typename OpType>
   xla::StatusOr<OpType> CreateOpWithoutAttrs(
-      xla::HloInstruction* instr,
+      const xla::HloInstruction* instr,
       absl::optional<xla::int64> num_operands = absl::nullopt) {
     size_t unused;
     return CreateOpWithoutAttrs<OpType>(instr, unused, unused, num_operands);
@@ -110,11 +111,13 @@
 
   template <typename OpType>
   xla::StatusOr<OpType> CreateOpWithoutAttrs(
-      xla::HloInstruction* instr, size_t& num_arguments, size_t& num_results,
+      const xla::HloInstruction* instr, size_t& num_arguments,
+      size_t& num_results,
       absl::optional<xla::int64> num_operands = absl::nullopt);
 
   template <typename OpType>
-  OpType CreateOpWithoutAttrs(xla::HloInstruction* instr, ValueRange operands);
+  OpType CreateOpWithoutAttrs(const xla::HloInstruction* instr,
+                              ValueRange operands);
 
   template <typename T>
   DenseIntElementsAttr GetI64DenseElementsAttr(const T& container) {
@@ -133,11 +136,11 @@
     return GetI64DenseElementsAttr(elements);
   }
 
-  tensorflow::Status DefaultAction(xla::HloInstruction* instr) final;
+  tensorflow::Status DefaultAction(const xla::HloInstruction* instr) final;
 
   // Computation parameters don't need any specific handling when they are
   // visited, they are already processed when we enter a new computation.
-  tensorflow::Status HandleParameter(xla::HloInstruction* instr) final {
+  tensorflow::Status HandleParameter(const xla::HloInstruction* instr) final {
     return tensorflow::Status::OK();
   }