Rewrite serialization to correctly handle partial reads/writes in all cases (#12143)
Summary:
Previously, doRead/doWrite were functions that could return partial reads/writes,
and we checked for this case inconsistently in the call sites of serialization.cpp.
Now, these functions do NOT return the amount of bytes read/written, and instead
handle the necessary checking loop themselves.
Fixes #12042. Maybe.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12143
Differential Revision: D10097027
Pulled By: ezyang
fbshipit-source-id: fd222ab8a825bed352153648ad396acfe124a3e1
diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index 2299cce..1e4e7bf 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -2,8 +2,6 @@
#define TH_GENERIC_FILE "generic/serialization.cpp"
#else
-#define SYSCHECK(call) { ssize_t __result = call; if (__result < 0) throw std::system_error((int) __result, std::system_category()); }
-
template <class io>
void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
{
@@ -16,23 +14,10 @@
data = (scalar_t*)cpu_data.get();
THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
#endif
- ssize_t result = doWrite(fd, &size, sizeof(int64_t));
- if (result != sizeof(int64_t))
- throw std::system_error(result, std::system_category());
+ doWrite(fd, &size, sizeof(int64_t));
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
- char *bytes = (char *) data;
- int64_t remaining = sizeof(scalar_t) * size;
- while (remaining > 0) {
- // we write and read in 1GB blocks to avoid bugs on some OSes
- ssize_t result = doWrite(fd, bytes, THMin(remaining, 1073741824));
- if (result < 0)
- throw std::system_error(result, std::system_category());
- bytes += result;
- remaining -= result;
- }
- if (remaining != 0)
- throw std::system_error(result, std::system_category());
+ doWrite(fd, data, sizeof(scalar_t) * size);
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
@@ -54,7 +39,7 @@
THPByteOrder::THP_LITTLE_ENDIAN,
to_convert);
}
- SYSCHECK(doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t)));
+ doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t));
}
}
}
@@ -67,11 +52,7 @@
{
scalar_t *data;
int64_t size;
- ssize_t result = doRead(file, &size, sizeof(int64_t));
- if (result == 0)
- throw std::runtime_error("unexpected EOF. The file might be corrupted.");
- if (result != sizeof(int64_t))
- throw std::system_error(result, std::system_category());
+ doRead(file, &size, sizeof(int64_t));
THWStoragePtr storage;
if (_storage == nullptr) {
storage = THWStorage_(newWithSize)(LIBRARY_STATE size);
@@ -91,20 +72,7 @@
// fast track for bytes and little endian
if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
- char *bytes = (char *) data;
- int64_t remaining = sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage);
- while (remaining > 0) {
- // we write and read in 1GB blocks to avoid bugs on some OSes
- ssize_t result = doRead(file, bytes, THMin(remaining, 1073741824));
- if (result == 0) // 0 means EOF, which is also an error
- throw std::runtime_error("unexpected EOF. The file might be corrupted.");
- if (result < 0)
- throw std::system_error(result, std::system_category());
- bytes += result;
- remaining -= result;
- }
- if (remaining != 0)
- throw std::system_error(result, std::system_category());
+ doRead(file, data, sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage));
} else {
int64_t buffer_size = std::min(size, (int64_t)5000);
std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
@@ -112,7 +80,7 @@
for (int64_t i = 0; i < size; i += buffer_size) {
size_t to_convert = std::min(size - i, buffer_size);
- SYSCHECK(doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert));
+ doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert);
if (sizeof(scalar_t) == 2) {
THP_decodeInt16Buffer((int16_t*)data + i,
@@ -142,6 +110,4 @@
template THWStorage* THPStorage_(readFileRaw<int>)(int fd, THWStorage* storage);
template THWStorage* THPStorage_(readFileRaw<PyObject*>)(PyObject* fd, THWStorage* storage);
-#undef SYSCHECK
-
#endif
diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp
index eaf93b9..de98d27 100644
--- a/torch/csrc/serialization.cpp
+++ b/torch/csrc/serialization.cpp
@@ -4,34 +4,41 @@
#include "THP.h"
#include "serialization.h"
-static ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
-static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
-static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
+template <class io>
+ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);
+
+template <class io>
+ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);
+
+static ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
+static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
+static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
template <>
-ssize_t doRead<int>(int fildes, void* buf, size_t nbytes) {
+ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
return read(fildes, buf, nbytes);
}
template <>
-ssize_t doRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
+ssize_t doPartialRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
// Try to use fildes.readinto() instead of fildes.read()
// because it is more memory efficient.
+ // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop
auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
if (has_readinto) {
- return doPythonReadInto(fildes, buf, nbytes);
+ return doPartialPythonReadInto(fildes, buf, nbytes);
}
- return doPythonReadBuffered(fildes, buf, nbytes);
+ return doPartialPythonReadBuffered(fildes, buf, nbytes);
}
template <>
-ssize_t doWrite<int>(int fildes, void* buf, size_t nbytes) {
+ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
return write(fildes, buf, nbytes);
}
template <>
-ssize_t doWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
- return doPythonWrite(fildes, buf, nbytes);
+ssize_t doPartialWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
+ return doPartialPythonWrite(fildes, buf, nbytes);
}
static inline bool isUnsupportedOperation() {
@@ -43,39 +50,39 @@
}
// Call Python fildes.read(nbytes) and copy it to buf.
-static inline ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes) {
- const size_t buffer_size = 262144; // 2^18
- size_t read_bytes = 0;
+static inline ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t raw_nbytes) {
+ // If we request a large amount of data, f.read() will internally try to
+ // allocate a buffer of that size. This is counterproductive, because
+ // it's not the buffer we ultimately want to write the data into. Read
+ // less than that and avoid allocating too much extra memory.
+ // TODO: Maybe 260 KB is a bit small...
+ const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)
- while (read_bytes < nbytes) {
- auto remaining = nbytes - read_bytes;
- auto to_read = remaining > buffer_size ? buffer_size : remaining;
- THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", to_read));
- if (!r) throw python_error();
+ THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
+ if (!r) throw python_error();
- // read output is String (Python 2) / Bytes (Python 3)
+ // read output is String (Python 2) / Bytes (Python 3)
#if PY_MAJOR_VERSION >= 3
- auto size = PyBytes_GET_SIZE(r.get());
- const void* bytes = PyBytes_AsString(r.get());
+ auto size = PyBytes_GET_SIZE(r.get());
+ const void* py_buf = PyBytes_AsString(r.get());
#else
- auto size = PyString_GET_SIZE(r.get());
- const void* bytes = PyString_AsString(r.get());
+ auto size = PyString_GET_SIZE(r.get());
+ const void* py_buf = PyString_AsString(r.get());
#endif
- // we read EOF
- if (size == 0) {
- return read_bytes;
- }
+ // we read EOF
+ if (size == 0) {
+ return 0;
+ }
- memcpy(reinterpret_cast<char*>(buf) + read_bytes, bytes, size);
- read_bytes += size;
- } // Reading loop
+ // Slurp it into the buffer we actually want
+ memcpy(buf, py_buf, size);
- return read_bytes;
+ return size;
}
// Either does fildes.readinto(buf) or fildes.write(buf)
-static inline ssize_t doPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
+static inline ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
#if PY_MAJOR_VERSION >= 3
auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
THPObjectPtr memview(PyMemoryView_FromMemory(
@@ -97,19 +104,77 @@
// fildes.readinto can return UnsupportedOperation so fall back to fildes.read.
if (is_read && isUnsupportedOperation()) {
PyErr_Clear();
- return doPythonReadBuffered(fildes, buf, nbytes);
+ return doPartialPythonReadBuffered(fildes, buf, nbytes);
}
throw python_error();
}
// Call Python fildes.readinto(buf)
-static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
- return doPythonIO(fildes, buf, nbytes, /* is_read */ true);
+static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
+ return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
}
// Call Python fildes.write(buf)
-static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
- return doPythonIO(fildes, buf, nbytes, /* is_read */ false);
+static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
+ return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
+}
+
+// Requires that we read EXACTLY nbytes; fails if we don't.
+template <typename io>
+void doRead(io fildes, void* raw_buf, size_t nbytes) {
+ char* buf = static_cast<char*>(raw_buf);
+ while (nbytes > 0) {
+ errno = 0; // doPartialRead may not set errno
+ // we read in 1GB blocks to avoid bugs on Mac OS X Lion
+ // see https://github.com/pytorch/pytorch/issues/1031 for more details
+ ssize_t r = doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
+ if (r < 0) {
+ int err = errno;
+ AT_ASSERTM(err != 0, "read(): impossible! r < 0, but no errno was set");
+ AT_ASSERTM(err != EAGAIN, "read(): non-blocking fd ", fildes,
+ " read EAGAIN; cowardly refusing to spin-wait");
+ if (err == EINTR) {
+ continue;
+ } else {
+ AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
+ }
+ } else if (r == 0) {
+ break;
+ }
+ buf += r;
+ // This is guaranteed by POSIX, but I just want to be double-sure
+ // to not underflow a signed integer.
+ AT_ASSERT(static_cast<size_t>(r) <= nbytes);
+ nbytes -= r;
+ }
+ if (nbytes != 0) {
+ AT_ERROR("unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted.");
+ }
+}
+
+template <typename io>
+void doWrite(io fildes, void* raw_buf, size_t nbytes) {
+ char* buf = static_cast<char*>(raw_buf);
+ while (nbytes > 0) {
+ errno = 0; // doPartialWrite may not set errno
+ // we write in 1GB blocks to avoid bugs on Mac OS X Lion
+ // see https://github.com/pytorch/pytorch/issues/1031 for more details
+ ssize_t r = doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
+ if (r < 0) {
+ int err = errno;
+ AT_ASSERTM(err != 0, "write(): impossible! r < 0, but no errno was set");
+ AT_ASSERTM(err != EAGAIN, "write(): non-blocking fd ", fildes,
+ " read EAGAIN; cowardly refusing to spin-wait");
+ if (err == EINTR) {
+ continue;
+ } else {
+ AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
+ }
+ }
+ buf += r;
+ AT_ASSERT(static_cast<size_t>(r) <= nbytes);
+ nbytes -= r;
+ }
}
#include "generic/serialization.cpp"
diff --git a/torch/csrc/serialization.h b/torch/csrc/serialization.h
index 410619a..df81105 100644
--- a/torch/csrc/serialization.h
+++ b/torch/csrc/serialization.h
@@ -8,9 +8,9 @@
#include <TH/THGenerateHalfType.h>
template <class io>
-ssize_t doRead(io fildes, void* buf, size_t nbytes);
+void doRead(io fildes, void* buf, size_t nbytes);
template <class io>
-ssize_t doWrite(io fildes, void* buf, size_t nbytes);
+void doWrite(io fildes, void* buf, size_t nbytes);
#endif