blob: a8239d44619a5f55a456d0d3209cec8839e9bb6f [file] [log] [blame]
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include "torch/csrc/jit/api/module.h"
#include "torch/csrc/jit/passes/xnnpack_rewrite.h"
#include "torch/csrc/jit/serialization/import.h"
C10_DEFINE_string(model, "", "The given torch script model to transform.");
C10_DEFINE_string(
output,
"",
"Name of the output model to be saved.");
C10_DEFINE_bool(
save_for_mobile,
false,
"Save the model with bytecode format compatible with lite inteprter.");
int main(int argc, char** argv) {
c10::SetUsageMessage(
"Run speed benchmark for pytorch model.\n"
"Example usage:\n"
"./optimize_for_mobile"
" --model=<model_file>"
" --output=<output_file_name>");
if (!c10::ParseCommandLineFlags(&argc, &argv)) {
std::cerr << "Failed to parse command line flags!" << std::endl;
return 1;
}
CAFFE_ENFORCE(FLAGS_model != "", "Valid input must be provided.");
std::string output_model_name =
FLAGS_model.substr(0, FLAGS_model.find(".")) + "_mobile_optimized.pt";
if (FLAGS_output != "") {
output_model_name = FLAGS_output;
}
auto module = torch::jit::load(FLAGS_model);
auto optimized_module = torch::jit::optimizeForMobile(module);
TORCH_CHECK(optimized_module.has_value(),
"Must return an optimized ScriptModule object.");
if (FLAGS_save_for_mobile) {
optimized_module->_save_for_mobile(output_model_name);
} else {
optimized_module->save(output_model_name);
}
return 0;
}