[compiled autograd] update benchmarks to use cli flags for fullgraph/dynamic (#127960)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127960
Approved by: https://github.com/jansel
diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py
index d520b1b..b70a077 100644
--- a/benchmarks/dynamo/common.py
+++ b/benchmarks/dynamo/common.py
@@ -714,7 +714,9 @@
maybe_mark_step(args)
with maybe_mark_profile(p=p, mark="actual"), maybe_enable_compiled_autograd(
- args.compiled_autograd
+ args.compiled_autograd,
+ fullgraph=args.nopython,
+ dynamic=args.dynamic_shapes,
):
timings[rep, 1], actual_output = timed(
model,
@@ -2586,7 +2588,11 @@
new_result = optimized_model_iter_fn(model_copy, example_inputs)
else:
optimized_model_iter_fn = optimize_ctx(self.run_n_iterations)
- with maybe_enable_compiled_autograd(self.args.compiled_autograd):
+ with maybe_enable_compiled_autograd(
+ self.args.compiled_autograd,
+ fullgraph=self.args.nopython,
+ dynamic=self.args.dynamic_shapes,
+ ):
new_result = optimized_model_iter_fn(model_copy, example_inputs)
except Exception as e:
log.exception("")
@@ -2804,7 +2810,9 @@
aot_compilation_time = 0
with maybe_enable_compiled_autograd(
- self.args.compiled_autograd
+ self.args.compiled_autograd,
+ fullgraph=self.args.nopython,
+ dynamic=self.args.dynamic_shapes,
), maybe_snapshot_memory(
self.args.snapshot_memory, f"compiled_{self.args.only}"
):
@@ -2824,7 +2832,11 @@
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU]
) as prof:
- with maybe_enable_compiled_autograd(self.args.compiled_autograd):
+ with maybe_enable_compiled_autograd(
+ self.args.compiled_autograd,
+ fullgraph=self.args.nopython,
+ dynamic=self.args.dynamic_shapes,
+ ):
warmup(optimized_model_iter_fn, model, example_inputs, "dynamo")
events = list(
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 6c85897..ebd69ea 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2648,19 +2648,22 @@
@contextlib.contextmanager
-def maybe_enable_compiled_autograd(should_enable):
- def compiler_fn(gm):
- def inner_compiler(gm_, example_inputs_):
- torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
- return torch._inductor.compile(gm_, example_inputs_)
+def maybe_enable_compiled_autograd(should_enable, fullgraph=True, dynamic=True):
+ if not should_enable:
+ yield
+ else:
- return torch.compile(gm, backend=inner_compiler, fullgraph=True, dynamic=True)
+ def compiler_fn(gm):
+ def inner_compiler(gm_, example_inputs_):
+ torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
+ return torch._inductor.compile(gm_, example_inputs_)
- if should_enable:
+ return torch.compile(
+ gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic
+ )
+
with torch._dynamo.compiled_autograd.enable(compiler_fn) as ctx:
yield ctx
- else:
- yield
def invalid_removeable_handle():