Support a few corner cases for nvFuser executor (#84416)

This PR adds asserts to the `nvfuser_execute` function for the cases that do not work. Fallback to eager is used in those cases.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84416
Approved by: https://github.com/jjsjann123, https://github.com/ngimel
diff --git a/test/test_prims.py b/test/test_prims.py
index c3cb339..1d1b08c 100644
--- a/test/test_prims.py
+++ b/test/test_prims.py
@@ -7,7 +7,7 @@
 
 import torch
 from torch.testing import make_tensor
-from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY
+from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
     onlyCUDA,
@@ -159,6 +159,75 @@
 
     @onlyCUDA
     @skipCUDAIfRocm
+    def test_nvfuser_empty_fusion(self, device):
+        from torch.fx.experimental.proxy_tensor import make_fx
+        from torch._prims.executor import execute
+
+        a = torch.randn(3, 3, device=device)
+
+        def func(a, b, c):
+            return (a, b, c)
+
+        gm = make_fx(func)(a, a, a)
+
+        with self.assertRaisesRegex(AssertionError, "Graph must contain at least one call_function node"):
+            execute(gm, a, a, a, executor="strictly_nvfuser")
+
+        # Should pass with partitioned executor
+        out = execute(gm, a, a, a, executor="nvfuser")
+        self.assertEqual(out, (a, a, a))
+
+    @skipCUDAMemoryLeakCheckIf(True)  # https://github.com/pytorch/pytorch/issues/84529
+    @onlyCUDA
+    @skipCUDAIfRocm
+    def test_nvfuser_no_args(self, device):
+        from torch._prims.context import TorchRefsNvfuserCapabilityMode
+        from torch.fx.experimental.proxy_tensor import make_fx
+        from torch._prims.executor import execute
+
+        a = torch.randn(3, 3, device=device)
+
+        def func():
+            return torch.sigmoid(a)
+
+        with TorchRefsNvfuserCapabilityMode():
+            gm = make_fx(func)()
+
+        with self.assertRaisesRegex(AssertionError, "There must be at least one argument"):
+            execute(gm, executor="strictly_nvfuser")
+
+        with self.assertRaisesRegex(AssertionError, "Number of placeholder nodes in the graph must match"):
+            execute(gm, a, executor="strictly_nvfuser")
+
+        # Should pass with partitioned executor
+        out = execute(gm, executor="nvfuser")
+        self.assertEqual(out, func())
+
+    @onlyCUDA
+    @skipCUDAIfRocm
+    def test_nvfuser_constant_tensors(self, device):
+        from torch._prims.context import TorchRefsNvfuserCapabilityMode
+        from torch.fx.experimental.proxy_tensor import make_fx
+        from torch._prims.executor import execute
+
+        a = torch.randn(3, 3, device=device)
+        b = torch.randn(3, 3, device=device)
+
+        def func(b):
+            return a + b
+
+        with TorchRefsNvfuserCapabilityMode():
+            gm = make_fx(func)(b)
+
+        with self.assertRaisesRegex(AssertionError, "not supported yet"):
+            execute(gm, b, executor="strictly_nvfuser")
+
+        # Should pass with partitioned executor
+        out = execute(gm, b, executor="nvfuser")
+        self.assertEqual(out, gm(b))
+
+    @onlyCUDA
+    @skipCUDAIfRocm
     def test_nvfuser_executor_cached_noncontiguous(self, device):
         # This test is to ensure that nvfuser computes correct results for noncontiguous tensors
         from torch.fx.experimental.proxy_tensor import make_fx
diff --git a/torch/_prims/nvfuser_executor.py b/torch/_prims/nvfuser_executor.py
index 7a8f032..72bc71e 100644
--- a/torch/_prims/nvfuser_executor.py
+++ b/torch/_prims/nvfuser_executor.py
@@ -51,10 +51,26 @@
     return tree_map(to_nvfuser, args)
 
 
+def _any_get_attr_used(call_function_nodes):
+    return any(
+        filter(
+            # bug in mypy https://github.com/python/mypy/issues/12682
+            lambda n: any(  # type: ignore[arg-type]
+                a.op == "get_attr" for a in n.args if isinstance(a, torch.fx.Node)  # type: ignore[attr-defined]
+            ),
+            call_function_nodes,
+        )
+    )
+
+
 # MyPy bug: https://github.com/python/mypy/issues/5107
 @lru_cache(maxsize=1024)  # type: ignore[arg-type]
 def make_nvfuser_fusion(gm: GraphModule, *nv_args_templates):
-    # PROTOTYPE nvfuser executor
+    if not torch.cuda.is_available():
+        raise RuntimeError(
+            "Attempting to use nvFuser trace executor but CUDA is not available!"
+        )
+
     # Everything in the graph must support nvfuser
     for node in gm.graph.nodes:
         if node.op == "call_function" and "getitem" in node.name:
@@ -68,6 +84,21 @@
                 f"Node {node} with target {node.target} does not support nvfuser"
             )
 
+    graph_input_nodes = list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))
+    call_function_nodes = list(
+        filter(lambda n: n.op == "call_function", gm.graph.nodes)
+    )
+    assert len(graph_input_nodes) == len(
+        nv_args_templates
+    ), "Number of placeholder nodes in the graph must match number of args"
+    assert len(nv_args_templates) > 0, "There must be at least one argument"
+    assert (
+        len(call_function_nodes) > 0
+    ), "Graph must contain at least one call_function node"
+    assert not _any_get_attr_used(
+        call_function_nodes
+    ), "Constant tensors that are saved in the graph and used as arguments are not supported yet"
+
     fusion = Fusion()
     with FusionDefinition(fusion) as fd:
 
@@ -122,11 +153,6 @@
 
 
 def nvfuser_execute(gm: GraphModule, *args):
-    if not torch.cuda.is_available():
-        raise RuntimeError(
-            "Attempting to use nvFuser trace executor but CUDA is not available!"
-        )
-
     flat_args, _ = tree_flatten(args)
 
     # Construction of the fusion is expensive and cached based on the GraphModule
@@ -180,11 +206,25 @@
 @lru_cache()  # type: ignore[arg-type]
 def maybe_partition_graph(gm: GraphModule):
     supported_ops = NvfuserPrimOperatorSupport()
-    call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
+    call_function_nodes = list(
+        filter(lambda n: n.op == "call_function", gm.graph.nodes)
+    )
     # the graph is partitioned only if at least one node is not supported by nvFuser
     any_unsupported = any(
         not supported_ops.is_node_supported(None, node) for node in call_function_nodes
     )
+    any_unsupported |= len(call_function_nodes) == 0
+
+    # When there are constant tensors in the graph, we can't partition it
+    # because deepcopy fails. Here we just return the original graph to be
+    # executed by eager mode
+    # https://github.com/pytorch/pytorch/issues/84415
+    if (
+        _any_get_attr_used(call_function_nodes)
+        or len(list(filter(lambda n: n.op == "placeholder", gm.graph.nodes))) == 0
+    ):
+        return gm, True
+
     if any_unsupported:
         # CapabilityBasedPartitioner modifies the graph in-place so we need to make a copy of the graph
         gm = deepcopy(gm)