blob: a19d1dcdccd863ef5e54f2312843b447dae483d1 [file] [log] [blame]
#include <ATen/Context.h>
#include <torch/xpu.h>
namespace torch::xpu {
size_t device_count() {
return at::detail::getXPUHooks().getNumGPUs();
}
bool is_available() {
return xpu::device_count() > 0;
}
void manual_seed(uint64_t seed) {
if (is_available()) {
auto index = at::detail::getXPUHooks().current_device();
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(index);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}
/// Sets the seed for all available GPUs.
void manual_seed_all(uint64_t seed) {
auto num_gpu = device_count();
for (const auto i : c10::irange(num_gpu)) {
auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(i);
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}
void synchronize(int64_t device_index) {
TORCH_CHECK(is_available(), "No XPU are available");
at::detail::getXPUHooks().deviceSynchronize(
static_cast<c10::DeviceIndex>(device_index));
}
} // namespace torch::xpu