Slimmer rewrite of socket_forward_proxy

Key differences:
 1. No QueueState
 2. No sequence number
 3. No generation number

Instead, the guest-side monitors all queues for new connections and the
host-side keeps track of which queues are allocated

design: https://docs.google.com/document/d/1z43c9LGeEEU6G-ojNtEeQP9-ezK890MU3GP6df3byYs

Change-Id: If0396de2ef8080ed78e7afc36ac0d661f99b6d3c
Merged-In: If0396de2ef8080ed78e7afc36ac0d661f99b6d3c
Bug: 80104636
Bug: 110707067
Test: run local while restarting the host-side process and guest-side
(cherry picked from commit b5fb338315f4d82807cda42242d4f1488ec613fe)
diff --git a/common/frontend/socket_forward_proxy/main.cpp b/common/frontend/socket_forward_proxy/main.cpp
index e361059..b2c16f6 100644
--- a/common/frontend/socket_forward_proxy/main.cpp
+++ b/common/frontend/socket_forward_proxy/main.cpp
@@ -14,9 +14,11 @@
  * limitations under the License.
  */
 
+#include <array>
 #include <cstdint>
 #include <cstdlib>
 #include <iostream>
+#include <limits>
 #include <memory>
 #include <mutex>
 #include <sstream>
@@ -114,14 +116,11 @@
 };
 
 void SocketToShm(SocketReceiver socket_receiver,
-                 SocketForwardRegionView::Sender shm_sender) {
-  auto packet = Packet::MakeData();
+                 SocketForwardRegionView::ShmSender shm_sender) {
   while (true) {
+    auto packet = Packet::MakeData();
     socket_receiver.Recv(&packet);
-    if (packet.empty()) {
-      break;
-    }
-    if (!shm_sender.Send(packet)) {
+    if (packet.empty() || !shm_sender.Send(packet)) {
       break;
     }
   }
@@ -129,11 +128,12 @@
 }
 
 void ShmToSocket(SocketSender socket_sender,
-                 SocketForwardRegionView::Receiver shm_receiver) {
-  Packet packet{};
+                 SocketForwardRegionView::ShmReceiver shm_receiver) {
+  auto packet = Packet{};
   while (true) {
     shm_receiver.Recv(&packet);
-    if (packet.IsEnd()) {
+    CHECK(packet.IsData());
+    if (packet.empty()) {
       break;
     }
     if (socket_sender.SendAll(packet) < 0) {
@@ -145,15 +145,12 @@
 
 // One thread for reading from shm and writing into a socket.
 // One thread for reading from a socket and writing into shm.
-void LaunchWorkers(std::pair<SocketForwardRegionView::Sender,
-                             SocketForwardRegionView::Receiver>
-                       conn,
-                   cvd::SharedFD socket) {
-  // TODO create the SocketSender/Receiver in their respective threads?
-  std::thread(
-      SocketToShm, SocketReceiver{socket}, std::move(conn.first)).detach();
-  std::thread(
-      ShmToSocket, SocketSender{socket}, std::move(conn.second)).detach();
+void HandleConnection(SocketForwardRegionView::ShmSenderReceiverPair shm_sender_and_receiver,
+                      cvd::SharedFD socket) {
+  auto socket_to_shm =
+      std::thread(SocketToShm, SocketReceiver{socket}, std::move(shm_sender_and_receiver.first));
+  ShmToSocket(SocketSender{socket}, std::move(shm_sender_and_receiver.second));
+  socket_to_shm.join();
 }
 
 #ifdef CUTTLEFISH_HOST
@@ -162,44 +159,142 @@
   int host_port;
 };
 
+enum class QueueState {
+  kFree,
+  kUsed,
+};
+
+struct SocketConnectionInfo {
+  std::mutex lock{};
+  std::condition_variable cv{};
+  cvd::SharedFD socket{};
+  int guest_port{};
+  QueueState state = QueueState::kFree;
+};
+
+static constexpr auto kNumHostThreads =
+    vsoc::layout::socket_forward::kNumQueues;
+
+using SocketConnectionInfoCollection =
+    std::array<SocketConnectionInfo, kNumHostThreads>;
+
 void LaunchConnectionMaintainer(int port) {
   std::thread(cvd::EstablishAndMaintainConnection, port).detach();
 }
 
+void MarkAsFree(SocketConnectionInfo* conn) {
+  std::lock_guard<std::mutex> guard{conn->lock};
+  conn->socket = cvd::SharedFD{};
+  conn->guest_port = 0;
+  conn->state = QueueState::kFree;
+}
 
-[[noreturn]] void host_impl(SocketForwardRegionView* shm,
-                            std::vector<PortPair> ports, std::size_t index) {
+std::pair<int, cvd::SharedFD> WaitForConnection(SocketConnectionInfo* conn) {
+  std::unique_lock<std::mutex> guard{conn->lock};
+  while (conn->state != QueueState::kUsed) {
+    conn->cv.wait(guard);
+  }
+  return {conn->guest_port, conn->socket};
+}
+
+[[noreturn]] void host_thread(SocketForwardRegionView::ShmConnectionView view,
+                              SocketConnectionInfo* conn) {
+  while (true) {
+    int guest_port{};
+    cvd::SharedFD socket{};
+    // TODO structured binding in C++17
+    std::tie(guest_port, socket) = WaitForConnection(conn);
+
+    LOG(INFO) << "Establishing connection to guest port " << guest_port
+              << " with connection_id: " << view.connection_id();
+    HandleConnection(view.EstablishConnection(guest_port), std::move(socket));
+    LOG(INFO) << "Connection to guest port " << guest_port
+              << " closed. Marking queue " << view.connection_id()
+              << " as free.";
+    MarkAsFree(conn);
+  }
+}
+
+bool TryAllocateConnection(SocketConnectionInfo* conn, int guest_port,
+                           cvd::SharedFD socket) {
+  bool success = false;
+  {
+    std::lock_guard<std::mutex> guard{conn->lock};
+    if (conn->state == QueueState::kFree) {
+      conn->socket = std::move(socket);
+      conn->guest_port = guest_port;
+      conn->state = QueueState::kUsed;
+      success = true;
+    }
+  }
+  if (success) {
+    conn->cv.notify_one();
+  }
+  return success;
+}
+
+void AllocateWorkers(cvd::SharedFD socket,
+                     SocketConnectionInfoCollection* socket_connection_info,
+                     int guest_port) {
+  while (true) {
+    for (auto& conn : *socket_connection_info) {
+      if (TryAllocateConnection(&conn, guest_port, socket)) {
+        return;
+      }
+    }
+    LOG(INFO) << "no queues available. sleeping and retrying";
+    sleep(5);
+  }
+}
+
+[[noreturn]] void host_impl(
+    SocketForwardRegionView* shm,
+    SocketConnectionInfoCollection* socket_connection_info,
+    std::vector<PortPair> ports, std::size_t index) {
   // launch a worker for the following port before handling the current port.
   // recursion (instead of a loop) removes the need fore any join() or having
   // the main thread do no work.
   if (index + 1 < ports.size()) {
-    std::thread(host_impl, shm, ports, index + 1).detach();
+    std::thread(host_impl, shm, socket_connection_info, ports, index + 1)
+        .detach();
   }
   auto guest_port = ports[index].guest_port;
   auto host_port = ports[index].host_port;
-  LOG(INFO) << "starting server on " << host_port
-            << " for guest port " << guest_port;
+  LOG(INFO) << "starting server on " << host_port << " for guest port "
+            << guest_port;
   auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM);
   CHECK(server->IsOpen()) << "Could not start server on port " << host_port;
+  // Note: If generically forwarding ports, the adb connection maintainer should
+  // be disabled
   LaunchConnectionMaintainer(host_port);
   while (true) {
     auto client_socket = cvd::SharedFD::Accept(*server);
     CHECK(client_socket->IsOpen()) << "error creating client socket";
     LOG(INFO) << "client socket accepted";
-    auto conn = shm->OpenConnection(guest_port);
-    LOG(INFO) << "shm connection opened";
-    LaunchWorkers(std::move(conn), std::move(client_socket));
+    AllocateWorkers(std::move(client_socket), socket_connection_info,
+                    guest_port);
   }
 }
 
 [[noreturn]] void host(SocketForwardRegionView* shm,
                        std::vector<PortPair> ports) {
   CHECK(!ports.empty());
-  host_impl(shm, ports, 0);
+
+  SocketConnectionInfoCollection socket_connection_info{};
+
+  auto conn_info_iter = std::begin(socket_connection_info);
+  for (auto& shm_connection_view : shm->AllConnections()) {
+    CHECK_NE(conn_info_iter, std::end(socket_connection_info));
+    std::thread(host_thread, std::move(shm_connection_view), &*conn_info_iter)
+        .detach();
+    ++conn_info_iter;
+  }
+  CHECK_EQ(conn_info_iter, std::end(socket_connection_info));
+  host_impl(shm, &socket_connection_info, ports, 0);
 }
 
 std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str,
-                                const std::string& host_ports_str) {
+                                     const std::string& host_ports_str) {
   std::vector<PortPair> ports{};
   auto guest_ports = cvd::StrSplit(guest_ports_str, ',');
   auto host_ports = cvd::StrSplit(host_ports_str, ',');
@@ -208,7 +303,6 @@
     ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])});
   }
   return ports;
-
 }
 
 #else
@@ -224,17 +318,28 @@
   }
 }
 
-[[noreturn]] void guest(SocketForwardRegionView* shm) {
-  LOG(INFO) << "Starting guest mainloop";
+[[noreturn]] void guest_thread(
+    SocketForwardRegionView::ShmConnectionView view) {
   while (true) {
-    auto conn = shm->AcceptConnection();
-    LOG(INFO) << "shm connection accepted";
-    auto sock = OpenSocketConnection(conn.first.port());
-    CHECK(sock->IsOpen());
-    LOG(INFO) << "socket opened to " << conn.first.port();
-    LaunchWorkers(std::move(conn), std::move(sock));
+    LOG(INFO) << "waiting for new connection";
+    auto shm_sender_and_receiver = view.WaitForNewConnection();
+    LOG(INFO) << "new connection for port " << view.port();
+    HandleConnection(std::move(shm_sender_and_receiver), OpenSocketConnection(view.port()));
+    LOG(INFO) << "connection closed on port " << view.port();
   }
 }
+
+[[noreturn]] void guest(SocketForwardRegionView* shm) {
+  LOG(INFO) << "Starting guest mainloop";
+  auto connection_views = shm->AllConnections();
+  for (auto&& shm_connection_view : connection_views) {
+    std::thread(guest_thread, std::move(shm_connection_view)).detach();
+  }
+  while (true) {
+    sleep(std::numeric_limits<unsigned int>::max());
+  }
+}
+
 #endif
 
 SocketForwardRegionView* GetShm() {
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index c8a153b..2dbd2e8 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -24,7 +24,6 @@
 
 using vsoc::layout::socket_forward::Queue;
 using vsoc::layout::socket_forward::QueuePair;
-namespace QueueState = vsoc::layout::socket_forward::QueueState;
 // store the read and write direction as variables to keep the ifdefs and macros
 // in later code to a minimum
 constexpr auto ReadDirection = &QueuePair::
@@ -41,29 +40,22 @@
                                     guest_to_host;
 #endif
 
-constexpr auto kOtherSideClosed = QueueState::
-#ifdef CUTTLEFISH_HOST
-    GUEST_CLOSED;
-#else
-    HOST_CLOSED;
-#endif
-
-constexpr auto kThisSideClosed = QueueState::
-#ifdef CUTTLEFISH_HOST
-    HOST_CLOSED;
-#else
-    GUEST_CLOSED;
-#endif
-
 using vsoc::socket_forward::SocketForwardRegionView;
 
+vsoc::socket_forward::Packet vsoc::socket_forward::Packet::MakeBegin(
+    std::uint16_t port) {
+  auto packet = MakePacket(Header::BEGIN);
+  std::memcpy(packet.payload(), &port, sizeof port);
+  packet.set_payload_length(sizeof port);
+  return packet;
+}
+
 void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
   CHECK(packet != nullptr);
   do {
     (data()->queues_[connection_id].*ReadDirection)
         .queue.Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
   } while (packet->IsBegin());
-  // TODO(haining) check packet generation number
   CHECK(!packet->empty()) << "zero-size data message received";
   CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
 }
@@ -72,228 +64,187 @@
   CHECK(!packet.empty());
   CHECK_LE(packet.payload_length(), kMaxPayloadSize);
 
-  // NOTE this is check-then-act but I think that it's okay. Worst case is that
-  // we send one-too-many packets.
-  auto& queue_pair = data()->queues_[connection_id];
-  {
-    auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-    if ((queue_pair.*WriteDirection).queue_state_ == kOtherSideClosed) {
-      LOG(INFO) << "connection closed, not sending\n";
-      return false;
-    }
-    CHECK((queue_pair.*WriteDirection).queue_state_ != QueueState::INACTIVE);
-  }
-  // TODO(haining) set packet generation number
   (data()->queues_[connection_id].*WriteDirection)
       .queue.Write(this, packet.raw_data(), packet.raw_data_length());
   return true;
 }
 
-void SocketForwardRegionView::IgnoreUntilBegin(int connection_id,
-                                               std::uint32_t generation) {
+int SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
   Packet packet{};
   do {
     (data()->queues_[connection_id].*ReadDirection)
         .queue.Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
-  } while (!packet.IsBegin() || packet.generation() < generation);
+  } while (!packet.IsBegin());
+  return packet.port();
 }
 
-bool SocketForwardRegionView::IsOtherSideRecvClosed(int connection_id) {
-  auto& queue_pair = data()->queues_[connection_id];
-  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-  auto& queue = queue_pair.*WriteDirection;
-  return queue.queue_state_ == kOtherSideClosed ||
-         queue.queue_state_ == QueueState::INACTIVE;
-}
-
-void SocketForwardRegionView::ResetQueueStates(QueuePair* queue_pair) {
-  using vsoc::layout::socket_forward::Queue;
-  auto guard = make_lock_guard(&queue_pair->queue_state_lock_);
-  Queue* queues[] = {&queue_pair->host_to_guest, &queue_pair->guest_to_host};
-  for (auto* queue : queues) {
-    auto& state = queue->queue_state_;
-    switch (state) {
-      case QueueState::HOST_CONNECTED:
-      case kOtherSideClosed:
-        LOG(DEBUG)
-            << "host_connected or other side is closed, marking inactive";
-        state = QueueState::INACTIVE;
-        break;
-
-      case QueueState::BOTH_CONNECTED:
-        LOG(DEBUG) << "both_connected, marking this side closed";
-        state = kThisSideClosed;
-        break;
-
-      case kThisSideClosed:
-        [[fallthrough]];
-      case QueueState::INACTIVE:
-        LOG(DEBUG) << "inactive or this side closed, not changing state";
-        break;
-    }
-  }
-}
+constexpr int kNumQueues =
+    static_cast<int>(vsoc::layout::socket_forward::kNumQueues);
 
 void SocketForwardRegionView::CleanUpPreviousConnections() {
   data()->Recover();
-  int connection_id = 0;
-  auto current_generation = generation();
-  auto begin_packet = Packet::MakeBegin();
-  begin_packet.set_generation(current_generation);
-  auto end_packet = Packet::MakeEnd();
-  end_packet.set_generation(current_generation);
-  for (auto&& queue_pair : data()->queues_) {
-    std::uint32_t state{};
-    {
-      auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-      state = (queue_pair.*WriteDirection).queue_state_;
-#ifndef CUTTLEFISH_HOST
-      if (state == QueueState::HOST_CONNECTED) {
-        state = (queue_pair.*WriteDirection).queue_state_ =
-            (queue_pair.*ReadDirection).queue_state_ =
-                QueueState::BOTH_CONNECTED;
-      }
-#endif
-    }
 
-    if (state == QueueState::BOTH_CONNECTED
-#ifdef CUTTLEFISH_HOST
-        || state == QueueState::HOST_CONNECTED
-#endif
-    ) {
-      LOG(INFO) << "found connected write queue state, sending begin and end";
-      Send(connection_id, begin_packet);
-      Send(connection_id, end_packet);
-    }
-    ResetQueueStates(&queue_pair);
-    ++connection_id;
-  }
-  ++data()->generation_num;
-}
-
-void SocketForwardRegionView::MarkQueueDisconnected(
-    int connection_id, Queue QueuePair::*direction) {
-  auto& queue_pair = data()->queues_[connection_id];
-  auto& queue = queue_pair.*direction;
-
-#ifdef CUTTLEFISH_HOST
-  // if the host has connected but the guest hasn't seen it yet, wait for the
-  // guest to connect so the protocol can follow the normal state transition.
-  while (queue.queue_state_ == QueueState::HOST_CONNECTED) {
-    LOG(WARNING) << "closing queue[" << connection_id
-                 << "] in HOST_CONNECTED state. waiting";
-    WaitForSignal(&queue.queue_state_, QueueState::HOST_CONNECTED);
-  }
-#endif
-
-  auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-
-  queue.queue_state_ = queue.queue_state_ == kOtherSideClosed
-                           ? QueueState::INACTIVE
-                           : kThisSideClosed;
-}
-
-void SocketForwardRegionView::MarkSendQueueDisconnected(int connection_id) {
-  MarkQueueDisconnected(connection_id, WriteDirection);
-}
-
-void SocketForwardRegionView::MarkRecvQueueDisconnected(int connection_id) {
-  MarkQueueDisconnected(connection_id, ReadDirection);
-}
-
-int SocketForwardRegionView::port(int connection_id) {
-  return data()->queues_[connection_id].port_;
-}
-
-std::uint32_t SocketForwardRegionView::generation() {
-  return data()->generation_num;
-}
-
-#ifdef CUTTLEFISH_HOST
-int SocketForwardRegionView::AcquireConnectionID(int port) {
-  while (true) {
-    int id = 0;
-    for (auto&& queue_pair : data()->queues_) {
-      LOG(DEBUG) << "locking and checking queue at index " << id;
-      auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-      if (queue_pair.host_to_guest.queue_state_ == QueueState::INACTIVE &&
-          queue_pair.guest_to_host.queue_state_ == QueueState::INACTIVE) {
-        queue_pair.port_ = port;
-        queue_pair.host_to_guest.queue_state_ = QueueState::HOST_CONNECTED;
-        queue_pair.guest_to_host.queue_state_ = QueueState::HOST_CONNECTED;
-        LOG(DEBUG) << "acquired queue " << id
-                   << ". current seq_num: " << data()->seq_num;
-        ++data()->seq_num;
-        SendSignal(layout::Sides::Peer, &data()->seq_num);
-        return id;
-      }
-      ++id;
-    }
-    LOG(ERROR) << "no remaining shm queues for connection, sleeping.";
-    sleep(10);
+  static constexpr auto kRestartPacket = Packet::MakeRestart();
+  for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
+    Send(connection_id, kRestartPacket);
   }
 }
 
-std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
-SocketForwardRegionView::OpenConnection(int port) {
-  int connection_id = AcquireConnectionID(port);
-  LOG(INFO) << "Acquired connection with id " << connection_id;
-  auto current_generation = generation();
-  return {Sender{this, connection_id, current_generation},
-          Receiver{this, connection_id, current_generation}};
-}
-#else
-int SocketForwardRegionView::GetWaitingConnectionID() {
-  while (data()->seq_num == last_seq_number_) {
-    WaitForSignal(&data()->seq_num, last_seq_number_);
+SocketForwardRegionView::ConnectionViewCollection
+SocketForwardRegionView::AllConnections() {
+  SocketForwardRegionView::ConnectionViewCollection all_queues;
+  for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
+    all_queues.emplace_back(this, connection_id);
   }
-  ++last_seq_number_;
-  int id = 0;
-  for (auto&& queue_pair : data()->queues_) {
-    LOG(DEBUG) << "locking and checking queue at index " << id;
-    auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
-    if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) {
-      CHECK(queue_pair.guest_to_host.queue_state_ ==
-            QueueState::HOST_CONNECTED);
-      LOG(DEBUG) << "found waiting connection at index " << id;
-      queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED;
-      queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED;
-      SendSignal(layout::Sides::Peer, &queue_pair.host_to_guest.queue_state_);
-      SendSignal(layout::Sides::Peer, &queue_pair.guest_to_host.queue_state_);
-      return id;
-    }
-    ++id;
-  }
-  return -1;
+  return all_queues;
 }
 
-std::pair<SocketForwardRegionView::Sender, SocketForwardRegionView::Receiver>
-SocketForwardRegionView::AcceptConnection() {
-  int connection_id = -1;
-  while (connection_id < 0) {
-    connection_id = GetWaitingConnectionID();
-  }
-  LOG(INFO) << "Accepted connection with id " << connection_id;
-
-  auto current_generation = generation();
-  return {Sender{this, connection_id, current_generation},
-          Receiver{this, connection_id, current_generation}};
-}
-#endif
-
 // --- Connection ---- //
-void SocketForwardRegionView::Receiver::Recv(Packet* packet) {
-  if (!got_begin_) {
-    view_->IgnoreUntilBegin(connection_id_, generation_);
-    got_begin_ = true;
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::Recv(Packet* packet) {
+  std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
+  while (received_packet_free_) {
+    receive_thread_data_cv_.wait(guard);
   }
-  return view_->Recv(connection_id_, packet);
+  CHECK(received_packet_.IsData());
+  *packet = received_packet_;
+  received_packet_free_ = true;
+  receive_thread_data_cv_.notify_one();
 }
 
-bool SocketForwardRegionView::Sender::closed() const {
-  return view_->IsOtherSideRecvClosed(connection_id_);
+bool SocketForwardRegionView::ShmConnectionView::Receiver::GotRecvClosed() const {
+      return received_packet_.IsRecvClosed() || (received_packet_.IsRestart()
+#ifdef CUTTLEFISH_HOST
+                                              && saw_data_
+#endif
+                                              );
 }
 
-bool SocketForwardRegionView::Sender::Send(const Packet& packet) {
-  return view_->Send(connection_id_, packet);
+bool SocketForwardRegionView::ShmConnectionView::Receiver::ShouldReceiveAnotherPacket() const {
+        return (received_packet_.IsRecvClosed() && !saw_end_) ||
+             (saw_end_ && received_packet_.IsEnd())
+#ifdef CUTTLEFISH_HOST
+             || (received_packet_.IsRestart() && !saw_data_) ||
+             (received_packet_.IsBegin())
+#endif
+             ;
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::ReceivePacket() {
+  view_->region_view()->Recv(view_->connection_id(), &received_packet_);
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForRecvClosed() {
+      if (GotRecvClosed()) {
+        saw_recv_closed_ = true;
+        view_->MarkOtherSideRecvClosed();
+      }
+#ifdef CUTTLEFISH_HOST
+      if (received_packet_.IsData()) {
+        saw_data_ = true;
+      }
+#endif
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::CheckPacketForEnd() {
+  if (received_packet_.IsEnd() || received_packet_.IsRestart()) {
+    CHECK(!saw_end_ || received_packet_.IsRestart());
+    saw_end_ = true;
+  }
+}
+
+
+bool SocketForwardRegionView::ShmConnectionView::Receiver::ExpectMorePackets() const {
+  return !saw_recv_closed_ || !saw_end_;
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::UpdatePacketAndSignalAvailable() {
+  if (!received_packet_.IsData()) {
+    static constexpr auto kEmptyPacket = Packet::MakeData();
+    received_packet_ = kEmptyPacket;
+  }
+  received_packet_free_ = false;
+  receive_thread_data_cv_.notify_one();
+}
+
+void SocketForwardRegionView::ShmConnectionView::Receiver::Start() {
+  while (ExpectMorePackets()) {
+    std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
+    while (!received_packet_free_) {
+      receive_thread_data_cv_.wait(guard);
+    }
+
+    do {
+      ReceivePacket();
+      CheckPacketForRecvClosed();
+    } while (ShouldReceiveAnotherPacket());
+
+    if (received_packet_.empty()) {
+      LOG(ERROR) << "Received empty packet.";
+    }
+
+    CheckPacketForEnd();
+
+    UpdatePacketAndSignalAvailable();
+  }
+}
+
+auto SocketForwardRegionView::ShmConnectionView::ResetAndConnect()
+    -> ShmSenderReceiverPair {
+  if (receiver_) {
+    receiver_->Join();
+  }
+
+  {
+    std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+    other_side_receive_closed_ = false;
+  }
+
+#ifdef CUTTLEFISH_HOST
+  region_view()->IgnoreUntilBegin(connection_id());
+  region_view()->Send(connection_id(), Packet::MakeBegin(port_));
+#else
+  region_view()->Send(connection_id(), Packet::MakeBegin(port_));
+  port_ =
+      region_view()->IgnoreUntilBegin(connection_id());
+#endif
+
+  receiver_.reset(new Receiver{this});
+  return {ShmSender{this}, ShmReceiver{this}};
+}
+
+#ifdef CUTTLEFISH_HOST
+auto SocketForwardRegionView::ShmConnectionView::EstablishConnection(int port)
+    -> ShmSenderReceiverPair {
+  port_ = port;
+  return ResetAndConnect();
+}
+#else
+auto SocketForwardRegionView::ShmConnectionView::WaitForNewConnection()
+    -> ShmSenderReceiverPair {
+  port_ = 0;
+  return ResetAndConnect();
+}
+#endif
+
+bool SocketForwardRegionView::ShmConnectionView::Send(const Packet& packet) {
+  if (packet.empty()) {
+    LOG(ERROR) << "Sending empty packet";
+  }
+  if (packet.IsData() && IsOtherSideRecvClosed()) {
+    return false;
+  }
+  return region_view()->Send(connection_id(), packet);
+}
+
+void SocketForwardRegionView::ShmConnectionView::Recv(Packet* packet) {
+  receiver_->Recv(packet);
+}
+
+void SocketForwardRegionView::ShmReceiver::Recv(Packet* packet) {
+  view_->Recv(packet);
+}
+
+bool SocketForwardRegionView::ShmSender::Send(const Packet& packet) {
+  return view_->Send(packet);
 }
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index ce6958a..c41517b 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -15,6 +15,7 @@
  */
 #pragma once
 
+#include <cstdlib>
 #include <utility>
 #include <vector>
 #include <memory>
@@ -27,11 +28,12 @@
 
 struct Header {
   std::uint32_t payload_length;
-  std::uint32_t generation;
   enum MessageType : std::uint32_t {
     DATA = 0,
     BEGIN,
     END,
+    RECV_CLOSED,  // indicate that this side's receive end is closed
+    RESTART,
   };
   MessageType message_type;
 };
@@ -45,51 +47,73 @@
   using Payload = char[kMaxPayloadSize];
   Payload payload_data_;
 
-  static Packet MakePacket(Header::MessageType type) {
+  static constexpr Packet MakePacket(Header::MessageType type) {
     Packet packet{};
     packet.header_.message_type = type;
     return packet;
   }
 
  public:
-  static Packet MakeBegin() { return MakePacket(Header::BEGIN); }
+  // port is only revelant on the host-side.
+  static Packet MakeBegin(std::uint16_t port);
 
-  static Packet MakeEnd() { return MakePacket(Header::END); }
+  static constexpr Packet MakeEnd() { return MakePacket(Header::END); }
+
+  static constexpr Packet MakeRecvClosed() {
+    return MakePacket(Header::RECV_CLOSED);
+  }
+
+  static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); }
 
   // NOTE payload and payload_length must still be set.
-  static Packet MakeData() { return MakePacket(Header::DATA); }
+  static constexpr Packet MakeData() { return MakePacket(Header::DATA); }
 
   bool empty() const { return IsData() && header_.payload_length == 0; }
 
   void set_payload_length(std::uint32_t length) {
     CHECK_LE(length, sizeof payload_data_);
-    header_.message_type = Header::DATA;
     header_.payload_length = length;
   }
 
-  std::uint32_t generation() const { return header_.generation; }
-
-  void set_generation(std::uint32_t generation) {
-    header_.generation = generation;
-  }
-
   Payload& payload() { return payload_data_; }
 
   const Payload& payload() const { return payload_data_; }
 
-  std::uint32_t payload_length() const { return header_.payload_length; }
+  constexpr std::uint32_t payload_length() const {
+    return header_.payload_length;
+  }
 
-  bool IsBegin() const { return header_.message_type == Header::BEGIN; }
+  constexpr bool IsBegin() const {
+    return header_.message_type == Header::BEGIN;
+  }
 
-  bool IsEnd() const { return header_.message_type == Header::END; }
+  constexpr bool IsEnd() const { return header_.message_type == Header::END; }
 
-  bool IsData() const { return header_.message_type == Header::DATA; }
+  constexpr bool IsData() const { return header_.message_type == Header::DATA; }
+
+  constexpr bool IsRecvClosed() const {
+    return header_.message_type == Header::RECV_CLOSED;
+  }
+
+  constexpr bool IsRestart() const {
+    return header_.message_type == Header::RESTART;
+  }
+
+  constexpr std::uint16_t port() const {
+    CHECK(IsBegin());
+    std::uint16_t port_number{};
+    CHECK_EQ(payload_length(), sizeof port_number);
+    std::memcpy(&port_number, payload(), sizeof port_number);
+    return port_number;
+  }
 
   char* raw_data() { return reinterpret_cast<char*>(this); }
 
   const char* raw_data() const { return reinterpret_cast<const char*>(this); }
 
-  size_t raw_data_length() const { return payload_length() + sizeof header_; }
+  constexpr size_t raw_data_length() const {
+    return payload_length() + sizeof header_;
+  }
 };
 
 static_assert(sizeof(Packet) == layout::socket_forward::kMaxPacketSize, "");
@@ -101,128 +125,186 @@
     : public TypedRegionView<SocketForwardRegionView,
                              layout::socket_forward::SocketForwardLayout> {
  private:
-#ifdef CUTTLEFISH_HOST
-  int AcquireConnectionID(int port);
-#else
-  int GetWaitingConnectionID();
-#endif
-
   // Returns an empty data packet if the other side is closed.
   void Recv(int connection_id, Packet* packet);
   // Returns true on success
   bool Send(int connection_id, const Packet& packet);
 
-  // skip everything in the connection queue until seeing a BEGIN for the
-  // current generation
-  void IgnoreUntilBegin(int connection_id, std::uint32_t generation);
-
-  bool IsOtherSideRecvClosed(int connection_id);
-
-  void ResetQueueStates(layout::socket_forward::QueuePair* queue_pair);
-
-  void MarkQueueDisconnected(int connection_id,
-                             layout::socket_forward::Queue
-                                 layout::socket_forward::QueuePair::*direction);
+  // skip everything in the connection queue until seeing a BEGIN packet.
+  // returns port from begin packet.
+  int IgnoreUntilBegin(int connection_id);
 
  public:
-  // Helper class that will send a ConnectionBegin marker when constructed and a
-  // ConnectionEnd marker when destroyed.
-  class Sender {
+  class ShmSender;
+  class ShmReceiver;
+
+  using ShmSenderReceiverPair = std::pair<ShmSender, ShmReceiver>;
+
+  class ShmConnectionView {
    public:
-    explicit Sender(SocketForwardRegionView* view, int connection_id,
-                    std::uint32_t generation)
-        : view_{view, {connection_id, generation}},
-          connection_id_{connection_id} {
-      auto packet = Packet::MakeBegin();
-      packet.set_generation(generation);
-      view_->Send(connection_id, packet);
+    ShmConnectionView(SocketForwardRegionView* region_view, int connection_id)
+        : region_view_{region_view}, connection_id_{connection_id} {}
+
+#ifdef CUTTLEFISH_HOST
+    ShmSenderReceiverPair EstablishConnection(int port);
+#else
+    // Should not be called while there is an active ShmSender or ShmReceiver
+    // for this connection.
+    ShmSenderReceiverPair WaitForNewConnection();
+#endif
+
+    int port() const { return port_; }
+
+    bool Send(const Packet& packet);
+    void Recv(Packet* packet);
+
+    ShmConnectionView(const ShmConnectionView&) = delete;
+    ShmConnectionView& operator=(const ShmConnectionView&) = delete;
+
+    // Moving invalidates all existing ShmSenders and ShmReceiver
+    ShmConnectionView(ShmConnectionView&&) = default;
+    ShmConnectionView& operator=(ShmConnectionView&&) = default;
+    ~ShmConnectionView() = default;
+
+    // NOTE should only be used for debugging/logging purposes.
+    // connection_ids are an implementation detail that are currently useful for
+    // debugging, but may go away in the future.
+    int connection_id() const { return connection_id_; }
+
+   private:
+    SocketForwardRegionView* region_view() const { return region_view_; }
+
+    bool IsOtherSideRecvClosed() {
+      std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+      return other_side_receive_closed_;
     }
 
-    Sender(const Sender&) = delete;
-    Sender& operator=(const Sender&) = delete;
+    void MarkOtherSideRecvClosed() {
+      std::lock_guard<std::mutex> guard(*other_side_receive_closed_lock_);
+      other_side_receive_closed_ = true;
+    }
 
-    Sender(Sender&&) = default;
-    Sender& operator=(Sender&&) = default;
-    ~Sender() = default;
+    void ReceiverThread();
+    ShmSenderReceiverPair ResetAndConnect();
+
+    class Receiver {
+     public:
+      Receiver(ShmConnectionView* view)
+          : view_{view}
+      {
+        receiver_thread_ = std::thread([this] { Start(); });
+      }
+
+      void Recv(Packet* packet);
+
+      void Join() { receiver_thread_.join(); }
+
+      Receiver(const Receiver&) = delete;
+      Receiver& operator=(const Receiver&) = delete;
+
+      ~Receiver() = default;
+     private:
+      void Start();
+      bool GotRecvClosed() const;
+      void ReceivePacket();
+      void CheckPacketForRecvClosed();
+      void CheckPacketForEnd();
+      void UpdatePacketAndSignalAvailable();
+      bool ShouldReceiveAnotherPacket() const;
+      bool ExpectMorePackets() const;
+
+      std::mutex receive_thread_data_lock_;
+      std::condition_variable receive_thread_data_cv_;
+      bool received_packet_free_ = true;
+      Packet received_packet_{};
+
+      ShmConnectionView* view_{};
+      bool saw_recv_closed_ = false;
+      bool saw_end_ = false;
+#ifdef CUTTLEFISH_HOST
+      bool saw_data_ = false;
+#endif
+
+      std::thread receiver_thread_;
+    };
+
+    SocketForwardRegionView* region_view_{};
+    int connection_id_ = -1;
+    int port_ = -1;
+
+    std::unique_ptr<std::mutex> other_side_receive_closed_lock_ =
+        std::unique_ptr<std::mutex>{new std::mutex{}};
+    bool other_side_receive_closed_ = false;
+
+    std::unique_ptr<Receiver> receiver_;
+  };
+
+  class ShmSender {
+   public:
+    explicit ShmSender(ShmConnectionView* view) : view_{view} {}
+
+    ShmSender(const ShmSender&) = delete;
+    ShmSender& operator=(const ShmSender&) = delete;
+
+    ShmSender(ShmSender&&) = default;
+    ShmSender& operator=(ShmSender&&) = default;
+    ~ShmSender() = default;
 
     // Returns true on success
     bool Send(const Packet& packet);
-    int port() const { return view_->port(connection_id_); }
 
    private:
-    bool closed() const;
-
     struct EndSender {
-      int connection_id = -1;
-      std::uint32_t generation{};
-      void operator()(SocketForwardRegionView* view) const {
+      void operator()(ShmConnectionView* view) const {
         if (view) {
-          CHECK(connection_id >= 0);
-          auto packet = Packet::MakeEnd();
-          packet.set_generation(generation);
-          view->Send(connection_id, packet);
-          view->MarkSendQueueDisconnected(connection_id);
+          view->Send(Packet::MakeEnd());
         }
       }
     };
+
     // Doesn't actually own the View, responsible for sending the End
     // indicator and marking the sending side as disconnected.
-    std::unique_ptr<SocketForwardRegionView, EndSender> view_;
-    int connection_id_{};
+    std::unique_ptr<ShmConnectionView, EndSender> view_;
   };
 
-  class Receiver {
+  class ShmReceiver {
    public:
-    explicit Receiver(SocketForwardRegionView* view, int connection_id,
-                      std::uint32_t generation)
-        : view_{view, {connection_id}},
-          connection_id_{connection_id},
-          generation_{generation} {}
-    Receiver(const Receiver&) = delete;
-    Receiver& operator=(const Receiver&) = delete;
+    explicit ShmReceiver(ShmConnectionView* view) : view_{view} {}
+    ShmReceiver(const ShmReceiver&) = delete;
+    ShmReceiver& operator=(const ShmReceiver&) = delete;
 
-    Receiver(Receiver&&) = default;
-    Receiver& operator=(Receiver&&) = default;
-    ~Receiver() = default;
+    ShmReceiver(ShmReceiver&&) = default;
+    ShmReceiver& operator=(ShmReceiver&&) = default;
+    ~ShmReceiver() = default;
 
     void Recv(Packet* packet);
-    int port() const { return view_->port(connection_id_); }
 
    private:
-    struct QueueCloser {
-      int connection_id = -1;
-      void operator()(SocketForwardRegionView* view) const {
+    struct RecvClosedSender {
+      void operator()(ShmConnectionView* view) const {
         if (view) {
-          CHECK(connection_id >= 0);
-          view->MarkRecvQueueDisconnected(connection_id);
+          view->Send(Packet::MakeRecvClosed());
         }
       }
     };
 
-    // Doesn't actually own the View, responsible for marking the receiving
-    // side as disconnected
-    std::unique_ptr<SocketForwardRegionView, QueueCloser> view_;
-    int connection_id_{};
-    std::uint32_t generation_{};
-    bool got_begin_ = false;
+    // Doesn't actually own the view, responsible for sending the RecvClosed
+    // indicator
+    std::unique_ptr<ShmConnectionView, RecvClosedSender> view_{};
   };
 
+  friend ShmConnectionView;
+
   SocketForwardRegionView() = default;
   ~SocketForwardRegionView() = default;
   SocketForwardRegionView(const SocketForwardRegionView&) = delete;
   SocketForwardRegionView& operator=(const SocketForwardRegionView&) = delete;
 
-#ifdef CUTTLEFISH_HOST
-  std::pair<Sender, Receiver> OpenConnection(int port);
-#else
-  std::pair<Sender, Receiver> AcceptConnection();
-#endif
+  using ConnectionViewCollection = std::vector<ShmConnectionView>;
+  ConnectionViewCollection AllConnections();
 
   int port(int connection_id);
-  std::uint32_t generation();
   void CleanUpPreviousConnections();
-  void MarkSendQueueDisconnected(int connection_id);
-  void MarkRecvQueueDisconnected(int connection_id);
 
  private:
 #ifndef CUTTLEFISH_HOST
diff --git a/common/vsoc/shm/socket_forward_layout.h b/common/vsoc/shm/socket_forward_layout.h
index 02523af..4a9beda 100644
--- a/common/vsoc/shm/socket_forward_layout.h
+++ b/common/vsoc/shm/socket_forward_layout.h
@@ -27,53 +27,35 @@
 constexpr std::size_t kMaxPacketSize = 8192;
 constexpr std::size_t kNumQueues = 16;
 
-namespace QueueState {
-constexpr std::uint32_t INACTIVE = 0;
-constexpr std::uint32_t HOST_CONNECTED = 1;
-constexpr std::uint32_t BOTH_CONNECTED = 2;
-constexpr std::uint32_t HOST_CLOSED = 3;
-constexpr std::uint32_t GUEST_CLOSED = 4;
-// If both are closed then the queue goes back to INACTIVE
-// BOTH_CLOSED = 0,
-}  // namespace QueueState
-
 struct Queue {
   static constexpr size_t layout_size =
-      CircularPacketQueue<16, kMaxPacketSize>::layout_size + 4;
+      CircularPacketQueue<16, kMaxPacketSize>::layout_size;
 
   CircularPacketQueue<16, kMaxPacketSize> queue;
 
-  std::atomic_uint32_t queue_state_;
-
   bool Recover() { return queue.Recover(); }
 };
 ASSERT_SHM_COMPATIBLE(Queue);
 
 struct QueuePair {
-  static constexpr size_t layout_size = 2 * Queue::layout_size + 8;
+  static constexpr size_t layout_size = 2 * Queue::layout_size;
 
   // Traffic originating from host that proceeds towards guest.
   Queue host_to_guest;
   // Traffic originating from guest that proceeds towards host.
   Queue guest_to_host;
 
-  std::uint32_t port_;
-
-  SpinLock queue_state_lock_;
-
   bool Recover() {
-    // TODO: Put queue_state_ and port_ recovery here, probably after grabbing
     bool recovered = false;
     recovered = recovered || host_to_guest.Recover();
     recovered = recovered || guest_to_host.Recover();
-    recovered = recovered || queue_state_lock_.Recover();
     return recovered;
   }
 };
 ASSERT_SHM_COMPATIBLE(QueuePair);
 
 struct SocketForwardLayout : public RegionLayout {
-  static constexpr size_t layout_size = QueuePair::layout_size * kNumQueues + 8;
+  static constexpr size_t layout_size = QueuePair::layout_size * kNumQueues;
 
   bool Recover() {
     bool recovered = false;
@@ -81,14 +63,10 @@
       bool rval = i.Recover();
       recovered = recovered || rval;
     }
-    // TODO: consider handling the sequence number here
     return recovered;
   }
 
   QueuePair queues_[kNumQueues];
-  std::atomic_uint32_t seq_num;  // incremented for every new connection
-  std::atomic_uint32_t
-      generation_num;  // incremented for every new socket forward process
   static const char* region_name;
 };