blob: e003bde74f15a72736878a5693ad1f96389f5281 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Notes on thread-safety: All of the classes here are thread-compatible. More
// specifically, the registry machinery is thread-safe, as long as each thread
// performs feature extraction on a different Sentence object.
#ifndef LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_
#define LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_
#include <stddef.h>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "util/base/logging.h"
#include "util/base/macros.h"
namespace libtextclassifier {
namespace nlp_core {
// A base class for shared workspaces. Derived classes implement a static member
// function TypeName() which returns a human readable std::string name for the
// class.
class Workspace {
public:
// Polymorphic destructor.
virtual ~Workspace() {}
protected:
// Create an empty workspace.
Workspace() {}
private:
TC_DISALLOW_COPY_AND_ASSIGN(Workspace);
};
// Returns a new, strictly increasing int every time it is invoked.
int GetFreshTypeId();
// Struct to simulate typeid, but without RTTI.
template <typename T>
struct TypeId {
static int type_id;
};
template <typename T>
int TypeId<T>::type_id = GetFreshTypeId();
// A registry that keeps track of workspaces.
class WorkspaceRegistry {
public:
// Create an empty registry.
WorkspaceRegistry() {}
// Returns the index of a named workspace, adding it to the registry first
// if necessary.
template <class W>
int Request(const std::string &name) {
const int id = TypeId<W>::type_id;
max_workspace_id_ = std::max(id, max_workspace_id_);
workspace_types_[id] = W::TypeName();
std::vector<std::string> &names = workspace_names_[id];
for (int i = 0; i < names.size(); ++i) {
if (names[i] == name) return i;
}
names.push_back(name);
return names.size() - 1;
}
// Returns the maximum workspace id that has been registered.
int MaxId() const {
return max_workspace_id_;
}
const std::unordered_map<int, std::vector<std::string> > &WorkspaceNames()
const {
return workspace_names_;
}
// Returns a std::string describing the registered workspaces.
std::string DebugString() const;
private:
// Workspace type names, indexed as workspace_types_[typeid].
std::unordered_map<int, std::string> workspace_types_;
// Workspace names, indexed as workspace_names_[typeid][workspace].
std::unordered_map<int, std::vector<std::string> > workspace_names_;
// The maximum workspace id that has been registered.
int max_workspace_id_ = 0;
TC_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
};
// A typed collected of workspaces. The workspaces are indexed according to an
// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
// also immutable.
class WorkspaceSet {
public:
~WorkspaceSet() { Reset(WorkspaceRegistry()); }
// Returns true if a workspace has been set.
template <class W>
bool Has(int index) const {
const int id = TypeId<W>::type_id;
TC_DCHECK_GE(id, 0);
TC_DCHECK_LT(id, workspaces_.size());
TC_DCHECK_GE(index, 0);
TC_DCHECK_LT(index, workspaces_[id].size());
if (id >= workspaces_.size()) return false;
return workspaces_[id][index] != nullptr;
}
// Returns an indexed workspace; the workspace must have been set.
template <class W>
const W &Get(int index) const {
TC_DCHECK(Has<W>(index));
const int id = TypeId<W>::type_id;
const Workspace *w = workspaces_[id][index];
return reinterpret_cast<const W &>(*w);
}
// Sets an indexed workspace; this takes ownership of the workspace, which
// must have been new-allocated. It is an error to set a workspace twice.
template <class W>
void Set(int index, W *workspace) {
const int id = TypeId<W>::type_id;
TC_DCHECK_GE(id, 0);
TC_DCHECK_LT(id, workspaces_.size());
TC_DCHECK_GE(index, 0);
TC_DCHECK_LT(index, workspaces_[id].size());
TC_DCHECK(workspaces_[id][index] == nullptr);
TC_DCHECK(workspace != nullptr);
workspaces_[id][index] = workspace;
}
void Reset(const WorkspaceRegistry &registry) {
// Deallocate current workspaces.
for (auto &it : workspaces_) {
for (size_t index = 0; index < it.size(); ++index) {
delete it[index];
}
}
workspaces_.clear();
workspaces_.resize(registry.MaxId() + 1, std::vector<Workspace *>());
for (auto &it : registry.WorkspaceNames()) {
workspaces_[it.first].resize(it.second.size());
}
}
private:
// The set of workspaces, indexed as workspaces_[typeid][index].
std::vector<std::vector<Workspace *> > workspaces_;
};
// A workspace that wraps around a single int.
class SingletonIntWorkspace : public Workspace {
public:
// Default-initializes the int value.
SingletonIntWorkspace() {}
// Initializes the int with the given value.
explicit SingletonIntWorkspace(int value) : value_(value) {}
// Returns the name of this type of workspace.
static std::string TypeName() { return "SingletonInt"; }
// Returns the int value.
int get() const { return value_; }
// Sets the int value.
void set(int value) { value_ = value; }
private:
// The enclosed int.
int value_ = 0;
};
// A workspace that wraps around a vector of int.
class VectorIntWorkspace : public Workspace {
public:
// Creates a vector of the given size.
explicit VectorIntWorkspace(int size);
// Creates a vector initialized with the given array.
explicit VectorIntWorkspace(const std::vector<int> &elements);
// Creates a vector of the given size, with each element initialized to the
// given value.
VectorIntWorkspace(int size, int value);
// Returns the name of this type of workspace.
static std::string TypeName();
// Returns the i'th element.
int element(int i) const { return elements_[i]; }
// Sets the i'th element.
void set_element(int i, int value) { elements_[i] = value; }
private:
// The enclosed vector.
std::vector<int> elements_;
};
// A workspace that wraps around a vector of vector of int.
class VectorVectorIntWorkspace : public Workspace {
public:
// Creates a vector of empty vectors of the given size.
explicit VectorVectorIntWorkspace(int size);
// Returns the name of this type of workspace.
static std::string TypeName();
// Returns the i'th vector of elements.
const std::vector<int> &elements(int i) const { return elements_[i]; }
// Mutable access to the i'th vector of elements.
std::vector<int> *mutable_elements(int i) { return &(elements_[i]); }
private:
// The enclosed vector of vector of elements.
std::vector<std::vector<int> > elements_;
};
} // namespace nlp_core
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_COMMON_WORKSPACE_H_