blob: d7f4d6bce2c665b5af1038333b835bed8806c946 [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Authors: allauzen@google.com (Cyril Allauzen)
// ttai@google.com (Terry Tai)
// jpr@google.com (Jake Ratkiewicz)
#ifndef FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
#define FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_
#include <libgen.h>
#include <string>
#include <vector>
using std::vector;
#include <fst/extensions/far/far.h>
#include <fst/string.h>
namespace fst {
// Construct a reader that provides FSTs from a file (stream) either on a
// line-by-line basis or on a per-stream basis. Note that the freshly
// constructed reader is already set to the first input.
//
// Sample Usage:
// for (StringReader<Arc> reader(...); !reader.Done(); reader.Next()) {
// Fst *fst = reader.GetVectorFst();
// }
template <class A>
class StringReader {
public:
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename StringCompiler<A>::TokenType TokenType;
enum EntryType { LINE = 1, FILE = 2 };
StringReader(istream &istrm,
const string &source,
EntryType entry_type,
TokenType token_type,
bool allow_negative_labels,
const SymbolTable *syms = 0,
Label unknown_label = kNoStateId)
: nline_(0), strm_(istrm), source_(source), entry_type_(entry_type),
token_type_(token_type), done_(false),
compiler_(token_type, syms, unknown_label, allow_negative_labels) {
Next(); // Initialize the reader to the first input.
}
bool Done() {
return done_;
}
void Next() {
VLOG(1) << "Processing source " << source_ << " at line " << nline_;
if (!strm_) { // We're done if we have no more input.
done_ = true;
return;
}
if (entry_type_ == LINE) {
getline(strm_, content_);
++nline_;
} else {
content_.clear();
string line;
while (getline(strm_, line)) {
++nline_;
content_.append(line);
content_.append("\n");
}
}
if (!strm_ && content_.empty()) // We're also done if we read off all the
done_ = true; // whitespace at the end of a file.
}
VectorFst<A> *GetVectorFst() {
VectorFst<A> *fst = new VectorFst<A>;
if (compiler_(content_, fst)) {
return fst;
} else {
delete fst;
return NULL;
}
}
CompactFst<A, StringCompactor<A> > *GetCompactFst() {
CompactFst<A, StringCompactor<A> > *fst =
new CompactFst<A, StringCompactor<A> >;
if (compiler_(content_, fst)) {
return fst;
} else {
delete fst;
return NULL;
}
}
private:
size_t nline_;
istream &strm_;
string source_;
EntryType entry_type_;
TokenType token_type_;
bool done_;
StringCompiler<A> compiler_;
string content_; // The actual content of the input stream's next FST.
DISALLOW_COPY_AND_ASSIGN(StringReader);
};
// Compute the minimal length required to encode each line number as a decimal
// number.
int KeySize(const char *filename);
template <class Arc>
void FarCompileStrings(const vector<string> &in_fnames,
const string &out_fname,
const string &fst_type,
const FarType &far_type,
int32 generate_keys,
FarEntryType fet,
FarTokenType tt,
const string &symbols_fname,
const string &unknown_symbol,
bool allow_negative_labels,
bool file_list_input,
const string &key_prefix,
const string &key_suffix) {
typename StringReader<Arc>::EntryType entry_type;
if (fet == FET_LINE) {
entry_type = StringReader<Arc>::LINE;
} else if (fet == FET_FILE) {
entry_type = StringReader<Arc>::FILE;
} else {
FSTERROR() << "FarCompileStrings: unknown entry type";
return;
}
typename StringCompiler<Arc>::TokenType token_type;
if (tt == FTT_SYMBOL) {
token_type = StringCompiler<Arc>::SYMBOL;
} else if (tt == FTT_BYTE) {
token_type = StringCompiler<Arc>::BYTE;
} else if (tt == FTT_UTF8) {
token_type = StringCompiler<Arc>::UTF8;
} else {
FSTERROR() << "FarCompileStrings: unknown token type";
return;
}
bool compact;
if (fst_type.empty() || (fst_type == "vector")) {
compact = false;
} else if (fst_type == "compact") {
compact = true;
} else {
FSTERROR() << "FarCompileStrings: unknown fst type: "
<< fst_type;
return;
}
const SymbolTable *syms = 0;
typename Arc::Label unknown_label = kNoLabel;
if (!symbols_fname.empty()) {
syms = SymbolTable::ReadText(symbols_fname,
allow_negative_labels);
if (!syms) {
FSTERROR() << "FarCompileStrings: error reading symbol table: "
<< symbols_fname;
return;
}
if (!unknown_symbol.empty()) {
unknown_label = syms->Find(unknown_symbol);
if (unknown_label == kNoLabel) {
FSTERROR() << "FarCompileStrings: unknown label \"" << unknown_label
<< "\" missing from symbol table: " << symbols_fname;
return;
}
}
}
FarWriter<Arc> *far_writer =
FarWriter<Arc>::Create(out_fname, far_type);
if (!far_writer) return;
vector<string> inputs;
if (file_list_input) {
for (int i = 1; i < in_fnames.size(); ++i) {
ifstream istrm(in_fnames[i].c_str());
string str;
while (getline(istrm, str))
inputs.push_back(str);
}
} else {
inputs = in_fnames;
}
for (int i = 0, n = 0; i < inputs.size(); ++i) {
int key_size = generate_keys ? generate_keys :
(entry_type == StringReader<Arc>::FILE ? 1 :
KeySize(inputs[i].c_str()));
ifstream istrm(inputs[i].c_str());
for (StringReader<Arc> reader(
istrm, inputs[i], entry_type, token_type,
allow_negative_labels, syms, unknown_label);
!reader.Done();
reader.Next()) {
++n;
const Fst<Arc> *fst;
if (compact)
fst = reader.GetCompactFst();
else
fst = reader.GetVectorFst();
if (!fst) {
FSTERROR() << "FarCompileStrings: compiling string number " << n
<< " in file " << inputs[i] << " failed with token_type = "
<< (tt == FTT_BYTE ? "byte" :
(tt == FTT_UTF8 ? "utf8" :
(tt == FTT_SYMBOL ? "symbol" : "unknown")))
<< " and entry_type = "
<< (fet == FET_LINE ? "line" :
(fet == FET_FILE ? "file" : "unknown"));
delete far_writer;
delete syms;
return;
}
ostringstream keybuf;
keybuf.width(key_size);
keybuf.fill('0');
keybuf << n;
string key;
if (generate_keys > 0) {
key = keybuf.str();
} else {
char* filename = new char[inputs[i].size() + 1];
strcpy(filename, inputs[i].c_str());
key = basename(filename);
if (entry_type != StringReader<Arc>::FILE) {
key += "-";
key += keybuf.str();
}
delete[] filename;
}
far_writer->Add(key_prefix + key + key_suffix, *fst);
delete fst;
}
if (generate_keys == 0)
n = 0;
}
delete far_writer;
}
} // namespace fst
#endif // FST_EXTENSIONS_FAR_COMPILE_STRINGS_H_