Cast checkpoint weights to match model parameter's dtype (#122100)
Fixes #121986
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122100
Approved by: https://github.com/BowenBao
diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh
index c4a68fb..ccaba19 100755
--- a/.ci/docker/common/install_onnx.sh
+++ b/.ci/docker/common/install_onnx.sh
@@ -38,7 +38,7 @@
# Cache the transformers model to be used later by ONNX tests. We need to run the transformers
# package to download the model. By default, the model is cached at ~/.cache/huggingface/hub/
IMPORT_SCRIPT_FILENAME="/tmp/onnx_import_script.py"
-as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2");' > "${IMPORT_SCRIPT_FILENAME}"
+as_jenkins echo 'import transformers; transformers.AutoModel.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2"); transformers.AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-large-v3");' > "${IMPORT_SCRIPT_FILENAME}"
# Need a PyTorch version for transformers to work
pip_install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py
index 917dde6..b413345 100644
--- a/test/onnx/test_fx_to_onnx.py
+++ b/test/onnx/test_fx_to_onnx.py
@@ -677,6 +677,35 @@
_ = torch.onnx.dynamo_export(Float8Module(), torch.randn(1, 2, 3, 4))
+ def test_checkpoint_cast(self):
+ model_id = "openai/whisper-large-v3"
+ feature_extractor = transformers.WhisperFeatureExtractor(feature_size=128)
+ batch = 4
+
+ with torch.onnx.enable_fake_mode() as ctx:
+ model = transformers.AutoModelForSpeechSeq2Seq.from_pretrained(
+ model_id, low_cpu_mem_usage=False, use_safetensors=False
+ )
+ input = {
+ "input_features": torch.randn(
+ (
+ batch,
+ feature_extractor.feature_size,
+ feature_extractor.nb_max_frames,
+ )
+ ),
+ "decoder_input_ids": torch.tensor([[1, 1]]) * 8001,
+ "return_dict": False,
+ }
+
+ export_options = torch.onnx.ExportOptions(fake_context=ctx)
+ onnx_program = torch.onnx.dynamo_export(
+ model, **input, export_options=export_options
+ )
+ with tempfile.NamedTemporaryFile(suffix=".onnx") as tmp_onnx_file:
+ onnx_program.save(tmp_onnx_file.name)
+ onnx.checker.check_model(tmp_onnx_file.name, full_check=True)
+
if __name__ == "__main__":
common_utils.run_tests()
diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
index 72d2e5a..5c03c25 100644
--- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py
+++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py
@@ -887,11 +887,6 @@
@pytorch_test_common.xfail_dynamic_fx_test(
error_message="shape_env should be set if tracing with 'symbolic'"
)
- @pytorch_test_common.xfail(
- error_message="Type Error: Data in initializer 'h_0_attn_bias' has element type tensor(uint8) "
- "but usage of initializer in graph expects tensor(bool)",
- reason="https://github.com/huggingface/transformers/issues/21013",
- )
def test_fx_symbolic_tracer_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"
device = "cpu"
diff --git a/torch/onnx/_internal/fx/serialization.py b/torch/onnx/_internal/fx/serialization.py
index d472095..e4c295e 100644
--- a/torch/onnx/_internal/fx/serialization.py
+++ b/torch/onnx/_internal/fx/serialization.py
@@ -3,7 +3,7 @@
import io
import logging
import os
-from typing import Tuple, TYPE_CHECKING, Union
+from typing import Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch.onnx import _type_utils as jit_type_utils
@@ -17,7 +17,11 @@
@_beartype.beartype
def _create_tensor_proto_with_external_data(
- tensor: torch.Tensor, name: str, location: str, basepath: str
+ tensor: torch.Tensor,
+ name: str,
+ location: str,
+ basepath: str,
+ dtype_override: Optional["onnx.TypeProto"] = None, # type: ignore[name-defined]
) -> onnx.TensorProto: # type: ignore[name-defined]
"""Create a TensorProto with external data from a PyTorch tensor.
The external data is saved to os.path.join(basepath, location).
@@ -41,11 +45,24 @@
# FIXME: Avoid importing onnx into torch.onnx.
import onnx
+ scalar_type = (
+ jit_type_utils.JitScalarType.from_onnx_type(
+ dtype_override.tensor_type.elem_type
+ )
+ if dtype_override is not None
+ else jit_type_utils.JitScalarType.from_dtype(tensor.dtype)
+ )
+
+ # Checkpoints can be stored with a different dtype as the model expects because
+ # the user script can explicitly cast the original type to something or maybe
+ # PyTorch's type promotion might do it
+ if dtype_override is not None and scalar_type.dtype() != tensor.dtype:
+ tensor = tensor.to(scalar_type.dtype())
+
tensor_proto = onnx.TensorProto() # type: ignore[attr-defined]
tensor_proto.name = name
- tensor_proto.data_type = jit_type_utils.JitScalarType.from_dtype(
- tensor.dtype
- ).onnx_type()
+ tensor_proto.data_type = scalar_type.onnx_type()
+
tensor_proto.dims.extend(tensor.shape)
tensor_proto.data_location = onnx.TensorProto.EXTERNAL # type: ignore[attr-defined]
@@ -200,8 +217,16 @@
# Create one file per tensor.
# tensor_proto.raw_data is stored to external file at
# os.path.join(basepath, relative_tensor_file_path).
+ model_input_types = {
+ k.name: k.type for k in onnx_model_with_initializers.graph.input
+ }
+
tensor_proto = _create_tensor_proto_with_external_data(
- tensor, name, relative_tensor_file_path, basepath
+ tensor,
+ name,
+ relative_tensor_file_path,
+ basepath,
+ model_input_types.pop(name, None),
)
# Add the tensor_proto to the ONNX model as an initializer with external data.
onnx_model_with_initializers.graph.initializer.append(tensor_proto)
diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py
index 0470793..d132325 100644
--- a/torch/onnx/_type_utils.py
+++ b/torch/onnx/_type_utils.py
@@ -159,6 +159,26 @@
@classmethod
@_beartype.beartype
+ def from_onnx_type(
+ cls, onnx_type: Optional[Union[int, _C_onnx.TensorProtoDataType]]
+ ) -> JitScalarType:
+ """Convert a ONNX data type to JitScalarType.
+
+ Args:
+ onnx_type: A torch._C._onnx.TensorProtoDataType to create a JitScalarType from
+
+ Returns:
+ JitScalarType
+
+ Raises:
+ OnnxExporterError: if dtype is not a valid torch.dtype or if it is None.
+ """
+ if onnx_type not in _ONNX_TO_SCALAR_TYPE:
+ raise errors.OnnxExporterError(f"Unknown onnx_type: {onnx_type}")
+ return _ONNX_TO_SCALAR_TYPE[typing.cast(_C_onnx.TensorProtoDataType, onnx_type)]
+
+ @classmethod
+ @_beartype.beartype
def from_value(
cls, value: Union[None, torch._C.Value, torch.Tensor], default=None
) -> JitScalarType:
@@ -352,6 +372,8 @@
JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ,
}
+_ONNX_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_ONNX.items()}
+
# source of truth is
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_dtypes.cpp
_SCALAR_TYPE_TO_DTYPE = {