[dynamo] check buffers when checking accuracy (#91037)

Tested by running `python benchmarks/dynamo/torchbench.py --accuracy --float32 -dcuda --output=inductor_torchbench_float32_training_cuda_performance.csv --training --inductor --no-skip --dashboard --only mobilenet_v2 --cold_start_latency` and breakpointing after the changes to inspect buffers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91037
Approved by: https://github.com/anijain2305
diff --git a/torch/_dynamo/debug_utils.py b/torch/_dynamo/debug_utils.py
index e67a348..b9c9599 100644
--- a/torch/_dynamo/debug_utils.py
+++ b/torch/_dynamo/debug_utils.py
@@ -531,10 +531,16 @@
         gm.zero_grad(True)
 
     # TorchInductor returned callable expects lists. So, boxing the call.
-    if not hasattr(gm, "_boxed_call") and hasattr(gm, "named_parameters"):
-        orig_named_parameters = gm.named_parameters
+    orig_named_parameters = getattr(gm, "named_parameters", None)
+    orig_named_buffers = getattr(gm, "named_buffers", None)
+    if not hasattr(gm, "_boxed_call") and (
+        orig_named_parameters is not None or orig_named_buffers is not None
+    ):
         gm = make_boxed_func(gm)
-        gm.named_parameters = orig_named_parameters
+        if orig_named_parameters is not None:
+            gm.named_parameters = orig_named_parameters
+        if orig_named_buffers is not None:
+            gm.named_buffers = orig_named_buffers
 
     out = gm(args)
     if only_fwd:
@@ -550,14 +556,19 @@
     Check two models have same accuracy.
     """
     from .eval_frame import OptimizedModule
-    from .testing import named_parameters_for_optimized_module
+    from .testing import (
+        named_buffers_for_optimized_module,
+        named_parameters_for_optimized_module,
+    )
     from .utils import same
 
     if isinstance(gm, OptimizedModule):
         gm.named_parameters = named_parameters_for_optimized_module(gm)
+        gm.named_buffers = named_buffers_for_optimized_module(gm)
 
     if isinstance(opt_gm, OptimizedModule):
         opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
+        opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)
 
     ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
 
diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py
index 601ba1a..7a2b4b3 100644
--- a/torch/_dynamo/testing.py
+++ b/torch/_dynamo/testing.py
@@ -37,6 +37,11 @@
     return mod._orig_mod.named_parameters
 
 
+def named_buffers_for_optimized_module(mod):
+    assert isinstance(mod, eval_frame.OptimizedModule)
+    return mod._orig_mod.named_buffers
+
+
 def remove_optimized_module_prefix(name):
     prefix = "_orig_mod."
     assert name.startswith(prefix)
@@ -67,6 +72,12 @@
         params[name] = param_copy
     results.append(grads)
     results.append(params)
+    buffers = dict()
+    for name, buffer in model.named_buffers():
+        if isinstance(model, eval_frame.OptimizedModule):
+            name = remove_optimized_module_prefix(name)
+        buffers[name] = buffer
+    results.append(buffers)
     for example in example_inputs:
         if isinstance(example, (tuple, list)):
             for inp in example:
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 2972d38..318f397 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -2369,6 +2369,7 @@
     # Just for convenience
     forward.zero_grad = mod.zero_grad
     forward.named_parameters = mod.named_parameters
+    forward.named_buffers = mod.named_buffers
 
     return forward