[XLA/GPU] Generate GpuExecutable from pure LMHLO graph.

This CL doesn't enable any logic added in this CL. There is a separate CL to enable it for easy rollback.

This CL accomplishes the following changes:
* For tuple parameters in ExecutionInput, add lmhlo.param_shape_index to LMHLO
  function arguments, to encode necessary information to reconstruct
  BufferAllocation.
* Remove uses of BufferAllocation::assigned_buffers, which depends on
  HloInstruction. Since HloInstruction is gone, accesses to
  HloInstruction::name() for kConstant has to be plumbed through
  lmhlo.constant_name.
* Added lmhlo.must_alias to support buffer donation detection in GpuExecutable.
* Added function attribute `result_xla_shape` to support GpuExecutable/ExecutionOutput.

All added attributes are specific to Execution{Input,Output} and GpuExecutable, therefore not changing the semantics of LMHLO.

PiperOrigin-RevId: 365903777
Change-Id: I733fe70002005f84a33550c3e165a0c71f6c652f
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
index 5708ab4..c43d3e8 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/hlo_text_to_lhlo_no_opt.hlotxt
@@ -536,8 +536,8 @@
 
 // CHECK: func @main
 // CHECK: "lmhlo.reduce_window"(%arg0, %{{.*}}, %{{.*}}) ( {
-// CHECK:   ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):
-// CHECK:   %2 = mhlo.maximum %arg6, %arg7 : tensor<f32>
+// CHECK:   ^bb0(%[[ARG6:.*]]: tensor<f32>, %[[ARG7:.*]]: tensor<f32>):
+// CHECK:   %2 = mhlo.maximum %[[ARG6]], %[[ARG7]] : tensor<f32>
 // CHECK:   "mhlo.return"(%2) : (tensor<f32>) -> ()
 // CHECK: }) {
 // CHECK-SAME: padding = dense<{{\[}}[0, 0], [2, 0], [0, 2], [0, 0]{{\]}}> : tensor<4x2xi64>,
diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
index 17d5dc4..b6b92fa 100644
--- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir
@@ -4,8 +4,8 @@
 // another one for the output, an no returned values.
 // CHECK-LABEL: func @main
 // CHECK-SAME:  %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.alloc = 1 : index, lmhlo.params = 0 : index},
-// CHECK-SAME:  %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true}
-// CHECK-SAME: ) {
+// CHECK-SAME:  %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.output_index = dense<> : tensor<0xi64>}
+// CHECK-SAME: ) {{.*}} {
 func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
   // The only expected instruction is a copy from the input into the output.
   // CHECK: %[[C0:.*]] = constant 0 : index
diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
index 54f848a..1c92372 100644
--- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc
@@ -1603,6 +1603,8 @@
 }
 
 Status LhloDialectEmitter::Initialize() {
+  TF_RET_CHECK(computation_.IsEntryComputation());
+
   mlir::IntegerAttr unique_id =
       builder_.getI32IntegerAttr(computation_.parent()->unique_id());
   module_->setAttr("hlo.unique_id", unique_id);
@@ -1613,6 +1615,13 @@
   // buffer allocation and update the type then.
   auto func_op = FuncOp::create(builder_.getUnknownLoc(), function_name,
                                 builder_.getFunctionType({}, {}));
+
+  {
+    const Shape& shape = computation_.root_instruction()->shape();
+    func_op->setAttr(
+        "result_xla_shape",
+        builder_.getStringAttr(shape.ToString(/*print_layout=*/true)));
+  }
   Block* block = func_op.addEntryBlock();
 
   llvm::SmallVector<const BufferAllocation*, 8> ordered_allocations;
@@ -1649,41 +1658,84 @@
                      allocation_comparator);
   }
 
+  absl::flat_hash_map<const BufferAllocation*, xla::ShapeIndex>
+      allocation_to_output_index;
+  TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus(
+      computation_.root_instruction()->shape(),
+      [&](const Shape& sub_shape, xla::ShapeIndex index) -> Status {
+        TF_ASSIGN_OR_RETURN(
+            auto slice,
+            assignment_.GetUniqueSlice(computation_.root_instruction(), index));
+        const BufferAllocation* alloc = slice.allocation();
+        TF_RET_CHECK(slice.offset() == 0);
+        TF_RET_CHECK(slice.size() == alloc->size());
+        allocation_to_output_index[alloc] = index;
+        return Status::OK();
+      }));
+
   // The function signature will be composed of:
   // - one memref for each of the parameters.
   // - one memref for each other buffer allocation.
   llvm::SmallVector<DictionaryAttr, 8> args_attrs;
   for (const BufferAllocation* alloc : ordered_allocations) {
-    if (computation_.IsEntryComputation() &&
-        alloc->is_entry_computation_parameter()) {
-      const xla::Shape& buffer_shape = xla::ShapeUtil::GetSubshape(
+    if (alloc->is_thread_local()) {
+      continue;
+    }
+
+    NamedAttrList arg_attr_list;
+    mlir::Type arg_type;
+    if (alloc->is_entry_computation_parameter() && !alloc->maybe_live_out()) {
+      xla::Shape buffer_shape = xla::ShapeUtil::GetSubshape(
           computation_.parameter_instruction(alloc->parameter_number())
               ->shape(),
           alloc->param_shape_index());
 
-      // TODO(jurahul): Revisit this when we can model memrefs with dynamic
-      // shape but static bounds in MLIR.
-      const Shape static_shape = xla::ShapeUtil::MakeStaticShape(buffer_shape);
-      TF_ASSIGN_OR_RETURN(auto arg_type, xla::ConvertShapeToType<MemRefType>(
-                                             static_shape, builder_));
-
-      // First map parameters to memrefs on the operation.
-      block->addArgument(arg_type);
-      allocations_[alloc] = block->getArguments().back();
-      NamedAttrList arg_attr_list;
-      arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
+      if (buffer_shape.IsTuple()) {
+        arg_type = MemRefType::get({alloc->size()}, i8_type_);
+      } else {
+        // TODO(jurahul): Revisit this when we can model memrefs with dynamic
+        // shape but static bounds in MLIR.
+        const Shape static_shape =
+            xla::ShapeUtil::MakeStaticShape(buffer_shape);
+        TF_ASSIGN_OR_RETURN(arg_type, xla::ConvertShapeToType<MemRefType>(
+                                          static_shape, builder_));
+      }
+    } else {
+      arg_type = MemRefType::get({alloc->size()}, i8_type_);
+    }
+    block->addArgument(arg_type);
+    allocations_[alloc] = block->getArguments().back();
+    arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
+    if (alloc->is_entry_computation_parameter()) {
       arg_attr_list.set("lmhlo.params",
                         builder_.getIndexAttr(alloc->parameter_number()));
-      args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
-    } else {
-      block->addArgument(MemRefType::get({alloc->size()}, i8_type_));
-      allocations_[alloc] = block->getArguments().back();
-
-      NamedAttrList arg_attr_list;
-      arg_attr_list.set("lmhlo.alloc", builder_.getIndexAttr(alloc->index()));
-      arg_attr_list.set("lmhlo.liveout", builder_.getBoolAttr(true));
-      args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
+      if (!alloc->param_shape_index().empty()) {
+        arg_attr_list.set("lmhlo.param_shape_index",
+                          builder_.getI64TensorAttr(llvm::makeArrayRef(
+                              alloc->param_shape_index().begin(),
+                              alloc->param_shape_index().end())));
+      }
     }
+    if (alloc->is_constant()) {
+      arg_attr_list.set(
+          "lmhlo.constant_name",
+          builder_.getStringAttr(
+              xla::llvm_ir::ConstantBufferAllocationToGlobalName(*alloc)));
+    }
+    auto iter = allocation_to_output_index.find(alloc);
+    if (iter != allocation_to_output_index.end()) {
+      arg_attr_list.set("lmhlo.output_index",
+                        builder_.getI64TensorAttr(llvm::makeArrayRef(
+                            iter->second.begin(), iter->second.end())));
+      if (auto alias = computation_.parent()
+                           ->input_output_alias_config()
+                           .GetAliasedParameter(iter->second)) {
+        if (alias->must_alias()) {
+          arg_attr_list.set("lmhlo.must_alias", builder_.getUnitAttr());
+        }
+      }
+    }
+    args_attrs.push_back(arg_attr_list.getDictionary(builder_.getContext()));
   }
 
   FunctionType function_type =
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 19fb251..70593a6 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -342,9 +342,9 @@
   }
   if (is_entry_computation_parameter()) {
     const HloInstruction* param = GetEntryParameterInstruction(*this);
-    CHECK(param);
     StrAppend(&output, ", parameter ", parameter_number(), ", shape |",
-              param->shape().ToString(/*print_layout=*/false),
+              param ? param->shape().ToString(/*print_layout=*/false)
+                    : "<unknown shape>",
               "| at ShapeIndex ", param_shape_index().ToString());
   }
   if (const HloInstruction* instr = GetOutputInstruction(*this)) {
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index de9e665..ca78975 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -283,6 +283,8 @@
     param_shape_index_ = std::move(param_shape_index);
   }
 
+  void set_constant(bool is_constant) { is_constant_ = is_constant; }
+
  private:
   // Only BufferAssigner and BufferAssignment can modify BufferAllocation.
   friend class BufferAssigner;
@@ -291,7 +293,6 @@
   // Adds a LogicalBuffer to the set assigned to this buffer.
   void AddAssignment(const HloValue& buffer, int64 offset, int64 size);
 
-  void set_constant(bool is_constant) { is_constant_ = is_constant; }
   void set_index(Index index) { index_ = index; }
   void set_size(int64 size) { size_ = size; }
 
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index e5157ee..1354158 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -1253,6 +1253,8 @@
         ":target_constants",
         ":tree_reduction_rewriter",
         ":variadic_op_splitter",
+        "//tensorflow/compiler/mlir:name_utils",
+        "//tensorflow/compiler/mlir/xla:hlo_utils",
         "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
         "//tensorflow/compiler/mlir/xla:type_to_shape",
         "//tensorflow/compiler/xla:protobuf_util",
@@ -1286,6 +1288,7 @@
         "//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
         "//tensorflow/compiler/xla/service:hlo_dce",
         "//tensorflow/compiler/xla/service:hlo_element_type_converter",
+        "//tensorflow/compiler/xla/service:hlo_parser",
         "//tensorflow/compiler/xla/service:hlo_pass",
         "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
         "//tensorflow/compiler/xla/service:hlo_proto_util",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 79feba8..7b53a9e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -35,6 +35,8 @@
 #include "llvm/Transforms/Utils/SplitModule.h"
 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
 #include "mlir/InitAllDialects.h"  // from @llvm-project
+#include "tensorflow/compiler/mlir/utils/name_utils.h"
+#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
 #include "tensorflow/compiler/xla/protobuf_util.h"
 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
@@ -96,6 +98,7 @@
 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
 #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
 #include "tensorflow/compiler/xla/service/hlo_proto_util.h"
@@ -570,6 +573,23 @@
   return std::make_tuple(std::move(hlo_module), std::move(assignment));
 }
 
+using OutputInfoMap =
+    absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
+static Status GetMlirAllocationInfo(mlir::FuncOp func,
+                                    std::vector<BufferAllocation>* allocations,
+                                    OutputInfoMap* output_info,
+                                    Shape* output_shape);
+
+struct CompileModuleResults {
+  std::unique_ptr<llvm::Module> llvm_module;
+  std::unique_ptr<BufferAssignment> buffer_assignment;
+  std::vector<BufferAllocation> allocations;
+  std::unique_ptr<ThunkSchedule> thunk_schedule;
+  std::vector<GpuExecutable::ConstantInfo> constants;
+  OutputInfoMap output_info;
+  Shape output_shape;
+  std::string module_name;
+};
 // The order of `thunk_sequence` corresponds to
 // `hlo_schedule->ThunkLaunchOrder()`.
 static Status CompileModuleToLlvmIrImpl(
@@ -579,14 +599,10 @@
     absl::optional<CudaComputeCapability> cuda_compute_capability,
     const HloDataflowAnalysis::CanShareBuffer& can_share_buffer_function,
     int pointer_size, const HloProfileIndexMap* profile_index_map,
-    std::unique_ptr<llvm::Module>* llvm_module,
-    std::unique_ptr<BufferAssignment>* buffer_assignment,
-    std::unique_ptr<ThunkSchedule>* thunk_schedule,
-    std::vector<GpuExecutable::ConstantInfo>* constants) {
-  *llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
-
-  (*llvm_module)->setTargetTriple(target_triple);
-  (*llvm_module)->setDataLayout(data_layout);
+    CompileModuleResults* results) {
+  results->llvm_module = absl::make_unique<llvm::Module>("", *llvm_context);
+  results->llvm_module->setTargetTriple(target_triple);
+  results->llvm_module->setDataLayout(data_layout);
 
   std::unique_ptr<StreamAssignment> stream_assignment =
       AssignStreams(*hlo_module);
@@ -600,7 +616,7 @@
   };
 
   TF_ASSIGN_OR_RETURN(
-      *buffer_assignment,
+      results->buffer_assignment,
       BufferAssigner::Run(
           hlo_module, hlo_schedule->ConsumeHloOrdering(),
           buffer_size_bytes_function,
@@ -611,8 +627,8 @@
           /*must_not_live_out=*/{}, can_share_buffer_function));
 
   VLOG(1) << "Buffer Assignment Stats "
-          << (*buffer_assignment)->GetStats().ToString();
-  DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
+          << results->buffer_assignment->GetStats().ToString();
+  DumpHloModuleIfEnabled(*hlo_module, *results->buffer_assignment,
                          "after_optimizations");
 
   mlir::MLIRContext mlir_context;
@@ -621,16 +637,29 @@
                            mlir::lmhlo_gpu::LmhloGpuDialect>();
   mlir::OwningModuleRef mlir_module =
       mlir::ModuleOp::create(mlir::Builder(&mlir_context).getUnknownLoc());
+
   TF_RETURN_IF_ERROR(
-      HloToLhloModule(**buffer_assignment, *hlo_module, *mlir_module));
+      HloToLhloModule(*results->buffer_assignment, *hlo_module, *mlir_module));
+
+  results->module_name = mlir::GetNameFromLoc(mlir_module->getLoc());
 
   llvm_ir::DumpIrIfEnabled(mlir_module.get(), hlo_module->unique_id(),
                            hlo_module->config().debug_options());
 
+  auto entry_function = mlir::cast<mlir::FuncOp>(
+      mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
+
+  TF_RETURN_IF_ERROR(
+      GetMlirAllocationInfo(entry_function, &results->allocations,
+                            &results->output_info, &results->output_shape));
+
+  CHECK(!results->allocations.empty());
+
   IrEmitterContext ir_emitter_context(
-      hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
-      cuda_compute_capability, profile_index_map, &mlir_context,
-      llvm_module->get());
+      /*hlo_module=*/nullptr,
+      /*buffer_assignment=*/results->buffer_assignment.get(), platform_name,
+      gpu_device_info, cuda_compute_capability, profile_index_map,
+      &mlir_context, results->llvm_module.get());
 
   TF_ASSIGN_OR_RETURN(
       auto ir_emitter,
@@ -639,17 +668,12 @@
   {
     XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
 
-    auto entry_function = mlir::cast<mlir::FuncOp>(
-        mlir_module->lookupSymbol(hlo_module->entry_computation()->name()));
-
     TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(&entry_function.body()));
 
-    *thunk_schedule =
+    results->thunk_schedule =
         absl::make_unique<ThunkSchedule>(ir_emitter->ConsumeThunkSequence());
 
-    if (constants) {
-      *constants = std::move(ir_emitter_context.constants());
-    }
+    results->constants = std::move(ir_emitter_context.constants());
   }
 
   return Status::OK();
@@ -894,49 +918,47 @@
     }
   }
 
-  std::unique_ptr<llvm::Module> llvm_module;
-  std::unique_ptr<BufferAssignment> buffer_assignment;
-  std::unique_ptr<ThunkSchedule> thunk_schedule;
-  std::vector<GpuExecutable::ConstantInfo> constants;
-
+  CompileModuleResults compile_module_results;
   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
       module.get(), &llvm_context, target_triple_, data_layout_,
       stream_exec->platform()->Name(), gpu_device_info, cuda_compute_capability,
-      GetCanShareBuffer(), pointer_size_, profile_index_map.get(), &llvm_module,
-      &buffer_assignment, &thunk_schedule, &constants));
+      GetCanShareBuffer(), pointer_size_, profile_index_map.get(),
+      &compile_module_results));
 
   if (user_pre_optimization_hook_) {
-    user_pre_optimization_hook_(*llvm_module);
+    user_pre_optimization_hook_(*compile_module_results.llvm_module);
   }
   string ir_module_string_before_opt;
   const bool embed_ir_in_executable =
       module->config().debug_options().xla_embed_ir_in_executable();
   if (embed_ir_in_executable) {
-    ir_module_string_before_opt = llvm_ir::DumpModuleToString(*llvm_module);
+    ir_module_string_before_opt =
+        llvm_ir::DumpModuleToString(*compile_module_results.llvm_module);
   }
 
-  llvm_ir::DumpIrIfEnabled(*module, *llvm_module, /*optimized=*/false);
+  llvm_ir::DumpIrIfEnabled(*module, *compile_module_results.llvm_module,
+                           /*optimized=*/false);
 
   using BackendCompileResult = std::pair<std::string, std::vector<uint8>>;
   TF_ASSIGN_OR_RETURN(
       BackendCompileResult backend_result,
-      CompileToTargetBinary(module->config(), std::move(llvm_module),
+      CompileToTargetBinary(module->config(),
+                            std::move(compile_module_results.llvm_module),
                             stream_exec, options, module.get()));
   if (DumpingEnabledForHloModule(*module)) {
     DumpToFileInDirOrStdout(*module, "", "thunk_schedule",
-                            thunk_schedule->ToString());
+                            compile_module_results.thunk_schedule->ToString());
   }
 
-  using OutputInfoMap =
-      absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>;
-  TF_ASSIGN_OR_RETURN(OutputInfoMap output_info,
-                      GetOutputInfo(*module, *buffer_assignment));
-  auto buffer_assignment_proto =
-      std::make_unique<BufferAssignmentProto>(buffer_assignment->ToProto());
+  TF_ASSIGN_OR_RETURN(
+      OutputInfoMap output_info,
+      GetOutputInfo(*module, *compile_module_results.buffer_assignment));
+  auto buffer_assignment_proto = std::make_unique<BufferAssignmentProto>(
+      compile_module_results.buffer_assignment->ToProto());
   std::vector<BufferAllocation> allocations =
-      buffer_assignment->ReleaseAllocations();
-  std::string module_name = module->name();
+      compile_module_results.buffer_assignment->ReleaseAllocations();
   Shape output_shape = module->entry_computation()->root_instruction()->shape();
+
   size_t profile_index = 0;
   if (profile_index_map) {
     profile_index =
@@ -946,11 +968,11 @@
   GpuVersion gpu_version = GetGpuVersion(stream_exec);
   auto* gpu_executable = new GpuExecutable(
       {std::move(backend_result.first), std::move(backend_result.second),
-       gpu_version, std::move(thunk_schedule), std::move(constants),
-       std::move(output_info), module_name, output_shape,
-       std::move(allocations), std::move(buffer_assignment_proto),
-       std::move(module), profile_index, std::move(profile_printer),
-       std::move(profile_index_map)});
+       gpu_version, std::move(compile_module_results.thunk_schedule),
+       std::move(compile_module_results.constants), std::move(output_info),
+       compile_module_results.module_name, output_shape, std::move(allocations),
+       std::move(buffer_assignment_proto), std::move(module), profile_index,
+       std::move(profile_printer), std::move(profile_index_map)});
   if (embed_ir_in_executable) {
     DCHECK_NE("", ir_module_string_before_opt);
     gpu_executable->set_ir_module_string(ir_module_string_before_opt);
@@ -990,16 +1012,12 @@
     const std::string& platform_name, GpuDeviceInfo gpu_device_info,
     absl::optional<CudaComputeCapability> cuda_compute_capability,
     int pointer_size) {
-  std::unique_ptr<llvm::Module> llvm_module;
-  std::unique_ptr<BufferAssignment> buffer_assignment;
-  std::unique_ptr<ThunkSchedule> thunk_schedule;
-
+  CompileModuleResults results;
   TF_RETURN_IF_ERROR(CompileModuleToLlvmIrImpl(
       hlo_module, llvm_context, target_triple, data_layout, platform_name,
       gpu_device_info, cuda_compute_capability, DummyCanShareBufferFunction,
-      pointer_size, /*profile_index_map=*/nullptr, &llvm_module,
-      &buffer_assignment, &thunk_schedule, nullptr));
-  return llvm_module;
+      pointer_size, /*profile_index_map=*/nullptr, &results));
+  return std::move(results.llvm_module);
 }
 
 // Analyze the function signature to reconstruct a vector of BufferAllocation
@@ -1007,10 +1025,10 @@
 //
 // This function also serves as a half-baked verifier for function arg
 // attributes, since a full verifier doens't exist yet.
-static Status GetMlirAllocationInfo(
-    mlir::FuncOp func, std::vector<BufferAllocation>* allocations,
-    absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo>* output_info,
-    Shape* output_shape) {
+static Status GetMlirAllocationInfo(mlir::FuncOp func,
+                                    std::vector<BufferAllocation>* allocations,
+                                    OutputInfoMap* output_info,
+                                    Shape* output_shape) {
   std::vector<absl::optional<BufferAllocation>> maybe_allocations;
 
   for (int i = 0; i < func.getNumArguments(); i++) {
@@ -1022,8 +1040,12 @@
       maybe_allocations.resize(index + 1);
     }
     mlir::BlockArgument arg = func.getArgument(i);
+
     TF_RET_CHECK(arg.getType().isa<mlir::ShapedType>());
-    size_t size = arg.getType().cast<mlir::ShapedType>().getSizeInBits() / 8;
+    mlir::ShapedType type = arg.getType().cast<mlir::ShapedType>();
+    TF_ASSIGN_OR_RETURN(auto element_type_bytes,
+                        GetElementTypeBytes(type.getElementType()));
+    size_t size = type.getNumElements() * element_type_bytes;
     maybe_allocations[index].emplace(index, size, 0);
   }
 
@@ -1032,7 +1054,7 @@
     if (maybe_alloc.has_value()) {
       allocations->push_back(*maybe_alloc);
     } else {
-      return InvalidArgument("Allocation indices should range in [0, n)");
+      allocations->push_back(BufferAllocation(allocations->size(), 0, {}));
     }
   }
 
@@ -1040,59 +1062,80 @@
     for (const mlir::NamedAttribute& attr : func.getArgAttrs(i)) {
       TF_RET_CHECK(attr.first == "lmhlo.alloc" ||
                    attr.first == "lmhlo.params" ||
+                   attr.first == "lmhlo.param_shape_index" ||
+                   attr.first == "lmhlo.constant_name" ||
+                   attr.first == "lmhlo.must_alias" ||
                    attr.first == "lmhlo.output_index");
     }
   }
 
-  std::vector<Shape> output_shapes;
-  absl::optional<int> rank;
+  std::vector<std::pair<ShapeIndex, Shape>> sub_shapes;
   for (int i = 0; i < func.getNumArguments(); i++) {
     auto index =
         func.getArgAttr(i, "lmhlo.alloc").cast<mlir::IntegerAttr>().getInt();
     if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
+      xla::ShapeIndex shape_index;
+      if (auto shape_index_attr =
+              func.getArgAttrOfType<mlir::DenseIntElementsAttr>(
+                  i, "lmhlo.param_shape_index")) {
+        for (const llvm::APInt& element : shape_index_attr) {
+          shape_index.push_back(element.getSExtValue());
+        }
+      }
       allocations->at(index).set_entry_computation_parameter(
-          param_attr.cast<mlir::IntegerAttr>().getInt(), {},
+          param_attr.cast<mlir::IntegerAttr>().getInt(), shape_index,
           static_cast<bool>(func.getArgAttr(i, "lmhlo.output_index")));
     }
+    // TODO(timshen): this information is redundant. This is here only for
+    // smooth migration to LMHLO. Remove it.
+    if (func.getArgAttr(i, "lmhlo.constant_name")) {
+      allocations->at(index).set_constant(true);
+    }
     if (auto output_index_attr = func.getArgAttr(i, "lmhlo.output_index")) {
       allocations->at(index).set_maybe_live_out(true);
 
       // Reconstruct a shape index from output_index.
       ShapeIndex shape_index;
-      for (const llvm::APInt& i :
+      for (const llvm::APInt& element :
            output_index_attr.cast<mlir::DenseIntElementsAttr>()) {
-        shape_index.push_back(i.getSExtValue());
-      }
-      if (rank.has_value()) {
-        if (*rank != shape_index.size()) {
-          return InvalidArgument("Expect output_index to have the same ranks");
-        }
-      } else {
-        rank.emplace(shape_index.size());
+        shape_index.push_back(element.getSExtValue());
       }
       auto& o = (*output_info)[shape_index];
       o.allocation_index = index;
       if (auto param_attr = func.getArgAttr(i, "lmhlo.params")) {
+        HloInputOutputAliasConfig::AliasKind kind =
+            HloInputOutputAliasConfig::kMayAlias;
+        if (func.getArgAttr(i, "lmhlo.must_alias")) {
+          kind = HloInputOutputAliasConfig::kMustAlias;
+        }
         o.alias_config.emplace(param_attr.cast<mlir::IntegerAttr>().getInt(),
-                               ShapeIndex{});
+                               ShapeIndex{}, kind);
       }
-
-      if (shape_index.size() > 1) {
-        return Unimplemented("Expect array type or 1-level tuple type");
+      if (func.getArgument(i).use_empty()) {
+        o.passthrough = true;
       }
 
       mlir::BlockArgument arg = func.getArgument(i);
-      if (shape_index.empty()) {
-        output_shapes.push_back(TypeToShape(arg.getType()));
-      } else {
-        if (shape_index[0] >= output_shapes.size()) {
-          output_shapes.resize(shape_index[0] + 1);
-        }
-        output_shapes[shape_index[0]] = TypeToShape(arg.getType());
-      }
+      sub_shapes.push_back(
+          std::make_pair(shape_index, TypeToShape(arg.getType())));
     }
   }
-  *output_shape = ShapeUtil::MakeTupleShape(output_shapes);
+  // Expects result_xla_shape as a XLA shape in string form.
+  //
+  // The attribute is necessary, because GpuExecutable/ExecutionOutput supports
+  // tuples / tree-like shapes, while the LMHLO argument list loses the tree
+  // form.
+  //
+  // The string format is necessary since MLIR doesn't support XLA shape with
+  // dynamic_dimension.
+  //
+  // TODO(timshen): now this field is mandatory. Make it optional for
+  // non-GpuExecutable outputs.
+  TF_ASSIGN_OR_RETURN(
+      *output_shape,
+      ParseShape(func->getAttrOfType<mlir::StringAttr>("result_xla_shape")
+                     .getValue()
+                     .str()));
 
   return Status::OK();
 }
@@ -1108,13 +1151,12 @@
       llvm::StringRef(entry_function_name.data(), entry_function_name.size())));
 
   std::vector<BufferAllocation> allocations;
-  absl::flat_hash_map<ShapeIndex, GpuExecutable::OutputInfo> output_info;
+  OutputInfoMap output_info;
   Shape output_shape;
-  absl::flat_hash_map<ShapeIndex, int> output_to_argnum_map;
   TF_RETURN_IF_ERROR(GetMlirAllocationInfo(entry_function, &allocations,
                                            &output_info, &output_shape));
 
-  CHECK(!allocations.empty());
+  TF_RET_CHECK(!allocations.empty());
 
   ir_emitter_context->set_allocations(allocations);
 
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 365ead7..9e818b8 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -509,7 +509,8 @@
         // the indices to drop the addresses from its own ScopedShapedBuffer
         // result, if the ExecutionOutput is not committed.
         result.AddAliasedIndex(index);
-      } else if (!output_info.passthrough) {
+      } else if (!output_info.passthrough &&
+                 !ShapeUtil::GetSubshape(output_shape_, index).IsTuple()) {
         // The guard is above is not to insert copy-protection when aliasing
         // pass-through params, as we do not need to write into the output
         // buffer.
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 6dff7dc..3e64520 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -727,9 +727,16 @@
   }
 }
 
-static int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
+static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
+                                  std::string* constant_name) {
   auto func_op =
       mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
+  if (constant_name) {
+    if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
+            func_arg.getArgNumber(), "lmhlo.constant_name")) {
+      *constant_name = constant_name_attr.getValue().str();
+    }
+  }
   return func_op
       .getArgAttrOfType<mlir::IntegerAttr>(func_arg.getArgNumber(),
                                            "lmhlo.alloc")
@@ -738,12 +745,17 @@
 }
 
 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
-    mlir::Value v, absl::Span<const BufferAllocation> allocations) {
+    mlir::Value v, absl::Span<const BufferAllocation> allocations,
+    std::string* constant_name) {
+  if (constant_name) {
+    constant_name->clear();
+  }
+
   int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
 
   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
-    return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
-                                   size);
+    return BufferAllocation::Slice(
+        &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
   }
 
   // We match the following patterns here:
@@ -761,7 +773,7 @@
     if (auto view = mlir::dyn_cast<mlir::memref::ViewOp>(op)) {
       return BufferAllocation::Slice(
           &allocations[GetAllocationIndex(
-              view.source().cast<mlir::BlockArgument>())],
+              view.source().cast<mlir::BlockArgument>(), constant_name)],
           mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
               .value()
               .cast<mlir::IntegerAttr>()
@@ -771,6 +783,9 @@
     } else if (auto get_global =
                    mlir::dyn_cast<mlir::memref::GetGlobalOp>(op)) {
       auto module = get_global->getParentOfType<mlir::ModuleOp>();
+      if (constant_name) {
+        *constant_name = get_global.name().str();
+      }
       auto global = mlir::cast<mlir::memref::GlobalOp>(
           module.lookupSymbol(get_global.name()));
       int64_t index =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
index a7b4432..a571bcc 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h
@@ -261,7 +261,8 @@
 }
 
 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
-    mlir::Value v, absl::Span<const BufferAllocation> allocations);
+    mlir::Value v, absl::Span<const BufferAllocation> allocations,
+    std::string* constant_name = nullptr);
 
 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
     mlir::lmhlo::FusionOp fusion,
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f96e0cd..ad8aef2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -605,9 +605,9 @@
 }
 
 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
-    mlir::Value v) {
+    mlir::Value v, std::string* constant_name) {
   return xla::gpu::GetAllocationSliceForMlir(
-      v, ir_emitter_context_->allocations());
+      v, ir_emitter_context_->allocations(), constant_name);
 }
 
 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
@@ -3511,12 +3511,12 @@
     const ShapeIndex& gte_index = slice->gte_index;
 
     llvm::Value* loc;
-    if (buffer_slice.allocation()->is_constant()) {
+    if (!slice->constant_name.empty()) {
       loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
-          llvm_ir::ConstantBufferAllocationToGlobalName(
-              *buffer_slice.allocation()));
+          slice->constant_name);
       CHECK_NE(loc, nullptr);
     } else {
+      CHECK(!buffer_slice.allocation()->is_constant());
       loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
                         {b_.getInt64(buffer_slice.offset())});
     }
@@ -3586,7 +3586,8 @@
   for (mlir::Value operand : operands) {
     slices.emplace_back();
     auto& slice = slices.back();
-    TF_ASSIGN_OR_RETURN(slice.buffer_slice, GetAllocationSliceForMlir(operand));
+    TF_ASSIGN_OR_RETURN(slice.buffer_slice, GetAllocationSliceForMlir(
+                                                operand, &slice.constant_name));
     slice.written = WritesMlirBuffer(op, operand);
     slice.shape = TypeToShape(operand.getType());
   }
@@ -3606,16 +3607,18 @@
     for (auto operand : operands) {
       slices.emplace_back();
       auto& slice = slices.back();
-      TF_ASSIGN_OR_RETURN(slice.buffer_slice,
-                          GetAllocationSliceForMlir(operand));
+      TF_ASSIGN_OR_RETURN(
+          slice.buffer_slice,
+          GetAllocationSliceForMlir(operand, &slice.constant_name));
       slice.written = false;
       slice.shape = TypeToShape(operand.getType());
     }
     for (auto output : outputs) {
       slices.emplace_back();
       auto& slice = slices.back();
-      TF_ASSIGN_OR_RETURN(slice.buffer_slice,
-                          GetAllocationSliceForMlir(output));
+      TF_ASSIGN_OR_RETURN(
+          slice.buffer_slice,
+          GetAllocationSliceForMlir(output, &slice.constant_name));
       slice.written = true;
       slice.shape = TypeToShape(output.getType());
     }
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 8d8805a..420b9c2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -33,6 +33,9 @@
   // The root buffer to look at.
   BufferAllocation::Slice buffer_slice;
 
+  // The global constant name of the buffer, if it's a constant.
+  std::string constant_name;
+
   // Describes how to dereference starting at that buffer to get to the buffer
   // in question.
   ShapeIndex gte_index;
@@ -339,7 +342,8 @@
     return MaybeGetAllocationSlice(hlo, index).ConsumeValueOrDie();
   }
 
-  StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(mlir::Value v);
+  StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
+      mlir::Value v, std::string* constant_name = nullptr);
 
   int64 ByteSizeOf(const Shape& shape) const {
     return llvm_ir::ByteSizeOf(
diff --git a/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
index 2fe8cfe..340c4ca 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
+++ b/tensorflow/compiler/xla/service/gpu/tests/fused_scatter.hlo
@@ -150,9 +150,9 @@
 // CHECK:         %[[VAL_136:.*]] = icmp ult i32 %[[VAL_134]], 3
 // CHECK:         %[[VAL_137:.*]] = and i1 true, %[[VAL_136]]
 // CHECK:         br i1 %[[VAL_137]], label %[[VAL_138:.*]], label %[[VAL_131]]
-// CHECK:       scatter.in_bounds-after3:                         ; preds = %[[VAL_138]], %[[VAL_129]]
+// CHECK:       scatter.in_bounds-after{{.*}}:                         ; preds = %[[VAL_138]], %[[VAL_129]]
 // CHECK:         br label %[[VAL_130]]
-// CHECK:       scatter.in_bounds-true2:                          ; preds = %[[VAL_129]]
+// CHECK:       scatter.in_bounds-true{{.*}}:                          ; preds = %[[VAL_129]]
 // CHECK:         %[[VAL_139:.*]] = getelementptr inbounds [3 x [3 x i32]], [3 x [3 x i32]]* %[[VAL_119]], i32 0, i32 %[[VAL_135]], i32 %[[VAL_126]]
 // CHECK:         %[[VAL_140:.*]] = bitcast [2 x [3 x i32]]* %[[VAL_116]] to i32*
 // CHECK:         %[[VAL_141:.*]] = getelementptr inbounds i32, i32* %[[VAL_140]], i32 %[[VAL_123]]
diff --git a/tensorflow/compiler/xla/service/gpu/tests/mlir_sorting_test.cc b/tensorflow/compiler/xla/service/gpu/tests/mlir_sorting_test.cc
index 1b19716..af59bec 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/mlir_sorting_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/mlir_sorting_test.cc
@@ -30,7 +30,9 @@
                  %arg2: memref<4xf32> {lmhlo.alloc = 2 : index, lmhlo.output_index = dense<[0]> : tensor<1xindex>},
                  %arg3: memref<4xf32> {lmhlo.alloc = 3 : index, lmhlo.output_index = dense<[1]> : tensor<1xindex>},
                  %arg4: memref<4xf32> {lmhlo.alloc = 4 : index, lmhlo.output_index = dense<[2]> : tensor<1xindex>},
-                 %arg5: memref<4xf32> {lmhlo.alloc = 5 : index, lmhlo.output_index = dense<[3]> : tensor<1xindex>}) -> () {
+                 %arg5: memref<4xf32> {lmhlo.alloc = 5 : index, lmhlo.output_index = dense<[3]> : tensor<1xindex>}) attributes {
+                     result_xla_shape = "(f32[4], f32[4], f32[4], f32[4]) "
+                 } {
           "lmhlo.sort"(%arg0, %arg1, %arg2, %arg3) ( {
           ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f32>, %d: tensor<f32>):
             %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>