blob: 1f7a8faa91a605f5b9876ee12c546f046b2c7b78 [file] [log] [blame]
// push.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: allauzen@google.com (Cyril Allauzen)
//
// \file
// Class to reweight/push an FST.
#ifndef FST_LIB_PUSH_H__
#define FST_LIB_PUSH_H__
#include <vector>
using std::vector;
#include <fst/factor-weight.h>
#include <fst/fst.h>
#include <fst/arc-map.h>
#include <fst/reweight.h>
#include <fst/shortest-distance.h>
namespace fst {
// Private helper functions for Push
namespace internal {
// Compute the total weight (sum of the weights of all accepting paths) from
// the output of ShortestDistance. 'distance' is the shortest distance from the
// initial state when 'reverse == false' and to the final states when
// 'reverse == true'.
template <class Arc>
typename Arc::Weight ComputeTotalWeight(
const Fst<Arc> &fst,
const vector<typename Arc::Weight> &distance,
bool reverse) {
if (reverse)
return fst.Start() < distance.size() ?
distance[fst.Start()] : Arc::Weight::Zero();
typename Arc::Weight sum = Arc::Weight::Zero();
for (typename Arc::StateId s = 0; s < distance.size(); ++s)
sum = Plus(sum, Times(distance[s], fst.Final(s)));
return sum;
}
// Divide the weight of every accepting path by 'w'. The weight 'w' is
// divided at the final states if 'at_final == true' and at the
// initial state otherwise.
template <class Arc>
void RemoveWeight(MutableFst<Arc> *fst, typename Arc::Weight w, bool at_final) {
if ((w == Arc::Weight::One()) || (w == Arc::Weight::Zero()))
return;
if (at_final) {
// Remove 'w' from the final states
for (StateIterator< MutableFst<Arc> > sit(*fst);
!sit.Done();
sit.Next())
fst->SetFinal(sit.Value(),
Divide(fst->Final(sit.Value()), w, DIVIDE_RIGHT));
} else { // at_final == false
// Remove 'w' from the initial state
typename Arc::StateId start = fst->Start();
for (MutableArcIterator<MutableFst<Arc> > ait(fst, start);
!ait.Done();
ait.Next()) {
Arc arc = ait.Value();
arc.weight = Divide(arc.weight, w, DIVIDE_LEFT);
ait.SetValue(arc);
}
fst->SetFinal(start, Divide(fst->Final(start), w, DIVIDE_LEFT));
}
}
} // namespace internal
// Pushes the weights in FST in the direction defined by TYPE. If
// pushing towards the initial state, the sum of the weight of the
// outgoing transitions and final weight at a non-initial state is
// equal to One() in the resulting machine. If pushing towards the
// final state, the same property holds on the reverse machine.
//
// Weight needs to be left distributive when pushing towards the
// initial state and right distributive when pushing towards the final
// states.
template <class Arc>
void Push(MutableFst<Arc> *fst,
ReweightType type,
float delta = kDelta,
bool remove_total_weight = false) {
vector<typename Arc::Weight> distance;
ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta);
typename Arc::Weight total_weight = Arc::Weight::One();
if (remove_total_weight)
total_weight = internal::ComputeTotalWeight(*fst, distance,
type == REWEIGHT_TO_INITIAL);
Reweight(fst, distance, type);
if (remove_total_weight)
internal::RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL);
}
const uint32 kPushWeights = 0x0001;
const uint32 kPushLabels = 0x0002;
const uint32 kPushRemoveTotalWeight = 0x0004;
const uint32 kPushRemoveCommonAffix = 0x0008;
// OFST obtained from IFST by pushing weights and/or labels according
// to PTYPE in the direction defined by RTYPE. Weight needs to be
// left distributive when pushing weights towards the initial state
// and right distributive when pushing weights towards the final
// states.
template <class Arc, ReweightType rtype>
void Push(const Fst<Arc> &ifst,
MutableFst<Arc> *ofst,
uint32 ptype,
float delta = kDelta) {
if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) {
*ofst = ifst;
Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight);
} else if (ptype & kPushLabels) {
const StringType stype = rtype == REWEIGHT_TO_INITIAL
? STRING_LEFT
: STRING_RIGHT;
vector<typename GallicArc<Arc, stype>::Weight> gdistance;
VectorFst<GallicArc<Arc, stype> > gfst;
ArcMap(ifst, &gfst, ToGallicMapper<Arc, stype>());
if (ptype & kPushWeights ) {
ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
} else {
ArcMapFst<Arc, Arc, RmWeightMapper<Arc> >
uwfst(ifst, RmWeightMapper<Arc>());
ArcMapFst<Arc, GallicArc<Arc, stype>, ToGallicMapper<Arc, stype> >
guwfst(uwfst, ToGallicMapper<Arc, stype>());
ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta);
}
typename GallicArc<Arc, stype>::Weight total_weight =
GallicArc<Arc, stype>::Weight::One();
if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) {
total_weight = internal::ComputeTotalWeight(
gfst, gdistance, rtype == REWEIGHT_TO_INITIAL);
total_weight = typename GallicArc<Arc, stype>::Weight(
ptype & kPushRemoveCommonAffix ? total_weight.Value1()
: StringWeight<typename Arc::Label, stype>::One(),
ptype & kPushRemoveTotalWeight ? total_weight.Value2()
: Arc::Weight::One());
}
Reweight(&gfst, gdistance, rtype);
if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix))
internal::RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL);
FactorWeightFst< GallicArc<Arc, stype>, GallicFactor<typename Arc::Label,
typename Arc::Weight, stype> > fwfst(gfst);
ArcMap(fwfst, ofst, FromGallicMapper<Arc, stype>());
ofst->SetOutputSymbols(ifst.OutputSymbols());
} else {
LOG(WARNING) << "Push: pushing type is set to 0: "
<< "pushing neither labels nor weights.";
*ofst = ifst;
}
}
} // namespace fst
#endif /* FST_LIB_PUSH_H_ */