blob: 2492cc0e5d903307bb874742ba93b9dece432de8 [file] [log] [blame]
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "rocm/include/hiprand/hiprand.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_executor.h"
#include "tensorflow/stream_executor/gpu/gpu_helpers.h"
#include "tensorflow/stream_executor/gpu/gpu_rng.h"
#include "tensorflow/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/dso_loader.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rng.h"
#include "tensorflow/stream_executor/rocm/rocm_platform_id.h"
// Formats hiprandStatus_t to output prettified values into a log stream.
std::ostream& operator<<(std::ostream& in, const hiprandStatus_t& status) {
#define OSTREAM_HIPRAND_STATUS(__name) \
case HIPRAND_STATUS_##__name: \
in << "HIPRAND_STATUS_" #__name; \
return in;
switch (status) {
OSTREAM_HIPRAND_STATUS(SUCCESS)
OSTREAM_HIPRAND_STATUS(VERSION_MISMATCH)
OSTREAM_HIPRAND_STATUS(NOT_INITIALIZED)
OSTREAM_HIPRAND_STATUS(ALLOCATION_FAILED)
OSTREAM_HIPRAND_STATUS(TYPE_ERROR)
OSTREAM_HIPRAND_STATUS(OUT_OF_RANGE)
OSTREAM_HIPRAND_STATUS(LENGTH_NOT_MULTIPLE)
OSTREAM_HIPRAND_STATUS(LAUNCH_FAILURE)
OSTREAM_HIPRAND_STATUS(PREEXISTING_FAILURE)
OSTREAM_HIPRAND_STATUS(INITIALIZATION_FAILED)
OSTREAM_HIPRAND_STATUS(ARCH_MISMATCH)
OSTREAM_HIPRAND_STATUS(INTERNAL_ERROR)
default:
in << "hiprandStatus_t(" << static_cast<int>(status) << ")";
return in;
}
}
namespace stream_executor {
namespace gpu {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kGpuRandPlugin);
namespace wrap {
#ifdef PLATFORM_GOOGLE
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct WrapperShim__##__name { \
template <typename... Args> \
hiprandStatus_t operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return ::__name(args...); \
} \
} __name;
#else
#define STREAM_EXECUTOR_HIPRAND_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
using FuncPtrT = std::add_pointer<decltype(::__name)>::type; \
static void* GetDsoHandle() { \
auto s = internal::CachedDsoLoader::GetRocrandDsoHandle(); \
return s.ValueOrDie(); \
} \
static FuncPtrT LoadOrDie() { \
void* f; \
auto s = port::Env::Default()->GetSymbolFromLibrary(GetDsoHandle(), \
kName, &f); \
CHECK(s.ok()) << "could not find " << kName \
<< " in rocrand DSO; dlerror: " << s.error_message(); \
return reinterpret_cast<FuncPtrT>(f); \
} \
static FuncPtrT DynLoad() { \
static FuncPtrT f = LoadOrDie(); \
return f; \
} \
template <typename... Args> \
hiprandStatus operator()(GpuExecutor* parent, Args... args) { \
gpu::ScopedActivateExecutorContext sac{parent}; \
return DynLoad()(args...); \
} \
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
#endif
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandCreateGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandDestroyGenerator);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetStream);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniform);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateUniformDouble);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetPseudoRandomGeneratorSeed);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandSetGeneratorOffset);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormal);
STREAM_EXECUTOR_HIPRAND_WRAP(hiprandGenerateNormalDouble);
} // namespace wrap
GpuRng::GpuRng(GpuExecutor* parent) : parent_(parent), rng_(nullptr) {}
GpuRng::~GpuRng() {
if (rng_ != nullptr) {
wrap::hiprandDestroyGenerator(parent_, rng_);
}
}
bool GpuRng::Init() {
absl::MutexLock lock{&mu_};
CHECK(rng_ == nullptr);
hiprandStatus_t ret =
wrap::hiprandCreateGenerator(parent_, &rng_, HIPRAND_RNG_PSEUDO_DEFAULT);
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to create random number generator: " << ret;
return false;
}
CHECK(rng_ != nullptr);
return true;
}
bool GpuRng::SetStream(Stream* stream) {
hiprandStatus_t ret =
wrap::hiprandSetStream(parent_, rng_, AsGpuStreamValue(stream));
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set stream for random generation: " << ret;
return false;
}
return true;
}
// Returns true if std::complex stores its contents as two consecutive
// elements. Tests int, float and double, as the last two are independent
// specializations.
constexpr bool ComplexIsConsecutiveFloats() {
return sizeof(std::complex<int>) == 8 && sizeof(std::complex<float>) == 8 &&
sizeof(std::complex<double>) == 16;
}
template <typename T>
bool GpuRng::DoPopulateRandUniformInternal(Stream* stream, DeviceMemory<T>* v) {
absl::MutexLock lock{&mu_};
static_assert(ComplexIsConsecutiveFloats(),
"std::complex values are not stored as consecutive values");
if (!SetStream(stream)) {
return false;
}
// std::complex<T> is currently implemented as two consecutive T variables.
uint64 element_count = v->ElementCount();
if (std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value) {
element_count *= 2;
}
hiprandStatus_t ret;
if (std::is_same<T, float>::value ||
std::is_same<T, std::complex<float>>::value) {
ret = wrap::hiprandGenerateUniform(
parent_, rng_, reinterpret_cast<float*>(GpuMemoryMutable(v)),
element_count);
} else {
ret = wrap::hiprandGenerateUniformDouble(
parent_, rng_, reinterpret_cast<double*>(GpuMemoryMutable(v)),
element_count);
}
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to do uniform generation of " << v->ElementCount()
<< " " << TypeString<T>() << "s at " << v->opaque() << ": "
<< ret;
return false;
}
return true;
}
bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<float>* v) {
return DoPopulateRandUniformInternal(stream, v);
}
bool GpuRng::DoPopulateRandUniform(Stream* stream, DeviceMemory<double>* v) {
return DoPopulateRandUniformInternal(stream, v);
}
bool GpuRng::DoPopulateRandUniform(Stream* stream,
DeviceMemory<std::complex<float>>* v) {
return DoPopulateRandUniformInternal(stream, v);
}
bool GpuRng::DoPopulateRandUniform(Stream* stream,
DeviceMemory<std::complex<double>>* v) {
return DoPopulateRandUniformInternal(stream, v);
}
template <typename ElemT, typename FuncT>
bool GpuRng::DoPopulateRandGaussianInternal(Stream* stream, ElemT mean,
ElemT stddev,
DeviceMemory<ElemT>* v,
FuncT func) {
absl::MutexLock lock{&mu_};
if (!SetStream(stream)) {
return false;
}
uint64 element_count = v->ElementCount();
hiprandStatus_t ret =
func(parent_, rng_, GpuMemoryMutable(v), element_count, mean, stddev);
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to do gaussian generation of " << v->ElementCount()
<< " floats at " << v->opaque() << ": " << ret;
return false;
}
return true;
}
bool GpuRng::DoPopulateRandGaussian(Stream* stream, float mean, float stddev,
DeviceMemory<float>* v) {
return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
wrap::hiprandGenerateNormal);
}
bool GpuRng::DoPopulateRandGaussian(Stream* stream, double mean, double stddev,
DeviceMemory<double>* v) {
return DoPopulateRandGaussianInternal(stream, mean, stddev, v,
wrap::hiprandGenerateNormalDouble);
}
bool GpuRng::SetSeed(Stream* stream, const uint8* seed, uint64 seed_bytes) {
absl::MutexLock lock{&mu_};
CHECK(rng_ != nullptr);
if (!CheckSeed(seed, seed_bytes)) {
return false;
}
if (!SetStream(stream)) {
return false;
}
// Requires 8 bytes of seed data; checked in RngSupport::CheckSeed (above)
// (which itself requires 16 for API consistency with host RNG fallbacks).
hiprandStatus_t ret = wrap::hiprandSetPseudoRandomGeneratorSeed(
parent_, rng_, *(reinterpret_cast<const uint64*>(seed)));
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to set rng seed: " << ret;
return false;
}
ret = wrap::hiprandSetGeneratorOffset(parent_, rng_, 0);
if (ret != HIPRAND_STATUS_SUCCESS) {
LOG(ERROR) << "failed to reset rng position: " << ret;
return false;
}
return true;
}
} // namespace gpu
void initialize_rocrand() {
auto rocRandAlreadyRegistered = PluginRegistry::Instance()->HasFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
if (!rocRandAlreadyRegistered) {
port::Status status =
PluginRegistry::Instance()->RegisterFactory<PluginRegistry::RngFactory>(
rocm::kROCmPlatformId, gpu::kGpuRandPlugin, "rocRAND",
[](internal::StreamExecutorInterface* parent) -> rng::RngSupport* {
gpu::GpuExecutor* rocm_executor =
dynamic_cast<gpu::GpuExecutor*>(parent);
if (rocm_executor == nullptr) {
LOG(ERROR)
<< "Attempting to initialize an instance of the hipRAND "
<< "support library with a non-ROCM StreamExecutor";
return nullptr;
}
gpu::GpuRng* rng = new gpu::GpuRng(rocm_executor);
if (!rng->Init()) {
// Note: Init() will log a more specific error.
delete rng;
return nullptr;
}
return rng;
});
if (!status.ok()) {
LOG(ERROR) << "Unable to register rocRAND factory: "
<< status.error_message();
}
PluginRegistry::Instance()->SetDefaultFactory(
rocm::kROCmPlatformId, PluginKind::kRng, gpu::kGpuRandPlugin);
}
}
} // namespace stream_executor
REGISTER_MODULE_INITIALIZER(register_rocrand,
{ stream_executor::initialize_rocrand(); });