blob: f464403e78b3ab1e5c635195baa13cde44053ba2 [file] [log] [blame]
// expand.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Expand a PDT to an FST.
#ifndef FST_EXTENSIONS_PDT_EXPAND_H__
#define FST_EXTENSIONS_PDT_EXPAND_H__
#include <vector>
using std::vector;
#include <fst/extensions/pdt/pdt.h>
#include <fst/extensions/pdt/paren.h>
#include <fst/extensions/pdt/shortest-path.h>
#include <fst/extensions/pdt/reverse.h>
#include <fst/cache.h>
#include <fst/mutable-fst.h>
#include <fst/queue.h>
#include <fst/state-table.h>
#include <fst/test-properties.h>
namespace fst {
template <class Arc>
struct ExpandFstOptions : public CacheOptions {
bool keep_parentheses;
PdtStack<typename Arc::StateId, typename Arc::Label> *stack;
PdtStateTable<typename Arc::StateId, typename Arc::StateId> *state_table;
ExpandFstOptions(
const CacheOptions &opts = CacheOptions(),
bool kp = false,
PdtStack<typename Arc::StateId, typename Arc::Label> *s = 0,
PdtStateTable<typename Arc::StateId, typename Arc::StateId> *st = 0)
: CacheOptions(opts), keep_parentheses(kp), stack(s), state_table(st) {}
};
// Properties for an expanded PDT.
inline uint64 ExpandProperties(uint64 inprops) {
return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted);
}
// Implementation class for ExpandFst
template <class A>
class ExpandFstImpl
: public CacheImpl<A> {
public:
using FstImpl<A>::SetType;
using FstImpl<A>::SetProperties;
using FstImpl<A>::Properties;
using FstImpl<A>::SetInputSymbols;
using FstImpl<A>::SetOutputSymbols;
using CacheBaseImpl< CacheState<A> >::PushArc;
using CacheBaseImpl< CacheState<A> >::HasArcs;
using CacheBaseImpl< CacheState<A> >::HasFinal;
using CacheBaseImpl< CacheState<A> >::HasStart;
using CacheBaseImpl< CacheState<A> >::SetArcs;
using CacheBaseImpl< CacheState<A> >::SetFinal;
using CacheBaseImpl< CacheState<A> >::SetStart;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef StateId StackId;
typedef PdtStateTuple<StateId, StackId> StateTuple;
ExpandFstImpl(const Fst<A> &fst,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,
const ExpandFstOptions<A> &opts)
: CacheImpl<A>(opts), fst_(fst.Copy()),
stack_(opts.stack ? opts.stack: new PdtStack<StateId, Label>(parens)),
state_table_(opts.state_table ? opts.state_table :
new PdtStateTable<StateId, StackId>()),
own_stack_(opts.stack == 0), own_state_table_(opts.state_table == 0),
keep_parentheses_(opts.keep_parentheses) {
SetType("expand");
uint64 props = fst.Properties(kFstProperties, false);
SetProperties(ExpandProperties(props), kCopyProperties);
SetInputSymbols(fst.InputSymbols());
SetOutputSymbols(fst.OutputSymbols());
}
ExpandFstImpl(const ExpandFstImpl &impl)
: CacheImpl<A>(impl),
fst_(impl.fst_->Copy(true)),
stack_(new PdtStack<StateId, Label>(*impl.stack_)),
state_table_(new PdtStateTable<StateId, StackId>()),
own_stack_(true), own_state_table_(true),
keep_parentheses_(impl.keep_parentheses_) {
SetType("expand");
SetProperties(impl.Properties(), kCopyProperties);
SetInputSymbols(impl.InputSymbols());
SetOutputSymbols(impl.OutputSymbols());
}
~ExpandFstImpl() {
delete fst_;
if (own_stack_)
delete stack_;
if (own_state_table_)
delete state_table_;
}
StateId Start() {
if (!HasStart()) {
StateId s = fst_->Start();
if (s == kNoStateId)
return kNoStateId;
StateTuple tuple(s, 0);
StateId start = state_table_->FindState(tuple);
SetStart(start);
}
return CacheImpl<A>::Start();
}
Weight Final(StateId s) {
if (!HasFinal(s)) {
const StateTuple &tuple = state_table_->Tuple(s);
Weight w = fst_->Final(tuple.state_id);
if (w != Weight::Zero() && tuple.stack_id == 0)
SetFinal(s, w);
else
SetFinal(s, Weight::Zero());
}
return CacheImpl<A>::Final(s);
}
size_t NumArcs(StateId s) {
if (!HasArcs(s)) {
ExpandState(s);
}
return CacheImpl<A>::NumArcs(s);
}
size_t NumInputEpsilons(StateId s) {
if (!HasArcs(s))
ExpandState(s);
return CacheImpl<A>::NumInputEpsilons(s);
}
size_t NumOutputEpsilons(StateId s) {
if (!HasArcs(s))
ExpandState(s);
return CacheImpl<A>::NumOutputEpsilons(s);
}
void InitArcIterator(StateId s, ArcIteratorData<A> *data) {
if (!HasArcs(s))
ExpandState(s);
CacheImpl<A>::InitArcIterator(s, data);
}
// Computes the outgoing transitions from a state, creating new destination
// states as needed.
void ExpandState(StateId s) {
StateTuple tuple = state_table_->Tuple(s);
for (ArcIterator< Fst<A> > aiter(*fst_, tuple.state_id);
!aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
StackId stack_id = stack_->Find(tuple.stack_id, arc.ilabel);
if (stack_id == -1) {
// Non-matching close parenthesis
continue;
} else if ((stack_id != tuple.stack_id) && !keep_parentheses_) {
// Stack push/pop
arc.ilabel = arc.olabel = 0;
}
StateTuple ntuple(arc.nextstate, stack_id);
arc.nextstate = state_table_->FindState(ntuple);
PushArc(s, arc);
}
SetArcs(s);
}
const PdtStack<StackId, Label> &GetStack() const { return *stack_; }
const PdtStateTable<StateId, StackId> &GetStateTable() const {
return *state_table_;
}
private:
const Fst<A> *fst_;
PdtStack<StackId, Label> *stack_;
PdtStateTable<StateId, StackId> *state_table_;
bool own_stack_;
bool own_state_table_;
bool keep_parentheses_;
void operator=(const ExpandFstImpl<A> &); // disallow
};
// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version is a delayed Fst. In the PDT, some transitions are
// labeled with open or close parentheses. To be interpreted as a PDT,
// the parens must balance on a path. The open-close parenthesis label
// pairs are passed in 'parens'. The expansion enforces the
// parenthesis constraints. The PDT must be expandable as an FST.
//
// This class attaches interface to implementation and handles
// reference counting, delegating most methods to ImplToFst.
template <class A>
class ExpandFst : public ImplToFst< ExpandFstImpl<A> > {
public:
friend class ArcIterator< ExpandFst<A> >;
friend class StateIterator< ExpandFst<A> >;
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::Weight Weight;
typedef typename A::StateId StateId;
typedef StateId StackId;
typedef CacheState<A> State;
typedef ExpandFstImpl<A> Impl;
ExpandFst(const Fst<A> &fst,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens)
: ImplToFst<Impl>(new Impl(fst, parens, ExpandFstOptions<A>())) {}
ExpandFst(const Fst<A> &fst,
const vector<pair<typename Arc::Label,
typename Arc::Label> > &parens,
const ExpandFstOptions<A> &opts)
: ImplToFst<Impl>(new Impl(fst, parens, opts)) {}
// See Fst<>::Copy() for doc.
ExpandFst(const ExpandFst<A> &fst, bool safe = false)
: ImplToFst<Impl>(fst, safe) {}
// Get a copy of this ExpandFst. See Fst<>::Copy() for further doc.
virtual ExpandFst<A> *Copy(bool safe = false) const {
return new ExpandFst<A>(*this, safe);
}
virtual inline void InitStateIterator(StateIteratorData<A> *data) const;
virtual void InitArcIterator(StateId s, ArcIteratorData<A> *data) const {
GetImpl()->InitArcIterator(s, data);
}
const PdtStack<StackId, Label> &GetStack() const {
return GetImpl()->GetStack();
}
const PdtStateTable<StateId, StackId> &GetStateTable() const {
return GetImpl()->GetStateTable();
}
private:
// Makes visible to friends.
Impl *GetImpl() const { return ImplToFst<Impl>::GetImpl(); }
void operator=(const ExpandFst<A> &fst); // Disallow
};
// Specialization for ExpandFst.
template<class A>
class StateIterator< ExpandFst<A> >
: public CacheStateIterator< ExpandFst<A> > {
public:
explicit StateIterator(const ExpandFst<A> &fst)
: CacheStateIterator< ExpandFst<A> >(fst, fst.GetImpl()) {}
};
// Specialization for ExpandFst.
template <class A>
class ArcIterator< ExpandFst<A> >
: public CacheArcIterator< ExpandFst<A> > {
public:
typedef typename A::StateId StateId;
ArcIterator(const ExpandFst<A> &fst, StateId s)
: CacheArcIterator< ExpandFst<A> >(fst.GetImpl(), s) {
if (!fst.GetImpl()->HasArcs(s))
fst.GetImpl()->ExpandState(s);
}
private:
DISALLOW_COPY_AND_ASSIGN(ArcIterator);
};
template <class A> inline
void ExpandFst<A>::InitStateIterator(StateIteratorData<A> *data) const
{
data->base = new StateIterator< ExpandFst<A> >(*this);
}
//
// PrunedExpand Class
//
// Prunes the delayed expansion of a pushdown transducer (PDT) encoded
// as an FST into an FST. In the PDT, some transitions are labeled
// with open or close parentheses. To be interpreted as a PDT, the
// parens must balance on a path. The open-close parenthesis label
// pairs are passed in 'parens'. The expansion enforces the
// parenthesis constraints.
//
// The algorithm works by visiting the delayed ExpandFst using a
// shortest-stack first queue discipline and relies on the
// shortest-distance information computed using a reverse
// shortest-path call to perform the pruning.
//
// The algorithm maintains the same state ordering between the ExpandFst
// being visited 'efst_' and the result of pruning written into the
// MutableFst 'ofst_' to improve readability of the code.
//
template <class A>
class PrunedExpand {
public:
typedef A Arc;
typedef typename A::Label Label;
typedef typename A::StateId StateId;
typedef typename A::Weight Weight;
typedef StateId StackId;
typedef PdtStack<StackId, Label> Stack;
typedef PdtStateTable<StateId, StackId> StateTable;
typedef typename PdtBalanceData<Arc>::SetIterator SetIterator;
// Constructor taking as input a PDT specified by 'ifst' and 'parens'.
// 'keep_parentheses' specifies whether parentheses are replaced by
// epsilons or not during the expansion. 'opts' is the cache options
// used to instantiate the underlying ExpandFst.
PrunedExpand(const Fst<A> &ifst,
const vector<pair<Label, Label> > &parens,
bool keep_parentheses = false,
const CacheOptions &opts = CacheOptions())
: ifst_(ifst.Copy()),
keep_parentheses_(keep_parentheses),
stack_(parens),
efst_(ifst, parens,
ExpandFstOptions<Arc>(opts, true, &stack_, &state_table_)),
queue_(state_table_, stack_, stack_length_, distance_, fdistance_) {
Reverse(*ifst_, parens, &rfst_);
VectorFst<Arc> path;
reverse_shortest_path_ = new SP(
rfst_, parens,
PdtShortestPathOptions<A, FifoQueue<StateId> >(true, false));
reverse_shortest_path_->ShortestPath(&path);
balance_data_ = reverse_shortest_path_->GetBalanceData()->Reverse(
rfst_.NumStates(), 10, -1);
InitCloseParenMultimap(parens);
}
~PrunedExpand() {
delete ifst_;
delete reverse_shortest_path_;
delete balance_data_;
}
// Expands and prunes with weight threshold 'threshold' the input PDT.
// Writes the result in 'ofst'.
void Expand(MutableFst<A> *ofst, const Weight &threshold);
private:
static const uint8 kEnqueued;
static const uint8 kExpanded;
static const uint8 kSourceState;
// Comparison functor used by the queue:
// 1. states corresponding to shortest stack first,
// 2. among stacks of the same length, reverse lexicographic order is used,
// 3. among states with the same stack, shortest-first order is used.
class StackCompare {
public:
StackCompare(const StateTable &st,
const Stack &s, const vector<StackId> &sl,
const vector<Weight> &d, const vector<Weight> &fd)
: state_table_(st), stack_(s), stack_length_(sl),
distance_(d), fdistance_(fd) {}
bool operator()(StateId s1, StateId s2) const {
StackId si1 = state_table_.Tuple(s1).stack_id;
StackId si2 = state_table_.Tuple(s2).stack_id;
if (stack_length_[si1] < stack_length_[si2])
return true;
if (stack_length_[si1] > stack_length_[si2])
return false;
// If stack id equal, use A*
if (si1 == si2) {
Weight w1 = (s1 < distance_.size()) && (s1 < fdistance_.size()) ?
Times(distance_[s1], fdistance_[s1]) : Weight::Zero();
Weight w2 = (s2 < distance_.size()) && (s2 < fdistance_.size()) ?
Times(distance_[s2], fdistance_[s2]) : Weight::Zero();
return less_(w1, w2);
}
// If lenghts are equal, use reverse lexico.
for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) {
if (stack_.Top(si1) < stack_.Top(si2)) return true;
if (stack_.Top(si1) > stack_.Top(si2)) return false;
}
return false;
}
private:
const StateTable &state_table_;
const Stack &stack_;
const vector<StackId> &stack_length_;
const vector<Weight> &distance_;
const vector<Weight> &fdistance_;
NaturalLess<Weight> less_;
};
class ShortestStackFirstQueue
: public ShortestFirstQueue<StateId, StackCompare> {
public:
ShortestStackFirstQueue(
const PdtStateTable<StateId, StackId> &st,
const Stack &s,
const vector<StackId> &sl,
const vector<Weight> &d, const vector<Weight> &fd)
: ShortestFirstQueue<StateId, StackCompare>(
StackCompare(st, s, sl, d, fd)) {}
};
void InitCloseParenMultimap(const vector<pair<Label, Label> > &parens);
Weight DistanceToDest(StateId state, StateId source) const;
uint8 Flags(StateId s) const;
void SetFlags(StateId s, uint8 flags, uint8 mask);
Weight Distance(StateId s) const;
void SetDistance(StateId s, Weight w);
Weight FinalDistance(StateId s) const;
void SetFinalDistance(StateId s, Weight w);
StateId SourceState(StateId s) const;
void SetSourceState(StateId s, StateId p);
void AddStateAndEnqueue(StateId s);
void Relax(StateId s, const A &arc, Weight w);
bool PruneArc(StateId s, const A &arc);
void ProcStart();
void ProcFinal(StateId s);
bool ProcNonParen(StateId s, const A &arc, bool add_arc);
bool ProcOpenParen(StateId s, const A &arc, StackId si, StackId nsi);
bool ProcCloseParen(StateId s, const A &arc);
void ProcDestStates(StateId s, StackId si);
Fst<A> *ifst_; // Input PDT
VectorFst<Arc> rfst_; // Reversed PDT
bool keep_parentheses_; // Keep parentheses in ofst?
StateTable state_table_; // State table for efst_
Stack stack_; // Stack trie
ExpandFst<Arc> efst_; // Expanded PDT
vector<StackId> stack_length_; // Length of stack for given stack id
vector<Weight> distance_; // Distance from initial state in efst_/ofst
vector<Weight> fdistance_; // Distance to final states in efst_/ofst
ShortestStackFirstQueue queue_; // Queue used to visit efst_
vector<uint8> flags_; // Status flags for states in efst_/ofst
vector<StateId> sources_; // PDT source state for each expanded state
typedef PdtShortestPath<Arc, FifoQueue<StateId> > SP;
typedef typename SP::CloseParenMultimap ParenMultimap;
SP *reverse_shortest_path_; // Shortest path for rfst_
PdtBalanceData<Arc> *balance_data_; // Not owned by shortest_path_
ParenMultimap close_paren_multimap_; // Maps open paren arcs to
// balancing close paren arcs.
MutableFst<Arc> *ofst_; // Output fst
Weight limit_; // Weight limit
typedef unordered_map<StateId, Weight> DestMap;
DestMap dest_map_;
StackId current_stack_id_;
// 'current_stack_id_' is the stack id of the states currently at the top
// of queue, i.e., the states currently being popped and processed.
// 'dest_map_' maps a state 's' in 'ifst_' that is the source
// of a close parentheses matching the top of 'current_stack_id_; to
// the shortest-distance from '(s, current_stack_id_)' to the final
// states in 'efst_'.
ssize_t current_paren_id_; // Paren id at top of current stack
ssize_t cached_stack_id_;
StateId cached_source_;
slist<pair<StateId, Weight> > cached_dest_list_;
// 'cached_dest_list_' contains the set of pair of destination
// states and weight to final states for source state
// 'cached_source_' and paren id 'cached_paren_id': the set of
// source state of a close parenthesis with paren id
// 'cached_paren_id' balancing an incoming open parenthesis with
// paren id 'cached_paren_id' in state 'cached_source_'.
NaturalLess<Weight> less_;
};
template <class A> const uint8 PrunedExpand<A>::kEnqueued = 0x01;
template <class A> const uint8 PrunedExpand<A>::kExpanded = 0x02;
template <class A> const uint8 PrunedExpand<A>::kSourceState = 0x04;
// Initializes close paren multimap, mapping pairs (s,paren_id) to
// all the arcs out of s labeled with close parenthese for paren_id.
template <class A>
void PrunedExpand<A>::InitCloseParenMultimap(
const vector<pair<Label, Label> > &parens) {
unordered_map<Label, Label> paren_id_map;
for (Label i = 0; i < parens.size(); ++i) {
const pair<Label, Label> &p = parens[i];
paren_id_map[p.first] = i;
paren_id_map[p.second] = i;
}
for (StateIterator<Fst<Arc> > siter(*ifst_); !siter.Done(); siter.Next()) {
StateId s = siter.Value();
for (ArcIterator<Fst<Arc> > aiter(*ifst_, s);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
typename unordered_map<Label, Label>::const_iterator pit
= paren_id_map.find(arc.ilabel);
if (pit == paren_id_map.end()) continue;
if (arc.ilabel == parens[pit->second].second) { // Close paren
ParenState<Arc> paren_state(pit->second, s);
close_paren_multimap_.insert(make_pair(paren_state, arc));
}
}
}
}
// Returns the weight of the shortest balanced path from 'source' to 'dest'
// in 'ifst_', 'dest' must be the source state of a close paren arc.
template <class A>
typename A::Weight PrunedExpand<A>::DistanceToDest(StateId source,
StateId dest) const {
typename SP::SearchState s(source + 1, dest + 1);
VLOG(2) << "D(" << source << ", " << dest << ") ="
<< reverse_shortest_path_->GetShortestPathData().Distance(s);
return reverse_shortest_path_->GetShortestPathData().Distance(s);
}
// Returns the flags for state 's' in 'ofst_'.
template <class A>
uint8 PrunedExpand<A>::Flags(StateId s) const {
return s < flags_.size() ? flags_[s] : 0;
}
// Modifies the flags for state 's' in 'ofst_'.
template <class A>
void PrunedExpand<A>::SetFlags(StateId s, uint8 flags, uint8 mask) {
while (flags_.size() <= s) flags_.push_back(0);
flags_[s] &= ~mask;
flags_[s] |= flags & mask;
}
// Returns the shortest distance from the initial state to 's' in 'ofst_'.
template <class A>
typename A::Weight PrunedExpand<A>::Distance(StateId s) const {
return s < distance_.size() ? distance_[s] : Weight::Zero();
}
// Sets the shortest distance from the initial state to 's' in 'ofst_' to 'w'.
template <class A>
void PrunedExpand<A>::SetDistance(StateId s, Weight w) {
while (distance_.size() <= s ) distance_.push_back(Weight::Zero());
distance_[s] = w;
}
// Returns the shortest distance from 's' to the final states in 'ofst_'.
template <class A>
typename A::Weight PrunedExpand<A>::FinalDistance(StateId s) const {
return s < fdistance_.size() ? fdistance_[s] : Weight::Zero();
}
// Sets the shortest distance from 's' to the final states in 'ofst_' to 'w'.
template <class A>
void PrunedExpand<A>::SetFinalDistance(StateId s, Weight w) {
while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero());
fdistance_[s] = w;
}
// Returns the PDT "source" state of state 's' in 'ofst_'.
template <class A>
typename A::StateId PrunedExpand<A>::SourceState(StateId s) const {
return s < sources_.size() ? sources_[s] : kNoStateId;
}
// Sets the PDT "source" state of state 's' in 'ofst_' to state 'p' in 'ifst_'.
template <class A>
void PrunedExpand<A>::SetSourceState(StateId s, StateId p) {
while (sources_.size() <= s) sources_.push_back(kNoStateId);
sources_[s] = p;
}
// Adds state 's' of 'efst_' to 'ofst_' and inserts it in the queue,
// modifying the flags for 's' accordingly.
template <class A>
void PrunedExpand<A>::AddStateAndEnqueue(StateId s) {
if (!(Flags(s) & (kEnqueued | kExpanded))) {
while (ofst_->NumStates() <= s) ofst_->AddState();
queue_.Enqueue(s);
SetFlags(s, kEnqueued, kEnqueued);
} else if (Flags(s) & kEnqueued) {
queue_.Update(s);
}
// TODO(allauzen): Check everything is fine when kExpanded?
}
// Relaxes arc 'arc' out of state 's' in 'ofst_':
// * if the distance to 's' times the weight of 'arc' is smaller than
// the currently stored distance for 'arc.nextstate',
// updates 'Distance(arc.nextstate)' with new estimate;
// * if 'fd' is less than the currently stored distance from 'arc.nextstate'
// to the final state, updates with new estimate.
template <class A>
void PrunedExpand<A>::Relax(StateId s, const A &arc, Weight fd) {
Weight nd = Times(Distance(s), arc.weight);
if (less_(nd, Distance(arc.nextstate))) {
SetDistance(arc.nextstate, nd);
SetSourceState(arc.nextstate, SourceState(s));
}
if (less_(fd, FinalDistance(arc.nextstate)))
SetFinalDistance(arc.nextstate, fd);
VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to "
<< arc.nextstate << ", d[ns] = " << Distance(arc.nextstate)
<< ", nd = " << nd;
}
// Returns 'true' if the arc 'arc' out of state 's' in 'efst_' needs to
// be pruned.
template <class A>
bool PrunedExpand<A>::PruneArc(StateId s, const A &arc) {
VLOG(2) << "Prune ?";
Weight fd = Weight::Zero();
if ((cached_source_ != SourceState(s)) ||
(cached_stack_id_ != current_stack_id_)) {
cached_source_ = SourceState(s);
cached_stack_id_ = current_stack_id_;
cached_dest_list_.clear();
if (cached_source_ != ifst_->Start()) {
for (SetIterator set_iter =
balance_data_->Find(current_paren_id_, cached_source_);
!set_iter.Done(); set_iter.Next()) {
StateId dest = set_iter.Element();
typename DestMap::const_iterator iter = dest_map_.find(dest);
cached_dest_list_.push_front(*iter);
}
} else {
// TODO(allauzen): queue discipline should prevent this never
// from happening; replace by a check.
cached_dest_list_.push_front(
make_pair(rfst_.Start() -1, Weight::One()));
}
}
for (typename slist<pair<StateId, Weight> >::const_iterator iter =
cached_dest_list_.begin();
iter != cached_dest_list_.end();
++iter) {
fd = Plus(fd,
Times(DistanceToDest(state_table_.Tuple(arc.nextstate).state_id,
iter->first),
iter->second));
}
Relax(s, arc, fd);
Weight w = Times(Distance(s), Times(arc.weight, fd));
return less_(limit_, w);
}
// Adds start state of 'efst_' to 'ofst_', enqueues it and initializes
// the distance data structures.
template <class A>
void PrunedExpand<A>::ProcStart() {
StateId s = efst_.Start();
AddStateAndEnqueue(s);
ofst_->SetStart(s);
SetSourceState(s, ifst_->Start());
current_stack_id_ = 0;
current_paren_id_ = -1;
stack_length_.push_back(0);
dest_map_[rfst_.Start() - 1] = Weight::One(); // not needed
cached_source_ = ifst_->Start();
cached_stack_id_ = 0;
cached_dest_list_.push_front(
make_pair(rfst_.Start() -1, Weight::One()));
PdtStateTuple<StateId, StackId> tuple(rfst_.Start() - 1, 0);
SetFinalDistance(state_table_.FindState(tuple), Weight::One());
SetDistance(s, Weight::One());
SetFinalDistance(s, DistanceToDest(ifst_->Start(), rfst_.Start() - 1));
VLOG(2) << DistanceToDest(ifst_->Start(), rfst_.Start() - 1);
}
// Makes 's' final in 'ofst_' if shortest accepting path ending in 's'
// is below threshold.
template <class A>
void PrunedExpand<A>::ProcFinal(StateId s) {
Weight final = efst_.Final(s);
if ((final == Weight::Zero()) || less_(limit_, Times(Distance(s), final)))
return;
ofst_->SetFinal(s, final);
}
// Returns true when arc (or meta-arc) 'arc' out of 's' in 'efst_' is
// below the threshold. When 'add_arc' is true, 'arc' is added to 'ofst_'.
template <class A>
bool PrunedExpand<A>::ProcNonParen(StateId s, const A &arc, bool add_arc) {
VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate
<< ", " << arc.ilabel << ":" << arc.olabel << " / " << arc.weight
<< ", add_arc = " << (add_arc ? "true" : "false");
if (PruneArc(s, arc)) return false;
if(add_arc) ofst_->AddArc(s, arc);
AddStateAndEnqueue(arc.nextstate);
return true;
}
// Processes an open paren arc 'arc' out of state 's' in 'ofst_'.
// When 'arc' is labeled with an open paren,
// 1. considers each (shortest) balanced path starting in 's' by
// taking 'arc' and ending by a close paren balancing the open
// paren of 'arc' as a meta-arc, processes and prunes each meta-arc
// as a non-paren arc, inserting its destination to the queue;
// 2. if at least one of these meta-arcs has not been pruned,
// adds the destination of 'arc' to 'ofst_' as a new source state
// for the stack id 'nsi' and inserts it in the queue.
template <class A>
bool PrunedExpand<A>::ProcOpenParen(StateId s, const A &arc, StackId si,
StackId nsi) {
// Update the stack lenght when needed: |nsi| = |si| + 1.
while (stack_length_.size() <= nsi) stack_length_.push_back(-1);
if (stack_length_[nsi] == -1)
stack_length_[nsi] = stack_length_[si] + 1;
StateId ns = arc.nextstate;
VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id
<< ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")";
bool proc_arc = false;
Weight fd = Weight::Zero();
ssize_t paren_id = stack_.ParenId(arc.ilabel);
slist<StateId> sources;
for (SetIterator set_iter =
balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id);
!set_iter.Done(); set_iter.Next()) {
sources.push_front(set_iter.Element());
}
for (typename slist<StateId>::const_iterator sources_iter = sources.begin();
sources_iter != sources.end();
++ sources_iter) {
StateId source = *sources_iter;
VLOG(2) << "Close paren source: " << source;
ParenState<Arc> paren_state(paren_id, source);
for (typename ParenMultimap::const_iterator iter =
close_paren_multimap_.find(paren_state);
iter != close_paren_multimap_.end() && paren_state == iter->first;
++iter) {
Arc meta_arc = iter->second;
PdtStateTuple<StateId, StackId> tuple(meta_arc.nextstate, si);
meta_arc.nextstate = state_table_.FindState(tuple);
VLOG(2) << state_table_.Tuple(ns).state_id << ", " << source;
VLOG(2) << "Meta arc weight = " << arc.weight << " Times "
<< DistanceToDest(state_table_.Tuple(ns).state_id, source)
<< " Times " << meta_arc.weight;
meta_arc.weight = Times(
arc.weight,
Times(DistanceToDest(state_table_.Tuple(ns).state_id, source),
meta_arc.weight));
proc_arc |= ProcNonParen(s, meta_arc, false);
fd = Plus(fd, Times(
Times(
DistanceToDest(state_table_.Tuple(ns).state_id, source),
iter->second.weight),
FinalDistance(meta_arc.nextstate)));
}
}
if (proc_arc) {
VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate;
ofst_->AddArc(
s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
AddStateAndEnqueue(arc.nextstate);
Weight nd = Times(Distance(s), arc.weight);
if(less_(nd, Distance(arc.nextstate)))
SetDistance(arc.nextstate, nd);
// FinalDistance not necessary for source state since pruning
// decided using the meta-arcs above. But this is a problem with
// A*, hence:
if (less_(fd, FinalDistance(arc.nextstate)))
SetFinalDistance(arc.nextstate, fd);
SetFlags(arc.nextstate, kSourceState, kSourceState);
}
return proc_arc;
}
// Checks that shortest path through close paren arc in 'efst_' is
// below threshold, if so adds it to 'ofst_'.
template <class A>
bool PrunedExpand<A>::ProcCloseParen(StateId s, const A &arc) {
Weight w = Times(Distance(s),
Times(arc.weight, FinalDistance(arc.nextstate)));
if (less_(limit_, w))
return false;
ofst_->AddArc(
s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate));
return true;
}
// When 's' in 'ofst_' is a source state for stack id 'si', identifies
// all the corresponding possible destination states, that is, all the
// states in 'ifst_' that have an outgoing close paren arc balancing
// the incoming open paren taken to get to 's', and for each such
// state 't', computes the shortest distance from (t, si) to the final
// states in 'ofst_'. Stores this information in 'dest_map_'.
template <class A>
void PrunedExpand<A>::ProcDestStates(StateId s, StackId si) {
if (!(Flags(s) & kSourceState)) return;
if (si != current_stack_id_) {
dest_map_.clear();
current_stack_id_ = si;
current_paren_id_ = stack_.Top(current_stack_id_);
VLOG(2) << "StackID " << si << " dequeued for first time";
}
// TODO(allauzen): clean up source state business; rename current function to
// ProcSourceState.
SetSourceState(s, state_table_.Tuple(s).state_id);
ssize_t paren_id = stack_.Top(si);
for (SetIterator set_iter =
balance_data_->Find(paren_id, state_table_.Tuple(s).state_id);
!set_iter.Done(); set_iter.Next()) {
StateId dest_state = set_iter.Element();
if (dest_map_.find(dest_state) != dest_map_.end())
continue;
Weight dest_weight = Weight::Zero();
ParenState<Arc> paren_state(paren_id, dest_state);
for (typename ParenMultimap::const_iterator iter =
close_paren_multimap_.find(paren_state);
iter != close_paren_multimap_.end() && paren_state == iter->first;
++iter) {
const Arc &arc = iter->second;
PdtStateTuple<StateId, StackId> tuple(arc.nextstate, stack_.Pop(si));
dest_weight = Plus(dest_weight,
Times(arc.weight,
FinalDistance(state_table_.FindState(tuple))));
}
dest_map_[dest_state] = dest_weight;
VLOG(2) << "State " << dest_state << " is a dest state for stack id "
<< si << " with weight " << dest_weight;
}
}
// Expands and prunes with weight threshold 'threshold' the input PDT.
// Writes the result in 'ofst'.
template <class A>
void PrunedExpand<A>::Expand(
MutableFst<A> *ofst, const typename A::Weight &threshold) {
ofst_ = ofst;
ofst_->DeleteStates();
ofst_->SetInputSymbols(ifst_->InputSymbols());
ofst_->SetOutputSymbols(ifst_->OutputSymbols());
limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold);
flags_.clear();
ProcStart();
while (!queue_.Empty()) {
StateId s = queue_.Head();
queue_.Dequeue();
SetFlags(s, kExpanded, kExpanded | kEnqueued);
VLOG(2) << s << " dequeued!";
ProcFinal(s);
StackId stack_id = state_table_.Tuple(s).stack_id;
ProcDestStates(s, stack_id);
for (ArcIterator<ExpandFst<Arc> > aiter(efst_, s);
!aiter.Done();
aiter.Next()) {
Arc arc = aiter.Value();
StackId nextstack_id = state_table_.Tuple(arc.nextstate).stack_id;
if (stack_id == nextstack_id)
ProcNonParen(s, arc, true);
else if (stack_id == stack_.Pop(nextstack_id))
ProcOpenParen(s, arc, stack_id, nextstack_id);
else
ProcCloseParen(s, arc);
}
VLOG(2) << "d[" << s << "] = " << Distance(s)
<< ", fd[" << s << "] = " << FinalDistance(s);
}
}
//
// Expand() Functions
//
template <class Arc>
struct ExpandOptions {
bool connect;
bool keep_parentheses;
typename Arc::Weight weight_threshold;
ExpandOptions(bool c = true, bool k = false,
typename Arc::Weight w = Arc::Weight::Zero())
: connect(c), keep_parentheses(k), weight_threshold(w) {}
};
// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version writes the expanded PDT result to a MutableFst.
// In the PDT, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path. The open-close parenthesis label pairs are passed in
// 'parens'. The expansion enforces the parenthesis constraints. The
// PDT must be expandable as an FST.
template <class Arc>
void Expand(
const Fst<Arc> &ifst,
const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
MutableFst<Arc> *ofst,
const ExpandOptions<Arc> &opts) {
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
typedef typename ExpandFst<Arc>::StackId StackId;
ExpandFstOptions<Arc> eopts;
eopts.gc_limit = 0;
if (opts.weight_threshold == Weight::Zero()) {
eopts.keep_parentheses = opts.keep_parentheses;
*ofst = ExpandFst<Arc>(ifst, parens, eopts);
} else {
PrunedExpand<Arc> pruned_expand(ifst, parens, opts.keep_parentheses);
pruned_expand.Expand(ofst, opts.weight_threshold);
}
if (opts.connect)
Connect(ofst);
}
// Expands a pushdown transducer (PDT) encoded as an FST into an FST.
// This version writes the expanded PDT result to a MutableFst.
// In the PDT, some transitions are labeled with open or close
// parentheses. To be interpreted as a PDT, the parens must balance on
// a path. The open-close parenthesis label pairs are passed in
// 'parens'. The expansion enforces the parenthesis constraints. The
// PDT must be expandable as an FST.
template<class Arc>
void Expand(
const Fst<Arc> &ifst,
const vector<pair<typename Arc::Label, typename Arc::Label> > &parens,
MutableFst<Arc> *ofst,
bool connect = true, bool keep_parentheses = false) {
Expand(ifst, parens, ofst, ExpandOptions<Arc>(connect, keep_parentheses));
}
} // namespace fst
#endif // FST_EXTENSIONS_PDT_EXPAND_H__