blob: a85d0fed029825b63e842e9fe7defc299fd0b1b7 [file] [log] [blame]
// replace.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Recursively replace Fst arcs with other Fst(s) returning a PDT.
#ifndef FST_EXTENSIONS_PDT_REPLACE_H__
#define FST_EXTENSIONS_PDT_REPLACE_H__
#include <fst/replace.h>
namespace fst {
// Hash to paren IDs
template <typename S>
struct ReplaceParenHash {
size_t operator()(const pair<size_t, S> &p) const {
return p.first + p.second * kPrime;
}
private:
static const size_t kPrime = 7853;
};
template <typename S> const size_t ReplaceParenHash<S>::kPrime;
// Builds a pushdown transducer (PDT) from an RTN specification
// identical to that in fst/lib/replace.h. The result is a PDT
// encoded as the FST 'ofst' where some transitions are labeled with
// open or close parentheses. To be interpreted as a PDT, the parens
// must balance on a path (see PdtExpand()). The open/close
// parenthesis label pairs are returned in 'parens'.
template <class Arc>
void Replace(const vector<pair<typename Arc::Label,
const Fst<Arc>* > >& ifst_array,
MutableFst<Arc> *ofst,
vector<pair<typename Arc::Label,
typename Arc::Label> > *parens,
typename Arc::Label root) {
typedef typename Arc::Label Label;
typedef typename Arc::StateId StateId;
typedef typename Arc::Weight Weight;
ofst->DeleteStates();
parens->clear();
unordered_map<Label, size_t> label2id;
for (size_t i = 0; i < ifst_array.size(); ++i)
label2id[ifst_array[i].first] = i;
Label max_label = kNoLabel;
deque<size_t> non_term_queue; // Queue of non-terminals to replace
unordered_set<Label> non_term_set; // Set of non-terminals to replace
non_term_queue.push_back(root);
non_term_set.insert(root);
// PDT state corr. to ith replace FST start state.
vector<StateId> fst_start(ifst_array.size(), kNoLabel);
// PDT state, weight pairs corr. to ith replace FST final state & weights.
vector< vector<pair<StateId, Weight> > > fst_final(ifst_array.size());
// Builds single Fst combining all referenced input Fsts. Leaves in the
// non-termnals for now. Tabulate the PDT states that correspond to
// the start and final states of the input Fsts.
for (StateId soff = 0; !non_term_queue.empty(); soff = ofst->NumStates()) {
Label label = non_term_queue.front();
non_term_queue.pop_front();
size_t fst_id = label2id[label];
const Fst<Arc> *ifst = ifst_array[fst_id].second;
for (StateIterator< Fst<Arc> > siter(*ifst);
!siter.Done(); siter.Next()) {
StateId is = siter.Value();
StateId os = ofst->AddState();
if (is == ifst->Start()) {
fst_start[fst_id] = os;
if (label == root)
ofst->SetStart(os);
}
if (ifst->Final(is) != Weight::Zero()) {
if (label == root)
ofst->SetFinal(os, ifst->Final(is));
fst_final[fst_id].push_back(make_pair(os, ifst->Final(is)));
}
for (ArcIterator< Fst<Arc> > aiter(*ifst, is);
!aiter.Done(); aiter.Next()) {
Arc arc = aiter.Value();
if (max_label == kNoLabel || arc.olabel > max_label)
max_label = arc.olabel;
typename unordered_map<Label, size_t>::const_iterator it =
label2id.find(arc.olabel);
if (it != label2id.end()) {
size_t nfst_id = it->second;
if (ifst_array[nfst_id].second->Start() == -1)
continue;
if (non_term_set.count(arc.olabel) == 0) {
non_term_queue.push_back(arc.olabel);
non_term_set.insert(arc.olabel);
}
}
arc.nextstate += soff;
ofst->AddArc(os, arc);
}
}
}
// Changes each non-terminal transition to an open parenthesis
// transition redirected to the PDT state that corresponds to the
// start state of the input FST for the non-terminal. Adds close parenthesis
// transitions from the PDT states corr. to the final states of the
// input FST for the non-terminal to the former destination state of the
// non-terminal transition.
typedef MutableArcIterator< MutableFst<Arc> > MIter;
typedef unordered_map<pair<size_t, StateId >, size_t,
ReplaceParenHash<StateId> > ParenMap;
// Parenthesis pair ID per fst, state pair.
ParenMap paren_map;
// # of parenthesis pairs per fst.
vector<size_t> nparens(ifst_array.size(), 0);
// Initial open parenthesis label
Label first_paren = max_label + 1;
for (StateIterator< Fst<Arc> > siter(*ofst);
!siter.Done(); siter.Next()) {
StateId os = siter.Value();
MIter *aiter = new MIter(ofst, os);
for (size_t n = 0; !aiter->Done(); aiter->Next(), ++n) {
Arc arc = aiter->Value();
typename unordered_map<Label, size_t>::const_iterator lit =
label2id.find(arc.olabel);
if (lit != label2id.end()) {
size_t nfst_id = lit->second;
// Get parentheses. Ensures distinct parenthesis pair per
// non-terminal and destination state but otherwise reuses them.
Label open_paren = kNoLabel, close_paren = kNoLabel;
pair<size_t, StateId> paren_key(nfst_id, arc.nextstate);
typename ParenMap::const_iterator pit = paren_map.find(paren_key);
if (pit != paren_map.end()) {
size_t paren_id = pit->second;
open_paren = (*parens)[paren_id].first;
close_paren = (*parens)[paren_id].second;
} else {
size_t paren_id = nparens[nfst_id]++;
open_paren = first_paren + 2 * paren_id;
close_paren = open_paren + 1;
paren_map[paren_key] = paren_id;
if (paren_id >= parens->size())
parens->push_back(make_pair(open_paren, close_paren));
}
// Sets open parenthesis.
Arc sarc(open_paren, open_paren, arc.weight, fst_start[nfst_id]);
aiter->SetValue(sarc);
// Adds close parentheses.
for (size_t i = 0; i < fst_final[nfst_id].size(); ++i) {
pair<StateId, Weight> &p = fst_final[nfst_id][i];
Arc farc(close_paren, close_paren, p.second, arc.nextstate);
ofst->AddArc(p.first, farc);
if (os == p.first) { // Invalidated iterator
delete aiter;
aiter = new MIter(ofst, os);
aiter->Seek(n);
}
}
}
}
delete aiter;
}
}
} // namespace fst
#endif // FST_EXTENSIONS_PDT_REPLACE_H__