// minimize.h
// minimize.h

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: johans@google.com (Johan Schalkwyk)
//
// \file Functions and classes to minimize a finite state acceptor
//

#ifndef FST_LIB_MINIMIZE_H__
#define FST_LIB_MINIMIZE_H__

#include <cmath>

#include <algorithm>
#include <map>
#include <queue>
#include <vector>
using std::vector;

#include <fst/arcsort.h>
#include <fst/connect.h>
#include <fst/dfs-visit.h>
#include <fst/encode.h>
#include <fst/factor-weight.h>
#include <fst/fst.h>
#include <fst/mutable-fst.h>
#include <fst/partition.h>
#include <fst/push.h>
#include <fst/queue.h>
#include <fst/reverse.h>
#include <fst/state-map.h>


namespace fst {

// comparator for creating partition based on sorting on
// - states
// - final weight
// - out degree,
// -  (input label, output label, weight, destination_block)
template <class A>
class StateComparator {
 public:
  typedef typename A::StateId StateId;
  typedef typename A::Weight Weight;

  static const uint32 kCompareFinal     = 0x00000001;
  static const uint32 kCompareOutDegree = 0x00000002;
  static const uint32 kCompareArcs      = 0x00000004;
  static const uint32 kCompareAll       = 0x00000007;

  StateComparator(const Fst<A>& fst,
                  const Partition<typename A::StateId>& partition,
                  uint32 flags = kCompareAll)
      : fst_(fst), partition_(partition), flags_(flags) {}

  // compare state x with state y based on sort criteria
  bool operator()(const StateId x, const StateId y) const {
    // check for final state equivalence
    if (flags_ & kCompareFinal) {
      const size_t xfinal = fst_.Final(x).Hash();
      const size_t yfinal = fst_.Final(y).Hash();
      if      (xfinal < yfinal) return true;
      else if (xfinal > yfinal) return false;
    }

    if (flags_ & kCompareOutDegree) {
      // check for # arcs
      if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true;
      if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false;

      if (flags_ & kCompareArcs) {
        // # arcs are equal, check for arc match
        for (ArcIterator<Fst<A> > aiter1(fst_, x), aiter2(fst_, y);
             !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) {
          const A& arc1 = aiter1.Value();
          const A& arc2 = aiter2.Value();
          if (arc1.ilabel < arc2.ilabel) return true;
          if (arc1.ilabel > arc2.ilabel) return false;

          if (partition_.class_id(arc1.nextstate) <
              partition_.class_id(arc2.nextstate)) return true;
          if (partition_.class_id(arc1.nextstate) >
              partition_.class_id(arc2.nextstate)) return false;
        }
      }
    }

    return false;
  }

 private:
  const Fst<A>& fst_;
  const Partition<typename A::StateId>& partition_;
  const uint32 flags_;
};

template <class A> const uint32 StateComparator<A>::kCompareFinal;
template <class A> const uint32 StateComparator<A>::kCompareOutDegree;
template <class A> const uint32 StateComparator<A>::kCompareArcs;
template <class A> const uint32 StateComparator<A>::kCompareAll;


// Computes equivalence classes for cyclic Fsts. For cyclic minimization
// we use the classic HopCroft minimization algorithm, which is of
//
//   O(E)log(N),
//
// where E is the number of edges in the machine and N is number of states.
//
// The following paper describes the original algorithm
//  An N Log N algorithm for minimizing states in a finite automaton
//  by John HopCroft, January 1971
//
template <class A, class Queue>
class CyclicMinimizer {
 public:
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::StateId ClassId;
  typedef typename A::Weight Weight;
  typedef ReverseArc<A> RevA;

  CyclicMinimizer(const ExpandedFst<A>& fst) {
    Initialize(fst);
    Compute(fst);
  }

  ~CyclicMinimizer() {
    delete aiter_queue_;
  }

  const Partition<StateId>& partition() const {
    return P_;
  }

  // helper classes
 private:
  typedef ArcIterator<Fst<RevA> > ArcIter;
  class ArcIterCompare {
   public:
    ArcIterCompare(const Partition<StateId>& partition)
        : partition_(partition) {}

    ArcIterCompare(const ArcIterCompare& comp)
        : partition_(comp.partition_) {}

    // compare two iterators based on there input labels, and proto state
    // (partition class Ids)
    bool operator()(const ArcIter* x, const ArcIter* y) const {
      const RevA& xarc = x->Value();
      const RevA& yarc = y->Value();
      return (xarc.ilabel > yarc.ilabel);
    }

   private:
    const Partition<StateId>& partition_;
  };

  typedef priority_queue<ArcIter*, vector<ArcIter*>, ArcIterCompare>
  ArcIterQueue;

  // helper methods
 private:
  // prepartitions the space into equivalence classes with
  //   same final weight
  //   same # arcs per state
  //   same outgoing arcs
  void PrePartition(const Fst<A>& fst) {
    VLOG(5) << "PrePartition";

    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    StateComparator<A> comp(fst, P_, StateComparator<A>::kCompareFinal);
    EquivalenceMap equiv_map(comp);

    StateIterator<Fst<A> > siter(fst);
    StateId class_id = P_.AddClass();
    P_.Add(siter.Value(), class_id);
    equiv_map[siter.Value()] = class_id;
    L_.Enqueue(class_id);
    for (siter.Next(); !siter.Done(); siter.Next()) {
      StateId  s = siter.Value();
      typename EquivalenceMap::const_iterator it = equiv_map.find(s);
      if (it == equiv_map.end()) {
        class_id = P_.AddClass();
        P_.Add(s, class_id);
        equiv_map[s] = class_id;
        L_.Enqueue(class_id);
      } else {
        P_.Add(s, it->second);
        equiv_map[s] = it->second;
      }
    }

    VLOG(5) << "Initial Partition: " << P_.num_classes();
  }

  // - Create inverse transition Tr_ = rev(fst)
  // - loop over states in fst and split on final, creating two blocks
  //   in the partition corresponding to final, non-final
  void Initialize(const Fst<A>& fst) {
    // construct Tr
    Reverse(fst, &Tr_);
    ILabelCompare<RevA> ilabel_comp;
    ArcSort(&Tr_, ilabel_comp);

    // initial split (F, S - F)
    P_.Initialize(Tr_.NumStates() - 1);

    // prep partition
    PrePartition(fst);

    // allocate arc iterator queue
    ArcIterCompare comp(P_);
    aiter_queue_ = new ArcIterQueue(comp);
  }

  // partition all classes with destination C
  void Split(ClassId C) {
    // Prep priority queue. Open arc iterator for each state in C, and
    // insert into priority queue.
    for (PartitionIterator<StateId> siter(P_, C);
         !siter.Done(); siter.Next()) {
      StateId s = siter.Value();
      if (Tr_.NumArcs(s + 1))
        aiter_queue_->push(new ArcIterator<Fst<RevA> >(Tr_, s + 1));
    }

    // Now pop arc iterator from queue, split entering equivalence class
    // re-insert updated iterator into queue.
    Label prev_label = -1;
    while (!aiter_queue_->empty()) {
      ArcIterator<Fst<RevA> >* aiter = aiter_queue_->top();
      aiter_queue_->pop();
      if (aiter->Done()) {
        delete aiter;
        continue;
     }

      const RevA& arc = aiter->Value();
      StateId from_state = aiter->Value().nextstate - 1;
      Label   from_label = arc.ilabel;
      if (prev_label != from_label)
        P_.FinalizeSplit(&L_);

      StateId from_class = P_.class_id(from_state);
      if (P_.class_size(from_class) > 1)
        P_.SplitOn(from_state);

      prev_label = from_label;
      aiter->Next();
      if (aiter->Done())
        delete aiter;
      else
        aiter_queue_->push(aiter);
    }
    P_.FinalizeSplit(&L_);
  }

  // Main loop for hopcroft minimization.
  void Compute(const Fst<A>& fst) {
    // process active classes (FIFO, or FILO)
    while (!L_.Empty()) {
      ClassId C = L_.Head();
      L_.Dequeue();

      // split on C, all labels in C
      Split(C);
    }
  }

  // helper data
 private:
  // Partioning of states into equivalence classes
  Partition<StateId> P_;

  // L = set of active classes to be processed in partition P
  Queue L_;

  // reverse transition function
  VectorFst<RevA> Tr_;

  // Priority queue of open arc iterators for all states in the 'splitter'
  // equivalence class
  ArcIterQueue* aiter_queue_;
};


// Computes equivalence classes for acyclic Fsts. The implementation details
// for this algorithms is documented by the following paper.
//
// Minimization of acyclic deterministic automata in linear time
//  Dominque Revuz
//
// Complexity O(|E|)
//
template <class A>
class AcyclicMinimizer {
 public:
  typedef typename A::Label Label;
  typedef typename A::StateId StateId;
  typedef typename A::StateId ClassId;
  typedef typename A::Weight Weight;

  AcyclicMinimizer(const ExpandedFst<A>& fst) {
    Initialize(fst);
    Refine(fst);
  }

  const Partition<StateId>& partition() {
    return partition_;
  }

  // helper classes
 private:
  // DFS visitor to compute the height (distance) to final state.
  class HeightVisitor {
   public:
    HeightVisitor() : max_height_(0), num_states_(0) { }

    // invoked before dfs visit
    void InitVisit(const Fst<A>& fst) {}

    // invoked when state is discovered (2nd arg is DFS tree root)
    bool InitState(StateId s, StateId root) {
      // extend height array and initialize height (distance) to 0
      for (size_t i = height_.size(); i <= s; ++i)
        height_.push_back(-1);

      if (s >= num_states_) num_states_ = s + 1;
      return true;
    }

    // invoked when tree arc examined (to undiscoverted state)
    bool TreeArc(StateId s, const A& arc) {
      return true;
    }

    // invoked when back arc examined (to unfinished state)
    bool BackArc(StateId s, const A& arc) {
      return true;
    }

    // invoked when forward or cross arc examined (to finished state)
    bool ForwardOrCrossArc(StateId s, const A& arc) {
      if (height_[arc.nextstate] + 1 > height_[s])
        height_[s] = height_[arc.nextstate] + 1;
      return true;
    }

    // invoked when state finished (parent is kNoStateId for tree root)
    void FinishState(StateId s, StateId parent, const A* parent_arc) {
      if (height_[s] == -1) height_[s] = 0;
      StateId h = height_[s] +  1;
      if (parent >= 0) {
        if (h > height_[parent]) height_[parent] = h;
        if (h > max_height_)     max_height_ = h;
      }
    }

    // invoked after DFS visit
    void FinishVisit() {}

    size_t max_height() const { return max_height_; }

    const vector<StateId>& height() const { return height_; }

    const size_t num_states() const { return num_states_; }

   private:
    vector<StateId> height_;
    size_t max_height_;
    size_t num_states_;
  };

  // helper methods
 private:
  // cluster states according to height (distance to final state)
  void Initialize(const Fst<A>& fst) {
    // compute height (distance to final state)
    HeightVisitor hvisitor;
    DfsVisit(fst, &hvisitor);

    // create initial partition based on height
    partition_.Initialize(hvisitor.num_states());
    partition_.AllocateClasses(hvisitor.max_height() + 1);
    const vector<StateId>& hstates = hvisitor.height();
    for (size_t s = 0; s < hstates.size(); ++s)
      partition_.Add(s, hstates[s]);
  }

  // refine states based on arc sort (out degree, arc equivalence)
  void Refine(const Fst<A>& fst) {
    typedef map<StateId, StateId, StateComparator<A> > EquivalenceMap;
    StateComparator<A> comp(fst, partition_);

    // start with tail (height = 0)
    size_t height = partition_.num_classes();
    for (size_t h = 0; h < height; ++h) {
      EquivalenceMap equiv_classes(comp);

      // sort states within equivalence class
      PartitionIterator<StateId> siter(partition_, h);
      equiv_classes[siter.Value()] = h;
      for (siter.Next(); !siter.Done(); siter.Next()) {
        const StateId s = siter.Value();
        typename EquivalenceMap::const_iterator it = equiv_classes.find(s);
        if (it == equiv_classes.end())
          equiv_classes[s] = partition_.AddClass();
        else
          equiv_classes[s] = it->second;
      }

      // create refined partition
      for (siter.Reset(); !siter.Done();) {
        const StateId s = siter.Value();
        const StateId old_class = partition_.class_id(s);
        const StateId new_class = equiv_classes[s];

        // a move operation can invalidate the iterator, so
        // we first update the iterator to the next element
        // before we move the current element out of the list
        siter.Next();
        if (old_class != new_class)
          partition_.Move(s, new_class);
      }
    }
  }

 private:
  Partition<StateId> partition_;
};


// Given a partition and a mutable fst, merge states of Fst inplace
// (i.e. destructively). Merging works by taking the first state in
// a class of the partition to be the representative state for the class.
// Each arc is then reconnected to this state. All states in the class
// are merged by adding there arcs to the representative state.
template <class A>
void MergeStates(
    const Partition<typename A::StateId>& partition, MutableFst<A>* fst) {
  typedef typename A::StateId StateId;

  vector<StateId> state_map(partition.num_classes());
  for (size_t i = 0; i < partition.num_classes(); ++i) {
    PartitionIterator<StateId> siter(partition, i);
    state_map[i] = siter.Value();  // first state in partition;
  }

  // relabel destination states
  for (size_t c = 0; c < partition.num_classes(); ++c) {
    for (PartitionIterator<StateId> siter(partition, c);
         !siter.Done(); siter.Next()) {
      StateId s = siter.Value();
      for (MutableArcIterator<MutableFst<A> > aiter(fst, s);
           !aiter.Done(); aiter.Next()) {
        A arc = aiter.Value();
        arc.nextstate = state_map[partition.class_id(arc.nextstate)];

        if (s == state_map[c])  // first state just set destination
          aiter.SetValue(arc);
        else
          fst->AddArc(state_map[c], arc);
      }
    }
  }
  fst->SetStart(state_map[partition.class_id(fst->Start())]);

  Connect(fst);
}

template <class A>
void AcceptorMinimize(MutableFst<A>* fst) {
  typedef typename A::StateId StateId;
  if (!(fst->Properties(kAcceptor | kUnweighted, true))) {
    FSTERROR() << "FST is not an unweighted acceptor";
    fst->SetProperties(kError, kError);
    return;
  }

  // connect fst before minimization, handles disconnected states
  Connect(fst);
  if (fst->NumStates() == 0) return;

  if (fst->Properties(kAcyclic, true)) {
    // Acyclic minimization (revuz)
    VLOG(2) << "Acyclic Minimization";
    ArcSort(fst, ILabelCompare<A>());
    AcyclicMinimizer<A> minimizer(*fst);
    MergeStates(minimizer.partition(), fst);

  } else {
    // Cyclic minimizaton (hopcroft)
    VLOG(2) << "Cyclic Minimization";
    CyclicMinimizer<A, LifoQueue<StateId> > minimizer(*fst);
    MergeStates(minimizer.partition(), fst);
  }

  // Merge in appropriate semiring
  ArcUniqueMapper<A> mapper(*fst);
  StateMap(fst, mapper);
}


// In place minimization of deterministic weighted automata and transducers.
// For transducers, then the 'sfst' argument is not null, the algorithm
// produces a compact factorization of the minimal transducer.
//
// In the acyclic case, we use an algorithm from Dominique Revuz that
// is linear in the number of arcs (edges) in the machine.
//  Complexity = O(E)
//
// In the cyclic case, we use the classical hopcroft minimization.
//  Complexity = O(|E|log(|N|)
//
template <class A>
void Minimize(MutableFst<A>* fst,
              MutableFst<A>* sfst = 0,
              float delta = kDelta) {
  uint64 props = fst->Properties(kAcceptor | kIDeterministic|
                                 kWeighted | kUnweighted, true);
  if (!(props & kIDeterministic)) {
    FSTERROR() << "FST is not deterministic";
    fst->SetProperties(kError, kError);
    return;
  }

  if (!(props & kAcceptor)) {  // weighted transducer
    VectorFst< GallicArc<A, STRING_LEFT> > gfst;
    ArcMap(*fst, &gfst, ToGallicMapper<A, STRING_LEFT>());
    fst->DeleteStates();
    gfst.SetProperties(kAcceptor, kAcceptor);
    Push(&gfst, REWEIGHT_TO_INITIAL, delta);
    ArcMap(&gfst, QuantizeMapper< GallicArc<A, STRING_LEFT> >(delta));
    EncodeMapper< GallicArc<A, STRING_LEFT> >
      encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    Encode(&gfst, &encoder);
    AcceptorMinimize(&gfst);
    Decode(&gfst, encoder);

    if (sfst == 0) {
      FactorWeightFst< GallicArc<A, STRING_LEFT>,
        GallicFactor<typename A::Label,
        typename A::Weight, STRING_LEFT> > fwfst(gfst);
      SymbolTable *osyms = fst->OutputSymbols() ?
          fst->OutputSymbols()->Copy() : 0;
      ArcMap(fwfst, fst, FromGallicMapper<A, STRING_LEFT>());
      fst->SetOutputSymbols(osyms);
      delete osyms;
    } else {
      sfst->SetOutputSymbols(fst->OutputSymbols());
      GallicToNewSymbolsMapper<A, STRING_LEFT> mapper(sfst);
      ArcMap(gfst, fst, &mapper);
      fst->SetOutputSymbols(sfst->InputSymbols());
    }
  } else if (props & kWeighted) {  // weighted acceptor
    Push(fst, REWEIGHT_TO_INITIAL, delta);
    ArcMap(fst, QuantizeMapper<A>(delta));
    EncodeMapper<A> encoder(kEncodeLabels | kEncodeWeights, ENCODE);
    Encode(fst, &encoder);
    AcceptorMinimize(fst);
    Decode(fst, encoder);
  } else {  // unweighted acceptor
    AcceptorMinimize(fst);
  }
}

}  // namespace fst

#endif  // FST_LIB_MINIMIZE_H__
