[compiled autograd] update benchmarks to use cli flags for fullgraph/dynamic (#127960)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127960
Approved by: https://github.com/jansel
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index d520b1b..b70a077 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -714,7 +714,9 @@
             maybe_mark_step(args)
 
             with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
-                args.compiled_autograd
+                args.compiled_autograd,
+                fullgraph=args.nopython,
+                dynamic=args.dynamic_shapes,
             ):
                 timings[rep, 1], actual_output = timed(
                     model,
@@ -2586,7 +2588,11 @@
                         new_result = optimized_model_iter_fn(model_copy, example_inputs)
                 else:
                     optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
-                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
+                    with maybe_enable_compiled_autograd(
+                        self.args.compiled_autograd,
+                        fullgraph=self.args.nopython,
+                        dynamic=self.args.dynamic_shapes,
+                    ):
                         new_result = optimized_model_iter_fn(model_copy, example_inputs)
             except Exception as e:
                 log.exception("")
@@ -2804,7 +2810,9 @@
                 aot_compilation_time = 0
 
             with maybe_enable_compiled_autograd(
-                self.args.compiled_autograd
+                self.args.compiled_autograd,
+                fullgraph=self.args.nopython,
+                dynamic=self.args.dynamic_shapes,
             ), maybe_snapshot_memory(
                 self.args.snapshot_memory, f"compiled_{self.args.only}"
             ):
@@ -2824,7 +2832,11 @@
                 with torch.profiler.profile(
                     activities=[torch.profiler.ProfilerActivity.CPU]
                 ) as prof:
-                    with maybe_enable_compiled_autograd(self.args.compiled_autograd):
+                    with maybe_enable_compiled_autograd(
+                        self.args.compiled_autograd,
+                        fullgraph=self.args.nopython,
+                        dynamic=self.args.dynamic_shapes,
+                    ):
                         warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
 
                 events = list(
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 6c85897..ebd69ea 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2648,19 +2648,22 @@
 
 
 @contextlib.contextmanager
-def maybe_enable_compiled_autograd(should_enable):
-    def compiler_fn(gm):
-        def inner_compiler(gm_, example_inputs_):
-            torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
-            return torch._inductor.compile(gm_, example_inputs_)
+def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True):
+    if not should_enable:
+        yield
+    else:
 
-        return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True)
+        def compiler_fn(gm):
+            def inner_compiler(gm_, example_inputs_):
+                torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
+                return torch._inductor.compile(gm_, example_inputs_)
 
-    if should_enable:
+            return torch.compile(
+                gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic
+            )
+
         with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx:
             yield ctx
-    else:
-        yield
 
 
 def invalid_removeable_handle():