Support non-ASCII characters in model file paths (#99453)
Fixes #98918
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99453
Approved by: https://github.com/albanD, https://github.com/malfet
diff --git a/test/test_serialization.py b/test/test_serialization.py
index 5941e69..754d824 100644
--- a/test/test_serialization.py
+++ b/test/test_serialization.py
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
# Owner(s): ["module: serialization"]
import torch
@@ -21,8 +22,9 @@
from torch._utils import _rebuild_tensor
from torch.serialization import check_module_version_greater_or_equal
-from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, TEST_DILL, \
- run_tests, download_file, BytesIOContext, TemporaryFileName, parametrize, instantiate_parametrized_tests
+from torch.testing._internal.common_utils import IS_FILESYSTEM_UTF8_ENCODING, TemporaryDirectoryName, \
+ TestCase, IS_WINDOWS, TEST_DILL, run_tests, download_file, BytesIOContext, TemporaryFileName, \
+ parametrize, instantiate_parametrized_tests
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_dtype import all_types_and_complex_and
@@ -924,6 +926,11 @@
with TemporaryFileName() as fname:
test(fname)
+ if IS_FILESYSTEM_UTF8_ENCODING:
+ with TemporaryDirectoryName(suffix='非ASCIIパス') as dname:
+ with TemporaryFileName(dir=dname) as fname:
+ test(fname)
+
test(io.BytesIO())
def test_serialization_zipfile_actually_jit(self):
diff --git a/torch/serialization.py b/torch/serialization.py
index 98fb642..2d3150c 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -332,10 +332,23 @@
class _open_zipfile_writer_file(_opener):
def __init__(self, name) -> None:
- super().__init__(torch._C.PyTorchFileWriter(str(name)))
+ self.file_stream = None
+ self.name = str(name)
+ try:
+ self.name.encode('ascii')
+ except UnicodeEncodeError:
+ # PyTorchFileWriter only supports ascii filename.
+ # For filenames with non-ascii characters, we rely on Python
+ # for writing out the file.
+ self.file_stream = io.FileIO(self.name, mode='w')
+ super().__init__(torch._C.PyTorchFileWriter(self.file_stream))
+ else:
+ super().__init__(torch._C.PyTorchFileWriter(self.name))
def __exit__(self, *args) -> None:
self.file_like.write_end_of_file()
+ if self.file_stream is not None:
+ self.file_stream.close()
class _open_zipfile_writer_buffer(_opener):