| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/aot/compile.h" |
| |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "llvm-c/Target.h" |
| #include "tensorflow/compiler/aot/codegen.h" |
| #include "tensorflow/compiler/aot/flags.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla.h" |
| #include "tensorflow/compiler/tf2xla/tf2xla_util.h" |
| #include "tensorflow/compiler/xla/client/client_library.h" |
| #include "tensorflow/compiler/xla/client/compile_only_client.h" |
| #include "tensorflow/compiler/xla/client/xla_computation.h" |
| #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" |
| #include "tensorflow/compiler/xla/statusor.h" |
| #include "tensorflow/compiler/xla/util.h" |
| #include "tensorflow/compiler/xla/xla_data.pb.h" |
| #include "tensorflow/core/framework/graph.pb.h" |
| #include "tensorflow/core/lib/core/errors.h" |
| #include "tensorflow/core/lib/io/path.h" |
| #include "tensorflow/core/lib/strings/proto_serialization.h" |
| #include "tensorflow/core/platform/env.h" |
| #include "tensorflow/core/platform/logging.h" |
| #include "tensorflow/core/platform/types.h" |
| |
| namespace tensorflow { |
| namespace tfcompile { |
| |
| namespace { |
| |
| // Compiles the XLA computation into executable code. |
| Status CompileXla(xla::CompileOnlyClient* client, |
| const xla::XlaComputation& computation, |
| const xla::cpu::CpuAotCompilationOptions& aot_opts, |
| CompileResult* compile_result) { |
| // Retrieves arg and result layouts from the computation. |
| // TODO(toddw): Should we let the user choose the major/minor ordering? |
| xla::StatusOr<std::unique_ptr<xla::ProgramShape>> pshape_or = |
| client->GetComputationShape(computation); |
| if (!pshape_or.ok()) { |
| return errors::Unknown("Couldn't get XLA program shape: ", |
| pshape_or.status().error_message()); |
| } |
| compile_result->program_shape = pshape_or.ValueOrDie()->ToProto(); |
| xla::ProgramShapeProto* pshape = &compile_result->program_shape; |
| |
| // AotXlaComputationInstance::argument_layouts is a vector of Shape |
| // pointers. Accumulate the Shape objects themselves in a separate vector |
| // while building the vector of pointers. |
| std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size()); |
| std::vector<xla::Shape> arg_layouts(pshape->parameters_size()); |
| for (int i = 0; i < pshape->parameters_size(); ++i) { |
| arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i)); |
| arg_layout_ptrs[i] = &arg_layouts[i]; |
| } |
| xla::CompileOnlyClient::AotXlaComputationInstance instance; |
| instance.computation = &computation; |
| instance.argument_layouts = std::move(arg_layout_ptrs); |
| xla::Shape result_shape(pshape->result()); |
| instance.result_layout = &result_shape; |
| xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>> |
| aot_or = client->CompileAheadOfTime({instance}, aot_opts); |
| if (!aot_or.ok()) { |
| return errors::Unknown("XLA compilation failed: ", |
| aot_or.status().error_message()); |
| } |
| compile_result->aot = |
| xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>( |
| std::move(aot_or.ValueOrDie().back())); |
| compile_result->entry_point = aot_opts.entry_point_name(); |
| compile_result->pointer_size = |
| xla::CompileOnlyClient::PointerSizeForTriple(aot_opts.triple()); |
| return Status::OK(); |
| } |
| |
| } // namespace |
| |
| Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, |
| const MainFlags& flags, CompileResult* compile_result) { |
| // Converts the graph into an XLA computation, and compiles the |
| // computation. |
| // TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client? |
| se::Platform* cpu_platform = |
| se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie(); |
| xla::CompileOnlyClient* client = |
| xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) |
| .ValueOrDie(); |
| xla::XlaComputation computation; |
| if (flags.mlir_components == "Bridge") { |
| TF_RETURN_IF_ERROR( |
| ConvertGraphDefToXlaViaMlir(graph_def, config, &computation)); |
| } else { |
| if (!flags.mlir_components.empty()) { |
| return errors::Unknown("Unknown mlir_components ", flags.mlir_components); |
| } |
| TF_RETURN_IF_ERROR(ConvertGraphDefToXla(std::move(graph_def), config, |
| client, &computation)); |
| } |
| if (!flags.out_session_module.empty()) { |
| TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::HloSnapshot> module, |
| computation.Snapshot()); |
| // Serialize the HloSnapshot deterministically so that all the outputs of a |
| // tf_library genrule are deterministic. |
| const size_t size = module->ByteSizeLong(); |
| auto serialized = absl::make_unique<char[]>(size); |
| TF_RET_CHECK( |
| SerializeToBufferDeterministic(*module, serialized.get(), size)); |
| TF_RETURN_IF_ERROR( |
| WriteStringToFile(Env::Default(), flags.out_session_module, |
| absl::string_view(serialized.get(), size))); |
| } |
| xla::cpu::CpuAotCompilationOptions aot_opts( |
| flags.target_triple, flags.target_cpu, flags.target_features, |
| flags.entry_point, |
| xla::cpu::CpuAotCompilationOptions::RelocationModel::BigPic); |
| |
| return CompileXla(client, computation, aot_opts, compile_result); |
| } |
| |
| static Status ReadProtoFile(const string& fname, protobuf::Message* proto) { |
| if (absl::EndsWith(fname, ".pbtxt")) { |
| return ReadTextProto(Env::Default(), fname, proto); |
| } else { |
| return ReadBinaryProto(Env::Default(), fname, proto); |
| } |
| } |
| |
| static std::once_flag targets_init; |
| |
| static void InitializeTargets() { |
| // Initialize all LLVM targets so we can cross compile. |
| #if TF_LLVM_AARCH64_AVAILABLE |
| LLVMInitializeAArch64Target(); |
| LLVMInitializeAArch64TargetInfo(); |
| LLVMInitializeAArch64TargetMC(); |
| LLVMInitializeAArch64AsmPrinter(); |
| #endif |
| LLVMInitializeARMTarget(); |
| LLVMInitializeARMTargetInfo(); |
| LLVMInitializeARMTargetMC(); |
| LLVMInitializeARMAsmPrinter(); |
| LLVMInitializePowerPCTarget(); |
| LLVMInitializePowerPCTargetInfo(); |
| LLVMInitializePowerPCTargetMC(); |
| LLVMInitializePowerPCAsmPrinter(); |
| LLVMInitializeX86Target(); |
| LLVMInitializeX86TargetInfo(); |
| LLVMInitializeX86TargetMC(); |
| LLVMInitializeX86AsmPrinter(); |
| } |
| |
| Status Main(const MainFlags& flags) { |
| std::call_once(targets_init, &InitializeTargets); |
| |
| // Process config. |
| tf2xla::Config config; |
| if (flags.config.empty()) { |
| return errors::InvalidArgument("Must specify --config"); |
| } |
| TF_RETURN_IF_ERROR(ReadProtoFile(flags.config, &config)); |
| TF_RETURN_IF_ERROR(ValidateConfig(config)); |
| if (flags.dump_fetch_nodes) { |
| std::set<string> nodes; |
| for (const tf2xla::Fetch& fetch : config.fetch()) { |
| nodes.insert(fetch.id().node_name()); |
| } |
| std::cout << absl::StrJoin(nodes, ","); |
| return Status::OK(); |
| } |
| |
| // Read and initialize the graph. |
| if (flags.graph.empty()) { |
| return errors::InvalidArgument("Must specify --graph"); |
| } |
| GraphDef graph_def; |
| TF_RETURN_IF_ERROR(ReadProtoFile(flags.graph, &graph_def)); |
| CompileResult compile_result; |
| TF_RETURN_IF_ERROR( |
| CompileGraph(std::move(graph_def), config, flags, &compile_result)); |
| |
| // Write output files. |
| Env* env = Env::Default(); |
| const std::vector<char>& obj = compile_result.aot->object_file_data(); |
| TF_RETURN_IF_ERROR( |
| WriteStringToFile(env, flags.out_function_object, |
| absl::string_view(obj.data(), obj.size()))); |
| CodegenOpts codegen_opts; |
| codegen_opts.gen_name_to_index = flags.gen_name_to_index; |
| codegen_opts.gen_program_shape = flags.gen_program_shape; |
| codegen_opts.target_triple = flags.target_triple; |
| if (flags.cpp_class.empty()) { |
| return errors::InvalidArgument("Must specify --cpp_class"); |
| } |
| codegen_opts.gen_hlo_profile_printer_data = |
| xla::GetDebugOptionsFromFlags().xla_hlo_profile(); |
| TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, |
| &codegen_opts.namespaces)); |
| |
| MetadataResult metadata_result; |
| TF_RETURN_IF_ERROR( |
| GenerateMetadata(codegen_opts, compile_result, &metadata_result)); |
| TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, |
| metadata_result.object_file_data)); |
| string header; |
| TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, |
| metadata_result, &header)); |
| TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); |
| return Status::OK(); |
| } |
| |
| } // namespace tfcompile |
| } // namespace tensorflow |