Rewrite serialization to correctly handle partial reads/writes in all cases (#12143)

Summary:
Previously, doRead/doWrite were functions that could return partial reads/writes,
and we checked for this case inconsistently in the call sites of serialization.cpp.
Now, these functions do NOT return the amount of bytes read/written, and instead
handle the necessary checking loop themselves.

Fixes #12042. Maybe.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12143

Differential Revision: D10097027

Pulled By: ezyang

fbshipit-source-id: fd222ab8a825bed352153648ad396acfe124a3e1
diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp
index 2299cce..1e4e7bf 100644
--- a/torch/csrc/generic/serialization.cpp
+++ b/torch/csrc/generic/serialization.cpp
@@ -2,8 +2,6 @@
 #define TH_GENERIC_FILE "generic/serialization.cpp"
 #else
 
-#define SYSCHECK(call) { ssize_t __result = call; if (__result < 0) throw std::system_error((int) __result, std::system_category()); }
-
 template <class io>
 void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
 {
@@ -16,23 +14,10 @@
   data = (scalar_t*)cpu_data.get();
   THCudaCheck(cudaMemcpy(data, THWStorage_(data)(LIBRARY_STATE self), size * sizeof(scalar_t), cudaMemcpyDeviceToHost));
 #endif
-  ssize_t result = doWrite(fd, &size, sizeof(int64_t));
-  if (result != sizeof(int64_t))
-    throw std::system_error(result, std::system_category());
+  doWrite(fd, &size, sizeof(int64_t));
   // fast track for bytes and little endian
   if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
-    char *bytes = (char *) data;
-    int64_t remaining = sizeof(scalar_t) * size;
-    while (remaining > 0) {
-      // we write and read in 1GB blocks to avoid bugs on some OSes
-      ssize_t result = doWrite(fd, bytes, THMin(remaining, 1073741824));
-      if (result < 0)
-        throw std::system_error(result, std::system_category());
-      bytes += result;
-      remaining -= result;
-    }
-    if (remaining != 0)
-      throw std::system_error(result, std::system_category());
+    doWrite(fd, data, sizeof(scalar_t) * size);
   } else {
     int64_t buffer_size = std::min(size, (int64_t)5000);
     std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
@@ -54,7 +39,7 @@
             THPByteOrder::THP_LITTLE_ENDIAN,
             to_convert);
       }
-      SYSCHECK(doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t)));
+      doWrite(fd, le_buffer.get(), to_convert * sizeof(scalar_t));
     }
   }
 }
@@ -67,11 +52,7 @@
 {
   scalar_t *data;
   int64_t size;
-  ssize_t result = doRead(file, &size, sizeof(int64_t));
-  if (result == 0)
-    throw std::runtime_error("unexpected EOF. The file might be corrupted.");
-  if (result != sizeof(int64_t))
-    throw std::system_error(result, std::system_category());
+  doRead(file, &size, sizeof(int64_t));
   THWStoragePtr storage;
   if (_storage == nullptr) {
     storage = THWStorage_(newWithSize)(LIBRARY_STATE size);
@@ -91,20 +72,7 @@
 
   // fast track for bytes and little endian
   if (sizeof(scalar_t) == 1 || THP_nativeByteOrder() == THPByteOrder::THP_LITTLE_ENDIAN) {
-    char *bytes = (char *) data;
-    int64_t remaining = sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage);
-    while (remaining > 0) {
-      // we write and read in 1GB blocks to avoid bugs on some OSes
-      ssize_t result = doRead(file, bytes, THMin(remaining, 1073741824));
-      if (result == 0) // 0 means EOF, which is also an error
-        throw std::runtime_error("unexpected EOF. The file might be corrupted.");
-      if (result < 0)
-        throw std::system_error(result, std::system_category());
-      bytes += result;
-      remaining -= result;
-    }
-    if (remaining != 0)
-      throw std::system_error(result, std::system_category());
+    doRead(file, data, sizeof(scalar_t) * THWStorage_(size)(LIBRARY_STATE storage));
   } else {
     int64_t buffer_size = std::min(size, (int64_t)5000);
     std::unique_ptr<uint8_t[]> le_buffer(new uint8_t[buffer_size * sizeof(scalar_t)]);
@@ -112,7 +80,7 @@
 
     for (int64_t i = 0; i < size; i += buffer_size) {
       size_t to_convert = std::min(size - i, buffer_size);
-      SYSCHECK(doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert));
+      doRead(file, le_buffer.get(), sizeof(scalar_t) * to_convert);
 
       if (sizeof(scalar_t) == 2) {
         THP_decodeInt16Buffer((int16_t*)data + i,
@@ -142,6 +110,4 @@
 template THWStorage* THPStorage_(readFileRaw<int>)(int fd, THWStorage* storage);
 template THWStorage* THPStorage_(readFileRaw<PyObject*>)(PyObject* fd, THWStorage* storage);
 
-#undef SYSCHECK
-
 #endif
diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp
index eaf93b9..de98d27 100644
--- a/torch/csrc/serialization.cpp
+++ b/torch/csrc/serialization.cpp
@@ -4,34 +4,41 @@
 #include "THP.h"
 #include "serialization.h"
 
-static ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
-static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
-static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
+template <class io>
+ssize_t doPartialRead(io fildes, void* buf, size_t nbytes);
+
+template <class io>
+ssize_t doPartialWrite(io fildes, void* buf, size_t nbytes);
+
+static ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes);
+static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes);
+static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes);
 
 template <>
-ssize_t doRead<int>(int fildes, void* buf, size_t nbytes) {
+ssize_t doPartialRead<int>(int fildes, void* buf, size_t nbytes) {
   return read(fildes, buf, nbytes);
 }
 
 template <>
-ssize_t doRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
+ssize_t doPartialRead<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
   // Try to use fildes.readinto() instead of fildes.read()
   // because it is more memory efficient.
+  // TODO: Stop calling PyObject_HasAttrString() in a loop on our read loop
   auto has_readinto = PyObject_HasAttrString(fildes, "readinto") == 1;
   if (has_readinto) {
-    return doPythonReadInto(fildes, buf, nbytes);
+    return doPartialPythonReadInto(fildes, buf, nbytes);
   }
-  return doPythonReadBuffered(fildes, buf, nbytes);
+  return doPartialPythonReadBuffered(fildes, buf, nbytes);
 }
 
 template <>
-ssize_t doWrite<int>(int fildes, void* buf, size_t nbytes) {
+ssize_t doPartialWrite<int>(int fildes, void* buf, size_t nbytes) {
   return write(fildes, buf, nbytes);
 }
 
 template <>
-ssize_t doWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
-  return doPythonWrite(fildes, buf, nbytes);
+ssize_t doPartialWrite<PyObject*>(PyObject* fildes, void* buf, size_t nbytes) {
+  return doPartialPythonWrite(fildes, buf, nbytes);
 }
 
 static inline bool isUnsupportedOperation() {
@@ -43,39 +50,39 @@
 }
 
 // Call Python fildes.read(nbytes) and copy it to buf.
-static inline ssize_t doPythonReadBuffered(PyObject* fildes, void* buf, size_t nbytes) {
-  const size_t buffer_size = 262144;  // 2^18
-  size_t read_bytes = 0;
+static inline ssize_t doPartialPythonReadBuffered(PyObject* fildes, void* buf, size_t raw_nbytes) {
+  // If we request a large amount of data, f.read() will internally try to
+  // allocate a buffer of that size.  This is counterproductive, because
+  // it's not the buffer we ultimately want to write the data into.  Read
+  // less than that and avoid allocating too much extra memory.
+  // TODO: Maybe 260 KB is a bit small...
+  const size_t nbytes = std::min<size_t>(raw_nbytes, 262144u); // 2^18 (~260 KB)
 
-  while (read_bytes < nbytes) {
-    auto remaining = nbytes - read_bytes;
-    auto to_read = remaining > buffer_size ? buffer_size : remaining;
-    THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", to_read));
-    if (!r) throw python_error();
+  THPObjectPtr r(PyObject_CallMethod(fildes, "read", "i", nbytes));
+  if (!r) throw python_error();
 
-    // read output is String (Python 2) / Bytes (Python 3)
+  // read output is String (Python 2) / Bytes (Python 3)
 #if PY_MAJOR_VERSION >= 3
-    auto size = PyBytes_GET_SIZE(r.get());
-    const void* bytes = PyBytes_AsString(r.get());
+  auto size = PyBytes_GET_SIZE(r.get());
+  const void* py_buf = PyBytes_AsString(r.get());
 #else
-    auto size = PyString_GET_SIZE(r.get());
-    const void* bytes = PyString_AsString(r.get());
+  auto size = PyString_GET_SIZE(r.get());
+  const void* py_buf = PyString_AsString(r.get());
 #endif
 
-    // we read EOF
-    if (size == 0) {
-      return read_bytes;
-    }
+  // we read EOF
+  if (size == 0) {
+    return 0;
+  }
 
-    memcpy(reinterpret_cast<char*>(buf) + read_bytes, bytes, size);
-    read_bytes += size;
-  } // Reading loop
+  // Slurp it into the buffer we actually want
+  memcpy(buf, py_buf, size);
 
-  return read_bytes;
+  return size;
 }
 
 // Either does fildes.readinto(buf) or fildes.write(buf)
-static inline ssize_t doPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
+static inline ssize_t doPartialPythonIO(PyObject* fildes, void* buf, size_t nbytes, bool is_read) {
 #if PY_MAJOR_VERSION >= 3
   auto rw_flag = is_read ? PyBUF_WRITE : PyBUF_READ;
   THPObjectPtr memview(PyMemoryView_FromMemory(
@@ -97,19 +104,77 @@
   // fildes.readinto can return UnsupportedOperation so fall back to fildes.read.
   if (is_read && isUnsupportedOperation()) {
     PyErr_Clear();
-    return doPythonReadBuffered(fildes, buf, nbytes);
+    return doPartialPythonReadBuffered(fildes, buf, nbytes);
   }
   throw python_error();
 }
 
 // Call Python fildes.readinto(buf)
-static ssize_t doPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
-  return doPythonIO(fildes, buf, nbytes, /* is_read */ true);
+static ssize_t doPartialPythonReadInto(PyObject* fildes, void* buf, size_t nbytes) {
+  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ true);
 }
 
 // Call Python fildes.write(buf)
-static ssize_t doPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
-  return doPythonIO(fildes, buf, nbytes, /* is_read */ false);
+static ssize_t doPartialPythonWrite(PyObject* fildes, void* buf, size_t nbytes) {
+  return doPartialPythonIO(fildes, buf, nbytes, /* is_read */ false);
+}
+
+// Requires that we read EXACTLY nbytes; fails if we don't.
+template <typename io>
+void doRead(io fildes, void* raw_buf, size_t nbytes) {
+  char* buf = static_cast<char*>(raw_buf);
+  while (nbytes > 0) {
+    errno = 0; // doPartialRead may not set errno
+    // we read in 1GB blocks to avoid bugs on Mac OS X Lion
+    // see https://github.com/pytorch/pytorch/issues/1031 for more details
+    ssize_t r = doPartialRead(fildes, buf, std::min<size_t>(nbytes, 1073741824));
+    if (r < 0) {
+      int err = errno;
+      AT_ASSERTM(err != 0, "read(): impossible! r < 0, but no errno was set");
+      AT_ASSERTM(err != EAGAIN, "read(): non-blocking fd ", fildes,
+                                " read EAGAIN; cowardly refusing to spin-wait");
+      if (err == EINTR) {
+        continue;
+      } else {
+        AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err));
+      }
+    } else if (r == 0) {
+      break;
+    }
+    buf += r;
+    // This is guaranteed by POSIX, but I just want to be double-sure
+    // to not underflow a signed integer.
+    AT_ASSERT(static_cast<size_t>(r) <= nbytes);
+    nbytes -= r;
+  }
+  if (nbytes != 0) {
+    AT_ERROR("unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted.");
+  }
+}
+
+template <typename io>
+void doWrite(io fildes, void* raw_buf, size_t nbytes) {
+  char* buf = static_cast<char*>(raw_buf);
+  while (nbytes > 0) {
+    errno = 0; // doPartialWrite may not set errno
+    // we write in 1GB blocks to avoid bugs on Mac OS X Lion
+    // see https://github.com/pytorch/pytorch/issues/1031 for more details
+    ssize_t r = doPartialWrite(fildes, buf, std::min<size_t>(nbytes, 1073741824));
+    if (r < 0) {
+      int err = errno;
+      AT_ASSERTM(err != 0, "write(): impossible! r < 0, but no errno was set");
+      AT_ASSERTM(err != EAGAIN, "write(): non-blocking fd ", fildes,
+                                " read EAGAIN; cowardly refusing to spin-wait");
+      if (err == EINTR) {
+        continue;
+      } else {
+        AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err));
+      }
+    }
+    buf += r;
+    AT_ASSERT(static_cast<size_t>(r) <= nbytes);
+    nbytes -= r;
+  }
 }
 
 #include "generic/serialization.cpp"
diff --git a/torch/csrc/serialization.h b/torch/csrc/serialization.h
index 410619a..df81105 100644
--- a/torch/csrc/serialization.h
+++ b/torch/csrc/serialization.h
@@ -8,9 +8,9 @@
 #include <TH/THGenerateHalfType.h>
 
 template <class io>
-ssize_t doRead(io fildes, void* buf, size_t nbytes);
+void doRead(io fildes, void* buf, size_t nbytes);
 
 template <class io>
-ssize_t doWrite(io fildes, void* buf, size_t nbytes);
+void doWrite(io fildes, void* buf, size_t nbytes);
 
 #endif