Allow remapping storages at load time and serialize data in little endian order
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 0da0f21..797eaff 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -406,6 +406,68 @@
self.assertEqual(x, y)
self.assertEqual(torch.cuda.initial_seed(), 2)
+ def test_serialization(self):
+ x = torch.randn(4, 4).cuda()
+ with tempfile.NamedTemporaryFile() as f:
+ torch.save(x, f)
+ f.seek(0)
+ x_copy = torch.load(f)
+ self.assertEqual(x_copy, x)
+ self.assertIs(type(x_copy), type(x))
+ self.assertEqual(x_copy.get_device(), x.get_device())
+
+ def test_serialization_empty(self):
+ x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
+ with tempfile.NamedTemporaryFile() as f:
+ torch.save(x, f)
+ f.seek(0)
+ x_copy = torch.load(f)
+ for original, copy in zip(x, x_copy):
+ self.assertEqual(copy, original)
+ self.assertIs(type(copy), type(original))
+ self.assertEqual(copy.get_device(), original.get_device())
+
+ @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+ def test_multigpu_serialization(self):
+ x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+ with tempfile.NamedTemporaryFile() as f:
+ torch.save(x, f)
+ f.seek(0)
+ x_copy = torch.load(f)
+ for original, copy in zip(x, x_copy):
+ self.assertEqual(copy, original)
+ self.assertIs(type(copy), type(original))
+ self.assertEqual(copy.get_device(), original.get_device())
+
+ @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+ def test_multigpu_serialization_remap(self):
+ x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+ def gpu_remap(storage, location):
+ if location == 'cuda:1':
+ return storage.cuda(0)
+
+ with tempfile.NamedTemporaryFile() as f:
+ torch.save(x, f)
+ f.seek(0)
+ x_copy = torch.load(f, map_location=gpu_remap)
+
+ for original, copy in zip(x, x_copy):
+ self.assertEqual(copy, original)
+ self.assertIs(type(copy), type(original))
+ self.assertEqual(copy.get_device(), 0)
+
+ @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+ def test_multigpu_serialization_remap_dict(self):
+ x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+ with tempfile.NamedTemporaryFile() as f:
+ torch.save(x, f)
+ f.seek(0)
+ x_copy = torch.load(f, map_location={'cuda:1': 'cuda:0'})
+ for original, copy in zip(x, x_copy):
+ self.assertEqual(copy, original)
+ self.assertIs(type(copy), type(original))
+ self.assertEqual(copy.get_device(), 0)
+
for decl in tests:
for t in types:
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
index e031f10..5fe10df 100644
--- a/torch/csrc/Storage.cpp
+++ b/torch/csrc/Storage.cpp
@@ -5,7 +5,6 @@
#include <TH/TH.h>
#include <libshm.h>
#include "THP.h"
-#include "byte_order.h"
#include "generic/Storage.cpp"
#include <TH/THGenerateAllTypes.h>
diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h
index a628a05..440523a 100644
--- a/torch/csrc/THP.h
+++ b/torch/csrc/THP.h
@@ -25,6 +25,7 @@
#include "Tensor.h"
#include "Module.h"
#include "utils.h" // This requires defined Storage and Tensor types
+#include "byte_order.h"
#ifdef _THP_CORE
#include "serialization.h"
diff --git a/torch/csrc/byte_order.cpp b/torch/csrc/byte_order.cpp
index 014522d..10edc81 100644
--- a/torch/csrc/byte_order.cpp
+++ b/torch/csrc/byte_order.cpp
@@ -1,5 +1,7 @@
#include "byte_order.h"
+#include <string.h>
+
static inline uint16_t decodeUInt16LE(const uint8_t *data) {
return (data[0]<<0) | (data[1]<<8);
}
@@ -79,3 +81,71 @@
src += sizeof(double);
}
}
+
+template<size_t size>
+static void swapBytes(uint8_t *ptr)
+{
+ uint8_t tmp;
+ for (size_t i = 0; i < size / 2; i++) {
+ tmp = ptr[i];
+ ptr[i] = ptr[size-i];
+ ptr[size-i] = tmp;
+ }
+}
+
+
+void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len)
+{
+ memcpy(dst, src, sizeof(int16_t) * len);
+ if (order != THP_nativeByteOrder()) {
+ for (size_t i = 0; i < len; i++) {
+ swapBytes<sizeof(int16_t)>(dst);
+ dst += sizeof(int16_t);
+ }
+ }
+}
+
+void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len)
+{
+ memcpy(dst, src, sizeof(int32_t) * len);
+ if (order != THP_nativeByteOrder()) {
+ for (size_t i = 0; i < len; i++) {
+ swapBytes<sizeof(int32_t)>(dst);
+ dst += sizeof(int32_t);
+ }
+ }
+}
+
+void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len)
+{
+ memcpy(dst, src, sizeof(int64_t) * len);
+ if (order != THP_nativeByteOrder()) {
+ for (size_t i = 0; i < len; i++) {
+ swapBytes<sizeof(int64_t)>(dst);
+ dst += sizeof(int64_t);
+ }
+ }
+}
+
+void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len)
+{
+ memcpy(dst, src, sizeof(float) * len);
+ if (order != THP_nativeByteOrder()) {
+ for (size_t i = 0; i < len; i++) {
+ swapBytes<sizeof(float)>(dst);
+ dst += sizeof(float);
+ }
+ }
+}
+
+void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len)
+{
+ memcpy(dst, src, sizeof(double) * len);
+ if (order != THP_nativeByteOrder()) {
+ for (size_t i = 0; i < len; i++) {
+ swapBytes<sizeof(double)>(dst);
+ dst += sizeof(double);
+ }
+ }
+}
+
diff --git a/torch/csrc/byte_order.h b/torch/csrc/byte_order.h
index bf5041e..0dcb79b 100644
--- a/torch/csrc/byte_order.h
+++ b/torch/csrc/byte_order.h
@@ -1,3 +1,6 @@
+#ifndef THP_BYTE_ORDER_H
+#define THP_BYTE_ORDER_H
+
#include <stdint.h>
#include <stddef.h>
@@ -7,8 +10,17 @@
};
THPByteOrder THP_nativeByteOrder();
+
void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
+
+void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
+void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);
+void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len);
+void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len);
+void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len);
+
+#endif
diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index cefb7de..1f0a8fe 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -38,7 +38,33 @@
THCudaCheck(cudaMemcpy(data, self->data, self->size * sizeof(real), cudaMemcpyDeviceToHost));
#endif
SYSCHECK(write(fd, &self->size, sizeof(long)));
- SYSCHECK(write(fd, data, sizeof(real) * self->size));
+ // fast track for bytes and little endian
+ if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
+ SYSCHECK(write(fd, data, sizeof(real) * self->size));
+ } else {
+ long buffer_size = std::min(self->size, (long)5000);
+ std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
+ for (long i = 0; i < self->size; i += buffer_size) {
+ size_t to_convert = std::min(self->size - i, buffer_size);
+ if (sizeof(real) == 2) {
+ THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
+ (const int16_t*)data + i,
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ } else if (sizeof(real) == 4) {
+ THP_encodeInt32Buffer((uint8_t*)le_buffer.get(),
+ (const int32_t*)data + i,
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ } else if (sizeof(real) == 8) {
+ THP_encodeInt64Buffer((uint8_t*)le_buffer.get(),
+ (const int64_t*)data + i,
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ }
+ SYSCHECK(write(fd, data, to_convert * sizeof(real)));
+ }
+ }
}
THStorage * THPStorage_(readFileRaw)(int fd)
@@ -55,7 +81,33 @@
data = (real*)cpu_data.get();
#endif
- SYSCHECK(read(fd, data, sizeof(real) * storage->size));
+ // fast track for bytes and little endian
+ if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
+ SYSCHECK(read(fd, data, sizeof(real) * storage->size));
+ } else {
+ long buffer_size = std::min(size, (long)5000);
+ std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
+ for (long i = 0; i < size; i += buffer_size) {
+ size_t to_convert = std::min(size - i, buffer_size);
+ SYSCHECK(read(fd, le_buffer.get(), sizeof(real) * to_convert));
+ if (sizeof(real) == 2) {
+ THP_decodeInt16Buffer((int16_t*)data + i,
+ le_buffer.get(),
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ } else if (sizeof(real) == 4) {
+ THP_decodeInt32Buffer((int32_t*)data + i,
+ le_buffer.get(),
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ } else if (sizeof(real) == 8) {
+ THP_decodeInt64Buffer((int64_t*)data + i,
+ le_buffer.get(),
+ THPByteOrder::THP_LITTLE_ENDIAN,
+ to_convert);
+ }
+ }
+ }
#ifdef THC_GENERIC_FILE
THCudaCheck(cudaMemcpy(storage->data, data, size * sizeof(real), cudaMemcpyHostToDevice));
diff --git a/torch/serialization.py b/torch/serialization.py
index a026a4a..6358ab7 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,7 @@
import pickle
import torch
+from ._utils import _import_dotted_name
DEFAULT_PROTOCOL = 2
@@ -36,8 +37,78 @@
shutil.rmtree(path)
+_package_registry = []
+
+
+def register_package(priority, tagger, deserializer):
+ queue_elem = (priority, tagger, deserializer)
+ _package_registry.append(queue_elem)
+ _package_registry.sort()
+
+
+def _cpu_tag(obj):
+ if type(obj).__module__ == 'torch':
+ return 'cpu'
+
+
+def _cuda_tag(obj):
+ if type(obj).__module__ == 'torch.cuda':
+ return 'cuda:' + str(obj.get_device())
+
+
+def _cpu_deserialize(obj, location):
+ if location == 'cpu':
+ return obj
+
+
+def _cuda_deserialize(obj, location):
+ if location.startswith('cuda'):
+ device_id = max(int(location[5:]), 0)
+ return obj.cuda(device_id)
+
+
+register_package(10, _cpu_tag, _cpu_deserialize)
+register_package(20, _cuda_tag, _cuda_deserialize)
+
+
+def location_tag(storage):
+ for _, tagger, _ in _package_registry:
+ location = tagger(storage)
+ if location:
+ return location
+ raise RuntimeError("don't know how to determine data location of " +
+ torch.typename(storage))
+
+
+def default_restore_location(storage, location):
+ for _, _, fn in _package_registry:
+ result = fn(storage, location)
+ if result is not None:
+ return result
+ raise RuntimeError("don't know how to restore data location of " +
+ torch.typename(storage) + " (tagged with " + location + ")")
+
+
+def normalize_storage_type(storage_type):
+ return getattr(torch, storage_type.__name__)
+
+
+def storage_to_tensor_type(storage):
+ storage_type = type(storage)
+ module = _import_dotted_name(storage_type.__module__)
+ return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+
+
# TODO: choose pickle protocol
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
+ """Saves an object to a disk file.
+
+ Args:
+ obj: saved object
+ f: a file-like object (has to implement fileno that returns a file descriptor)
+ pickle_module: module used for pickling metadata and objects
+ pickle_protocol: can be specified to override the default protocol
+ """
serialized_tensors = {}
serialized_storages = {}
@@ -60,7 +131,8 @@
else:
storage_id = None
- pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
+ pickle_module.dump((key, storage_id, type(tensor)), f,
+ protocol=pickle_protocol)
f.flush()
tensor._write_metadata(f)
@@ -80,7 +152,10 @@
pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
for key, storage in serialized_storages.items():
- pickle_module.dump((key, type(storage)), f, protocol=pickle_protocol)
+ location = location_tag(storage)
+ storage_type = normalize_storage_type(type(storage))
+ pickle_module.dump((key, location, storage_type), f,
+ protocol=pickle_protocol)
f.flush()
storage._write_file(f)
@@ -110,26 +185,58 @@
_add_to_tar(save_storages, tar, 'storages')
-def load(f, pickle_module=pickle):
+def load(f, map_location=None, pickle_module=pickle):
+ """Loads an object saved with torch.save from a disk file.
+
+ torch.load can dynamically remap storages to be loaded on a different device
+ using the map_location argument. If it's a callable, it will be called with
+ two arguments: storage and location tag. It's expected to either return a
+ storage that's been moved to a different location, or None (and the location
+ will be resolved using the default method). If this argument is a dict it's
+ expected to be a mapping from location tags used in a file, to location
+ tags of the current system.
+
+ By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
+ (e.g. 'cuda:2') for cuda tensors. User extensions can register their own
+ tagging and deserialization methods using register_package.
+
+ Args:
+ f: a file-like object (has to implement fileno that returns a file descriptor)
+ map_location: a function or a dict specifying how to remap storage locations
+ pickle_module: module used for unpickling metadata and objects (has to match
+ the pickle_module used to serialize file)
+ """
deserialized_objects = {}
+ if map_location is None:
+ restore_location = default_restore_location
+ elif isinstance(map_location, dict):
+ def restore_location(storage, location):
+ location = map_location.get(location, location)
+ return default_restore_location(storage, location)
+ else:
+ def restore_location(storage, location):
+ result = map_location(storage, location)
+ if not result:
+ result = default_restore_location(storage, location)
+ return result
+
def persistent_load(saved_id):
return deserialized_objects[int(saved_id)]
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
- def extract(f, init):
+ tar.extract('storages', path=tmpdir)
+ with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
num_storages = pickle_module.load(f)
for i in range(num_storages):
args = pickle_module.load(f)
- key, args = args[0], args[1:]
- obj = init(*args)
+ key, location, storage_type = args
+ obj = storage_type._new_with_file(f)
+ obj = restore_location(obj, location)
deserialized_objects[key] = obj
- tar.extract('storages', path=tmpdir)
- with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
- extract(f, lambda storage_type: storage_type._new_with_file(f))
storage_views = pickle_module.load(f)
for target_cdata, root_cdata, offset, size in storage_views:
root = deserialized_objects[root_cdata]
@@ -137,10 +244,17 @@
tar.extract('tensors', path=tmpdir)
with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
- def deserialize_tensor(tensor_type, storage_id):
+ num_tensors = pickle_module.load(f)
+ for i in range(num_tensors):
+ args = pickle_module.load(f)
+ key, storage_id, original_tensor_type = args
storage = deserialized_objects.get(storage_id, None)
- return tensor_type._new_with_metadata_file(f, storage)
- extract(f, deserialize_tensor)
+ if storage:
+ tensor_type = storage_to_tensor_type(storage)
+ tensor = tensor_type._new_with_metadata_file(f, storage)
+ else:
+ tensor = original_tensor_type._new_with_metadata_file(f, storage)
+ deserialized_objects[key] = tensor
pickle_file = tar.extractfile('pickle')
unpickler = pickle_module.Unpickler(pickle_file)