blob: 20fa5ecc20c4685cdfe4f704d0ffaa32e3e13c96 [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This abstracts command line arguments in toco.
// Arg<T> is a parseable type that can register a default value, be able to
// parse itself, and keep track of whether it was specified.
#ifndef TENSORFLOW_LITE_TOCO_ARGS_H_
#define TENSORFLOW_LITE_TOCO_ARGS_H_
#include <functional>
#include <unordered_map>
#include <vector>
#include "tensorflow/lite/toco/toco_port.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/lite/toco/toco_types.h"
namespace toco {
// Since std::vector<int32> is in the std namespace, and we are not allowed
// to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
// to use as the flag type:
struct IntList {
std::vector<int32> elements;
};
struct StringMapList {
std::vector<std::unordered_map<string, string>> elements;
};
// command_line_flags.h don't track whether or not a flag is specified. Arg
// contains the value (which will be default if not specified) and also
// whether the flag is specified.
// TODO(aselle): consider putting doc string and ability to construct the
// tensorflow argument into this, so declaration of parameters can be less
// distributed.
// Every template specialization of Arg is required to implement
// default_value(), specified(), value(), parse(), bind().
template <class T>
class Arg final {
public:
explicit Arg(T default_ = T()) : value_(default_) {}
virtual ~Arg() {}
// Provide default_value() to arg list
T default_value() const { return value_; }
// Return true if the command line argument was specified on the command line.
bool specified() const { return specified_; }
// Const reference to parsed value.
const T& value() const { return value_; }
// Parsing callback for the tensorflow::Flags code
bool Parse(T value_in) {
value_ = value_in;
specified_ = true;
return true;
}
// Bind the parse member function so tensorflow::Flags can call it.
std::function<bool(T)> bind() {
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
private:
// Becomes true after parsing if the value was specified
bool specified_ = false;
// Value of the argument (initialized to the default in the constructor).
T value_;
};
template <>
class Arg<toco::IntList> final {
public:
// Provide default_value() to arg list
string default_value() const { return ""; }
// Return true if the command line argument was specified on the command line.
bool specified() const { return specified_; }
// Bind the parse member function so tensorflow::Flags can call it.
bool Parse(string text);
std::function<bool(string)> bind() {
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
const toco::IntList& value() const { return parsed_value_; }
private:
toco::IntList parsed_value_;
bool specified_ = false;
};
template <>
class Arg<toco::StringMapList> final {
public:
// Provide default_value() to StringMapList
string default_value() const { return ""; }
// Return true if the command line argument was specified on the command line.
bool specified() const { return specified_; }
// Bind the parse member function so tensorflow::Flags can call it.
bool Parse(string text);
std::function<bool(string)> bind() {
return std::bind(&Arg::Parse, this, std::placeholders::_1);
}
const toco::StringMapList& value() const { return parsed_value_; }
private:
toco::StringMapList parsed_value_;
bool specified_ = false;
};
// Flags that describe a model. See model_cmdline_flags.cc for details.
struct ParsedModelFlags {
Arg<string> input_array;
Arg<string> input_arrays;
Arg<string> output_array;
Arg<string> output_arrays;
Arg<string> input_shapes;
Arg<int> batch_size = Arg<int>(1);
Arg<float> mean_value = Arg<float>(0.f);
Arg<string> mean_values;
Arg<float> std_value = Arg<float>(1.f);
Arg<string> std_values;
Arg<string> input_data_type;
Arg<string> input_data_types;
Arg<bool> variable_batch = Arg<bool>(false);
Arg<toco::IntList> input_shape;
Arg<toco::StringMapList> rnn_states;
Arg<toco::StringMapList> model_checks;
Arg<bool> change_concat_input_ranges = Arg<bool>(true);
// Debugging output options.
// TODO(benoitjacob): these shouldn't be ModelFlags.
Arg<string> graphviz_first_array;
Arg<string> graphviz_last_array;
Arg<string> dump_graphviz;
Arg<bool> dump_graphviz_video = Arg<bool>(false);
Arg<string> conversion_summary_dir;
Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
Arg<string> arrays_extra_info_file;
Arg<string> model_flags_file;
};
// Flags that describe the operation you would like to do (what conversion
// you want). See toco_cmdline_flags.cc for details.
struct ParsedTocoFlags {
Arg<string> input_file;
Arg<string> savedmodel_directory;
Arg<string> output_file;
Arg<string> input_format = Arg<string>("TENSORFLOW_GRAPHDEF");
Arg<string> output_format = Arg<string>("TFLITE");
Arg<string> savedmodel_tagset;
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
Arg<float> default_int16_ranges_min = Arg<float>(0.);
Arg<float> default_int16_ranges_max = Arg<float>(0.);
Arg<string> inference_type;
Arg<string> inference_input_type;
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
Arg<string> custom_opdefs;
Arg<bool> post_training_quantize = Arg<bool>(false);
Arg<bool> quantize_to_float16 = Arg<bool>(false);
// Deprecated flags
Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
// WARNING: Experimental interface, subject to change
Arg<bool> enable_select_tf_ops = Arg<bool>(false);
// WARNING: Experimental interface, subject to change
Arg<bool> force_select_tf_ops = Arg<bool>(false);
};
} // namespace toco
#endif // TENSORFLOW_LITE_TOCO_ARGS_H_