blob: 3099b8755a5d60515e8b0feae861f0e55106484a [file] [log] [blame]
// string.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Utilities to convert strings into FSTs.
//
#ifndef FST_LIB_STRING_H_
#define FST_LIB_STRING_H_
#include <fst/compact-fst.h>
#include <fst/mutable-fst.h>
DECLARE_string(fst_field_separator);
namespace fst {
// Functor compiling a string in an FST
template <class A>
class StringCompiler {
public:
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
StringCompiler(TokenType type, const SymbolTable *syms = 0,
Label unknown_label = kNoLabel,
bool allow_negative = false)
: token_type_(type), syms_(syms), unknown_label_(unknown_label),
allow_negative_(allow_negative) {}
// Compile string 's' into FST 'fst'.
template <class F>
bool operator()(const string &s, F *fst) {
vector<Label> labels;
if (!ConvertStringToLabels(s, &labels))
return false;
Compile(labels, fst);
return true;
}
private:
bool ConvertStringToLabels(const string &str, vector<Label> *labels) const {
labels->clear();
if (token_type_ == BYTE) {
for (size_t i = 0; i < str.size(); ++i)
labels->push_back(static_cast<unsigned char>(str[i]));
} else if (token_type_ == UTF8) {
return UTF8StringToLabels(str, labels);
} else {
char *c_str = new char[str.size() + 1];
str.copy(c_str, str.size());
c_str[str.size()] = 0;
vector<char *> vec;
string separator = "\n" + FLAGS_fst_field_separator;
SplitToVector(c_str, separator.c_str(), &vec, true);
for (size_t i = 0; i < vec.size(); ++i) {
Label label;
if (!ConvertSymbolToLabel(vec[i], &label))
return false;
labels->push_back(label);
}
delete[] c_str;
}
return true;
}
void Compile(const vector<Label> &labels, MutableFst<A> *fst) const {
fst->DeleteStates();
while (fst->NumStates() <= labels.size())
fst->AddState();
for (size_t i = 0; i < labels.size(); ++i)
fst->AddArc(i, Arc(labels[i], labels[i], Weight::One(), i + 1));
fst->SetStart(0);
fst->SetFinal(labels.size(), Weight::One());
}
template <class Unsigned>
void Compile(const vector<Label> &labels, CompactFst<A, StringCompactor<A>,
Unsigned> *fst) const {
fst->SetCompactElements(labels.begin(), labels.end());
}
bool ConvertSymbolToLabel(const char *s, Label* output) const {
int64 n;
if (syms_) {
n = syms_->Find(s);
if ((n == -1) && (unknown_label_ != kNoLabel))
n = unknown_label_;
if (n == -1 || (!allow_negative_ && n < 0)) {
VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Symbol \"" << s
<< "\" is not mapped to any integer label, symbol table = "
<< syms_->Name();
return false;
}
} else {
char *p;
n = strtoll(s, &p, 10);
if (p < s + strlen(s) || (!allow_negative_ && n < 0)) {
VLOG(1) << "StringCompiler::ConvertSymbolToLabel: Bad label integer "
<< "= \"" << s << "\"";
return false;
}
}
*output = n;
return true;
}
TokenType token_type_; // Token type: symbol, byte or utf8 encoded
const SymbolTable *syms_; // Symbol table used when token type is symbol
Label unknown_label_; // Label for token missing from symbol table
bool allow_negative_; // Negative labels allowed?
DISALLOW_COPY_AND_ASSIGN(StringCompiler);
};
// Functor to print a string FST as a string.
template <class A>
class StringPrinter {
public:
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
enum TokenType { SYMBOL = 1, BYTE = 2, UTF8 = 3 };
StringPrinter(TokenType token_type,
const SymbolTable *syms = 0)
: token_type_(token_type), syms_(syms) {}
// Convert the FST 'fst' into the string 'output'
bool operator()(const Fst<A> &fst, string *output) {
bool is_a_string = FstToLabels(fst);
if (!is_a_string) {
VLOG(1) << "StringPrinter::operator(): Fst is not a string.";
return false;
}
output->clear();
if (token_type_ == SYMBOL) {
stringstream sstrm;
for (size_t i = 0; i < labels_.size(); ++i) {
if (i)
sstrm << *(FLAGS_fst_field_separator.rbegin());
if (!PrintLabel(labels_[i], sstrm))
return false;
}
*output = sstrm.str();
} else if (token_type_ == BYTE) {
for (size_t i = 0; i < labels_.size(); ++i) {
output->push_back(labels_[i]);
}
} else if (token_type_ == UTF8) {
return LabelsToUTF8String(labels_, output);
} else {
VLOG(1) << "StringPrinter::operator(): Unknown token type: "
<< token_type_;
return false;
}
return true;
}
private:
bool FstToLabels(const Fst<A> &fst) {
labels_.clear();
StateId s = fst.Start();
if (s == kNoStateId) {
VLOG(2) << "StringPrinter::FstToLabels: Invalid starting state for "
<< "string fst.";
return false;
}
while (fst.Final(s) == Weight::Zero()) {
ArcIterator<Fst<A> > aiter(fst, s);
if (aiter.Done()) {
VLOG(2) << "StringPrinter::FstToLabels: String fst traversal does "
<< "not reach final state.";
return false;
}
const A& arc = aiter.Value();
labels_.push_back(arc.olabel);
s = arc.nextstate;
if (s == kNoStateId) {
VLOG(2) << "StringPrinter::FstToLabels: Transition to invalid "
<< "state.";
return false;
}
aiter.Next();
if (!aiter.Done()) {
VLOG(2) << "StringPrinter::FstToLabels: State with multiple "
<< "outgoing arcs found.";
return false;
}
}
return true;
}
bool PrintLabel(Label lab, ostream& ostrm) {
if (syms_) {
string symbol = syms_->Find(lab);
if (symbol == "") {
VLOG(2) << "StringPrinter::PrintLabel: Integer " << lab << " is not "
<< "mapped to any textual symbol, symbol table = "
<< syms_->Name();
return false;
}
ostrm << symbol;
} else {
ostrm << lab;
}
return true;
}
TokenType token_type_; // Token type: symbol, byte or utf8 encoded
const SymbolTable *syms_; // Symbol table used when token type is symbol
vector<Label> labels_; // Input FST labels.
DISALLOW_COPY_AND_ASSIGN(StringPrinter);
};
} // namespace fst
#endif // FST_LIB_STRING_H_