blob: f38983b4a002e8d0e35bbd305933019a289459a1 [file] [log] [blame]
import logging
import random
import weakref
import functorch
import torch
from torch import _prims
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
from torch.overrides import TorchFunctionMode
from . import config
from .utils import decode_device, is_cpu_device
log = logging.getLogger(__name__)
class AutogradMonkeypatch(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if not kwargs:
kwargs = {}
return replace_fn(func, is_cpu_device(args))(*args, **kwargs)
patch_functions = AutogradMonkeypatch
def replace_fx(gm: torch.fx.GraphModule, example_inputs):
# Sometimes patch_functions() misses things already in the graph
changed = 0
is_cpu = is_cpu_device(example_inputs)
for node in reversed(list(gm.graph.nodes)):
if (
node.op == "call_function"
and replace_fn(node.target, is_cpu) is not node.target
):
with gm.graph.inserting_before(node):
node.replace_all_uses_with(
gm.graph.call_function(
replace_fn(node.target, is_cpu), node.args, node.kwargs
)
)
gm.graph.erase_node(node)
changed += 1
if changed:
gm.graph.lint()
gm.recompile()
return gm
def _philox_rand_like_meta(input, seed, offset):
return _prims.TensorMeta(input)
def _philox_rand_like(input, seed, offset):
# placeholder only used in tracing
return torch.rand_like(input)
philox_rand_like = _prims._make_prim(
schema="philox_rand_like(Tensor input, Tensor seed, SymInt offset) -> Tensor",
return_type=_prims.RETURN_TYPE.NEW,
meta=_philox_rand_like_meta,
impl_aten=_philox_rand_like,
doc="",
)
def _philox_seed_like_meta(x):
return _prims.TensorMeta(_philox_seed_like(x))
def _philox_seed_like(x):
# we need a tensor input here so AOT autograd properly captures this
# with just a device input, this becomes a constant
return torch.tensor(random.randrange(2**31), device=x.device, dtype=torch.int32)
philox_seed_like = _prims._make_prim(
schema="philox_seed_like(Tensor other) -> Tensor",
return_type=_prims.RETURN_TYPE.NEW,
meta=_philox_seed_like_meta,
impl_aten=_philox_seed_like,
doc="",
)
def null_ref():
return None
class PhiloxRandomState:
next_offset = 0
seed = {}
last_tracer_ref = null_ref
@classmethod
def reset(cls, tracer=None):
cls.next_offset = 0
cls.seed = {}
cls.last_tracer_ref = weakref.ref(tracer) if tracer is not None else null_ref
@classmethod
def get_seed_offset(cls, x, device=None):
modes = torch.fx.experimental.proxy_tensor.get_torch_dispatch_modes()
proxy_modes = [m for m in modes if isinstance(m, ProxyTorchDispatchMode)]
if proxy_modes:
tracer = proxy_modes[0].tracer
if cls.last_tracer_ref() is not tracer:
# tracer changed, need to reset state
cls.reset(tracer)
else:
# no tracer, need to reset state
cls.reset()
if device is None:
device = x.device
device = decode_device(device)
if device not in cls.seed:
# Compute the seed just once per trace so that we pass fewer
# things from forward to backward
cls.seed[device] = philox_seed_like(x)
seed = cls.seed[device]
offset = cls.next_offset
cls.next_offset += x.numel()
return seed, offset
class LowmemDropout(torch.autograd.Function):
@staticmethod
def forward(ctx, x, p):
ctx.p = p
scale = float(0.0) if p == 1.0 else float(1.0 / (1.0 - p))
seed, offset = PhiloxRandomState.get_seed_offset(x)
ctx.save_for_backward(seed)
ctx.offset = offset
bool_mask = philox_rand_like(x, seed, offset) > p
return bool_mask.to(x.dtype) * x * scale
@staticmethod
def backward(ctx, grad_output):
p = ctx.p
scale = float(0.0) if p == 1.0 else float(1.0 / (1.0 - p))
(seed,) = ctx.saved_tensors
bool_mask = philox_rand_like(grad_output, seed, ctx.offset) > p
return bool_mask.to(grad_output.dtype) * grad_output * scale, None
@torch.fx.wrap
def lowmem_dropout(input, p=0.5, training=True, inplace=False):
if isinstance(input, torch.fx.Proxy):
# double check we don't FX trace this
return input.tracer.create_proxy(
"call_function",
lowmem_dropout,
(input, p, training),
{},
)
if not training or p == 0:
return input
result = LowmemDropout.apply(input, p)
if inplace:
input.copy_(result)
return result
@torch.fx.wrap
def rand_like(x, **kwargs):
if isinstance(x, torch.fx.Proxy):
# double check we don't FX trace this
return x.tracer.create_proxy("call_function", rand_like, (x), kwargs)
device = kwargs.get("device", x.device)
seed, offset = PhiloxRandomState.get_seed_offset(x, device)
return philox_rand_like(x.to(device), seed, offset).to(
kwargs.get("dtype", torch.float32)
)
def replace_fn(fn, is_cpu):
"""
Perform any applicable replacements on `fn`
"""
if config.fallback_random:
return fn
if config.lowmem_dropout and fn is torch.nn.functional.dropout and not is_cpu:
return lowmem_dropout
replacements = {}
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
if not functorch.compile.config.functionalize_rng_ops:
replacements.update({torch.rand_like: rand_like})
return replacements.get(fn, fn)