blob: 4926214b349184e96567db53064d21251dc24786 [file] [log] [blame]
#include <ATen/Context.h>
#include <torch/mps.h>
namespace torch {
namespace mps {
bool is_available() {
return at::detail::getMPSHooks().hasMPS();
}
/// Sets the seed for the MPS's default generator.
void manual_seed(uint64_t seed) {
if (is_available()) {
auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator();
{
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen.mutex());
gen.set_current_seed(seed);
}
}
}
void synchronize() {
at::detail::getMPSHooks().deviceSynchronize();
}
void commit() {
at::detail::getMPSHooks().commitStream();
}
MTLCommandBuffer_t get_command_buffer() {
return at::detail::getMPSHooks().getCommandBuffer();
}
DispatchQueue_t get_dispatch_queue() {
return at::detail::getMPSHooks().getDispatchQueue();
}
} // namespace mps
} // namespace torch