Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...

PiperOrigin-RevId: 328109152
Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 23c2e67..2d3a58b 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -760,7 +760,7 @@
     deps = [
         ":flatbuffer_translate_registeration",
         # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
 )
 
@@ -812,7 +812,7 @@
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
         # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Support",
@@ -836,19 +836,18 @@
     deps = [
         ":flatbuffer_translate_lib",
         ":flatbuffer_translate_registeration",
-        "@com_google_absl//absl/strings",
-        "@llvm-project//llvm:Support",
-        # TODO(b/155809683): Link only necessary dialects.
-        "@llvm-project//mlir:AllPassesAndDialects",
-        "@llvm-project//mlir:IR",
-        "@llvm-project//mlir:Parser",
-        "@llvm-project//mlir:Support",
-        "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+        ":tensorflow_lite",
+        "//tensorflow/compiler/mlir/tensorflow",
         "//tensorflow/core:lib",
         "//tensorflow/core/platform:logging",
         "//tensorflow/lite:framework",
         "//tensorflow/lite/delegates/flex:delegate",
         "//tensorflow/lite/kernels:builtin_ops",
+        "@com_google_absl//absl/strings",
+        "@llvm-project//llvm:Support",
+        "@llvm-project//mlir:IR",
+        "@llvm-project//mlir:Parser",
+        "@llvm-project//mlir:StandardOps",
     ],
 )
 
@@ -875,7 +874,7 @@
         "//tensorflow/compiler/mlir/tensorflow:translate_lib",
         "//tensorflow/core:core_cpu_base",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Pass",
         "@llvm-project//mlir:Transforms",
@@ -909,7 +908,7 @@
         "//tensorflow/stream_executor/lib",
         "@com_google_absl//absl/types:span",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
         "@llvm-project//mlir:IR",
         "@llvm-project//mlir:Parser",
         "@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
index f6da6eb..35a58a0 100644
--- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
+++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
@@ -30,12 +30,16 @@
 #include "llvm/Support/MemoryBuffer.h"
 #include "llvm/Support/SMLoc.h"
 #include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
+#include "mlir/IR/Dialect.h"  // from @llvm-project
 #include "mlir/IR/Function.h"  // from @llvm-project
 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
 #include "mlir/IR/Module.h"  // from @llvm-project
 #include "mlir/Parser.h"  // from @llvm-project
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
 #include "tensorflow/core/platform/init_main.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/lite/delegates/flex/delegate.h"
@@ -98,7 +102,10 @@
 
   // Load the MLIR module.
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
+  context.getDialectRegistry()
+      .insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
+              mlir::StandardOpsDialect>();
+
   llvm::SourceMgr source_mgr;
   source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
   mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 935ad3c..e786bed 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -49,7 +49,6 @@
                                          const GraphDef& input,
                                          string* result) {
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   GraphImportConfig specs;
   mlir::TFL::QuantizationSpecs quant_specs;
 
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index 5229ee3..529c9ee 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -122,7 +122,6 @@
     const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
     string* result) {
   mlir::MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   mlir::TFL::QuantizationSpecs quant_specs;
 
   // Parse input arrays.
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 6299a70..7e7d467 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -62,6 +62,10 @@
 
   void runOnFunction() override;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<quant::QuantizationDialect>();
+  }
+
   // Parses the serialized quant stats protobuf and initialize the internal
   // data structure. This method must be called after the pass is created.
   bool ParseQuantStats(const std::string &stats_str);
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
index 31c0e4c..38c7ad8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
@@ -28,6 +28,7 @@
     deps = [
         "//tensorflow/compiler/mlir/lite:common",
         "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
+        "//tensorflow/compiler/mlir/lite:tensorflow_lite",
         "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
         "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
         "//tensorflow/compiler/mlir/tensorflow:error_util",
@@ -74,6 +75,6 @@
         "//tensorflow/lite/schema:schema_fbs",
         "@com_google_absl//absl/strings",
         "@llvm-project//llvm:Support",
-        "@llvm-project//mlir:AllPassesAndDialects",
+        "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
     ],
 )
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
index 599d809..238710b 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
@@ -25,6 +25,7 @@
 #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
 #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
 #include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@@ -52,7 +53,7 @@
   }
 
   MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
+  context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
   StatusScopedDiagnosticHandler statusHandler(&context,
                                               /*propagate=*/true);
 
diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
index e9e0341..8d9228e 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
+++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
@@ -37,7 +37,6 @@
                            flatbuffers::FlatBufferBuilder* builder,
                            tflite::ErrorReporter* error_reporter) {
   MLIRContext context;
-  context.loadAllGloballyRegisteredDialects();
   StatusScopedDiagnosticHandler statusHandler(&context,
                                               /*propagate=*/true);
 
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index d69666d..c521ca0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -40,6 +40,7 @@
 #include "llvm/Support/Debug.h"
 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/FakeQuantSupport.h"  // from @llvm-project
+#include "mlir/Dialect/Quant/QuantOps.h"  // from @llvm-project
 #include "mlir/Dialect/Quant/UniformSupport.h"  // from @llvm-project
 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
 #include "mlir/IR/Attributes.h"  // from @llvm-project
@@ -84,6 +85,11 @@
       : unfold_batch_matmul_(unfold_batch_matmul) {}
   void runOnFunction() override;
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
+                    TFL::TensorFlowLiteDialect>();
+  }
+
  private:
   bool unfold_batch_matmul_;
 };
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
index 081ba7a..f26689f 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
@@ -93,8 +93,9 @@
   LstmUtilsTest() {}
 
   void SetUp() override {
-    RegisterDialects();
     context_ = std::make_unique<mlir::MLIRContext>();
+    context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
+                          TensorFlowLiteDialect>();
     builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
     fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
     fused_lstm_func_cifg_ =
@@ -109,12 +110,6 @@
     builder_.reset();
   }
 
-  void RegisterDialects() {
-    mlir::registerDialect<mlir::StandardOpsDialect>();
-    mlir::registerDialect<mlir::TF::TensorFlowDialect>();
-    mlir::registerDialect<TensorFlowLiteDialect>();
-  }
-
   FuncOp fused_lstm_func_;
   FuncOp fused_lstm_func_cifg_;
   FuncOp fused_ln_lstm_func_;