blob: f368988b5b40e30d1591d10317e924db5de54215 [file] [log] [blame]
//===- DialectHooks.h - MLIR DialectHooks mechanism -------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines abstraction and registration mechanism for dialect hooks.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_DIALECT_HOOKS_H
#define MLIR_IR_DIALECT_HOOKS_H
#include "mlir/IR/Dialect.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
using DialectHooksSetter = std::function<void(MLIRContext *)>;
/// Dialect hooks allow external components to register their functions to
/// be called for specific tasks specialized per dialect, such as decoding
/// of opaque constants. To register concrete dialect hooks, one should
/// define a DialectHooks subclass and use it as a template
/// argument to DialectHooksRegistration. For example,
/// class MyHooks : public DialectHooks {...};
/// static DialectHooksRegistration<MyHooks, MyDialect> hooksReg;
/// The subclass should override DialectHook methods for supported hooks.
class DialectHooks {
public:
// Returns hook to constant fold an operation.
DialectConstantFoldHook getConstantFoldHook() { return nullptr; }
// Returns hook to decode opaque constant tensor.
DialectConstantDecodeHook getDecodeHook() { return nullptr; }
// Returns hook to extract an element of an opaque constant tensor.
DialectExtractElementHook getExtractElementHook() { return nullptr; }
};
/// Registers a function that will set hooks in the registered dialects
/// based on information coming from DialectHooksRegistration.
void registerDialectHooksSetter(const DialectHooksSetter &function);
/// DialectHooksRegistration provides a global initialiser that registers
/// a dialect hooks setter routine.
/// Usage:
///
/// // At namespace scope.
/// static DialectHooksRegistration<MyHooks, MyDialect> unused;
template <typename ConcreteHooks> struct DialectHooksRegistration {
DialectHooksRegistration(StringRef dialectName) {
registerDialectHooksSetter([dialectName](MLIRContext *ctx) {
Dialect *dialect = ctx->getRegisteredDialect(dialectName);
if (!dialect) {
llvm::errs() << "error: cannot register hooks for unknown dialect '"
<< dialectName << "'\n";
abort();
}
// Set hooks.
ConcreteHooks hooks;
if (auto h = hooks.getConstantFoldHook())
dialect->constantFoldHook = h;
if (auto h = hooks.getDecodeHook())
dialect->decodeHook = h;
if (auto h = hooks.getExtractElementHook())
dialect->extractElementHook = h;
});
}
};
} // namespace mlir
#endif