blob: 54ba85dec79facf206e7243e91d9bf06b013989d [file] [log] [blame]
// weight_test.h
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Copyright 2005-2010 Google, Inc.
// Author: riley@google.com (Michael Riley)
//
// \file
// Regression test for Fst weights.
#include <cstdlib>
#include <ctime>
#include <fst/expectation-weight.h>
#include <fst/float-weight.h>
#include <fst/random-weight.h>
#include "./weight-tester.h"
DEFINE_int32(seed, -1, "random seed");
DEFINE_int32(repeat, 100000, "number of test repetitions");
using fst::TropicalWeight;
using fst::TropicalWeightGenerator;
using fst::TropicalWeightTpl;
using fst::TropicalWeightGenerator_;
using fst::LogWeight;
using fst::LogWeightGenerator;
using fst::LogWeightTpl;
using fst::LogWeightGenerator_;
using fst::MinMaxWeight;
using fst::MinMaxWeightGenerator;
using fst::MinMaxWeightTpl;
using fst::MinMaxWeightGenerator_;
using fst::StringWeight;
using fst::StringWeightGenerator;
using fst::GallicWeight;
using fst::GallicWeightGenerator;
using fst::LexicographicWeight;
using fst::LexicographicWeightGenerator;
using fst::ProductWeight;
using fst::ProductWeightGenerator;
using fst::PowerWeight;
using fst::PowerWeightGenerator;
using fst::SignedLogWeightTpl;
using fst::SignedLogWeightGenerator_;
using fst::ExpectationWeight;
using fst::SparsePowerWeight;
using fst::SparsePowerWeightGenerator;
using fst::STRING_LEFT;
using fst::STRING_RIGHT;
using fst::WeightTester;
template <class T>
void TestTemplatedWeights(int repeat, int seed) {
TropicalWeightGenerator_<T> tropical_generator(seed);
WeightTester<TropicalWeightTpl<T>, TropicalWeightGenerator_<T> >
tropical_tester(tropical_generator);
tropical_tester.Test(repeat);
LogWeightGenerator_<T> log_generator(seed);
WeightTester<LogWeightTpl<T>, LogWeightGenerator_<T> >
log_tester(log_generator);
log_tester.Test(repeat);
MinMaxWeightGenerator_<T> minmax_generator(seed);
WeightTester<MinMaxWeightTpl<T>, MinMaxWeightGenerator_<T> >
minmax_tester(minmax_generator);
minmax_tester.Test(repeat);
SignedLogWeightGenerator_<T> signedlog_generator(seed);
WeightTester<SignedLogWeightTpl<T>, SignedLogWeightGenerator_<T> >
signedlog_tester(signedlog_generator);
signedlog_tester.Test(repeat);
}
int main(int argc, char **argv) {
std::set_new_handler(FailedNewHandler);
SetFlags(argv[0], &argc, &argv, true);
int seed = FLAGS_seed >= 0 ? FLAGS_seed : time(0);
LOG(INFO) << "Seed = " << seed;
TestTemplatedWeights<float>(FLAGS_repeat, seed);
TestTemplatedWeights<double>(FLAGS_repeat, seed);
FLAGS_fst_weight_parentheses = "()";
TestTemplatedWeights<float>(FLAGS_repeat, seed);
TestTemplatedWeights<double>(FLAGS_repeat, seed);
FLAGS_fst_weight_parentheses = "";
// Make sure type names for templated weights are consistent
CHECK(TropicalWeight::Type() == "tropical");
CHECK(TropicalWeightTpl<double>::Type() != TropicalWeightTpl<float>::Type());
CHECK(LogWeight::Type() == "log");
CHECK(LogWeightTpl<double>::Type() != LogWeightTpl<float>::Type());
TropicalWeightTpl<double> w(15.0);
TropicalWeight tw(15.0);
StringWeightGenerator<int> left_string_generator(seed);
WeightTester<StringWeight<int>, StringWeightGenerator<int> >
left_string_tester(left_string_generator);
left_string_tester.Test(FLAGS_repeat);
StringWeightGenerator<int, STRING_RIGHT> right_string_generator(seed);
WeightTester<StringWeight<int, STRING_RIGHT>,
StringWeightGenerator<int, STRING_RIGHT> >
right_string_tester(right_string_generator);
right_string_tester.Test(FLAGS_repeat);
typedef GallicWeight<int, TropicalWeight> TropicalGallicWeight;
typedef GallicWeightGenerator<int, TropicalWeightGenerator>
TropicalGallicWeightGenerator;
TropicalGallicWeightGenerator tropical_gallic_generator(seed);
WeightTester<TropicalGallicWeight, TropicalGallicWeightGenerator>
tropical_gallic_tester(tropical_gallic_generator);
tropical_gallic_tester.Test(FLAGS_repeat);
typedef ProductWeight<TropicalWeight, TropicalWeight> TropicalProductWeight;
typedef ProductWeightGenerator<TropicalWeightGenerator,
TropicalWeightGenerator> TropicalProductWeightGenerator;
TropicalProductWeightGenerator tropical_product_generator(seed);
WeightTester<TropicalProductWeight, TropicalProductWeightGenerator>
tropical_product_weight_tester(tropical_product_generator);
tropical_product_weight_tester.Test(FLAGS_repeat);
typedef PowerWeight<TropicalWeight, 3> TropicalCubeWeight;
typedef PowerWeightGenerator<TropicalWeightGenerator, 3>
TropicalCubeWeightGenerator;
TropicalCubeWeightGenerator tropical_cube_generator(seed);
WeightTester<TropicalCubeWeight, TropicalCubeWeightGenerator>
tropical_cube_weight_tester(tropical_cube_generator);
tropical_cube_weight_tester.Test(FLAGS_repeat);
typedef ProductWeight<TropicalWeight, TropicalProductWeight>
SecondNestedProductWeight;
typedef ProductWeightGenerator<TropicalWeightGenerator,
TropicalProductWeightGenerator> SecondNestedProductWeightGenerator;
SecondNestedProductWeightGenerator second_nested_product_generator(seed);
WeightTester<SecondNestedProductWeight, SecondNestedProductWeightGenerator>
second_nested_product_weight_tester(second_nested_product_generator);
second_nested_product_weight_tester.Test(FLAGS_repeat);
// This only works with fst_weight_parentheses = "()"
typedef ProductWeight<TropicalProductWeight, TropicalWeight>
FirstNestedProductWeight;
typedef ProductWeightGenerator<TropicalProductWeightGenerator,
TropicalWeightGenerator> FirstNestedProductWeightGenerator;
FirstNestedProductWeightGenerator first_nested_product_generator(seed);
WeightTester<FirstNestedProductWeight, FirstNestedProductWeightGenerator>
first_nested_product_weight_tester(first_nested_product_generator);
typedef PowerWeight<FirstNestedProductWeight, 3> NestedProductCubeWeight;
typedef PowerWeightGenerator<FirstNestedProductWeightGenerator, 3>
NestedProductCubeWeightGenerator;
NestedProductCubeWeightGenerator nested_product_cube_generator(seed);
WeightTester<NestedProductCubeWeight, NestedProductCubeWeightGenerator>
nested_product_cube_weight_tester(nested_product_cube_generator);
typedef SparsePowerWeight<NestedProductCubeWeight,
size_t > SparseNestedProductCubeWeight;
typedef SparsePowerWeightGenerator<NestedProductCubeWeightGenerator,
size_t, 3> SparseNestedProductCubeWeightGenerator;
SparseNestedProductCubeWeightGenerator
sparse_nested_product_cube_generator(seed);
WeightTester<SparseNestedProductCubeWeight,
SparseNestedProductCubeWeightGenerator>
sparse_nested_product_cube_weight_tester(
sparse_nested_product_cube_generator);
typedef SparsePowerWeight<LogWeight, size_t > LogSparsePowerWeight;
typedef SparsePowerWeightGenerator<LogWeightGenerator,
size_t, 3> LogSparsePowerWeightGenerator;
LogSparsePowerWeightGenerator
log_sparse_power_weight_generator(seed);
WeightTester<LogSparsePowerWeight,
LogSparsePowerWeightGenerator>
log_sparse_power_weight_tester(
log_sparse_power_weight_generator);
typedef ExpectationWeight<LogWeight, LogWeight>
LogLogExpectWeight;
typedef ProductWeightGenerator<LogWeightGenerator, LogWeightGenerator,
LogLogExpectWeight> LogLogExpectWeightGenerator;
LogLogExpectWeightGenerator log_log_expect_weight_generator(seed);
WeightTester<LogLogExpectWeight, LogLogExpectWeightGenerator>
log_log_expect_weight_tester(log_log_expect_weight_generator);
typedef ExpectationWeight<LogWeight, LogSparsePowerWeight>
LogLogSparseExpectWeight;
typedef ProductWeightGenerator<
LogWeightGenerator,
LogSparsePowerWeightGenerator,
LogLogSparseExpectWeight> LogLogSparseExpectWeightGenerator;
LogLogSparseExpectWeightGenerator log_logsparse_expect_weight_generator(seed);
WeightTester<LogLogSparseExpectWeight, LogLogSparseExpectWeightGenerator>
log_logsparse_expect_weight_tester(log_logsparse_expect_weight_generator);
// Test all product weight I/O with parentheses
FLAGS_fst_weight_parentheses = "()";
first_nested_product_weight_tester.Test(FLAGS_repeat);
nested_product_cube_weight_tester.Test(FLAGS_repeat);
log_sparse_power_weight_tester.Test(1);
sparse_nested_product_cube_weight_tester.Test(1);
tropical_product_weight_tester.Test(5);
second_nested_product_weight_tester.Test(5);
tropical_gallic_tester.Test(5);
tropical_cube_weight_tester.Test(5);
FLAGS_fst_weight_parentheses = "";
log_sparse_power_weight_tester.Test(1);
log_log_expect_weight_tester.Test(1, false); // disables division
log_logsparse_expect_weight_tester.Test(1, false);
typedef LexicographicWeight<TropicalWeight, TropicalWeight>
TropicalLexicographicWeight;
typedef LexicographicWeightGenerator<TropicalWeightGenerator,
TropicalWeightGenerator> TropicalLexicographicWeightGenerator;
TropicalLexicographicWeightGenerator tropical_lexicographic_generator(seed);
WeightTester<TropicalLexicographicWeight,
TropicalLexicographicWeightGenerator>
tropical_lexicographic_tester(tropical_lexicographic_generator);
tropical_lexicographic_tester.Test(FLAGS_repeat);
cout << "PASS" << endl;
return 0;
}