[ONNX] Adjust `is_train` flag for onnx pass deduplicate initializers
Previous logic didn't consider the case for TrainingMode.PRESERVE.
A more direct way is to check `model.training`, which is the accurate
training mode, set by `exporter_context(model, training)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74247
Approved by: https://github.com/garymm
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 6e0bdc3..cdbfe26 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -1284,6 +1284,13 @@
graph = onnx.load(io.BytesIO(f.getvalue()))
self.assertSetEqual(set([i.name for i in graph.graph.initializer]), param_name_set)
+ model.train()
+ f = io.BytesIO()
+ torch.onnx.export(model, (x,), f, training=TrainingMode.PRESERVE,
+ opset_version=self.opset_version)
+ graph = onnx.load(io.BytesIO(f.getvalue()))
+ self.assertSetEqual(set([i.name for i in graph.graph.initializer]), param_name_set)
+
# Test eval mode.
model.eval()
f = io.BytesIO()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 6bb2c29..313a378 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -818,8 +818,8 @@
# NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
node_attr_to_name = torch._C._jit_pass_onnx_function_extraction(
graph, export_modules_as_functions, list(params_dict.keys()))
- params_dict = torch._C._jit_pass_onnx_deduplicate_initializers(graph, params_dict,
- training == TrainingMode.TRAINING)
+ params_dict = torch._C._jit_pass_onnx_deduplicate_initializers(
+ graph, params_dict, getattr(model, "training", False))
if export_params:
proto, export_map, val_use_external_data_format, node_names = graph._export_onnx(
params_dict, opset_version, dynamic_axes, defer_weight_export,