Specify if mismatch is input or output in export (#107145)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107145
Approved by: https://github.com/suo, https://github.com/gmagogsfm
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index a9b3605..840c0d8 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -916,7 +916,7 @@
):
orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec)
- def produce_matching(source_args, candidate_args):
+ def produce_matching(source_args, candidate_args, loc):
matched_elements_positions = []
dict_of_source_args = dict()
for i in range(0, len(source_args)):
@@ -935,23 +935,27 @@
)
else:
raise AssertionError(
- "Dynamo input/output is not consistent with traced input/output"
+ f"Dynamo {loc} is not consistent with traced {loc}"
)
else:
assert (
id(arg) in dict_of_source_args
- ), "Dynamo input and output is a strict subset of traced input/output"
+ ), f"Dynamo {loc} is a strict subset of traced {loc}"
matched_elements_positions.append(dict_of_source_args[id(arg)])
return matched_elements_positions
- matched_input_elements_positions = produce_matching(flat_args, graph_captured_input)
+ matched_input_elements_positions = produce_matching(
+ flat_args, graph_captured_input, "input"
+ )
flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result)
assert graph_captured_output is not None
flat_both = list(graph_captured_output) + flat_args
- matched_output_elements_positions = produce_matching(flat_both, flat_results_traced)
+ matched_output_elements_positions = produce_matching(
+ flat_both, flat_results_traced, "output"
+ )
new_graph = FlattenInputOutputSignature(
graph,