blob: 3baccd649f31c3e96a247bab8c83752e6a122289 [file] [log] [blame]
from typing import Callable
import torch
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx
from torch._prims.utils import getnvFuserDtype, Number
from torch._prims.context import TorchRefsMode
import torch.overrides
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
if torch.cuda.is_available():
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
def execute(gm: GraphModule, *args, executor: str = "aten"):
"""
Prototype ATen executor.
Just executes the context's graph.
"""
if executor == "aten":
return gm.forward(*args)
elif executor == "nvfuser":
if not torch.cuda.is_available():
raise RuntimeError(
"Attempting to use nvFuser trace executor but CUDA is not available!"
)
# PROTOTYPE nvfuser executor
# Everything in the graph must support nvfuser
fusion = Fusion()
with FusionDefinition(fusion) as fd:
def _to_nvfuser_constant(arg):
if isinstance(arg, Number):
return fd.define_constant(arg)
else:
return arg
class FusionInterpreter(torch.fx.Interpreter):
def call_function(self, target, args, kwargs):
args = tuple(map(_to_nvfuser_constant, args))
target = target.impl_nvfuser
args = (fd,) + args
return target(*args, **kwargs)
def to_nv(arg):
if isinstance(arg, torch.Tensor):
x = fd.define_tensor(
arg.size(), arg.stride(), getnvFuserDtype(arg.dtype)
)
fd.add_input(x)
return x
else:
return arg
# Transforms graph to call nvfuser lowerings
# Note, this doesn't handle nested structures in the args, TODO: add tree_flatten
nv_args = tree_map(to_nv, args)
out = FusionInterpreter(gm).run(*nv_args)
flat_out, unflatten_spec = tree_flatten(out)
for o in flat_out:
fd.add_output(o)
assert len(args) == 1
args = args[0] # we are passing a packed list of args
return tree_unflatten(
fusion.execute(
tuple(arg for arg in args if isinstance(arg, torch.Tensor))
),
unflatten_spec,
)
msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
executor
)
raise ValueError(msg)
def make_traced(fn: Callable):
"""
Returns a function that, when called, will
trace its torch operations to prims and then
execute those prims on the requested trace executor
(possibly lowering them to that trace executor first).
Only supports the torch operations defined in _torch_to_reference_map
in context.py and operations with positional args. All args must
be tensors.
In the near future all these restrictions will be lifted.
Example usage:
def foo(a, b):
return torch.add(a, b)
traced_foo = make_traced(foo)
a = torch.randn((1, 2, 3, 4, 5), device='cuda')
b = torch.randn((1, 2, 3, 4, 5), device='cuda')
result = traced_foo(a, b, executor='nvfuser')
Executor may be either 'aten' or 'nvfuser'.
"""
def _traced(*args, executor="aten", **kwargs):
# TODO: caching
nargs = len(args)
fn_kwargs = kwargs
flat_fn_kwargs = list(fn_kwargs.values())
all_args = list(args) + flat_fn_kwargs
def wrapped(args):
fn_args = args[:nargs]
kwargs_keys = list(fn_kwargs.keys())
kwargs = dict(zip(kwargs_keys, args[nargs:]))
return fn(*fn_args, **kwargs)
with TorchRefsMode.push():
gm = make_fx(wrapped)(all_args)
return execute(gm, all_args, executor=executor)
return _traced