blob: 3eacab48be80affc214dfd56e2936cdde729e738 [file] [log] [blame]
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: jpr@google.com (Jake Ratkiewicz)
#ifndef FST_SCRIPT_FST_CLASS_H_
#define FST_SCRIPT_FST_CLASS_H_
#include <string>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/vector-fst.h>
#include <iostream>
#include <fstream>
// Classes to support "boxing" all existing types of FST arcs in a single
// FstClass which hides the arc types. This allows clients to load
// and work with FSTs without knowing the arc type.
// These classes are only recommended for use in high-level scripting
// applications. Most users should use the lower-level templated versions
// corresponding to these classes.
namespace fst {
namespace script {
//
// Abstract base class defining the set of functionalities implemented
// in all impls, and passed through by all bases Below FstClassBase
// the class hierarchy bifurcates; FstClassImplBase serves as the base
// class for all implementations (of which FstClassImpl is currently
// the only one) and FstClass serves as the base class for all
// interfaces.
//
class FstClassBase {
public:
virtual const string &ArcType() const = 0;
virtual const string &FstType() const = 0;
virtual const string &WeightType() const = 0;
virtual const SymbolTable *InputSymbols() const = 0;
virtual const SymbolTable *OutputSymbols() const = 0;
virtual void Write(const string& fname) const = 0;
virtual uint64 Properties(uint64 mask, bool test) const = 0;
virtual ~FstClassBase() { }
};
class FstClassImplBase : public FstClassBase {
public:
virtual FstClassImplBase *Copy() = 0;
virtual void SetInputSymbols(SymbolTable *is) = 0;
virtual void SetOutputSymbols(SymbolTable *is) = 0;
virtual ~FstClassImplBase() { }
};
//
// CONTAINER CLASS
// Wraps an Fst<Arc>, hiding its arc type. Whether this Fst<Arc>
// pointer refers to a special kind of FST (e.g. a MutableFst) is
// known by the type of interface class that owns the pointer to this
// container.
//
template<class Arc>
class FstClassImpl : public FstClassImplBase {
public:
explicit FstClassImpl(Fst<Arc> *impl,
bool should_own = false) :
impl_(should_own ? impl : impl->Copy()) { }
virtual const string &ArcType() const {
return Arc::Type();
}
virtual const string &FstType() const {
return impl_->Type();
}
virtual const string &WeightType() const {
return Arc::Weight::Type();
}
virtual const SymbolTable *InputSymbols() const {
return impl_->InputSymbols();
}
virtual const SymbolTable *OutputSymbols() const {
return impl_->OutputSymbols();
}
// Warning: calling this method casts the FST to a mutable FST.
virtual void SetInputSymbols(SymbolTable *is) {
static_cast<MutableFst<Arc> *>(impl_)->SetInputSymbols(is);
}
// Warning: calling this method casts the FST to a mutable FST.
virtual void SetOutputSymbols(SymbolTable *os) {
static_cast<MutableFst<Arc> *>(impl_)->SetOutputSymbols(os);
}
virtual void Write(const string &fname) const {
impl_->Write(fname);
}
virtual uint64 Properties(uint64 mask, bool test) const {
return impl_->Properties(mask, test);
}
virtual ~FstClassImpl() { delete impl_; }
Fst<Arc> *GetImpl() { return impl_; }
virtual FstClassImpl *Copy() {
return new FstClassImpl<Arc>(impl_);
}
private:
Fst<Arc> *impl_;
};
//
// BASE CLASS DEFINITIONS
//
class MutableFstClass;
class FstClass : public FstClassBase {
public:
template<class Arc>
static FstClass *Read(istream &stream,
const FstReadOptions &opts) {
if (!opts.header) {
FSTERROR() << "FstClass::Read: options header not specified";
return 0;
}
const FstHeader &hdr = *opts.header;
if (hdr.Properties() & kMutable) {
return ReadTypedFst<MutableFstClass, MutableFst<Arc> >(stream, opts);
} else {
return ReadTypedFst<FstClass, Fst<Arc> >(stream, opts);
}
}
template<class Arc>
explicit FstClass(Fst<Arc> *fst) : impl_(new FstClassImpl<Arc>(fst)) { }
explicit FstClass(const FstClass &other) : impl_(other.impl_->Copy()) { }
static FstClass *Read(const string &fname);
virtual const string &ArcType() const {
return impl_->ArcType();
}
virtual const string& FstType() const {
return impl_->FstType();
}
virtual const SymbolTable *InputSymbols() const {
return impl_->InputSymbols();
}
virtual const SymbolTable *OutputSymbols() const {
return impl_->OutputSymbols();
}
virtual const string& WeightType() const {
return impl_->WeightType();
}
virtual void Write(const string &fname) const {
impl_->Write(fname);
}
virtual uint64 Properties(uint64 mask, bool test) const {
return impl_->Properties(mask, test);
}
template<class Arc>
const Fst<Arc> *GetFst() const {
if (Arc::Type() != ArcType()) {
return NULL;
} else {
FstClassImpl<Arc> *typed_impl = static_cast<FstClassImpl<Arc> *>(impl_);
return typed_impl->GetImpl();
}
}
virtual ~FstClass() { delete impl_; }
// These methods are required by IO registration
template<class Arc>
static FstClassImplBase *Convert(const FstClass &other) {
LOG(ERROR) << "Doesn't make sense to convert any class to type FstClass.";
return 0;
}
template<class Arc>
static FstClassImplBase *Create() {
LOG(ERROR) << "Doesn't make sense to create an FstClass with a "
<< "particular arc type.";
return 0;
}
protected:
explicit FstClass(FstClassImplBase *impl) : impl_(impl) { }
// Generic template method for reading an arc-templated FST of type
// UnderlyingT, and returning it wrapped as FstClassT, with appropriate
// error checking. Called from arc-templated Read() static methods.
template<class FstClassT, class UnderlyingT>
static FstClassT* ReadTypedFst(istream &stream,
const FstReadOptions &opts) {
UnderlyingT *u = UnderlyingT::Read(stream, opts);
if (!u) {
return 0;
} else {
FstClassT *r = new FstClassT(u);
delete u;
return r;
}
}
FstClassImplBase *GetImpl() { return impl_; }
private:
FstClassImplBase *impl_;
};
//
// Specific types of FstClass with special properties
//
class MutableFstClass : public FstClass {
public:
template<class Arc>
explicit MutableFstClass(MutableFst<Arc> *fst) :
FstClass(fst) { }
template<class Arc>
MutableFst<Arc> *GetMutableFst() {
Fst<Arc> *fst = const_cast<Fst<Arc> *>(this->GetFst<Arc>());
MutableFst<Arc> *mfst = static_cast<MutableFst<Arc> *>(fst);
return mfst;
}
template<class Arc>
static MutableFstClass *Read(istream &stream,
const FstReadOptions &opts) {
MutableFst<Arc> *mfst = MutableFst<Arc>::Read(stream, opts);
if (!mfst) {
return 0;
} else {
MutableFstClass *retval = new MutableFstClass(mfst);
delete mfst;
return retval;
}
}
static MutableFstClass *Read(const string &fname, bool convert = false);
virtual void SetInputSymbols(SymbolTable *is) {
GetImpl()->SetInputSymbols(is);
}
virtual void SetOutputSymbols(SymbolTable *os) {
GetImpl()->SetOutputSymbols(os);
}
// These methods are required by IO registration
template<class Arc>
static FstClassImplBase *Convert(const FstClass &other) {
LOG(ERROR) << "Doesn't make sense to convert any class to type "
<< "MutableFstClass.";
return 0;
}
template<class Arc>
static FstClassImplBase *Create() {
LOG(ERROR) << "Doesn't make sense to create a MutableFstClass with a "
<< "particular arc type.";
return 0;
}
protected:
explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) { }
};
class VectorFstClass : public MutableFstClass {
public:
explicit VectorFstClass(const FstClass &other);
explicit VectorFstClass(const string &arc_type);
template<class Arc>
explicit VectorFstClass(VectorFst<Arc> *fst) :
MutableFstClass(fst) { }
template<class Arc>
static VectorFstClass *Read(istream &stream,
const FstReadOptions &opts) {
VectorFst<Arc> *vfst = VectorFst<Arc>::Read(stream, opts);
if (!vfst) {
return 0;
} else {
VectorFstClass *retval = new VectorFstClass(vfst);
delete vfst;
return retval;
}
}
static VectorFstClass *Read(const string &fname);
// Converter / creator for known arc types
template<class Arc>
static FstClassImplBase *Convert(const FstClass &other) {
return new FstClassImpl<Arc>(new VectorFst<Arc>(
*other.GetFst<Arc>()), true);
}
template<class Arc>
static FstClassImplBase *Create() {
return new FstClassImpl<Arc>(new VectorFst<Arc>(), true);
}
};
} // namespace script
} // namespace fst
#endif // FST_SCRIPT_FST_CLASS_H_