blob: 7e3666ef98493dd98011dbd9739c8e68abee95fa [file] [log] [blame]
#include <ATen/native/DispatchStub.h>
#include <c10/util/Exception.h>
#include <cpuinfo.h>
#include <cstdlib>
#include <cstring>
namespace at { namespace native {
static CPUCapability compute_cpu_capability() {
auto envar = std::getenv("ATEN_CPU_CAPABILITY");
if (envar) {
#ifdef HAVE_VSX_CPU_DEFINITION
if (strcmp(envar, "vsx") == 0) {
return CPUCapability::VSX;
}
#else
if (strcmp(envar, "avx2") == 0) {
return CPUCapability::AVX2;
}
if (strcmp(envar, "avx") == 0) {
return CPUCapability::AVX;
}
#endif
if (strcmp(envar, "default") == 0) {
return CPUCapability::DEFAULT;
}
TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar);
}
#if !defined(__powerpc__) && !defined(__s390x__)
if (cpuinfo_initialize()) {
if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) {
return CPUCapability::AVX2;
}
if (cpuinfo_has_x86_avx()) {
return CPUCapability::AVX;
}
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
return CPUCapability::VSX;
#else
return CPUCapability::DEFAULT;
#endif
}
CPUCapability get_cpu_capability() {
static CPUCapability capability = compute_cpu_capability();
return capability;
}
void* DispatchStubImpl::get_call_ptr(
DeviceType device_type
, void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION
, void *AVX
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
) {
switch (device_type) {
case DeviceType::CPU: {
// Use memory_order_relaxed here since even if two threads race,
// they will still compute the same value for cpu_dispatch_ptr.
auto fptr = cpu_dispatch_ptr.load(std::memory_order_relaxed);
if (!fptr) {
fptr = choose_cpu_impl(
DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION
, AVX
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, VSX
#endif
);
cpu_dispatch_ptr.store(fptr, std::memory_order_relaxed);
}
return fptr;
}
case DeviceType::CUDA:
TORCH_INTERNAL_ASSERT(cuda_dispatch_ptr, "DispatchStub: missing CUDA kernel");
return cuda_dispatch_ptr;
case DeviceType::HIP:
TORCH_INTERNAL_ASSERT(hip_dispatch_ptr, "DispatchStub: missing HIP kernel");
return hip_dispatch_ptr;
default:
AT_ERROR("DispatchStub: unsupported device type", device_type);
}
}
void* DispatchStubImpl::choose_cpu_impl(
void *DEFAULT
#ifdef HAVE_AVX_CPU_DEFINITION
, void *AVX
#endif
#ifdef HAVE_AVX2_CPU_DEFINITION
, void *AVX2
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
, void *VSX
#endif
) {
auto capability = static_cast<int>(get_cpu_capability());
(void)capability;
#ifdef HAVE_AVX2_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX2)) {
TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel");
return AVX2;
}
#endif
#ifdef HAVE_AVX_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::AVX)) {
TORCH_INTERNAL_ASSERT(AVX, "DispatchStub: missing AVX kernel");
return AVX;
}
#endif
#ifdef HAVE_VSX_CPU_DEFINITION
if (capability >= static_cast<int>(CPUCapability::VSX)) {
TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel");
return VSX;
}
#endif
TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel");
return DEFAULT;
}
}} // namespace at::native