Host will wait on HOST_CONNECTED queues until queue_state changes.

Bug: 80486122
Bug: 80104636
Change-Id: If447364be757feb40053e85e30c7969f84b7d25a
Merged-In: If447364be757feb40053e85e30c7969f84b7d25a
Test: local boot and hammer with adb connect commands
(cherry picked from commit 0bf3b641ce6bbb328ad755442bcea3dfd43cb844)
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index 2dbd2e8..6f64b19 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -24,6 +24,7 @@
 
 using vsoc::layout::socket_forward::Queue;
 using vsoc::layout::socket_forward::QueuePair;
+namespace QueueState = vsoc::layout::socket_forward::QueueState;
 // store the read and write direction as variables to keep the ifdefs and macros
 // in later code to a minimum
 constexpr auto ReadDirection = &QueuePair::
@@ -83,6 +84,25 @@
 
 void SocketForwardRegionView::CleanUpPreviousConnections() {
   data()->Recover();
+  int connection_id = 0;
+  auto current_generation = generation();
+  auto begin_packet = Packet::MakeBegin();
+  begin_packet.set_generation(current_generation);
+  auto end_packet = Packet::MakeEnd();
+  end_packet.set_generation(current_generation);
+  for (auto&& queue_pair : data()->queues_) {
+    std::uint32_t state{};
+    {
+      auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+      state = (queue_pair.*WriteDirection).queue_state_;
+#ifndef CUTTLEFISH_HOST
+      if (state == QueueState::HOST_CONNECTED) {
+        state = (queue_pair.*WriteDirection).queue_state_ =
+            (queue_pair.*ReadDirection).queue_state_ =
+                QueueState::BOTH_CONNECTED;
+      }
+#endif
+    }
 
   static constexpr auto kRestartPacket = Packet::MakeRestart();
   for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
@@ -90,11 +110,18 @@
   }
 }
 
-SocketForwardRegionView::ConnectionViewCollection
-SocketForwardRegionView::AllConnections() {
-  SocketForwardRegionView::ConnectionViewCollection all_queues;
-  for (int connection_id = 0; connection_id < kNumQueues; ++connection_id) {
-    all_queues.emplace_back(this, connection_id);
+void SocketForwardRegionView::MarkQueueDisconnected(
+    int connection_id, Queue QueuePair::*direction) {
+  auto& queue_pair = data()->queues_[connection_id];
+  auto& queue = queue_pair.*direction;
+
+#ifdef CUTTLEFISH_HOST
+  // if the host has connected but the guest hasn't seen it yet, wait for the
+  // guest to connect so the protocol can follow the normal state transition.
+  while (queue.queue_state_ == QueueState::HOST_CONNECTED) {
+    LOG(WARNING) << "closing queue[" << connection_id
+                 << "] in HOST_CONNECTED state. waiting";
+    WaitForSignal(&queue.queue_state_, QueueState::HOST_CONNECTED);
   }
   return all_queues;
 }
@@ -163,15 +190,20 @@
     static constexpr auto kEmptyPacket = Packet::MakeData();
     received_packet_ = kEmptyPacket;
   }
-  received_packet_free_ = false;
-  receive_thread_data_cv_.notify_one();
-}
-
-void SocketForwardRegionView::ShmConnectionView::Receiver::Start() {
-  while (ExpectMorePackets()) {
-    std::unique_lock<std::mutex> guard(receive_thread_data_lock_);
-    while (!received_packet_free_) {
-      receive_thread_data_cv_.wait(guard);
+  ++last_seq_number_;
+  int id = 0;
+  for (auto&& queue_pair : data()->queues_) {
+    LOG(DEBUG) << "locking and checking queue at index " << id;
+    auto guard = make_lock_guard(&queue_pair.queue_state_lock_);
+    if (queue_pair.host_to_guest.queue_state_ == QueueState::HOST_CONNECTED) {
+      CHECK(queue_pair.guest_to_host.queue_state_ ==
+            QueueState::HOST_CONNECTED);
+      LOG(DEBUG) << "found waiting connection at index " << id;
+      queue_pair.host_to_guest.queue_state_ = QueueState::BOTH_CONNECTED;
+      queue_pair.guest_to_host.queue_state_ = QueueState::BOTH_CONNECTED;
+      SendSignal(layout::Sides::Peer, &queue_pair.host_to_guest.queue_state_);
+      SendSignal(layout::Sides::Peer, &queue_pair.guest_to_host.queue_state_);
+      return id;
     }
 
     do {
diff --git a/common/vsoc/shm/socket_forward_layout.h b/common/vsoc/shm/socket_forward_layout.h
index 4a9beda..141aff3 100644
--- a/common/vsoc/shm/socket_forward_layout.h
+++ b/common/vsoc/shm/socket_forward_layout.h
@@ -27,12 +27,30 @@
 constexpr std::size_t kMaxPacketSize = 8192;
 constexpr std::size_t kNumQueues = 16;
 
+<<<<<<< HEAD
+=======
+namespace QueueState {
+constexpr std::uint32_t INACTIVE = 0;
+constexpr std::uint32_t HOST_CONNECTED = 1;
+constexpr std::uint32_t BOTH_CONNECTED = 2;
+constexpr std::uint32_t HOST_CLOSED = 3;
+constexpr std::uint32_t GUEST_CLOSED = 4;
+// If both are closed then the queue goes back to INACTIVE
+// BOTH_CLOSED = 0,
+}  // namespace QueueState
+
+>>>>>>> 4d9ddd4... Host will wait on HOST_CONNECTED queues until queue_state changes.
 struct Queue {
   static constexpr size_t layout_size =
       CircularPacketQueue<16, kMaxPacketSize>::layout_size;
 
   CircularPacketQueue<16, kMaxPacketSize> queue;
 
+<<<<<<< HEAD
+=======
+  std::atomic_uint32_t queue_state_;
+
+>>>>>>> 4d9ddd4... Host will wait on HOST_CONNECTED queues until queue_state changes.
   bool Recover() { return queue.Recover(); }
 };
 ASSERT_SHM_COMPATIBLE(Queue);
@@ -45,6 +63,13 @@
   // Traffic originating from guest that proceeds towards host.
   Queue guest_to_host;
 
+<<<<<<< HEAD
+=======
+  std::uint32_t port_;
+
+  SpinLock queue_state_lock_;
+
+>>>>>>> 4d9ddd4... Host will wait on HOST_CONNECTED queues until queue_state changes.
   bool Recover() {
     bool recovered = false;
     recovered = recovered || host_to_guest.Recover();
@@ -63,10 +88,20 @@
       bool rval = i.Recover();
       recovered = recovered || rval;
     }
+<<<<<<< HEAD
+=======
+    // TODO: consider handling the sequence number here
+>>>>>>> 4d9ddd4... Host will wait on HOST_CONNECTED queues until queue_state changes.
     return recovered;
   }
 
   QueuePair queues_[kNumQueues];
+<<<<<<< HEAD
+=======
+  std::atomic_uint32_t seq_num;  // incremented for every new connection
+  std::atomic_uint32_t
+      generation_num;  // incremented for every new socket forward process
+>>>>>>> 4d9ddd4... Host will wait on HOST_CONNECTED queues until queue_state changes.
   static const char* region_name;
 };