blob: e1cf8d4362883ef508d6cf05d26c167e7f9f33d2 [file] [log] [blame]
from typing import Optional, Tuple
import torch
from torch import _prims
from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
from torch._prims_common.wrappers import backwards_not_supported
from torch.types import _device, _dtype
rngprim_namespace = "rngprims"
rngprim = torch.library.Library(rngprim_namespace, "DEF")
rngprim_impl = torch.library.Library(
rngprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
rngprim_autograd_impl = torch.library.Library(rngprim_namespace, "IMPL", "Autograd")
rngprim_meta_impl = torch.library.Library(rngprim_namespace, "IMPL", "Meta")
def throw_on_non_cuda(device):
raise RuntimeError(
f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
"not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
)
def register_philox_rand():
name = "philox_rand"
schema = "philox_rand(int[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> Tensor" # noqa: B950
rngprim.define(schema)
def _philox_rand_meta(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
stride = make_contiguous_strides_for(shape)
return _prims.TensorMeta(
shape=shape, strides=stride, dtype=dtype, device=device
)
def _philox_rand(
shape: torch.Size,
seed: torch.Tensor,
offset: torch.Tensor,
stride: Optional[Tuple[int, ...]],
device: _device,
dtype: _dtype,
):
# stride arg will be useful for distributed usecase. Currently, its unused.
assert stride is None
if device.type == "cpu":
devices = []
else:
devices = [device]
if device.type != "cuda":
raise throw_on_non_cuda(device)
with torch.random.fork_rng(devices):
CUDARngStateHelper.set_torch_state_tensor(seed, offset)
return torch.rand(shape, device=device, dtype=dtype)
rngprim_impl.impl(name, _philox_rand)
rngprim_meta_impl.impl(name, _philox_rand_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
prim = prim_packet.default
prim._tags = (torch.Tag.nondeterministic_seeded,) # type: ignore[attr-defined]
rngprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Philox based stateless rand operator"
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
p.schema = schema
p.prim_meta_impl = _philox_rand_meta
p.impl_aten = _philox_rand
def register_rng_prims():
register_philox_rand()