Update for TFL to TOSA Passes not being statically registered
PiperOrigin-RevId: 392949629
Change-Id: If2d73a9fd99ffeafcaf5ae3fa635526f85e37bd7
diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD
index 8afa83c..d1f58b4 100644
--- a/tensorflow/compiler/mlir/BUILD
+++ b/tensorflow/compiler/mlir/BUILD
@@ -104,6 +104,8 @@
"//tensorflow/compiler/mlir/tensorflow:tensorflow_test_passes",
"//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes",
"//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops",
+ "//tensorflow/compiler/mlir/tosa:tf_passes",
+ "//tensorflow/compiler/mlir/tosa:tfl_passes",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
"//tensorflow/compiler/mlir/xla:xla_passes",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
index 9c98168..2442f2e 100644
--- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
+++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc
@@ -26,6 +26,9 @@
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h"
#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
+#include "tensorflow/compiler/mlir/tosa/tf_passes.h"
+#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
+#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/core/platform/init_main.h"
@@ -43,6 +46,9 @@
mlir::mhlo::registerLegalizeTFPass();
mlir::mhlo::registerLegalizeTFControlFlowPass();
mlir::mhlo::registerLegalizeTfTypesPassPass();
+ mlir::tosa::registerLegalizeTosaPasses();
+ mlir::tosa::registerTFtoTOSALegalizationPipeline();
+ mlir::tosa::registerTFLtoTOSALegalizationPipeline();
mlir::tf_test::registerTensorFlowTestPasses();
mlir::DialectRegistry registry;
diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD
index 84b005c..2b0d17d 100644
--- a/tensorflow/compiler/mlir/tosa/BUILD
+++ b/tensorflow/compiler/mlir/tosa/BUILD
@@ -88,7 +88,6 @@
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:TosaDialect",
],
- alwayslink = 1,
)
gentbl_cc_library(
@@ -137,7 +136,6 @@
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
- alwayslink = 1,
)
gentbl_cc_library(
@@ -187,5 +185,4 @@
"@llvm-project//mlir:TosaDialect",
"@llvm-project//mlir:Transforms",
],
- alwayslink = 1,
)
diff --git a/tensorflow/compiler/mlir/tosa/tf_passes.cc b/tensorflow/compiler/mlir/tosa/tf_passes.cc
index 9f58c87..113b761 100644
--- a/tensorflow/compiler/mlir/tosa/tf_passes.cc
+++ b/tensorflow/compiler/mlir/tosa/tf_passes.cc
@@ -56,10 +56,11 @@
pm.addPass(mlir::createSymbolDCEPass());
}
-static mlir::PassPipelineRegistration<TOSATFLegalizationPipelineOptions>
- tf_tosa_pipeline("tf-to-tosa-pipeline",
- "TensorFlow to TOSA legalization pipeline",
- createTFtoTOSALegalizationPipeline);
+void registerTFtoTOSALegalizationPipeline() {
+ mlir::PassPipelineRegistration<TOSATFLegalizationPipelineOptions>(
+ "tf-to-tosa-pipeline", "TensorFlow to TOSA legalization pipeline",
+ createTFtoTOSALegalizationPipeline);
+}
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tf_passes.h b/tensorflow/compiler/mlir/tosa/tf_passes.h
index 18d11cd..741bc23 100644
--- a/tensorflow/compiler/mlir/tosa/tf_passes.h
+++ b/tensorflow/compiler/mlir/tosa/tf_passes.h
@@ -29,6 +29,8 @@
void createTFtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSATFLegalizationPipelineOptions& opts);
+void registerTFtoTOSALegalizationPipeline();
+
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.cc b/tensorflow/compiler/mlir/tosa/tfl_passes.cc
index ff06d06..a807fbf 100644
--- a/tensorflow/compiler/mlir/tosa/tfl_passes.cc
+++ b/tensorflow/compiler/mlir/tosa/tfl_passes.cc
@@ -59,10 +59,11 @@
pm.addPass(mlir::createSymbolDCEPass());
}
-static mlir::PassPipelineRegistration<TOSATFLLegalizationPipelineOptions>
- tfl_tosa_pipeline("tfl-to-tosa-pipeline",
- "TensorFlow Lite to TOSA legalization pipeline",
- createTFLtoTOSALegalizationPipeline);
+void registerTFLtoTOSALegalizationPipeline() {
+ mlir::PassPipelineRegistration<TOSATFLLegalizationPipelineOptions>(
+ "tfl-to-tosa-pipeline", "TensorFlow Lite to TOSA legalization pipeline",
+ createTFLtoTOSALegalizationPipeline);
+}
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/tfl_passes.h b/tensorflow/compiler/mlir/tosa/tfl_passes.h
index 255418a..21e239f 100644
--- a/tensorflow/compiler/mlir/tosa/tfl_passes.h
+++ b/tensorflow/compiler/mlir/tosa/tfl_passes.h
@@ -29,6 +29,8 @@
void createTFLtoTOSALegalizationPipeline(
OpPassManager& pm, const TOSATFLLegalizationPipelineOptions& opts);
+void registerTFLtoTOSALegalizationPipeline();
+
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
index 5202a50..33557ff 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc
@@ -352,9 +352,6 @@
return std::make_unique<ConvertUint8ToInt8>();
}
-static PassRegistration<ConvertUint8ToInt8> pass(
- PASS_NAME, "Convert uint8 graph to int8.");
-
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
index 8dd1277..dda1e87 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc
@@ -123,9 +123,6 @@
return std::make_unique<FuseBiasTF>();
}
-static PassRegistration<FuseBiasTF> pass(
- PASS_NAME, "Fuse tf.Op + tf.BiasAdd and legalized to TOSA.");
-
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
index 8d2af6c..da146aa 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc
@@ -2369,9 +2369,6 @@
return std::make_unique<LegalizeTF>();
}
-static PassRegistration<LegalizeTF> pass(
- PASS_NAME, "Legalize from TensorFlow to TOSA dialect");
-
} // namespace tosa
} // namespace mlir
diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
index fc8f2e2..fa56317 100644
--- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
+++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc
@@ -3086,8 +3086,5 @@
return std::make_unique<LegalizeTFL>();
}
-static PassRegistration<LegalizeTFL> pass(
- PASS_NAME, "Legalize from TensorFlow Lite to TOSA dialect");
-
} // namespace tosa
} // namespace mlir