blob: 7b9887fd9a8b0f3ffce71e642f4f5368cefa9e90 [file] [log] [blame]
// paren.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// Common classes for PDT parentheses
// \file
#ifndef FST_EXTENSIONS_PDT_PAREN_H_
#define FST_EXTENSIONS_PDT_PAREN_H_
#include <algorithm>
#include <unordered_map>
using std::tr1::unordered_map;
using std::tr1::unordered_multimap;
#include <tr1/unordered_set>
using std::tr1::unordered_set;
using std::tr1::unordered_multiset;
#include <set>
#include <fst/extensions/pdt/pdt.h>
#include <fst/extensions/pdt/collection.h>
#include <fst/fst.h>
#include <fst/dfs-visit.h>
namespace fst {
//
// ParenState: Pair of an open (close) parenthesis and
// its destination (source) state.
//
template <class A>
class ParenState {
public:
typedef typename A::Label Label;
typedef typename A::StateId StateId;
struct Hash {
size_t operator()(const ParenState<A> &p) const {
return p.paren_id + p.state_id * kPrime;
}
};
Label paren_id; // ID of open (close) paren
StateId state_id; // destination (source) state of open (close) paren
ParenState() : paren_id(kNoLabel), state_id(kNoStateId) {}
ParenState(Label p, StateId s) : paren_id(p), state_id(s) {}
bool operator==(const ParenState<A> &p) const {
if (&p == this)
return true;
return p.paren_id == this->paren_id && p.state_id == this->state_id;
}
bool operator!=(const ParenState<A> &p) const { return !(p == *this); }
bool operator<(const ParenState<A> &p) const {
return paren_id < this->paren.id ||
(p.paren_id == this->paren.id && p.state_id < this->state_id);
}
private:
static const size_t kPrime;
};
template <class A>
const size_t ParenState<A>::kPrime = 7853;
// Creates an FST-style iterator from STL map and iterator.
template <class M>
class MapIterator {
public:
typedef typename M::const_iterator StlIterator;
typedef typename M::value_type PairType;
typedef typename PairType::second_type ValueType;
MapIterator(const M &m, StlIterator iter)
: map_(m), begin_(iter), iter_(iter) {}
bool Done() const {
return iter_ == map_.end() || iter_->first != begin_->first;
}
ValueType Value() const { return iter_->second; }
void Next() { ++iter_; }
void Reset() { iter_ = begin_; }
private:
const M &map_;
StlIterator begin_;
StlIterator iter_;
};
//
// PdtParenReachable: Provides various parenthesis reachability information
// on a PDT.
//
template <class A>
class PdtParenReachable {
public:
typedef typename A::StateId StateId;
typedef typename A::Label Label;
public:
// Maps from state ID to reachable paren IDs from (to) that state.
typedef unordered_multimap<StateId, Label> ParenMultiMap;
// Maps from paren ID and state ID to reachable state set ID
typedef unordered_map<ParenState<A>, ssize_t,
typename ParenState<A>::Hash> StateSetMap;
// Maps from paren ID and state ID to arcs exiting that state with that
// Label.
typedef unordered_multimap<ParenState<A>, A,
typename ParenState<A>::Hash> ParenArcMultiMap;
typedef MapIterator<ParenMultiMap> ParenIterator;
typedef MapIterator<ParenArcMultiMap> ParenArcIterator;
typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
// Computes close (open) parenthesis reachabilty information for
// a PDT with bounded stack.
PdtParenReachable(const Fst<A> &fst,
const vector<pair<Label, Label> > &parens, bool close)
: fst_(fst),
parens_(parens),
close_(close) {
for (Label i = 0; i < parens.size(); ++i) {
const pair<Label, Label> &p = parens[i];
paren_id_map_[p.first] = i;
paren_id_map_[p.second] = i;
}
if (close_) {
StateId start = fst.Start();
if (start == kNoStateId)
return;
DFSearch(start, start);
} else {
FSTERROR() << "PdtParenReachable: open paren info not implemented";
}
}
// Given a state ID, returns an iterator over paren IDs
// for close (open) parens reachable from that state along balanced
// paths.
ParenIterator FindParens(StateId s) const {
return ParenIterator(paren_multimap_, paren_multimap_.find(s));
}
// Given a paren ID and a state ID s, returns an iterator over
// states that can be reached along balanced paths from (to) s that
// have have close (open) parentheses matching the paren ID exiting
// (entering) those states.
SetIterator FindStates(Label paren_id, StateId s) const {
ParenState<A> paren_state(paren_id, s);
typename StateSetMap::const_iterator id_it = set_map_.find(paren_state);
if (id_it == set_map_.end()) {
return state_sets_.FindSet(-1);
} else {
return state_sets_.FindSet(id_it->second);
}
}
// Given a paren Id and a state ID s, return an iterator over
// arcs that exit (enter) s and are labeled with a close (open)
// parenthesis matching the paren ID.
ParenArcIterator FindParenArcs(Label paren_id, StateId s) const {
ParenState<A> paren_state(paren_id, s);
return ParenArcIterator(paren_arc_multimap_,
paren_arc_multimap_.find(paren_state));
}
private:
// DFS that gathers paren and state set information.
// Bool returns false when cycle detected.
bool DFSearch(StateId s, StateId start);
// Unions state sets together gathered by the DFS.
void ComputeStateSet(StateId s);
// Gather state set(s) from state 'nexts'.
void UpdateStateSet(StateId nexts, set<Label> *paren_set,
vector< set<StateId> > *state_sets) const;
const Fst<A> &fst_;
const vector<pair<Label, Label> > &parens_; // Paren ID -> Labels
bool close_; // Close/open paren info?
unordered_map<Label, Label> paren_id_map_; // Paren labels -> ID
ParenMultiMap paren_multimap_; // Paren reachability
ParenArcMultiMap paren_arc_multimap_; // Paren Arcs
vector<char> state_color_; // DFS state
mutable Collection<ssize_t, StateId> state_sets_; // Reachable states -> ID
StateSetMap set_map_; // ID -> Reachable states
DISALLOW_COPY_AND_ASSIGN(PdtParenReachable);
};
// DFS that gathers paren and state set information.
template <class A>
bool PdtParenReachable<A>::DFSearch(StateId s, StateId start) {
if (s >= state_color_.size())
state_color_.resize(s + 1, kDfsWhite);
if (state_color_[s] == kDfsBlack)
return true;
if (state_color_[s] == kDfsGrey)
return false;
state_color_[s] = kDfsGrey;
for (ArcIterator<Fst<A> > aiter(fst_, s);
!aiter.Done();
aiter.Next()) {
const A &arc = aiter.Value();
typename unordered_map<Label, Label>::const_iterator pit
= paren_id_map_.find(arc.ilabel);
if (pit != paren_id_map_.end()) { // paren?
Label paren_id = pit->second;
if (arc.ilabel == parens_[paren_id].first) { // open paren
DFSearch(arc.nextstate, arc.nextstate);
for (SetIterator set_iter = FindStates(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
for (ParenArcIterator paren_arc_iter =
FindParenArcs(paren_id, set_iter.Element());
!paren_arc_iter.Done();
paren_arc_iter.Next()) {
const A &cparc = paren_arc_iter.Value();
DFSearch(cparc.nextstate, start);
}
}
}
} else { // non-paren
if(!DFSearch(arc.nextstate, start)) {
FSTERROR() << "PdtReachable: Underlying cyclicity not supported";
return true;
}
}
}
ComputeStateSet(s);
state_color_[s] = kDfsBlack;
return true;
}
// Unions state sets together gathered by the DFS.
template <class A>
void PdtParenReachable<A>::ComputeStateSet(StateId s) {
set<Label> paren_set;
vector< set<StateId> > state_sets(parens_.size());
for (ArcIterator< Fst<A> > aiter(fst_, s);
!aiter.Done();
aiter.Next()) {
const A &arc = aiter.Value();
typename unordered_map<Label, Label>::const_iterator pit
= paren_id_map_.find(arc.ilabel);
if (pit != paren_id_map_.end()) { // paren?
Label paren_id = pit->second;
if (arc.ilabel == parens_[paren_id].first) { // open paren
for (SetIterator set_iter =
FindStates(paren_id, arc.nextstate);
!set_iter.Done(); set_iter.Next()) {
for (ParenArcIterator paren_arc_iter =
FindParenArcs(paren_id, set_iter.Element());
!paren_arc_iter.Done();
paren_arc_iter.Next()) {
const A &cparc = paren_arc_iter.Value();
UpdateStateSet(cparc.nextstate, &paren_set, &state_sets);
}
}
} else { // close paren
paren_set.insert(paren_id);
state_sets[paren_id].insert(s);
ParenState<A> paren_state(paren_id, s);
paren_arc_multimap_.insert(make_pair(paren_state, arc));
}
} else { // non-paren
UpdateStateSet(arc.nextstate, &paren_set, &state_sets);
}
}
vector<StateId> state_set;
for (typename set<Label>::iterator paren_iter = paren_set.begin();
paren_iter != paren_set.end(); ++paren_iter) {
state_set.clear();
Label paren_id = *paren_iter;
paren_multimap_.insert(make_pair(s, paren_id));
for (typename set<StateId>::iterator state_iter
= state_sets[paren_id].begin();
state_iter != state_sets[paren_id].end();
++state_iter) {
state_set.push_back(*state_iter);
}
ParenState<A> paren_state(paren_id, s);
set_map_[paren_state] = state_sets_.FindId(state_set);
}
}
// Gather state set(s) from state 'nexts'.
template <class A>
void PdtParenReachable<A>::UpdateStateSet(
StateId nexts, set<Label> *paren_set,
vector< set<StateId> > *state_sets) const {
for(ParenIterator paren_iter = FindParens(nexts);
!paren_iter.Done(); paren_iter.Next()) {
Label paren_id = paren_iter.Value();
paren_set->insert(paren_id);
for (SetIterator set_iter = FindStates(paren_id, nexts);
!set_iter.Done(); set_iter.Next()) {
(*state_sets)[paren_id].insert(set_iter.Element());
}
}
}
// Store balancing parenthesis data for a PDT. Allows on-the-fly
// construction (e.g. in PdtShortestPath) unlike PdtParenReachable above.
template <class A>
class PdtBalanceData {
public:
typedef typename A::StateId StateId;
typedef typename A::Label Label;
// Hash set for open parens
typedef unordered_set<ParenState<A>, typename ParenState<A>::Hash> OpenParenSet;
// Maps from open paren destination state to parenthesis ID.
typedef unordered_multimap<StateId, Label> OpenParenMap;
// Maps from open paren state to source states of matching close parens
typedef unordered_multimap<ParenState<A>, StateId,
typename ParenState<A>::Hash> CloseParenMap;
// Maps from open paren state to close source set ID
typedef unordered_map<ParenState<A>, ssize_t,
typename ParenState<A>::Hash> CloseSourceMap;
typedef typename Collection<ssize_t, StateId>::SetIterator SetIterator;
PdtBalanceData() {}
void Clear() {
open_paren_map_.clear();
close_paren_map_.clear();
}
// Adds an open parenthesis with destination state 'open_dest'.
void OpenInsert(Label paren_id, StateId open_dest) {
ParenState<A> key(paren_id, open_dest);
if (!open_paren_set_.count(key)) {
open_paren_set_.insert(key);
open_paren_map_.insert(make_pair(open_dest, paren_id));
}
}
// Adds a matching closing parenthesis with source state
// 'close_source' that balances an open_parenthesis with destination
// state 'open_dest' if OpenInsert() previously called
// (o.w. CloseInsert() does nothing).
void CloseInsert(Label paren_id, StateId open_dest, StateId close_source) {
ParenState<A> key(paren_id, open_dest);
if (open_paren_set_.count(key))
close_paren_map_.insert(make_pair(key, close_source));
}
// Find close paren source states matching an open parenthesis.
// Methods that follow, iterate through those matching states.
// Should be called only after FinishInsert(open_dest).
SetIterator Find(Label paren_id, StateId open_dest) {
ParenState<A> close_key(paren_id, open_dest);
typename CloseSourceMap::const_iterator id_it =
close_source_map_.find(close_key);
if (id_it == close_source_map_.end()) {
return close_source_sets_.FindSet(-1);
} else {
return close_source_sets_.FindSet(id_it->second);
}
}
// Call when all open and close parenthesis insertions wrt open
// parentheses entering 'open_dest' are finished. Must be called
// before Find(open_dest). Stores close paren source state sets
// efficiently.
void FinishInsert(StateId open_dest) {
vector<StateId> close_sources;
for (typename OpenParenMap::iterator oit = open_paren_map_.find(open_dest);
oit != open_paren_map_.end() && oit->first == open_dest;) {
Label paren_id = oit->second;
close_sources.clear();
ParenState<A> okey(paren_id, open_dest);
open_paren_set_.erase(open_paren_set_.find(okey));
for (typename CloseParenMap::iterator cit = close_paren_map_.find(okey);
cit != close_paren_map_.end() && cit->first == okey;) {
close_sources.push_back(cit->second);
close_paren_map_.erase(cit++);
}
sort(close_sources.begin(), close_sources.end());
typename vector<StateId>::iterator unique_end =
unique(close_sources.begin(), close_sources.end());
close_sources.resize(unique_end - close_sources.begin());
if (!close_sources.empty())
close_source_map_[okey] = close_source_sets_.FindId(close_sources);
open_paren_map_.erase(oit++);
}
}
// Return a new balance data object representing the reversed balance
// information.
PdtBalanceData<A> *Reverse(StateId num_states,
StateId num_split,
StateId state_id_shift) const;
private:
OpenParenSet open_paren_set_; // open par. at dest?
OpenParenMap open_paren_map_; // open parens per state
ParenState<A> open_dest_; // cur open dest. state
typename OpenParenMap::const_iterator open_iter_; // cur open parens/state
CloseParenMap close_paren_map_; // close states/open
// paren and state
CloseSourceMap close_source_map_; // paren, state to set ID
mutable Collection<ssize_t, StateId> close_source_sets_;
};
// Return a new balance data object representing the reversed balance
// information.
template <class A>
PdtBalanceData<A> *PdtBalanceData<A>::Reverse(
StateId num_states,
StateId num_split,
StateId state_id_shift) const {
PdtBalanceData<A> *bd = new PdtBalanceData<A>;
unordered_set<StateId> close_sources;
StateId split_size = num_states / num_split;
for (StateId i = 0; i < num_states; i+= split_size) {
close_sources.clear();
for (typename CloseSourceMap::const_iterator
sit = close_source_map_.begin();
sit != close_source_map_.end();
++sit) {
ParenState<A> okey = sit->first;
StateId open_dest = okey.state_id;
Label paren_id = okey.paren_id;
for (SetIterator set_iter = close_source_sets_.FindSet(sit->second);
!set_iter.Done(); set_iter.Next()) {
StateId close_source = set_iter.Element();
if ((close_source < i) || (close_source >= i + split_size))
continue;
close_sources.insert(close_source + state_id_shift);
bd->OpenInsert(paren_id, close_source + state_id_shift);
bd->CloseInsert(paren_id, close_source + state_id_shift,
open_dest + state_id_shift);
}
}
for (typename unordered_set<StateId>::const_iterator it
= close_sources.begin();
it != close_sources.end();
++it) {
bd->FinishInsert(*it);
}
}
return bd;
}
} // namespace fst
#endif // FST_EXTENSIONS_PDT_PAREN_H_