| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/strings/match.h" |
| #include "absl/strings/str_join.h" |
| #include "absl/strings/string_view.h" |
| #include "third_party/llvm/llvm/include/llvm-c/Target.h" |
| #include "tensorflow/compiler/aot/codegen.h" |
| #include "tensorflow/compiler/aot/compile.h" |
| #include "tensorflow/compiler/aot/flags.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/xla/debug_options_flags.h" |
| #include "tensorflow/compiler/xla/service/compiler.h" |
| #include "tensorflow/core/framework/function.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/framework/tensor_shape.h" |
| #include "tensorflow/core/framework/types.h" |
| #include "tensorflow/core/graph/graph.h" |
| #include "tensorflow/core/graph/tensor_id.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/strings/numbers.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/init_main.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/protobuf.h" |
| #include "tensorflow/core/util/command_line_flags.h" |
| |
| namespace tensorflow { |
| namespace tfcompile { |
| |
| const char kUsageHeader[] = |
| "tfcompile performs ahead-of-time compilation of a TensorFlow graph,\n" |
| "resulting in an object file compiled for your target architecture, and a\n" |
| "header file that gives access to the functionality in the object file.\n" |
| "A typical invocation looks like this:\n" |
| "\n" |
| " $ tfcompile --graph=mygraph.pb --config=myfile.pbtxt " |
| "--cpp_class=\"mynamespace::MyComputation\"\n" |
| "\n"; |
| |
| Status ReadProtoFile(const string& fname, protobuf::Message* proto) { |
| if (absl::EndsWith(fname, ".pbtxt")) { |
| return ReadTextProto(Env::Default(), fname, proto); |
| } else { |
| return ReadBinaryProto(Env::Default(), fname, proto); |
| } |
| } |
| |
| Status Main(const MainFlags& flags) { |
| // Initialize all LLVM targets so we can cross compile. |
| LLVMInitializeAArch64Target(); |
| LLVMInitializeAArch64TargetInfo(); |
| LLVMInitializeAArch64TargetMC(); |
| LLVMInitializeAArch64AsmPrinter(); |
| LLVMInitializeARMTarget(); |
| LLVMInitializeARMTargetInfo(); |
| LLVMInitializeARMTargetMC(); |
| LLVMInitializeARMAsmPrinter(); |
| LLVMInitializePowerPCTarget(); |
| LLVMInitializePowerPCTargetInfo(); |
| LLVMInitializePowerPCTargetMC(); |
| LLVMInitializePowerPCAsmPrinter(); |
| LLVMInitializeX86Target(); |
| LLVMInitializeX86TargetInfo(); |
| LLVMInitializeX86TargetMC(); |
| LLVMInitializeX86AsmPrinter(); |
| |
| // Process config. |
| tf2xla::Config config; |
| if (flags.config.empty()) { |
| return errors::InvalidArgument("Must specify --config"); |
| } |
| TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); |
| TF_RETURN_IF_ERROR(ValidateConfig(config)); |
| if (flags.dump_fetch_nodes) { |
| std::set<string> nodes; |
| for (const tf2xla::Fetch& fetch : config.fetch()) { |
| nodes.insert(fetch.id().node_name()); |
| } |
| std::cout << absl::StrJoin(nodes, ","); |
| return Status::OK(); |
| } |
| |
| // Read and initialize the graph. |
| if (flags.graph.empty()) { |
| return errors::InvalidArgument("Must specify --graph"); |
| } |
| GraphDef graph_def; |
| TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); |
| CompileResult compile_result; |
| TF_RETURN_IF_ERROR(CompileGraph(graph_def, config, flags, &compile_result)); |
| |
| // Write output files. |
| Env* env = Env::Default(); |
| const std::vector<char>& obj = compile_result.aot->object_file_data(); |
| TF_RETURN_IF_ERROR( |
| WriteStringToFile(env, flags.out_function_object, |
| absl::string_view(obj.data(), obj.size()))); |
| CodegenOpts codegen_opts; |
| codegen_opts.gen_name_to_index = flags.gen_name_to_index; |
| codegen_opts.gen_program_shape = flags.gen_program_shape; |
| codegen_opts.target_triple = flags.target_triple; |
| if (flags.cpp_class.empty()) { |
| return errors::InvalidArgument("Must specify --cpp_class"); |
| } |
| codegen_opts.gen_hlo_profile_printer_data = |
| xla::GetDebugOptionsFromFlags().xla_hlo_profile(); |
| TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, |
| &codegen_opts.namespaces)); |
| |
| MetadataResult metadata_result; |
| TF_RETURN_IF_ERROR( |
| GenerateMetadata(codegen_opts, compile_result, &metadata_result)); |
| TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, |
| metadata_result.object_file_data)); |
| string header; |
| TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, |
| metadata_result, &header)); |
| TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); |
| return Status::OK(); |
| } |
| |
| } // end namespace tfcompile |
| } // end namespace tensorflow |
| |
| int main(int argc, char** argv) { |
| tensorflow::tfcompile::MainFlags flags; |
| flags.target_triple = "x86_64-pc-linux"; |
| flags.out_function_object = "out_model.o"; |
| flags.out_metadata_object = "out_helper.o"; |
| flags.out_header = "out.h"; |
| flags.entry_point = "entry"; |
| |
| std::vector<tensorflow::Flag> flag_list; |
| AppendMainFlags(&flag_list, &flags); |
| xla::AppendDebugOptionsFlags(&flag_list); |
| |
| tensorflow::string usage = tensorflow::tfcompile::kUsageHeader; |
| usage += tensorflow::Flags::Usage(argv[0], flag_list); |
| if (argc > 1 && absl::string_view(argv[1]) == "--help") { |
| std::cerr << usage << "\n"; |
| return 0; |
| } |
| bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); |
| QCHECK(parsed_flags_ok) << "\n" << usage; |
| |
| tensorflow::port::InitMain(usage.c_str(), &argc, &argv); |
| QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " |
| "other than flags\n\n" |
| << usage; |
| tensorflow::Status status = tensorflow::tfcompile::Main(flags); |
| if (status.code() == tensorflow::error::INVALID_ARGUMENT) { |
| std::cerr << "INVALID ARGUMENTS: " << status.error_message() << "\n\n" |
| << usage; |
| return 1; |
| } else { |
| TF_QCHECK_OK(status); |
| } |
| return 0; |
| } |