Allow remapping storages at load time and serialize data in little endian order
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 0da0f21..797eaff 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -406,6 +406,68 @@
             self.assertEqual(x, y)
             self.assertEqual(torch.cuda.initial_seed(), 2)
 
+    def test_serialization(self):
+        x = torch.randn(4, 4).cuda()
+        with tempfile.NamedTemporaryFile() as f:
+            torch.save(x, f)
+            f.seek(0)
+            x_copy = torch.load(f)
+        self.assertEqual(x_copy, x)
+        self.assertIs(type(x_copy), type(x))
+        self.assertEqual(x_copy.get_device(), x.get_device())
+
+    def test_serialization_empty(self):
+        x = [torch.randn(4, 4).cuda(), torch.cuda.FloatTensor()]
+        with tempfile.NamedTemporaryFile() as f:
+            torch.save(x, f)
+            f.seek(0)
+            x_copy = torch.load(f)
+        for original, copy in zip(x, x_copy):
+            self.assertEqual(copy, original)
+            self.assertIs(type(copy), type(original))
+            self.assertEqual(copy.get_device(), original.get_device())
+
+    @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+    def test_multigpu_serialization(self):
+        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+        with tempfile.NamedTemporaryFile() as f:
+            torch.save(x, f)
+            f.seek(0)
+            x_copy = torch.load(f)
+        for original, copy in zip(x, x_copy):
+            self.assertEqual(copy, original)
+            self.assertIs(type(copy), type(original))
+            self.assertEqual(copy.get_device(), original.get_device())
+
+    @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+    def test_multigpu_serialization_remap(self):
+        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+        def gpu_remap(storage, location):
+            if location == 'cuda:1':
+                return storage.cuda(0)
+
+        with tempfile.NamedTemporaryFile() as f:
+            torch.save(x, f)
+            f.seek(0)
+            x_copy = torch.load(f, map_location=gpu_remap)
+
+        for original, copy in zip(x, x_copy):
+            self.assertEqual(copy, original)
+            self.assertIs(type(copy), type(original))
+            self.assertEqual(copy.get_device(), 0)
+
+    @unittest.skipIf(torch.cuda.device_count() < 2, "detected only one GPU")
+    def test_multigpu_serialization_remap_dict(self):
+        x = [torch.randn(4, 4).cuda(0), torch.randn(4, 4).cuda(1)]
+        with tempfile.NamedTemporaryFile() as f:
+            torch.save(x, f)
+            f.seek(0)
+            x_copy = torch.load(f, map_location={'cuda:1': 'cuda:0'})
+        for original, copy in zip(x, x_copy):
+            self.assertEqual(copy, original)
+            self.assertIs(type(copy), type(original))
+            self.assertEqual(copy.get_device(), 0)
+
 
 for decl in tests:
     for t in types:
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
index e031f10..5fe10df 100644
--- a/torch/csrc/Storage.cpp
+++ b/torch/csrc/Storage.cpp
@@ -5,7 +5,6 @@
 #include <TH/TH.h>
 #include <libshm.h>
 #include "THP.h"
-#include "byte_order.h"
 
 #include "generic/Storage.cpp"
 #include <TH/THGenerateAllTypes.h>
diff --git a/torch/csrc/THP.h b/torch/csrc/THP.h
index a628a05..440523a 100644
--- a/torch/csrc/THP.h
+++ b/torch/csrc/THP.h
@@ -25,6 +25,7 @@
 #include "Tensor.h"
 #include "Module.h"
 #include "utils.h" // This requires defined Storage and Tensor types
+#include "byte_order.h"
 
 #ifdef _THP_CORE
 #include "serialization.h"
diff --git a/torch/csrc/byte_order.cpp b/torch/csrc/byte_order.cpp
index 014522d..10edc81 100644
--- a/torch/csrc/byte_order.cpp
+++ b/torch/csrc/byte_order.cpp
@@ -1,5 +1,7 @@
 #include "byte_order.h"
 
+#include <string.h>
+
 static inline uint16_t decodeUInt16LE(const uint8_t *data) {
   return (data[0]<<0) | (data[1]<<8);
 }
@@ -79,3 +81,71 @@
     src += sizeof(double);
   }
 }
+
+template<size_t size>
+static void swapBytes(uint8_t *ptr)
+{
+  uint8_t tmp;
+  for (size_t i = 0; i < size / 2; i++) {
+    tmp = ptr[i];
+    ptr[i] = ptr[size-i];
+    ptr[size-i] = tmp;
+  }
+}
+
+
+void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len)
+{
+  memcpy(dst, src, sizeof(int16_t) * len);
+  if (order != THP_nativeByteOrder()) {
+    for (size_t i = 0; i < len; i++) {
+      swapBytes<sizeof(int16_t)>(dst);
+      dst += sizeof(int16_t);
+    }
+  }
+}
+
+void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len)
+{
+  memcpy(dst, src, sizeof(int32_t) * len);
+  if (order != THP_nativeByteOrder()) {
+    for (size_t i = 0; i < len; i++) {
+      swapBytes<sizeof(int32_t)>(dst);
+      dst += sizeof(int32_t);
+    }
+  }
+}
+
+void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len)
+{
+  memcpy(dst, src, sizeof(int64_t) * len);
+  if (order != THP_nativeByteOrder()) {
+    for (size_t i = 0; i < len; i++) {
+      swapBytes<sizeof(int64_t)>(dst);
+      dst += sizeof(int64_t);
+    }
+  }
+}
+
+void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len)
+{
+  memcpy(dst, src, sizeof(float) * len);
+  if (order != THP_nativeByteOrder()) {
+    for (size_t i = 0; i < len; i++) {
+      swapBytes<sizeof(float)>(dst);
+      dst += sizeof(float);
+    }
+  }
+}
+
+void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len)
+{
+  memcpy(dst, src, sizeof(double) * len);
+  if (order != THP_nativeByteOrder()) {
+    for (size_t i = 0; i < len; i++) {
+      swapBytes<sizeof(double)>(dst);
+      dst += sizeof(double);
+    }
+  }
+}
+
diff --git a/torch/csrc/byte_order.h b/torch/csrc/byte_order.h
index bf5041e..0dcb79b 100644
--- a/torch/csrc/byte_order.h
+++ b/torch/csrc/byte_order.h
@@ -1,3 +1,6 @@
+#ifndef THP_BYTE_ORDER_H
+#define THP_BYTE_ORDER_H
+
 #include <stdint.h>
 #include <stddef.h>
 
@@ -7,8 +10,17 @@
 };
 
 THPByteOrder THP_nativeByteOrder();
+
 void THP_decodeInt16Buffer(int16_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeInt32Buffer(int32_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeInt64Buffer(int64_t* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeFloatBuffer(float* dst, const uint8_t* src, THPByteOrder order, size_t len);
 void THP_decodeDoubleBuffer(double* dst, const uint8_t* src, THPByteOrder order, size_t len);
+
+void THP_encodeInt16Buffer(uint8_t* dst, const int16_t* src, THPByteOrder order, size_t len);
+void THP_encodeInt32Buffer(uint8_t* dst, const int32_t* src, THPByteOrder order, size_t len);
+void THP_encodeInt64Buffer(uint8_t* dst, const int64_t* src, THPByteOrder order, size_t len);
+void THP_encodeFloatBuffer(uint8_t* dst, const float* src, THPByteOrder order, size_t len);
+void THP_encodeDoubleBuffer(uint8_t* dst, const double* src, THPByteOrder order, size_t len);
+
+#endif
diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index cefb7de..1f0a8fe 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -38,7 +38,33 @@
   THCudaCheck(cudaMemcpy(data, self->data, self->size * sizeof(real), cudaMemcpyDeviceToHost));
 #endif
   SYSCHECK(write(fd, &self->size, sizeof(long)));
-  SYSCHECK(write(fd, data, sizeof(real) * self->size));
+  // fast track for bytes and little endian
+  if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
+    SYSCHECK(write(fd, data, sizeof(real) * self->size));
+  } else {
+    long buffer_size = std::min(self->size, (long)5000);
+    std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
+    for (long i = 0; i < self->size; i += buffer_size) {
+      size_t to_convert = std::min(self->size - i, buffer_size);
+      if (sizeof(real) == 2) {
+        THP_encodeInt16Buffer((uint8_t*)le_buffer.get(),
+            (const int16_t*)data + i,
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      } else if (sizeof(real) == 4) {
+        THP_encodeInt32Buffer((uint8_t*)le_buffer.get(),
+            (const int32_t*)data + i,
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      } else if (sizeof(real) == 8) {
+        THP_encodeInt64Buffer((uint8_t*)le_buffer.get(),
+            (const int64_t*)data + i,
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      }
+      SYSCHECK(write(fd, data, to_convert * sizeof(real)));
+    }
+  }
 }
 
 THStorage * THPStorage_(readFileRaw)(int fd)
@@ -55,7 +81,33 @@
   data = (real*)cpu_data.get();
 #endif
 
-  SYSCHECK(read(fd, data, sizeof(real) * storage->size));
+  // fast track for bytes and little endian
+  if (sizeof(real) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
+    SYSCHECK(read(fd, data, sizeof(real) * storage->size));
+  } else {
+    long buffer_size = std::min(size, (long)5000);
+    std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(real)]);
+    for (long i = 0; i < size; i += buffer_size) {
+      size_t to_convert = std::min(size - i, buffer_size);
+      SYSCHECK(read(fd, le_buffer.get(), sizeof(real) * to_convert));
+      if (sizeof(real) == 2) {
+        THP_decodeInt16Buffer((int16_t*)data + i,
+            le_buffer.get(),
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      } else if (sizeof(real) == 4) {
+        THP_decodeInt32Buffer((int32_t*)data + i,
+            le_buffer.get(),
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      } else if (sizeof(real) == 8) {
+        THP_decodeInt64Buffer((int64_t*)data + i,
+            le_buffer.get(),
+            THPByteOrder::THP_LITTLE_ENDIAN,
+            to_convert);
+      }
+    }
+  }
 
 #ifdef THC_GENERIC_FILE
   THCudaCheck(cudaMemcpy(storage->data, data, size * sizeof(real), cudaMemcpyHostToDevice));
diff --git a/torch/serialization.py b/torch/serialization.py
index a026a4a..6358ab7 100644
--- a/torch/serialization.py
+++ b/torch/serialization.py
@@ -12,6 +12,7 @@
     import pickle
 
 import torch
+from ._utils import _import_dotted_name
 
 DEFAULT_PROTOCOL = 2
 
@@ -36,8 +37,78 @@
     shutil.rmtree(path)
 
 
+_package_registry = []
+
+
+def register_package(priority, tagger, deserializer):
+    queue_elem = (priority, tagger, deserializer)
+    _package_registry.append(queue_elem)
+    _package_registry.sort()
+
+
+def _cpu_tag(obj):
+    if type(obj).__module__ == 'torch':
+        return 'cpu'
+
+
+def _cuda_tag(obj):
+    if type(obj).__module__ == 'torch.cuda':
+        return 'cuda:' + str(obj.get_device())
+
+
+def _cpu_deserialize(obj, location):
+    if location == 'cpu':
+        return obj
+
+
+def _cuda_deserialize(obj, location):
+    if location.startswith('cuda'):
+        device_id = max(int(location[5:]), 0)
+        return obj.cuda(device_id)
+
+
+register_package(10, _cpu_tag, _cpu_deserialize)
+register_package(20, _cuda_tag, _cuda_deserialize)
+
+
+def location_tag(storage):
+    for _, tagger, _ in _package_registry:
+        location = tagger(storage)
+        if location:
+            return location
+    raise RuntimeError("don't know how to determine data location of " +
+            torch.typename(storage))
+
+
+def default_restore_location(storage, location):
+    for _, _, fn in _package_registry:
+        result = fn(storage, location)
+        if result is not None:
+            return result
+    raise RuntimeError("don't know how to restore data location of " +
+            torch.typename(storage) + " (tagged with " + location + ")")
+
+
+def normalize_storage_type(storage_type):
+    return getattr(torch, storage_type.__name__)
+
+
+def storage_to_tensor_type(storage):
+    storage_type = type(storage)
+    module = _import_dotted_name(storage_type.__module__)
+    return getattr(module, storage_type.__name__.replace('Storage', 'Tensor'))
+
+
 # TODO: choose pickle protocol
 def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
+    """Saves an object to a disk file.
+
+    Args:
+        obj: saved object
+        f: a file-like object (has to implement fileno that returns a file descriptor)
+        pickle_module: module used for pickling metadata and objects
+        pickle_protocol: can be specified to override the default protocol
+    """
     serialized_tensors = {}
     serialized_storages = {}
 
@@ -60,7 +131,8 @@
             else:
                 storage_id = None
 
-            pickle_module.dump((key, type(tensor), storage_id), f, protocol=pickle_protocol)
+            pickle_module.dump((key, storage_id, type(tensor)), f,
+                    protocol=pickle_protocol)
             f.flush()
             tensor._write_metadata(f)
 
@@ -80,7 +152,10 @@
 
         pickle_module.dump(len(serialized_storages), f, protocol=pickle_protocol)
         for key, storage in serialized_storages.items():
-            pickle_module.dump((key, type(storage)), f, protocol=pickle_protocol)
+            location = location_tag(storage)
+            storage_type = normalize_storage_type(type(storage))
+            pickle_module.dump((key, location, storage_type), f,
+                    protocol=pickle_protocol)
             f.flush()
             storage._write_file(f)
 
@@ -110,26 +185,58 @@
         _add_to_tar(save_storages, tar, 'storages')
 
 
-def load(f, pickle_module=pickle):
+def load(f, map_location=None, pickle_module=pickle):
+    """Loads an object saved with torch.save from a disk file.
+
+    torch.load can dynamically remap storages to be loaded on a different device
+    using the map_location argument. If it's a callable, it will be called with
+    two arguments: storage and location tag. It's expected to either return a
+    storage that's been moved to a different location, or None (and the location
+    will be resolved using the default method). If this argument is a dict it's
+    expected to be a mapping from location tags used in a file, to location
+    tags of the current system.
+
+    By default the location tags are 'cpu' for host tensors and 'cuda:device_id'
+    (e.g. 'cuda:2') for cuda tensors. User extensions can register their own
+    tagging and deserialization methods using register_package.
+
+    Args:
+        f: a file-like object (has to implement fileno that returns a file descriptor)
+        map_location: a function or a dict specifying how to remap storage locations
+        pickle_module: module used for unpickling metadata and objects (has to match
+            the pickle_module used to serialize file)
+    """
     deserialized_objects = {}
 
+    if map_location is None:
+        restore_location = default_restore_location
+    elif isinstance(map_location, dict):
+        def restore_location(storage, location):
+            location = map_location.get(location, location)
+            return default_restore_location(storage, location)
+    else:
+        def restore_location(storage, location):
+            result = map_location(storage, location)
+            if not result:
+                result = default_restore_location(storage, location)
+            return result
+
     def persistent_load(saved_id):
         return deserialized_objects[int(saved_id)]
 
     with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
          mkdtemp() as tmpdir:
 
-        def extract(f, init):
+        tar.extract('storages', path=tmpdir)
+        with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
             num_storages = pickle_module.load(f)
             for i in range(num_storages):
                 args = pickle_module.load(f)
-                key, args = args[0], args[1:]
-                obj = init(*args)
+                key, location, storage_type = args
+                obj = storage_type._new_with_file(f)
+                obj = restore_location(obj, location)
                 deserialized_objects[key] = obj
 
-        tar.extract('storages', path=tmpdir)
-        with open(os.path.join(tmpdir, 'storages'), 'rb', 0) as f:
-            extract(f, lambda storage_type: storage_type._new_with_file(f))
             storage_views = pickle_module.load(f)
             for target_cdata, root_cdata, offset, size in storage_views:
                 root = deserialized_objects[root_cdata]
@@ -137,10 +244,17 @@
 
         tar.extract('tensors', path=tmpdir)
         with open(os.path.join(tmpdir, 'tensors'), 'rb', 0) as f:
-            def deserialize_tensor(tensor_type, storage_id):
+            num_tensors = pickle_module.load(f)
+            for i in range(num_tensors):
+                args = pickle_module.load(f)
+                key, storage_id, original_tensor_type = args
                 storage = deserialized_objects.get(storage_id, None)
-                return tensor_type._new_with_metadata_file(f, storage)
-            extract(f, deserialize_tensor)
+                if storage:
+                    tensor_type = storage_to_tensor_type(storage)
+                    tensor = tensor_type._new_with_metadata_file(f, storage)
+                else:
+                    tensor = original_tensor_type._new_with_metadata_file(f, storage)
+                deserialized_objects[key] = tensor
 
         pickle_file = tar.extractfile('pickle')
         unpickler = pickle_module.Unpickler(pickle_file)