blob: 9ec0c8f65705db335379649def746921e6b05bea [file] [log] [blame]
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include <stdint.h>
#include <algorithm>
#include <list>
#include <utility>
#include "absl/memory/memory.h"
#include "llvm/ExecutionEngine/ExecutionEngine.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/ExecutionEngine/SectionMemoryManager.h"
#include "llvm/IR/Mangler.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Host.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_key_value_sort.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_matmul_mkl.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_fft.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace cpu {
namespace {
llvm::SmallVector<std::string, 0> DetectMachineAttributes() {
llvm::SmallVector<std::string, 0> result;
llvm::StringMap<bool> host_features;
if (llvm::sys::getHostCPUFeatures(host_features)) {
for (auto& feature : host_features) {
if (feature.second) {
llvm::StringRef feature_name = feature.first();
// Skip avx512 for now, it isn't quite ready in LLVM.
if (feature_name.startswith("avx512")) {
continue;
}
result.push_back(feature_name);
}
}
}
return result;
}
llvm::StringRef GetHostCpuName() {
auto cpu_name = llvm::sys::getHostCPUName();
// Skip avx512 for now, it isn't quite ready in LLVM.
cpu_name.consume_back("-avx512");
return cpu_name;
}
} // namespace
/*static*/ std::unique_ptr<llvm::TargetMachine>
SimpleOrcJIT::InferTargetMachineForJIT(
const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level) {
std::unique_ptr<llvm::TargetMachine> target_machine(
llvm::EngineBuilder()
.setTargetOptions(target_options)
.setOptLevel(opt_level)
.selectTarget(
/*TargetTriple=*/llvm::Triple(), /*MArch=*/"",
/*MCPU=*/GetHostCpuName(),
/*MAttrs=*/DetectMachineAttributes()));
CHECK(target_machine != nullptr);
return target_machine;
}
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
llvm::CodeGenOpt::Level opt_level,
bool optimize_for_size, bool enable_fast_math,
bool disable_expensive_passes,
LLVMCompiler::ModuleHook pre_optimization_hook,
LLVMCompiler::ModuleHook post_optimization_hook)
: target_machine_(InferTargetMachineForJIT(target_options, opt_level)),
disassembler_(*target_machine_),
data_layout_(target_machine_->createDataLayout()),
symbol_resolver_(llvm::orc::createLegacyLookupResolver(
execution_session_,
[this](const std::string& name) -> llvm::JITSymbol {
return this->ResolveRuntimeSymbol(name);
},
[](llvm::Error Err) {
cantFail(std::move(Err), "lookupFlags failed");
})),
object_layer_(execution_session_,
[this](llvm::orc::VModuleKey) {
llvm::orc::RTDyldObjectLinkingLayer::Resources result;
result.MemMgr =
std::make_shared<llvm::SectionMemoryManager>(
orc_jit_memory_mapper::GetInstance());
result.Resolver = symbol_resolver_;
return result;
}),
compile_layer_(object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_,
opt_level, optimize_for_size,
enable_fast_math, disable_expensive_passes,
std::move(pre_optimization_hook),
std::move(post_optimization_hook))) {
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
<< " features: " << target_machine_->getTargetFeatureString().str();
}
llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name);
if (func_addr == nullptr) {
return nullptr;
}
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
llvm::JITSymbolFlags::None);
return symbol_info;
}
SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule(
std::unique_ptr<llvm::Module> module) {
auto key = execution_session_.allocateVModule();
cantFail(compile_layer_.addModule(key, std::move(module)));
module_keys_.push_back(key);
return key;
}
void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) {
module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key),
module_keys_.end());
cantFail(compile_layer_.removeModule(key));
}
llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) {
// Resolve symbol from last module to first, allowing later redefinitions of
// symbols shadow earlier ones.
for (auto& key :
llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) {
if (auto symbol =
compile_layer_.findSymbolIn(key, name,
/*ExportedSymbolsOnly=*/true)) {
return symbol;
}
}
return nullptr;
}
namespace {
// Register some known symbols with the CustomCallTargetRegistry.
bool RegisterKnownJITSymbols() {
CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global();
#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \
do { \
auto* function_address = \
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
function_address); \
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
"__xla_cpu_runtime_" #base_name); \
} while (false)
REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(MKLConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenFft);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(MKLMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(MKLSingleThreadedMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedFft);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF16);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortPRED);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS8);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU8);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF16);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF32);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortS64);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortU64);
REGISTER_CPU_RUNTIME_SYMBOL(KeyValueSortF64);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
#undef REGISTER_CPU_RUNTIME_SYMBOL
// Register both the f32 (float) and f64 (double) versions of a libm symbol.
// Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature
// for that case.
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
registry->Register( \
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
} while (false)
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
REGISTER_LIBM_SYMBOL(asin, double (*)(double));
REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
REGISTER_LIBM_SYMBOL(atan, double (*)(double));
REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
REGISTER_LIBM_SYMBOL(cos, double (*)(double));
REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
REGISTER_LIBM_SYMBOL(erf, double (*)(double));
REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
REGISTER_LIBM_SYMBOL(exp, double (*)(double));
REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
REGISTER_LIBM_SYMBOL(floor, double (*)(double));
REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int)
REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int)
REGISTER_LIBM_SYMBOL(log, double (*)(double));
REGISTER_LIBM_SYMBOL(log10, double (*)(double));
REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
REGISTER_LIBM_SYMBOL(log2, double (*)(double));
REGISTER_LIBM_SYMBOL(logb, double (*)(double));
REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int)
REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int)
REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
REGISTER_LIBM_SYMBOL(rint, double (*)(double));
REGISTER_LIBM_SYMBOL(round, double (*)(double));
REGISTER_LIBM_SYMBOL(scalbln,
double (*)(double, long)); // NOLINT(runtime/int)
REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
REGISTER_LIBM_SYMBOL(sin, double (*)(double));
#ifdef __APPLE__
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
#else
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
#endif
REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
REGISTER_LIBM_SYMBOL(tan, double (*)(double));
REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
#undef REGISTER_LIBM_SYMBOL
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
registry->Register("memmove", reinterpret_cast<void*>(memmove));
registry->Register("memset", reinterpret_cast<void*>(memset));
return true;
}
bool unused = RegisterKnownJITSymbols();
} // namespace
} // namespace cpu
} // namespace xla