[ONNX] Adjust `is_train` flag for onnx pass deduplicate initializers

Previous logic didn't consider the case for TrainingMode.PRESERVE.
A more direct way is to check `model.training`, which is the accurate
training mode, set by `exporter_context(model, training)`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/74247
Approved by: https://github.com/garymm
diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py
index 6e0bdc3..cdbfe26 100644
--- a/test/onnx/test_utility_funs.py
+++ b/test/onnx/test_utility_funs.py
@@ -1284,6 +1284,13 @@
         graph = onnx.load(io.BytesIO(f.getvalue()))
         self.assertSetEqual(set([i.name for i in graph.graph.initializer]), param_name_set)
 
+        model.train()
+        f = io.BytesIO()
+        torch.onnx.export(model, (x,), f, training=TrainingMode.PRESERVE,
+                          opset_version=self.opset_version)
+        graph = onnx.load(io.BytesIO(f.getvalue()))
+        self.assertSetEqual(set([i.name for i in graph.graph.initializer]), param_name_set)
+
         # Test eval mode.
         model.eval()
         f = io.BytesIO()
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index 6bb2c29..313a378 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -818,8 +818,8 @@
                 # NOTE: cannot call DCE after this pass. DCE will remove function definition nodes.
                 node_attr_to_name = torch._C._jit_pass_onnx_function_extraction(
                     graph, export_modules_as_functions, list(params_dict.keys()))
-            params_dict = torch._C._jit_pass_onnx_deduplicate_initializers(graph, params_dict,
-                                                                           training == TrainingMode.TRAINING)
+            params_dict = torch._C._jit_pass_onnx_deduplicate_initializers(
+                graph, params_dict, getattr(model, "training", False))
             if export_params:
                 proto, export_map, val_use_external_data_format, node_names = graph._export_onnx(
                     params_dict, opset_version, dynamic_axes, defer_weight_export,