Specify if mismatch is input or output in export (#107145)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107145
Approved by: https://github.com/suo, https://github.com/gmagogsfm
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index a9b3605..840c0d8 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -916,7 +916,7 @@
 ):
     orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
 
-    def produce_matching(source_args, candidate_args):
+    def produce_matching(source_args, candidate_args, loc):
         matched_elements_positions = []
         dict_of_source_args = dict()
         for i in range(0, len(source_args)):
@@ -935,23 +935,27 @@
                     )
                 else:
                     raise AssertionError(
-                        "Dynamo input/output is not consistent with traced input/output"
+                        f"Dynamo {loc} is not consistent with traced {loc}"
                     )
             else:
                 assert (
                     id(arg) in dict_of_source_args
-                ), "Dynamo input and output is a strict subset of traced input/output"
+                ), f"Dynamo {loc} is a strict subset of traced {loc}"
                 matched_elements_positions.append(dict_of_source_args[id(arg)])
 
         return matched_elements_positions
 
-    matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
+    matched_input_elements_positions = produce_matching(
+        flat_args, graph_captured_input, "input"
+    )
 
     flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
 
     assert graph_captured_output is not None
     flat_both = list(graph_captured_output) + flat_args
-    matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
+    matched_output_elements_positions = produce_matching(
+        flat_both, flat_results_traced, "output"
+    )
 
     new_graph = FlattenInputOutputSignature(
         graph,