Re-apply 328481481 after fixing more internal issues
PiperOrigin-RevId: 328665093
Change-Id: I1111d69d96a87969013bf11890672ea369cac214
diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD
index 32a2ed1..ec98d9d 100644
--- a/tensorflow/compiler/mlir/xla/BUILD
+++ b/tensorflow/compiler/mlir/xla/BUILD
@@ -238,7 +238,6 @@
deps = [
":type_to_shape",
"//tensorflow/compiler/mlir/hlo",
- "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/tf2xla:common",
@@ -389,7 +388,6 @@
":xla_legalize_tf_with_tf2xla",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
- "//tensorflow/compiler/mlir/hlo:hlo_dialect_force_registration",
"//tensorflow/compiler/mlir/hlo:hlo_legalize_to_lhlo",
"//tensorflow/compiler/mlir/hlo:legalize_control_flow",
"//tensorflow/compiler/mlir/hlo:legalize_tanh_to_approximation",
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index d1d0827..ce761d8 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -254,6 +254,7 @@
":target_util",
":thunk",
":thunk_emitter",
+ "//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
index 7d5a8d0..2a493fe 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h
@@ -17,7 +17,10 @@
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
#include "llvm/IR/Module.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
@@ -44,7 +47,11 @@
cuda_compute_capability_(cuda_compute_capability),
profile_index_map_(profile_index_map),
mlir_context_(mlir_context),
- llvm_module_(llvm_module) {}
+ llvm_module_(llvm_module) {
+ mlir_context_
+ ->loadDialect<mlir::lmhlo::LmhloDialect, mlir::mhlo::MhloDialect,
+ mlir::StandardOpsDialect>();
+ }
// Disallow copy and assign.
IrEmitterContext(const IrEmitterContext&) = delete;
IrEmitterContext& operator=(const IrEmitterContext&) = delete;
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
index 786f28c..af670eb 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD
+++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD
@@ -41,9 +41,12 @@
srcs = ["emission_context.cc"],
hdrs = ["emission_context.h"],
deps = [
+ "//tensorflow/compiler/mlir/hlo",
+ "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/xla/service:hlo",
"@com_google_absl//absl/strings",
"@llvm-project//mlir:IR",
+ "@llvm-project//mlir:StandardOps",
],
)
diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
index cb5ea94..06c7ebd 100644
--- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
+++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc
@@ -16,8 +16,11 @@
#include "tensorflow/compiler/xla/service/mlir_gpu/emission_context.h"
#include "absl/strings/substitute.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
+#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
@@ -25,7 +28,8 @@
EmissionContext::EmissionContext(std::unique_ptr<HloModule> module)
: module_(std::move(module)), context_() {
- context_.loadAllGloballyRegisteredDialects();
+ context_.loadDialect<mlir::mhlo::MhloDialect, mlir::lmhlo::LmhloDialect,
+ mlir::StandardOpsDialect>();
error_handler_ = [](const ErrorMap& instructions_with_error,
HloModule* module) {
std::set<const HloComputation*> computations_with_error;