Cast checkpoint weights to match model parameter's dtype (#122100)

Fixes #121986
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122100
Approved by: https://github.com/BowenBao
diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh
index c4a68fb..ccaba19 100755
--- a/.ci/docker/common/install_onnx.sh
+++ b/.ci/docker/common/install_onnx.sh
@@ -38,7 +38,7 @@
 # Cache the transformers model to be used later by ONNX tests. We need to run the transformers
 # package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
 IMPORT_SCRIPT_FILENAME="/tmp/onnx_import_script.py"
-as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2");' > "${IMPORT_SCRIPT_FILENAME}"
+as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3");' > "${IMPORT_SCRIPT_FILENAME}"
 
 # Need a PyTorch version for transformers to work
 pip_install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py
index 917dde6..b413345 100644
--- a/test/onnx/test_fx_to_onnx.py
+++ b/test/onnx/test_fx_to_onnx.py
@@ -677,6 +677,35 @@
 
         _ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
 
+    def test_checkpoint_cast(self):
+        model_id = "openai/whisper-large-v3"
+        feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128)
+        batch = 4
+
+        with torch.onnx.enable_fake_mode() as ctx:
+            model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
+                model_id, low_cpu_mem_usage=False, use_safetensors=False
+            )
+            input = {
+                "input_features": torch.randn(
+                    (
+                        batch,
+                        feature_extractor.feature_size,
+                        feature_extractor.nb_max_frames,
+                    )
+                ),
+                "decoder_input_ids": torch.tensor([[1, 1]]) * 8001,
+                "return_dict": False,
+            }
+
+        export_options = torch.onnx.ExportOptions(fake_context=ctx)
+        onnx_program = torch.onnx.dynamo_export(
+            model, **input, export_options=export_options
+        )
+        with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
+            onnx_program.save(tmp_onnx_file.name)
+            onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
+
 
 if __name__ == "__main__":
     common_utils.run_tests()
diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
index 72d2e5a..5c03c25 100644
--- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py
+++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
@@ -887,11 +887,6 @@
     @pytorch_test_common.xfail_dynamic_fx_test(
         error_message="shape_env should be set if tracing with 'symbolic'"
     )
-    @pytorch_test_common.xfail(
-        error_message="Type Error: Data in initializer 'h_0_attn_bias' has element type tensor(uint8) "
-        "but usage of initializer in graph expects tensor(bool)",
-        reason="https://github.com/huggingface/transformers/issues/21013",
-    )
     def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
         model_name = "sshleifer/tiny-gpt2"
         device = "cpu"
diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py
index d472095..e4c295e 100644
--- a/torch/onnx/_internal/fx/serialization.py
+++ b/torch/onnx/_internal/fx/serialization.py
@@ -3,7 +3,7 @@
 import io
 import logging
 import os
-from typing import Tuple, TYPE_CHECKING, Union
+from typing import Optional, Tuple, TYPE_CHECKING, Union
 
 import torch
 from torch.onnx import _type_utils as jit_type_utils
@@ -17,7 +17,11 @@
 
 @_beartype.beartype
 def _create_tensor_proto_with_external_data(
-    tensor: torch.Tensor, name: str, location: str, basepath: str
+    tensor: torch.Tensor,
+    name: str,
+    location: str,
+    basepath: str,
+    dtype_override: Optional["onnx.TypeProto"] = None,  # type: ignore[name-defined]
 ) -> onnx.TensorProto:  # type: ignore[name-defined]
     """Create a TensorProto with external data from a PyTorch tensor.
     The external data is saved to os.path.join(basepath, location).
@@ -41,11 +45,24 @@
     # FIXME: Avoid importing onnx into torch.onnx.
     import onnx
 
+    scalar_type = (
+        jit_type_utils.JitScalarType.from_onnx_type(
+            dtype_override.tensor_type.elem_type
+        )
+        if dtype_override is not None
+        else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
+    )
+
+    # Checkpoints can be stored with a different dtype as the model expects because
+    # the user script can explicitly cast the original type to something or maybe
+    # PyTorch's type promotion might do it
+    if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
+        tensor = tensor.to(scalar_type.dtype())
+
     tensor_proto = onnx.TensorProto()  # type: ignore[attr-defined]
     tensor_proto.name = name
-    tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
-        tensor.dtype
-    ).onnx_type()
+    tensor_proto.data_type = scalar_type.onnx_type()
+
     tensor_proto.dims.extend(tensor.shape)
     tensor_proto.data_location = onnx.TensorProto.EXTERNAL  # type: ignore[attr-defined]
 
@@ -200,8 +217,16 @@
             # Create one file per tensor.
             # tensor_proto.raw_data is stored to external file at
             # os.path.join(basepath, relative_tensor_file_path).
+            model_input_types = {
+                k.name: k.type for k in onnx_model_with_initializers.graph.input
+            }
+
             tensor_proto = _create_tensor_proto_with_external_data(
-                tensor, name, relative_tensor_file_path, basepath
+                tensor,
+                name,
+                relative_tensor_file_path,
+                basepath,
+                model_input_types.pop(name, None),
             )
             # Add the tensor_proto to the ONNX model as an initializer with external data.
             onnx_model_with_initializers.graph.initializer.append(tensor_proto)
diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py
index 0470793..d132325 100644
--- a/torch/onnx/_type_utils.py
+++ b/torch/onnx/_type_utils.py
@@ -159,6 +159,26 @@
 
     @classmethod
     @_beartype.beartype
+    def from_onnx_type(
+        cls, onnx_type: Optional[Union[int, _C_onnx.TensorProtoDataType]]
+    ) -> JitScalarType:
+        """Convert a ONNX data type to JitScalarType.
+
+        Args:
+            onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from
+
+        Returns:
+            JitScalarType
+
+        Raises:
+            OnnxExporterError: if dtype is not a valid torch.dtype or if it is None.
+        """
+        if onnx_type not in _ONNX_TO_SCALAR_TYPE:
+            raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}")
+        return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)]
+
+    @classmethod
+    @_beartype.beartype
     def from_value(
         cls, value: Union[None, torch._C.Value, torch.Tensor], default=None
     ) -> JitScalarType:
@@ -352,6 +372,8 @@
     JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
 }
 
+_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()}
+
 # source of truth is
 # https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp
 _SCALAR_TYPE_TO_DTYPE = {