Support a few corner cases for nvFuser executor (#84416)
This PR adds asserts to the `nvfuser_execute` function for the cases that do not work. Fallback to eager is used in those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84416
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
diff --git a/test/test_prims.py b/test/test_prims.py
index c3cb339..1d1b08c 100644
--- a/test/test_prims.py
+++ b/test/test_prims.py
@@ -7,7 +7,7 @@
import torch
from torch.testing import make_tensor
-from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY
+from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@@ -159,6 +159,75 @@
@onlyCUDA
@skipCUDAIfRocm
+ def test_nvfuser_empty_fusion(self, device):
+ from torch.fx.experimental.proxy_tensor import make_fx
+ from torch._prims.executor import execute
+
+ a = torch.randn(3, 3, device=device)
+
+ def func(a, b, c):
+ return (a, b, c)
+
+ gm = make_fx(func)(a, a, a)
+
+ with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"):
+ execute(gm, a, a, a, executor="strictly_nvfuser")
+
+ # Should pass with partitioned executor
+ out = execute(gm, a, a, a, executor="nvfuser")
+ self.assertEqual(out, (a, a, a))
+
+ @skipCUDAMemoryLeakCheckIf(True) # https://github.com/pytorch/pytorch/issues/84529
+ @onlyCUDA
+ @skipCUDAIfRocm
+ def test_nvfuser_no_args(self, device):
+ from torch._prims.context import TorchRefsNvfuserCapabilityMode
+ from torch.fx.experimental.proxy_tensor import make_fx
+ from torch._prims.executor import execute
+
+ a = torch.randn(3, 3, device=device)
+
+ def func():
+ return torch.sigmoid(a)
+
+ with TorchRefsNvfuserCapabilityMode():
+ gm = make_fx(func)()
+
+ with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
+ execute(gm, executor="strictly_nvfuser")
+
+ with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
+ execute(gm, a, executor="strictly_nvfuser")
+
+ # Should pass with partitioned executor
+ out = execute(gm, executor="nvfuser")
+ self.assertEqual(out, func())
+
+ @onlyCUDA
+ @skipCUDAIfRocm
+ def test_nvfuser_constant_tensors(self, device):
+ from torch._prims.context import TorchRefsNvfuserCapabilityMode
+ from torch.fx.experimental.proxy_tensor import make_fx
+ from torch._prims.executor import execute
+
+ a = torch.randn(3, 3, device=device)
+ b = torch.randn(3, 3, device=device)
+
+ def func(b):
+ return a + b
+
+ with TorchRefsNvfuserCapabilityMode():
+ gm = make_fx(func)(b)
+
+ with self.assertRaisesRegex(AssertionError, "not supported yet"):
+ execute(gm, b, executor="strictly_nvfuser")
+
+ # Should pass with partitioned executor
+ out = execute(gm, b, executor="nvfuser")
+ self.assertEqual(out, gm(b))
+
+ @onlyCUDA
+ @skipCUDAIfRocm
def test_nvfuser_executor_cached_noncontiguous(self, device):
# This test is to ensure that nvfuser computes correct results for noncontiguous tensors
from torch.fx.experimental.proxy_tensor import make_fx
diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py
index 7a8f032..72bc71e 100644
--- a/torch/_prims/nvfuser_executor.py
+++ b/torch/_prims/nvfuser_executor.py
@@ -51,10 +51,26 @@
return tree_map(to_nvfuser, args)
+def _any_get_attr_used(call_function_nodes):
+ return any(
+ filter(
+ # bug in mypy https://github.com/python/mypy/issues/12682
+ lambda n: any( # type: ignore[arg-type]
+ a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node) # type: ignore[attr-defined]
+ ),
+ call_function_nodes,
+ )
+ )
+
+
# MyPy bug: https://github.com/python/mypy/issues/5107
@lru_cache(maxsize=1024) # type: ignore[arg-type]
def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
- # PROTOTYPE nvfuser executor
+ if not torch.cuda.is_available():
+ raise RuntimeError(
+ "Attempting to use nvFuser trace executor but CUDA is not available!"
+ )
+
# Everything in the graph must support nvfuser
for node in gm.graph.nodes:
if node.op == "call_function" and "getitem" in node.name:
@@ -68,6 +84,21 @@
f"Node {node} with target {node.target} does not support nvfuser"
)
+ graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
+ call_function_nodes = list(
+ filter(lambda n: n.op == "call_function", gm.graph.nodes)
+ )
+ assert len(graph_input_nodes) == len(
+ nv_args_templates
+ ), "Number of placeholder nodes in the graph must match number of args"
+ assert len(nv_args_templates) > 0, "There must be at least one argument"
+ assert (
+ len(call_function_nodes) > 0
+ ), "Graph must contain at least one call_function node"
+ assert not _any_get_attr_used(
+ call_function_nodes
+ ), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
+
fusion = Fusion()
with FusionDefinition(fusion) as fd:
@@ -122,11 +153,6 @@
def nvfuser_execute(gm: GraphModule, *args):
- if not torch.cuda.is_available():
- raise RuntimeError(
- "Attempting to use nvFuser trace executor but CUDA is not available!"
- )
-
flat_args, _ = tree_flatten(args)
# Construction of the fusion is expensive and cached based on the GraphModule
@@ -180,11 +206,25 @@
@lru_cache() # type: ignore[arg-type]
def maybe_partition_graph(gm: GraphModule):
supported_ops = NvfuserPrimOperatorSupport()
- call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
+ call_function_nodes = list(
+ filter(lambda n: n.op == "call_function", gm.graph.nodes)
+ )
# the graph is partitioned only if at least one node is not supported by nvFuser
any_unsupported = any(
not supported_ops.is_node_supported(None, node) for node in call_function_nodes
)
+ any_unsupported |= len(call_function_nodes) == 0
+
+ # When there are constant tensors in the graph, we can't partition it
+ # because deepcopy fails. Here we just return the original graph to be
+ # executed by eager mode
+ # https://github.com/pytorch/pytorch/issues/84415
+ if (
+ _any_get_attr_used(call_function_nodes)
+ or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
+ ):
+ return gm, True
+
if any_unsupported:
# CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
gm = deepcopy(gm)