Fix saving and loading pickle files on Big Endian systems (#95881)
This change fixes test/test_cpp_api_parity.py tests on Big Endian systems.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95881
Approved by: https://github.com/malfet
diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp
index a46f434..fd5b8c5 100644
--- a/torch/csrc/jit/serialization/pickler.cpp
+++ b/torch/csrc/jit/serialization/pickler.cpp
@@ -7,6 +7,7 @@
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/serialization/pickler.h>
+#include <torch/csrc/utils/byte_order.h>
#include <string>
#include <type_traits>
@@ -232,17 +233,17 @@
n >= std::numeric_limits<uint16_t>::min() &&
n <= std::numeric_limits<uint16_t>::max()) {
push<PickleOpCode>(PickleOpCode::BININT2);
- push<uint16_t>(n);
+ push<uint16_t>(to_le16(n));
} else if (
n >= std::numeric_limits<int32_t>::min() &&
n <= std::numeric_limits<int32_t>::max()) {
push<PickleOpCode>(PickleOpCode::BININT);
- push<int32_t>(n);
+ push<int32_t>(to_le32(n));
} else {
// Push 8 byte integer
push<PickleOpCode>(PickleOpCode::LONG1);
push<uint8_t>(8);
- push<int64_t>(n);
+ push<int64_t>(to_le64(n));
}
}
@@ -264,7 +265,7 @@
// unmemoized encoding of a string
void Pickler::pushStringImpl(const std::string& string) {
push<PickleOpCode>(PickleOpCode::BINUNICODE);
- push<uint32_t>(string.size());
+ push<uint32_t>(to_le32(string.size()));
pushBytes(string);
}
@@ -542,8 +543,12 @@
void Pickler::pushDouble(double value) {
push<PickleOpCode>(PickleOpCode::BINFLOAT);
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
// Python pickle format is big endian, swap.
push<double>(swapDouble(value));
+#else /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
+ push<double>(value);
+#endif /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
}
void Pickler::pushComplexDouble(const IValue& value) {
c10::complex<double> d = value.toComplexDouble();
diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp
index d1f5379..edb8449 100644
--- a/torch/csrc/jit/serialization/unpickler.cpp
+++ b/torch/csrc/jit/serialization/unpickler.cpp
@@ -8,6 +8,7 @@
#include <torch/csrc/jit/serialization/pickler.h>
#include <torch/csrc/jit/serialization/storage_context.h>
#include <torch/csrc/jit/serialization/unpickler.h>
+#include <torch/csrc/utils/byte_order.h>
#include <string>
namespace torch::jit {
@@ -210,6 +211,7 @@
double Unpickler::readFloat() {
AT_ASSERT(sizeof(double) == 8);
double big_endian = read<double>();
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double little_endian;
@@ -221,6 +223,9 @@
reinterpret_cast<char*>(&little_endian));
return little_endian;
+#else /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
+ return big_endian;
+#endif /* __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ */
}
void Unpickler::run() {
@@ -323,21 +328,21 @@
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::BININT2: {
- uint16_t value = read<uint16_t>();
+ uint16_t value = from_le16(read<uint16_t>());
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::BININT: {
- int32_t value = read<int32_t>();
+ int32_t value = from_le32(read<int32_t>());
stack_.emplace_back(int64_t(value));
} break;
case PickleOpCode::LONG1: {
// Only read LONG1s with 8 as the length
uint8_t length = read<uint8_t>();
TORCH_CHECK(length == 8, "Expected length to be 8, got ", int(length));
- stack_.emplace_back(int64_t(read<int64_t>()));
+ stack_.emplace_back(int64_t(from_le64(read<int64_t>())));
} break;
case PickleOpCode::BINUNICODE: {
- uint32_t length = read<uint32_t>();
+ uint32_t length = from_le32(read<uint32_t>());
stack_.emplace_back(readBytes(length));
} break;
case PickleOpCode::BINFLOAT:
diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h
index 60aa8fc..8588bee 100644
--- a/torch/csrc/utils/byte_order.h
+++ b/torch/csrc/utils/byte_order.h
@@ -6,6 +6,56 @@
#include <cstddef>
#include <cstdint>
+#ifdef __FreeBSD__
+#include <sys/endian.h>
+#include <sys/types.h>
+#define thp_bswap16(x) bswap16(x)
+#define thp_bswap32(x) bswap32(x)
+#define thp_bswap64(x) bswap64(x)
+#elif defined(__APPLE__)
+#include <libkern/OSByteOrder.h>
+#define thp_bswap16(x) OSSwapInt16(x)
+#define thp_bswap32(x) OSSwapInt32(x)
+#define thp_bswap64(x) OSSwapInt64(x)
+#elif defined(__GNUC__) && !defined(__MINGW32__)
+#include <byteswap.h>
+#define thp_bswap16(x) bswap_16(x)
+#define thp_bswap32(x) bswap_32(x)
+#define thp_bswap64(x) bswap_64(x)
+#elif defined _WIN32 || defined _WIN64
+#define thp_bswap16(x) _byteswap_ushort(x)
+#define thp_bswap32(x) _byteswap_ulong(x)
+#define thp_bswap64(x) _byteswap_uint64(x)
+#endif
+
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
+#define to_be16(x) thp_bswap16(x)
+#define from_be16(x) thp_bswap16(x)
+#define to_be32(x) thp_bswap32(x)
+#define from_be32(x) thp_bswap32(x)
+#define to_be64(x) thp_bswap64(x)
+#define from_be64(x) thp_bswap64(x)
+#define to_le16(x) (x)
+#define from_le16(x) (x)
+#define to_le32(x) (x)
+#define from_le32(x) (x)
+#define to_le64(x) (x)
+#define from_le64(x) (x)
+#else
+#define to_be16(x) (x)
+#define from_be16(x) (x)
+#define to_be32(x) (x)
+#define from_be32(x) (x)
+#define to_be64(x) (x)
+#define from_be64(x) (x)
+#define to_le16(x) thp_bswap16(x)
+#define from_le16(x) thp_bswap16(x)
+#define to_le32(x) thp_bswap32(x)
+#define from_le32(x) thp_bswap32(x)
+#define to_le64(x) thp_bswap64(x)
+#define from_le64(x) thp_bswap64(x)
+#endif
+
namespace torch {
namespace utils {