Extract RunInitProcess and SendPid/RecvPid

Also properly check status of send and use one-byte messages
to avoid issues with partial send, receive.

PiperOrigin-RevId: 258362495
Change-Id: I889b4699c100c80d15b129bf3a254f5442405bc2
diff --git a/sandboxed_api/sandbox2/BUILD.bazel b/sandboxed_api/sandbox2/BUILD.bazel
index 2f9f488..c478dfa 100644
--- a/sandboxed_api/sandbox2/BUILD.bazel
+++ b/sandboxed_api/sandbox2/BUILD.bazel
@@ -393,6 +393,8 @@
         "//sandboxed_api/sandbox2/util:fileops",
         "//sandboxed_api/sandbox2/util:strerror",
         "//sandboxed_api/util:raw_logging",
+        "//sandboxed_api/util:status",
+        "//sandboxed_api/util:statusor",
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/strings:str_format",
         "@com_google_absl//absl/synchronization",
diff --git a/sandboxed_api/sandbox2/forkserver.cc b/sandboxed_api/sandbox2/forkserver.cc
index f4ac4d6..eb4f356 100644
--- a/sandboxed_api/sandbox2/forkserver.cc
+++ b/sandboxed_api/sandbox2/forkserver.cc
@@ -56,6 +56,9 @@
 #include "sandboxed_api/sandbox2/util/fileops.h"
 #include "sandboxed_api/sandbox2/util/strerror.h"
 #include "sandboxed_api/util/raw_logging.h"
+#include "sandboxed_api/util/canonical_errors.h"
+#include "sandboxed_api/util/status.h"
+#include "sandboxed_api/util/statusor.h"
 
 namespace {
 // "Moves" the old FD to the new FD number.
@@ -74,6 +77,104 @@
 
   *old_fd = new_fd;
 }
+
+void RunInitProcess(std::set<int> open_fds) {
+  if (prctl(PR_SET_NAME, "S2-INIT-PROC", 0, 0, 0) != 0) {
+    SAPI_RAW_PLOG(WARNING, "prctl(PR_SET_NAME, 'S2-INIT-PROC')");
+  }
+  // Close all open fds (equals to CloseAllFDsExcept but does not require /proc
+  // to be available).
+  for (const auto& fd : open_fds) {
+    close(fd);
+  }
+
+  // Apply seccomp.
+  struct sock_filter code[] = {
+      LOAD_ARCH,
+      JNE32(sandbox2::Syscall::GetHostAuditArch(), DENY),
+
+      LOAD_SYSCALL_NR,
+#ifdef __NR_waitpid
+      SYSCALL(__NR_waitpid, ALLOW),
+#endif
+      SYSCALL(__NR_wait4, ALLOW),
+      SYSCALL(__NR_exit, ALLOW),
+      SYSCALL(__NR_exit_group, ALLOW),
+      DENY,
+  };
+
+  struct sock_fprog prog {};
+  prog.len = ABSL_ARRAYSIZE(code);
+  prog.filter = code;
+
+  SAPI_RAW_CHECK(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) == 0,
+                 "Denying new privs");
+  SAPI_RAW_CHECK(prctl(PR_SET_KEEPCAPS, 0) == 0, "Dropping caps");
+  SAPI_RAW_CHECK(
+      syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_TSYNC,
+              reinterpret_cast<uintptr_t>(&prog)) == 0,
+      "Enabling seccomp filter");
+
+  pid_t pid;
+  int status = 0;
+
+  // Reap children.
+  while (true) {
+    // Wait until we don't have any children anymore.
+    // We cannot watch for the child pid as ptrace steals our waitpid
+    // notifications. (See man ptrace / man waitpid).
+    pid = TEMP_FAILURE_RETRY(waitpid(-1, &status, __WALL));
+    if (pid < 0) {
+      if (errno == ECHILD) {
+        _exit(0);
+      }
+      _exit(1);
+    }
+  }
+}
+
+::sapi::Status SendPid(int signaling_fd) {
+  // Send our PID (the actual sandboxee process) via SCM_CREDENTIALS.
+  // The ancillary message will be attached to the message as SO_PASSCRED is set
+  // on the socket.
+  char dummy = ' ';
+  if (TEMP_FAILURE_RETRY(send(signaling_fd, &dummy, 1, 0)) != 1) {
+    return ::sapi::InternalError(
+        absl::StrCat("Sending PID: send: ", sandbox2::StrError(errno)));
+  }
+  return ::sapi::OkStatus();
+}
+
+::sapi::StatusOr<pid_t> ReceivePid(int signaling_fd) {
+  union {
+    struct cmsghdr cmh;
+    char ctrl[CMSG_SPACE(sizeof(struct ucred))];
+  } ucred_msg{};
+
+  struct msghdr msgh {};
+  struct iovec iov {};
+
+  msgh.msg_iov = &iov;
+  msgh.msg_iovlen = 1;
+  msgh.msg_control = ucred_msg.ctrl;
+  msgh.msg_controllen = sizeof(ucred_msg);
+
+  char dummy;
+  iov.iov_base = &dummy;
+  iov.iov_len = sizeof(char);
+
+  if (TEMP_FAILURE_RETRY(recvmsg(signaling_fd, &msgh, MSG_WAITALL)) != 1) {
+    return ::sapi::InternalError(absl::StrCat("Receiving pid failed: recvmsg: ",
+                                              sandbox2::StrError(errno)));
+  }
+  struct cmsghdr* cmsgp = CMSG_FIRSTHDR(&msgh);
+  if (cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
+      cmsgp->cmsg_level != SOL_SOCKET || cmsgp->cmsg_type != SCM_CREDENTIALS) {
+    return ::sapi::InternalError("Receiving pid failed");
+  }
+  struct ucred* ucredp = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsgp));
+  return ucredp->pid;
+}
 }  // namespace
 
 namespace sandbox2 {
@@ -153,84 +254,6 @@
                 absl::StrJoin(*args, "', '"), absl::StrJoin(*envp, "', '"));
 }
 
-static void RunInitProcess(int signaling_fd, std::set<int> open_fds) {
-  // Spawn a child process and wait until it is dead.
-  pid_t child = fork();
-  if (child < 0) {
-    SAPI_RAW_LOG(FATAL, "Could not spawn init process");
-  } else if (child == 0) {
-    // Send our PID (the actual sandboxee process) via SCM_CREDENTIALS.
-    struct msghdr msgh {};
-    struct iovec iov {};
-    msgh.msg_name = nullptr;
-    msgh.msg_namelen = 0;
-
-    msgh.msg_iov = &iov;
-    msgh.msg_iovlen = 1;
-    int data = 1;
-    iov.iov_base = &data;
-    iov.iov_len = sizeof(int);
-    msgh.msg_control = nullptr;
-    msgh.msg_controllen = 0;
-    SAPI_RAW_CHECK(sendmsg(signaling_fd, &msgh, 0), "Sending child PID");
-    return;
-  } else if (child > 0) {
-    if (prctl(PR_SET_NAME, "S2-INIT-PROC", 0, 0, 0) != 0) {
-      SAPI_RAW_PLOG(WARNING, "prctl(PR_SET_NAME, 'S2-INIT-PROC')");
-    }
-
-    // Close all open fds, do not use CloseAllFDsExcept as /proc might not be
-    // mounted here
-    for (const auto& fd : open_fds) {
-      close(fd);
-    }
-
-    // Apply seccomp.
-    struct sock_filter code[] = {
-        LOAD_ARCH,
-        JNE32(Syscall::GetHostAuditArch(), DENY),
-
-        LOAD_SYSCALL_NR,
-#ifdef __NR_waitpid
-        SYSCALL(__NR_waitpid, ALLOW),
-#endif
-        SYSCALL(__NR_wait4, ALLOW),
-        SYSCALL(__NR_exit, ALLOW),
-        SYSCALL(__NR_exit_group, ALLOW),
-        DENY,
-    };
-
-    struct sock_fprog prog {};
-    prog.len = ABSL_ARRAYSIZE(code);
-    prog.filter = code;
-
-    SAPI_RAW_CHECK(prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) == 0,
-                   "Denying new privs");
-    SAPI_RAW_CHECK(prctl(PR_SET_KEEPCAPS, 0) == 0, "Dropping caps");
-    SAPI_RAW_CHECK(syscall(__NR_seccomp, SECCOMP_SET_MODE_FILTER,
-                           SECCOMP_FILTER_FLAG_TSYNC,
-                           reinterpret_cast<uintptr_t>(&prog)) == 0,
-                   "Enabling seccomp filter");
-
-    pid_t pid;
-    int status = 0;
-
-    // Reap children.
-    while (true) {
-      // Wait until we don't have any children anymore.
-      // We cannot watch for the child pid as ptrace steals our waitpid
-      // notifications. (See man ptrace / man waitpid).
-      pid = TEMP_FAILURE_RETRY(waitpid(-1, &status, __WALL));
-      if (pid < 0) {
-        if (errno == ECHILD) {
-          _exit(0);
-        }
-        _exit(1);
-      }
-    }
-  }
-}
-
 void ForkServer::LaunchChild(const ForkRequest& request, int execve_fd,
                              int client_fd, uid_t uid, gid_t gid,
                              int user_ns_fd, int signaling_fd) {
@@ -276,8 +299,21 @@
 
   // A custom init process is only needed if a new PID NS is created.
   if (request.clone_flags() & CLONE_NEWPID) {
-    RunInitProcess(signaling_fd, open_fds);
+    // Spawn a child process
+    pid_t child = fork();
+    if (child < 0) {
+      SAPI_RAW_PLOG(FATAL, "Could not spawn init process");
+    }
+    if (child != 0) {
+      RunInitProcess(open_fds);
+    }
+    // Send sandboxee pid
+    auto status = SendPid(signaling_fd);
+    if (!status.ok()) {
+      SAPI_RAW_LOG(FATAL, "%s", status.message());
+    }
   }
+
   if (request.mode() == FORKSERVER_FORK_EXECVE_SANDBOX ||
       request.mode() == FORKSERVER_FORK_JOIN_SANDBOX_UNWIND) {
     // Sandboxing can be enabled either here - just before execve, or somewhere
@@ -412,42 +448,19 @@
   fd_closer1.Close();
 
   if (fork_request.clone_flags() & CLONE_NEWPID) {
-    union {
-      struct cmsghdr cmh;
-      char ctrl[CMSG_SPACE(sizeof(struct ucred))];
-    } test_msg{};
-
-    struct msghdr msgh {};
-    struct iovec iov {};
-
-    msgh.msg_iov = &iov;
-    msgh.msg_iovlen = 1;
-    msgh.msg_control = test_msg.ctrl;
-    msgh.msg_controllen = sizeof(test_msg);
-
-    int data = 0;
-    iov.iov_base = &data;
-    iov.iov_len = sizeof(int);
-
     // The pid of the init process is equal to the child process that we've
     // previously forked.
     init_pid = sandboxee_pid;
-
-    // And the actual sandboxee will be forked from the init process, so we
-    // need to receive the actual PID.
-    struct cmsghdr* cmsgp = nullptr;
-    if (TEMP_FAILURE_RETRY(recvmsg(fd_closer0.get(), &msgh, MSG_WAITALL)) <=
-            0 ||
-        !(cmsgp = CMSG_FIRSTHDR(&msgh)) || /* Assigning here on purpose */
-        cmsgp->cmsg_len != CMSG_LEN(sizeof(struct ucred)) ||
-        cmsgp->cmsg_level != SOL_SOCKET ||
-        cmsgp->cmsg_type != SCM_CREDENTIALS) {
-      SAPI_RAW_LOG(ERROR, "Receiving sandboxee pid failed");
-      sandboxee_pid = -1;
+    sandboxee_pid = -1;
+    // And the actual sandboxee is forked from the init process, so we need to
+    // receive the actual PID.
+    auto pid_or = ReceivePid(fd_closer0.get());
+    if (!pid_or.ok()) {
+      SAPI_RAW_LOG(ERROR, "%s", pid_or.status().message());
       kill(init_pid, SIGKILL);
+      init_pid = -1;
     } else {
-      struct ucred* ucredp = reinterpret_cast<struct ucred*>(CMSG_DATA(cmsgp));
-      sandboxee_pid = ucredp->pid;
+      sandboxee_pid = pid_or.ValueOrDie();
     }
   }