blob: d95822558f1224d43141e2891d177bced139cb1d [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Mechanism to instantiate classes by name.
//
// This mechanism is useful if the concrete classes to be instantiated are not
// statically known (e.g., if their names are read from a dynamically-provided
// config).
//
// In that case, the first step is to define the API implemented by the
// instantiated classes. E.g.,
//
// // In a header file function.h:
//
// // Abstract function that takes a double and returns a double.
// class Function : public RegisterableClass<Function> {
// public:
// virtual ~Function() {}
// virtual double Evaluate(double x) = 0;
// };
//
// // Should be inside namespace libtextclassifier::nlp_core.
// TC_DECLARE_CLASS_REGISTRY_NAME(Function);
//
// Notice the inheritance from RegisterableClass<Function>. RegisterableClass
// is defined by this file (registry.h). Under the hood, this inheritanace
// defines a "registry" that maps names (zero-terminated arrays of chars) to
// factory methods that create Functions. You should give a human-readable name
// to this registry. To do that, use the following macro in a .cc file (it has
// to be a .cc file, as it defines some static data):
//
// // Inside function.cc
// // Should be inside namespace libtextclassifier::nlp_core.
// TC_DEFINE_CLASS_REGISTRY_NAME("function", Function);
//
// Now, let's define a few concrete Functions: e.g.,
//
// class Cos : public Function {
// public:
// double Evaluate(double x) override { return cos(x); }
// TC_DEFINE_REGISTRATION_METHOD("cos", Cos);
// };
//
// class Exp : public Function {
// public:
// double Evaluate(double x) override { return exp(x); }
// TC_DEFINE_REGISTRATION_METHOD("sin", Sin);
// };
//
// Each concrete Function implementation should have (in the public section) the
// macro
//
// TC_DEFINE_REGISTRATION_METHOD("name", implementation_class);
//
// This defines a RegisterClass static method that, when invoked, associates
// "name" with a factory method that creates instances of implementation_class.
//
// Before instantiating Functions by name, we need to tell our system which
// Functions we may be interested in. This is done by calling the
// Foo::RegisterClass() for each relevant Foo implementation of Function. It is
// ok to call Foo::RegisterClass() multiple times (even in parallel): only the
// first call will perform something, the others will return immediately.
//
// Cos::RegisterClass();
// Exp::RegisterClass();
//
// Now, let's instantiate a Function based on its name. This get a lot more
// interesting if the Function name is not statically known (i.e.,
// read from an input proto:
//
// std::unique_ptr<Function> f(Function::Create("cos"));
// double result = f->Evaluate(arg);
//
// NOTE: the same binary can use this mechanism for different APIs. E.g., one
// can also have (in the binary with Function, Sin, Cos, etc):
//
// class IntFunction : public RegisterableClass<IntFunction> {
// public:
// virtual ~IntFunction() {}
// virtual int Evaluate(int k) = 0;
// };
//
// TC_DECLARE_CLASS_REGISTRY_NAME(IntFunction);
//
// TC_DEFINE_CLASS_REGISTRY_NAME("int function", IntFunction);
//
// class Inc : public IntFunction {
// public:
// int Evaluate(int k) override { return k + 1; }
// TC_DEFINE_REGISTRATION_METHOD("inc", Inc);
// };
//
// RegisterableClass<Function> and RegisterableClass<IntFunction> define their
// own registries: each maps string names to implementation of the corresponding
// API.
#ifndef LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_
#define LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_
#include <stdlib.h>
#include <string.h>
#include <string>
#include "util/base/logging.h"
namespace libtextclassifier {
namespace nlp_core {
namespace internal {
// Registry that associates keys (zero-terminated array of chars) with values.
// Values are pointers to type T (the template parameter). This is used to
// store the association between component names and factory methods that
// produce those components; the error messages are focused on that case.
//
// Internally, this registry uses a linked list of (key, value) pairs. We do
// not use an STL map, list, etc because we aim for small code size.
template <class T>
class ComponentRegistry {
public:
explicit ComponentRegistry(const char *name) : name_(name), head_(nullptr) {}
// Adds a the (key, value) pair to this registry (if the key does not already
// exists in this registry) and returns true. If the registry already has a
// mapping for key, returns false and does not modify the registry. NOTE: the
// error (false) case happens even if the existing value for key is equal with
// the new one.
//
// This method does not take ownership of key, nor of value.
bool Add(const char *key, T *value) {
const Cell *old_cell = FindCell(key);
if (old_cell != nullptr) {
TC_LOG(ERROR) << "Duplicate component: " << key;
return false;
}
Cell *new_cell = new Cell(key, value, head_);
head_ = new_cell;
return true;
}
// Returns the value attached to a key in this registry. Returns nullptr on
// error (e.g., unknown key).
T *Lookup(const char *key) const {
const Cell *cell = FindCell(key);
if (cell == nullptr) {
TC_LOG(ERROR) << "Unknown " << name() << " component: " << key;
}
return (cell == nullptr) ? nullptr : cell->value();
}
T *Lookup(const std::string &key) const { return Lookup(key.c_str()); }
// Returns name of this ComponentRegistry.
const char *name() const { return name_; }
private:
// Cell for the singly-linked list underlying this ComponentRegistry. Each
// cell contains a key, the value for that key, as well as a pointer to the
// next Cell from the list.
class Cell {
public:
// Constructs a new Cell.
Cell(const char *key, T *value, Cell *next)
: key_(key), value_(value), next_(next) {}
const char *key() const { return key_; }
T *value() const { return value_; }
Cell *next() const { return next_; }
private:
const char *const key_;
T *const value_;
Cell *const next_;
};
// Finds Cell for indicated key in the singly-linked list pointed to by head_.
// Returns pointer to that first Cell with that key, or nullptr if no such
// Cell (i.e., unknown key).
//
// Caller does NOT own the returned pointer.
const Cell *FindCell(const char *key) const {
Cell *c = head_;
while (c != nullptr && strcmp(key, c->key()) != 0) {
c = c->next();
}
return c;
}
// Human-readable description for this ComponentRegistry. For debug purposes.
const char *const name_;
// Pointer to the first Cell from the underlying list of (key, value) pairs.
Cell *head_;
};
} // namespace internal
// Base class for registerable classes.
template <class T>
class RegisterableClass {
public:
// Factory function type.
typedef T *(Factory)();
// Registry type.
typedef internal::ComponentRegistry<Factory> Registry;
// Creates a new instance of T. Returns pointer to new instance or nullptr in
// case of errors (e.g., unknown component).
//
// Passes ownership of the returned pointer to the caller.
static T *Create(const std::string &name) { // NOLINT
auto *factory = registry()->Lookup(name);
if (factory == nullptr) {
TC_LOG(ERROR) << "Unknown RegisterableClass " << name;
return nullptr;
}
return factory();
}
// Returns registry for class.
static Registry *registry() {
static Registry *registry_for_type_t = new Registry(kRegistryName);
return registry_for_type_t;
}
protected:
// Factory method for subclass ComponentClass. Used internally by the static
// method RegisterClass() defined by TC_DEFINE_REGISTRATION_METHOD.
template <class ComponentClass>
static T *_internal_component_factory() {
return new ComponentClass();
}
private:
// Human-readable name for the registry for this class.
static const char kRegistryName[];
};
// Defines the static method component_class::RegisterClass() that should be
// called before trying to instantiate component_class by name. Should be used
// inside the public section of the declaration of component_class. See
// comments at the top-level of this file.
#define TC_DEFINE_REGISTRATION_METHOD(component_name, component_class) \
static void RegisterClass() { \
static bool once = registry()->Add( \
component_name, &_internal_component_factory<component_class>); \
if (!once) { \
TC_LOG(ERROR) << "Problem registering " << component_name; \
} \
TC_DCHECK(once); \
}
// Defines the human-readable name of the registry associated with base_class.
#define TC_DECLARE_CLASS_REGISTRY_NAME(base_class) \
template <> \
const char ::libtextclassifier::nlp_core::RegisterableClass< \
base_class>::kRegistryName[]
// Defines the human-readable name of the registry associated with base_class.
#define TC_DEFINE_CLASS_REGISTRY_NAME(registry_name, base_class) \
template <> \
const char ::libtextclassifier::nlp_core::RegisterableClass< \
base_class>::kRegistryName[] = registry_name
} // namespace nlp_core
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_COMMON_REGISTRY_H_