Remove dependency on Dialect global registration from //tensorflow/compiler/mlir/lite/...
PiperOrigin-RevId: 328109152
Change-Id: Ia460e89f785e9a2aaf21538083733e7e13730299
diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD
index 23c2e67..2d3a58b 100644
--- a/tensorflow/compiler/mlir/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/BUILD
@@ -760,7 +760,7 @@
deps = [
":flatbuffer_translate_registeration",
# TODO(b/155809683): Link only necessary dialects.
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
)
@@ -812,7 +812,7 @@
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@@ -836,19 +836,18 @@
deps = [
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
- "@com_google_absl//absl/strings",
- "@llvm-project//llvm:Support",
- # TODO(b/155809683): Link only necessary dialects.
- "@llvm-project//mlir:AllPassesAndDialects",
- "@llvm-project//mlir:IR",
- "@llvm-project//mlir:Parser",
- "@llvm-project//mlir:Support",
- "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
+ ":tensorflow_lite",
+ "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"//tensorflow/lite:framework",
"//tensorflow/lite/delegates/flex:delegate",
"//tensorflow/lite/kernels:builtin_ops",
+ "@com_google_absl//absl/strings",
+ "@llvm-project//llvm:Support",
+ "@llvm-project//mlir:IR",
+ "@llvm-project//mlir:Parser",
+ "@llvm-project//mlir:StandardOps",
],
)
@@ -875,7 +874,7 @@
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Transforms",
@@ -909,7 +908,7 @@
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
index f6da6eb..35a58a0 100644
--- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
+++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc
@@ -30,12 +30,16 @@
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
+#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Parser.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
+#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/lite/delegates/flex/delegate.h"
@@ -98,7 +102,10 @@
// Load the MLIR module.
mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
+ context.getDialectRegistry()
+ .insert<mlir::TF::TensorFlowDialect, mlir::TFL::TensorFlowLiteDialect,
+ mlir::StandardOpsDialect>();
+
llvm::SourceMgr source_mgr;
source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc());
mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context));
diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
index 935ad3c..e786bed 100644
--- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc
@@ -49,7 +49,6 @@
const GraphDef& input,
string* result) {
mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
GraphImportConfig specs;
mlir::TFL::QuantizationSpecs quant_specs;
diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
index 5229ee3..529c9ee 100644
--- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
+++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc
@@ -122,7 +122,6 @@
const toco::ModelFlags& model_flags, const toco::TocoFlags& toco_flags,
string* result) {
mlir::MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
mlir::TFL::QuantizationSpecs quant_specs;
// Parse input arrays.
diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
index 6299a70..7e7d467 100644
--- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc
@@ -62,6 +62,10 @@
void runOnFunction() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<quant::QuantizationDialect>();
+ }
+
// Parses the serialized quant stats protobuf and initialize the internal
// data structure. This method must be called after the pass is created.
bool ParseQuantStats(const std::string &stats_str);
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
index 31c0e4c..38c7ad8 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD
@@ -28,6 +28,7 @@
deps = [
"//tensorflow/compiler/mlir/lite:common",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib",
+ "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize",
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"//tensorflow/compiler/mlir/tensorflow:error_util",
@@ -74,6 +75,6 @@
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
- "@llvm-project//mlir:AllPassesAndDialects",
+ "@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
],
)
diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
index 599d809..238710b 100644
--- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
+++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc
@@ -25,6 +25,7 @@
#include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h"
+#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/lite/utils/convert_type.h"
@@ -52,7 +53,7 @@
}
MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
+ context.getDialectRegistry().insert<mlir::TFL::TensorFlowLiteDialect>();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);
diff --git a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
index e9e0341..8d9228e 100644
--- a/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
+++ b/tensorflow/compiler/mlir/lite/sparsity/sparsify_model.cc
@@ -37,7 +37,6 @@
flatbuffers::FlatBufferBuilder* builder,
tflite::ErrorReporter* error_reporter) {
MLIRContext context;
- context.loadAllGloballyRegisteredDialects();
StatusScopedDiagnosticHandler statusHandler(&context,
/*propagate=*/true);
diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
index d69666d..c521ca0 100644
--- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
+++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc
@@ -40,6 +40,7 @@
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
#include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project
+#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
@@ -84,6 +85,11 @@
: unfold_batch_matmul_(unfold_batch_matmul) {}
void runOnFunction() override;
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<mhlo::MhloDialect, quant::QuantizationDialect,
+ TFL::TensorFlowLiteDialect>();
+ }
+
private:
bool unfold_batch_matmul_;
};
diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
index 081ba7a..f26689f 100644
--- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
+++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc
@@ -93,8 +93,9 @@
LstmUtilsTest() {}
void SetUp() override {
- RegisterDialects();
context_ = std::make_unique<mlir::MLIRContext>();
+ context_->loadDialect<mlir::StandardOpsDialect, mlir::TF::TensorFlowDialect,
+ TensorFlowLiteDialect>();
builder_ = std::unique_ptr<mlir::Builder>(new Builder(context_.get()));
fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false);
fused_lstm_func_cifg_ =
@@ -109,12 +110,6 @@
builder_.reset();
}
- void RegisterDialects() {
- mlir::registerDialect<mlir::StandardOpsDialect>();
- mlir::registerDialect<mlir::TF::TensorFlowDialect>();
- mlir::registerDialect<TensorFlowLiteDialect>();
- }
-
FuncOp fused_lstm_func_;
FuncOp fused_lstm_func_cifg_;
FuncOp fused_ln_lstm_func_;