| /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" |
| |
| #include <stdint.h> |
| |
| #include <algorithm> |
| #include <list> |
| #include <utility> |
| |
| #include "absl/memory/memory.h" |
| #include "llvm/ExecutionEngine/ExecutionEngine.h" |
| #include "llvm/ExecutionEngine/JITSymbol.h" |
| #include "llvm/ExecutionEngine/SectionMemoryManager.h" |
| #include "llvm/IR/Mangler.h" |
| #include "llvm/Support/CodeGen.h" |
| #include "llvm/Support/Host.h" |
| #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" |
| #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h" |
| #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" |
| #include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" |
| #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" |
| #include "tensorflow/compiler/xla/types.h" |
| #include "tensorflow/core/platform/logging.h" |
| |
| namespace xla { |
| namespace cpu { |
| namespace { |
| |
| llvm::SmallVector<std::string, 0> DetectMachineAttributes() { |
| llvm::SmallVector<std::string, 0> result; |
| llvm::StringMap<bool> host_features; |
| if (llvm::sys::getHostCPUFeatures(host_features)) { |
| for (auto& feature : host_features) { |
| if (feature.second) { |
| result.push_back(feature.first()); |
| } |
| } |
| } |
| return result; |
| } |
| |
| } // namespace |
| |
| /*static*/ std::unique_ptr<llvm::TargetMachine> |
| SimpleOrcJIT::InferTargetMachineForJIT( |
| const llvm::TargetOptions& target_options, |
| llvm::CodeGenOpt::Level opt_level) { |
| std::unique_ptr<llvm::TargetMachine> target_machine( |
| llvm::EngineBuilder() |
| .setTargetOptions(target_options) |
| .setOptLevel(opt_level) |
| .selectTarget( |
| /*TargetTriple=*/llvm::Triple(), /*MArch=*/"", |
| /*MCPU=*/llvm::sys::getHostCPUName(), |
| /*MAttrs=*/DetectMachineAttributes())); |
| CHECK(target_machine != nullptr); |
| return target_machine; |
| } |
| |
| SimpleOrcJIT::SimpleOrcJIT( |
| const llvm::TargetOptions& target_options, |
| llvm::CodeGenOpt::Level opt_level, bool optimize_for_size, |
| bool disable_expensive_passes, |
| LLVMCompiler::ModuleHook pre_optimization_hook, |
| LLVMCompiler::ModuleHook post_optimization_hook, |
| std::function<void(const llvm::object::ObjectFile&)> post_codegen_hook) |
| : target_machine_(InferTargetMachineForJIT(target_options, opt_level)), |
| data_layout_(target_machine_->createDataLayout()), |
| symbol_resolver_(llvm::orc::createLegacyLookupResolver( |
| execution_session_, |
| [this](const std::string& name) -> llvm::JITSymbol { |
| return this->ResolveRuntimeSymbol(name); |
| }, |
| [](llvm::Error Err) { |
| cantFail(std::move(Err), "lookupFlags failed"); |
| })), |
| object_layer_( |
| execution_session_, |
| [this](llvm::orc::VModuleKey) { |
| llvm::orc::LegacyRTDyldObjectLinkingLayer::Resources result; |
| result.MemMgr = std::make_shared<llvm::SectionMemoryManager>( |
| orc_jit_memory_mapper::GetInstance()); |
| result.Resolver = symbol_resolver_; |
| return result; |
| }, |
| /*NotifyLoaded=*/ |
| llvm::orc::LegacyRTDyldObjectLinkingLayer::NotifyLoadedFtor(), |
| /*NotifyFinalized=*/ |
| [this](VModuleKeyT, const llvm::object::ObjectFile& object, |
| const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { |
| this->NotifyObjectFinalized(object, object_info); |
| }, |
| /*NotifyFreed=*/ |
| [this](VModuleKeyT, const llvm::object::ObjectFile& object) { |
| this->NotifyObjectFreed(object); |
| }), |
| compile_layer_( |
| object_layer_, |
| CompilerFunctor( |
| target_machine_.get(), opt_level, optimize_for_size, |
| disable_expensive_passes, std::move(pre_optimization_hook), |
| std::move(post_optimization_hook), std::move(post_codegen_hook))), |
| gdb_jit_event_listener_( |
| llvm::JITEventListener::createGDBRegistrationListener()) { |
| VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() |
| << " features: " << target_machine_->getTargetFeatureString().str(); |
| } |
| |
| llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { |
| void* func_addr = nullptr; |
| if (name.size() > 1 && name.front() == data_layout_.getGlobalPrefix()) { |
| // On Mac OS X, 'name' may have a leading underscore prefix, even though the |
| // registered name may not. |
| std::string stripped_name(name.begin() + 1, name.end()); |
| func_addr = |
| xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host"); |
| } else { |
| func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host"); |
| } |
| |
| if (func_addr == nullptr) { |
| LOG(ERROR) |
| << "Unable to resolve runtime symbol: `" << name |
| << "'. Hint: if the symbol a custom call target, make sure you've " |
| "registered it with the JIT using " |
| "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; |
| return nullptr; |
| } |
| llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr), |
| llvm::JITSymbolFlags::None); |
| return symbol_info; |
| } |
| |
| void SimpleOrcJIT::NotifyObjectFinalized( |
| const llvm::object::ObjectFile& object, |
| const llvm::RuntimeDyld::LoadedObjectInfo& object_info) { |
| uint64_t key = static_cast<uint64_t>( |
| reinterpret_cast<uintptr_t>(object.getData().data())); |
| gdb_jit_event_listener_->notifyObjectLoaded(key, object, object_info); |
| } |
| |
| void SimpleOrcJIT::NotifyObjectFreed(const llvm::object::ObjectFile& object) { |
| uint64_t key = static_cast<uint64_t>( |
| reinterpret_cast<uintptr_t>(object.getData().data())); |
| gdb_jit_event_listener_->notifyFreeingObject(key); |
| } |
| |
| SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( |
| std::unique_ptr<llvm::Module> module) { |
| auto key = execution_session_.allocateVModule(); |
| cantFail(compile_layer_.addModule(key, std::move(module))); |
| module_keys_.push_back(key); |
| return key; |
| } |
| |
| void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { |
| module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), |
| module_keys_.end()); |
| cantFail(compile_layer_.removeModule(key)); |
| } |
| |
| llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { |
| // Resolve symbol from last module to first, allowing later redefinitions of |
| // symbols shadow earlier ones. |
| for (auto& key : |
| llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { |
| if (auto symbol = |
| compile_layer_.findSymbolIn(key, name, |
| /*ExportedSymbolsOnly=*/true)) { |
| return symbol; |
| } |
| } |
| |
| return nullptr; |
| } |
| |
| namespace { |
| // Register some known symbols with the CustomCallTargetRegistry. |
| bool RegisterKnownJITSymbols() { |
| xla::CustomCallTargetRegistry* registry = |
| xla::CustomCallTargetRegistry::Global(); |
| |
| #define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ |
| do { \ |
| auto* function_address = \ |
| reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ |
| registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ |
| function_address, "Host"); \ |
| CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ |
| "__xla_cpu_runtime_" #base_name); \ |
| } while (false) |
| |
| REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); |
| REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); |
| REGISTER_CPU_RUNTIME_SYMBOL(AllReduce); |
| REGISTER_CPU_RUNTIME_SYMBOL(ReplicaId); |
| REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulS32); |
| REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64); |
| REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); |
| REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulS32); |
| REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); |
| REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); |
| REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); |
| REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSort); |
| REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); |
| REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); |
| |
| registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee), |
| "Host"); |
| registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee), |
| "Host"); |
| |
| #undef REGISTER_CPU_RUNTIME_SYMBOL |
| |
| // Register both the f32 (float) and f64 (double) versions of a libm symbol. |
| // Unfortunately the double versions are overloaded on some systems, e.g. |
| // Mac so we need an explicit cast. This requires passing the function signature |
| // for that case. |
| #define REGISTER_LIBM_SYMBOL(name, double_sig) \ |
| do { \ |
| registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \ |
| registry->Register(#name, \ |
| reinterpret_cast<void*>(static_cast<double_sig>(name)), \ |
| "Host"); \ |
| } while (false) |
| |
| REGISTER_LIBM_SYMBOL(acos, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(acosh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(asin, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(asinh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(atan, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(atanh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(cbrt, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(ceil, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(cos, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(cosh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(erf, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(erfc, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(exp, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(exp2, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(expm1, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(fabs, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(floor, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double)); |
| REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*)); |
| REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); |
| REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); |
| REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int) |
| REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int) |
| REGISTER_LIBM_SYMBOL(log, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(log10, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(log2, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(logb, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int) |
| REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int) |
| REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); |
| REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); |
| REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double)); |
| REGISTER_LIBM_SYMBOL(pow, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double)); |
| REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); |
| REGISTER_LIBM_SYMBOL(rint, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(round, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(scalbln, |
| double (*)(double, long)); // NOLINT(runtime/int) |
| REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); |
| REGISTER_LIBM_SYMBOL(sin, double (*)(double)); |
| #ifdef __APPLE__ |
| REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); |
| registry->Register("__sincosf_stret", |
| reinterpret_cast<void*>(__sincosf_stret), "Host"); |
| registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret), |
| "Host"); |
| #else |
| REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); |
| #endif |
| REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(tan, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(tanh, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(tgamma, double (*)(double)); |
| REGISTER_LIBM_SYMBOL(trunc, double (*)(double)); |
| |
| #undef REGISTER_LIBM_SYMBOL |
| |
| registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host"); |
| registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host"); |
| registry->Register("memset", reinterpret_cast<void*>(memset), "Host"); |
| |
| #ifdef __APPLE__ |
| registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host"); |
| registry->Register("memset_pattern16", |
| reinterpret_cast<void*>(memset_pattern16), "Host"); |
| #endif |
| |
| #ifdef MEMORY_SANITIZER |
| registry->Register("__msan_unpoison", |
| reinterpret_cast<void*>(__msan_unpoison), "Host"); |
| #endif |
| |
| return true; |
| } |
| |
| bool unused = RegisterKnownJITSymbols(); |
| } // namespace |
| |
| } // namespace cpu |
| } // namespace xla |