Separate translators into "from MLIR" and "to MLIR".
Translations performed by mlir-translate only have MLIR on one end.
MLIR-to-MLIR conversions (including dialect changes) should be treated as
passes and run by mlir-opt. Individual translations should not care about
reading or writing MLIR and should work on in-memory representation of MLIR
modules instead. Split the TranslateFunction interface and the translate
registry into two parts: "from MLIR" and "to MLIR".
Update mlir-translate to handle both registries together by wrapping
translation functions into source-to-source convresions. Remove MLIR parsing
and writing from individual translations and make them operate on Modules
instead. This removes the need for individual translators to include
tools/mlir-translate/mlir-translate.h, which can now be safely removed.
Remove mlir-to-mlir translation that only existed as a registration example and
use mlir-opt instead for tests.
PiperOrigin-RevId: 222398707
diff --git a/include/mlir/Translation.h b/include/mlir/Translation.h
index 9968605..78518e9 100644
--- a/include/mlir/Translation.h
+++ b/include/mlir/Translation.h
@@ -28,26 +28,42 @@
class MLIRContext;
class Module;
-using TranslateFunction =
- std::function<bool(llvm::StringRef inputFilename,
- llvm::StringRef oututFilename, MLIRContext *)>;
+/// Interface of the function that translates a file to MLIR. The
+/// implementation should create a new MLIR Module in the given context and
+/// return a pointer to it, or a nullptr in case of any error.
+using TranslateToMLIRFunction =
+ std::function<std::unique_ptr<Module>(llvm::StringRef, MLIRContext *)>;
+/// Interface of the function that translates MLIR to a different format and
+/// outputs the result to a file. The implementation should return "true" on
+/// error and "false" otherwise. It is allowed to modify the module.
+using TranslateFromMLIRFunction =
+ std::function<bool(Module *, llvm::StringRef)>;
-// Use TranslateRegistration as a global initialiser that registers a function
-// and associates it with name. This requires that a translation has not been
-// registered to a given name.
-//
-// Usage:
-//
-// // At namespace scope.
-// static TranslateRegistration Unused(&MySubCommand, [] { ... });
-//
-struct TranslateRegistration {
- TranslateRegistration(llvm::StringRef name,
- const TranslateFunction &function);
+/// Use Translate[To|From]MLIRRegistration as a global initialiser that
+/// registers a function and associates it with name. This requires that a
+/// translation has not been registered to a given name.
+///
+/// Usage:
+///
+/// // At namespace scope.
+/// static TranslateToMLIRRegistration Unused(&MySubCommand, [] { ... });
+///
+/// \{
+struct TranslateToMLIRRegistration {
+ TranslateToMLIRRegistration(llvm::StringRef name,
+ const TranslateToMLIRFunction &function);
};
+struct TranslateFromMLIRRegistration {
+ TranslateFromMLIRRegistration(llvm::StringRef name,
+ const TranslateFromMLIRFunction &function);
+};
+/// \}
+
/// Get a read-only reference to the translator registry.
-const llvm::StringMap<TranslateFunction> &getTranslationRegistry();
+const llvm::StringMap<TranslateToMLIRFunction> &getTranslationToMLIRRegistry();
+const llvm::StringMap<TranslateFromMLIRFunction> &
+getTranslationFromMLIRRegistry();
} // namespace mlir
diff --git a/lib/Translation/Translation.cpp b/lib/Translation/Translation.cpp
index 96765c71b..c1a9f9d 100644
--- a/lib/Translation/Translation.cpp
+++ b/lib/Translation/Translation.cpp
@@ -25,24 +25,50 @@
using namespace mlir;
-// Get the mutable static map between translations registered and the
-// TranslateFunctions that perform those translations.
-static llvm::StringMap<TranslateFunction> &getMutableTranslationRegistry() {
- static llvm::StringMap<TranslateFunction> translationRegistry;
- return translationRegistry;
+// Get the mutable static map between registered "to MLIR" translations and the
+// TranslateToMLIRFunctions that perform those translations.
+static llvm::StringMap<TranslateToMLIRFunction> &
+getMutableTranslationToMLIRRegistry() {
+ static llvm::StringMap<TranslateToMLIRFunction> translationToMLIRRegistry;
+ return translationToMLIRRegistry;
+}
+// Get the mutable static map between registered "from MLIR" translations and
+// the TranslateFromMLIRFunctions that perform those translations.
+static llvm::StringMap<TranslateFromMLIRFunction> &
+getMutableTranslationFromMLIRRegistry() {
+ static llvm::StringMap<TranslateFromMLIRFunction> translationFromMLIRRegistry;
+ return translationFromMLIRRegistry;
}
-TranslateRegistration::TranslateRegistration(
- StringRef name, const TranslateFunction &function) {
- auto &translationRegistry = getMutableTranslationRegistry();
- if (translationRegistry.find(name) != translationRegistry.end())
- llvm::report_fatal_error("Attempting to overwrite an existing function");
- assert(function && "Attempting to register an empty translate function");
- translationRegistry[name] = function;
+TranslateToMLIRRegistration::TranslateToMLIRRegistration(
+ StringRef name, const TranslateToMLIRFunction &function) {
+ auto &translationToMLIRRegistry = getMutableTranslationToMLIRRegistry();
+ if (translationToMLIRRegistry.find(name) != translationToMLIRRegistry.end())
+ llvm::report_fatal_error(
+ "Attempting to overwrite an existing <to> function");
+ assert(function && "Attempting to register an empty translate <to> function");
+ translationToMLIRRegistry[name] = function;
+}
+
+TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
+ StringRef name, const TranslateFromMLIRFunction &function) {
+ auto &translationFromMLIRRegistry = getMutableTranslationFromMLIRRegistry();
+ if (translationFromMLIRRegistry.find(name) !=
+ translationFromMLIRRegistry.end())
+ llvm::report_fatal_error(
+ "Attempting to overwrite an existing <from> function");
+ assert(function && "Attempting to register an empty translate <to> function");
+ translationFromMLIRRegistry[name] = function;
}
// Merely add the const qualifier to the mutable registry so that external users
// cannot modify it.
-const llvm::StringMap<TranslateFunction> &mlir::getTranslationRegistry() {
- return getMutableTranslationRegistry();
+const llvm::StringMap<TranslateToMLIRFunction> &
+mlir::getTranslationToMLIRRegistry() {
+ return getMutableTranslationToMLIRRegistry();
+}
+
+const llvm::StringMap<TranslateFromMLIRFunction> &
+mlir::getTranslationFromMLIRRegistry() {
+ return getMutableTranslationFromMLIRRegistry();
}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index a557ac8..66d8ceb 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-translate -mlir-to-mlir %s | FileCheck %s
+// RUN: mlir-opt %s | FileCheck %s
// CHECK-DAG: #map{{[0-9]+}} = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)
#map0 = (d0, d1, d2, d3, d4)[s0] -> (d0, d1, d2, d4, d3)
diff --git a/tools/mlir-translate/mlir-translate.cpp b/tools/mlir-translate/mlir-translate.cpp
index 32e6b2e..d403cca 100644
--- a/tools/mlir-translate/mlir-translate.cpp
+++ b/tools/mlir-translate/mlir-translate.cpp
@@ -20,7 +20,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir-translate.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
@@ -43,7 +42,7 @@
outputFilename("o", llvm::cl::desc("Output filename"),
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
-Module *mlir::parseMLIRInput(StringRef inputFilename, MLIRContext *context) {
+static Module *parseMLIRInput(StringRef inputFilename, MLIRContext *context) {
// Set up the input file.
auto fileOrErr = llvm::MemoryBuffer::getFileOrSTDIN(inputFilename);
if (std::error_code error = fileOrErr.getError()) {
@@ -57,8 +56,8 @@
return parseSourceFile(sourceMgr, context);
}
-std::unique_ptr<llvm::ToolOutputFile>
-mlir::openOutputFile(llvm::StringRef outputFilename) {
+static std::unique_ptr<llvm::ToolOutputFile>
+openOutputFile(llvm::StringRef outputFilename) {
std::error_code error;
auto result = llvm::make_unique<llvm::ToolOutputFile>(outputFilename, error,
llvm::sys::fs::F_None);
@@ -70,8 +69,8 @@
return result;
}
-bool mlir::printMLIROutput(const Module &module,
- llvm::StringRef outputFilename) {
+static bool printMLIROutput(const Module &module,
+ llvm::StringRef outputFilename) {
auto file = openOutputFile(outputFilename);
if (!file)
return true;
@@ -80,24 +79,48 @@
return false;
}
-// Example translation registration. This performs a MLIR to MLIR "translation"
-// which simply parses and prints the MLIR input file.
-static TranslateRegistration MLIRToMLIRTranslate(
- "mlir-to-mlir", [](StringRef inputFilename, StringRef outputFilename,
- MLIRContext *context) {
- std::unique_ptr<Module> module(parseMLIRInput(inputFilename, context));
- if (!module)
- return true;
+// Common interface for source-to-source translation functions.
+using TranslateFunction =
+ std::function<bool(StringRef, StringRef, MLIRContext *)>;
- return printMLIROutput(*module, outputFilename);
- });
+// Storage for the translation function wrappers that survive the parser.
+static llvm::SmallVector<TranslateFunction, 8> wrapperStorage;
// Custom parser for TranslateFunction.
+// Wraps TranslateToMLIRFunctions and TranslateFromMLIRFunctions into
+// TranslateFunctions before registering them as options.
struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
TranslationParser(llvm::cl::Option &opt)
: llvm::cl::parser<const TranslateFunction *>(opt) {
- for (const auto &kv : getTranslationRegistry()) {
- addLiteralOption(kv.first(), &kv.second, kv.first());
+ for (const auto &kv : getTranslationToMLIRRegistry()) {
+ TranslateToMLIRFunction function = kv.second;
+ TranslateFunction wrapper = [function](StringRef inputFilename,
+ StringRef outputFilename,
+ MLIRContext *context) {
+ std::unique_ptr<Module> module = function(inputFilename, context);
+ if (!module)
+ return true;
+ printMLIROutput(*module, outputFilename);
+ return false;
+ };
+ wrapperStorage.emplace_back(std::move(wrapper));
+
+ addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
+ }
+ for (const auto &kv : getTranslationFromMLIRRegistry()) {
+ TranslateFromMLIRFunction function = kv.second;
+ TranslateFunction wrapper = [function](StringRef inputFilename,
+ StringRef outputFilename,
+ MLIRContext *context) {
+ auto module =
+ std::unique_ptr<Module>(parseMLIRInput(inputFilename, context));
+ if (!module)
+ return true;
+ return function(module.get(), outputFilename);
+ };
+ wrapperStorage.emplace_back(std::move(wrapper));
+
+ addLiteralOption(kv.first(), &wrapperStorage.back(), kv.first());
}
}
diff --git a/tools/mlir-translate/mlir-translate.h b/tools/mlir-translate/mlir-translate.h
deleted file mode 100644
index 758c17b..0000000
--- a/tools/mlir-translate/mlir-translate.h
+++ /dev/null
@@ -1,48 +0,0 @@
-//===- mlir-translate.h - Translation driver -----------------*- C++ -*----===//
-//
-// Copyright 2019 The MLIR Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-// =============================================================================
-//
-// Registry for user provided translations and common utility functions for
-// translations.
-//
-//===----------------------------------------------------------------------===//
-#ifndef TOOLS_MLIR_TRANSLATE_H
-#define TOOLS_MLIR_TRANSLATE_H
-
-#include "mlir/Support/LLVM.h"
-#include <memory>
-
-namespace llvm {
-class ToolOutputFile;
-}
-
-namespace mlir {
-class MLIRContext;
-class Module;
-
-/// Open a file to be used as raw_ostream.
-std::unique_ptr<llvm::ToolOutputFile>
-openOutputFile(llvm::StringRef outputFilename);
-
-// Returns module parsed from input filename or null in case of error.
-Module *parseMLIRInput(llvm::StringRef inputFilename, MLIRContext *context);
-
-// Prints module to outputFilename and returns whether printing module failed.
-bool printMLIROutput(const Module &module, llvm::StringRef outputFilename);
-
-} // namespace mlir
-
-#endif // TOOLS_MLIR_TRANSLATE_H