Uses new Packet struct instead of Message for shm

This reduces copying overhead and groups the data more explicitly.
Before starting the worker's loops, they each create a Packet which is
then used to hold the data transmitted. This requires a more low-level
approach than is provided by tcp_socket.

BUG: 72654144
Change-Id: Id23d3d548f15b62cc7b5f61c570f24d6ccd92aee
diff --git a/common/commands/socket_forward_proxy/Android.bp b/common/commands/socket_forward_proxy/Android.bp
index 3a5d84b..ad1eabd 100644
--- a/common/commands/socket_forward_proxy/Android.bp
+++ b/common/commands/socket_forward_proxy/Android.bp
@@ -5,6 +5,7 @@
     ],
     shared_libs: [
         "libbase",
+        "libcuttlefish_fs",
         "vsoc_lib",
         "liblog",
         "cuttlefish_tcp_socket",
diff --git a/common/commands/socket_forward_proxy/main.cpp b/common/commands/socket_forward_proxy/main.cpp
index cec1b05..cb1fadb 100644
--- a/common/commands/socket_forward_proxy/main.cpp
+++ b/common/commands/socket_forward_proxy/main.cpp
@@ -26,13 +26,14 @@
 
 #include <unistd.h>
 
+#include "common/libs/fs/shared_fd.h"
 #include "common/vsoc/lib/socket_forward_region_view.h"
-#include "common/libs/tcp_socket/tcp_socket.h"
 
 #ifdef CUTTLEFISH_HOST
 #include "host/libs/config/host_config.h"
 #endif
 
+using vsoc::socket_forward::Packet;
 using vsoc::socket_forward::SocketForwardRegionView;
 
 #ifdef CUTTLEFISH_HOST
@@ -43,7 +44,7 @@
 class Worker {
  public:
   Worker(SocketForwardRegionView::Connection shm_connection,
-         cvd::ClientSocket socket)
+         cvd::SharedFD socket)
       : shm_connection_(std::move(shm_connection)),
         socket_(std::move(socket)){}
 
@@ -54,7 +55,7 @@
         return true;
       }
     }
-    if (shm_connection_.closed() || socket_.closed()) {
+    if (shm_connection_.closed() || !socket_->IsOpen()) {
       std::lock_guard<std::mutex> guard(closed_lock_);
       closed_ = true;
     }
@@ -76,35 +77,54 @@
 
  private:
   void SocketToShmImpl() {
-    constexpr int kRecvSize = 8192;
+    auto shm_sender = shm_connection_.MakeSender();
 
-    auto sender = shm_connection_.MakeSender();
-
+    auto packet = Packet::MakeData();
     while (true) {
       if (closed()) {
         break;
       }
-      auto msg = socket_.RecvAny(kRecvSize);
-      if (msg.empty()) {
+      auto size = socket_->Recv(packet.payload(), sizeof packet.payload(), 0);
+      if (size <= 0) {
         break;
       }
-      sender.Send(std::move(msg));
+      packet.set_payload_length(size);
+      shm_sender.Send(packet);
     }
     LOG(INFO) << "Socket to shm exiting";
     close();
   }
 
+  ssize_t SocketSendAll(const Packet& packet) {
+    ssize_t written{};
+    while (written < static_cast<ssize_t>(packet.payload_length())) {
+      if (!socket_->IsOpen()) {
+        return -1;
+      }
+      auto just_written = socket_->Write(packet.payload() + written,
+                                         packet.payload_length() - written);
+      if (just_written <= 0) {
+        LOG(INFO) << "Couldn't write to client: "
+                  << strerror(socket_->GetErrno());
+        return just_written;
+      }
+      written += just_written;
+    }
+    return written;
+  }
+
   void ShmToSocketImpl() {
-    auto receiver = shm_connection_.MakeReceiver();
+    auto shm_receiver = shm_connection_.MakeReceiver();
+    Packet packet{};
     while (true) {
       if (closed()) {
         break;
       }
-      auto msg = receiver.Recv();
-      if (msg.empty() || socket_.closed()) {
+      shm_receiver.Recv(&packet);
+      if (packet.IsEnd()) {
         break;
       }
-      if (socket_.Send(msg) < 0) {
+      if (SocketSendAll(packet) < 0) {
         break;
       }
     }
@@ -113,7 +133,7 @@
   }
 
   SocketForwardRegionView::Connection shm_connection_;
-  cvd::ClientSocket socket_;
+  cvd::SharedFD socket_;
   bool closed_{};
   std::mutex closed_lock_;
 };
@@ -121,7 +141,7 @@
 // One thread for reading from shm and writing into a socket.
 // One thread for reading from a socket and writing into shm.
 void LaunchWorkers(SocketForwardRegionView::Connection conn,
-                   cvd::ClientSocket socket) {
+                   cvd::SharedFD socket) {
   auto worker = std::make_shared<Worker>(std::move(conn), std::move(socket));
   std::thread threads[] = {std::thread(Worker::SocketToShm, worker),
                            std::thread(Worker::ShmToSocket, worker)};
@@ -133,9 +153,11 @@
 #ifdef CUTTLEFISH_HOST
 [[noreturn]] void host(SocketForwardRegionView* shm, int port) {
   LOG(INFO) << "starting server on " << port;
-  cvd::ServerSocket server(port);
+  auto server = cvd::SharedFD::SocketLocalServer(port, SOCK_STREAM);
+  CHECK(server->IsOpen()) << "Could not start server on port " << port;
   while (true) {
-    auto client_socket = server.Accept();
+    auto client_socket = cvd::SharedFD::Accept(*server);
+    CHECK(client_socket->IsOpen()) << "error creating client socket";
     LOG(INFO) << "client socket accepted";
     auto conn = shm->OpenConnection(port);
     LOG(INFO) << "shm connection opened";
@@ -148,7 +170,8 @@
   while (true) {
     auto conn = shm->AcceptConnection();
     LOG(INFO) << "shm connection accepted";
-    auto sock = cvd::ClientSocket(conn.port());
+    auto sock = cvd::SharedFD::SocketLocalClient(conn.port(), SOCK_STREAM);
+    CHECK(sock->IsOpen()) << "Could not open socket to port " << conn.port();
     LOG(INFO) << "socket opened to " << conn.port();
     LaunchWorkers(std::move(conn), std::move(sock));
   }
@@ -170,7 +193,7 @@
 // makes sure we're running as root on the guest, no-op on the host
 void assert_correct_user() {
 #ifndef CUTTLEFISH_HOST
-    CHECK_EQ(getuid(), 0u) << "must run as root!";
+  CHECK_EQ(getuid(), 0u) << "must run as root!";
 #endif
 }
 
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index b56dde4..102a827 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -39,69 +39,44 @@
 guest_to_host;
 #endif
 
-using vsoc::socket_forward::Message;
 using vsoc::socket_forward::SocketForwardRegionView;
 
-constexpr std::int32_t kConnectionBegin = -1;
-constexpr std::int32_t kConnectionEnd = -2;
-
-Message SocketForwardRegionView::Recv(int connection_id) {
-  std::int32_t len{};
-  (data()->queues_[connection_id].*ReadDirection)
-      .Read(this, reinterpret_cast<char*>(&len), sizeof len);
-  if (len == kConnectionEnd) {
-    return {};
-  }
-  CHECK_NE(len, 0) << "zero-size message received";
-  CHECK_GT(len, 0) << "invalid size";
-  Message message(len);
-  (data()->queues_[connection_id].*ReadDirection)
-      .Read(this, reinterpret_cast<char*>(message.data()), message.size());
-  return message;
+void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
+  CHECK(packet != nullptr);
+  do {
+    (data()->queues_[connection_id].*ReadDirection)
+        .Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
+  } while (packet->IsBegin());
+  // TODO(haining) check packet generation number
+  CHECK(!packet->empty()) << "zero-size data message received";
+  CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
 }
 
-void SocketForwardRegionView::Send(int connection_id, const Message& message) {
-  if (message.empty()) {
+void SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
+  if (packet.empty()) {
+    LOG(WARNING) << "ignoring empty packet (not sending)";
     return;
   }
-  std::int32_t len = message.size();
+  // TODO(haining) set packet generation number
+  CHECK_LE(packet.payload_length(), kMaxPayloadSize);
   (data()->queues_[connection_id].*WriteDirection)
-      .Write(this, reinterpret_cast<const char*>(&len), sizeof len);
-  (data()->queues_[connection_id].*WriteDirection)
-      .Write(this, reinterpret_cast<const char*>(message.data()),
-             message.size());
+      .Write(this, packet.raw_data(), packet.raw_data_length());
 }
 
 void SocketForwardRegionView::SendBegin(int connection_id) {
-  (data()->queues_[connection_id].*WriteDirection)
-      .Write(this, reinterpret_cast<const char*>(&kConnectionBegin),
-             sizeof kConnectionBegin);
+  Send(connection_id, Packet::MakeBegin());
 }
 
 void SocketForwardRegionView::SendEnd(int connection_id) {
-  (data()->queues_[connection_id].*WriteDirection)
-      .Write(this, reinterpret_cast<const char*>(&kConnectionEnd),
-             sizeof kConnectionEnd);
+  Send(connection_id, Packet::MakeEnd());
 }
 
 void SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
-  Message ignored(128);
-  while (true) {
-    std::int32_t len{};
+  Packet packet{};
+  do {
     (data()->queues_[connection_id].*ReadDirection)
-        .Read(this, reinterpret_cast<char*>(&len), sizeof len);
-    if (len == kConnectionBegin) {
-      break;
-    } else if (len == kConnectionEnd) {
-      continue;
-    }
-
-    CHECK_NE(len, 0) << "zero-size message received";
-    CHECK_GT(len, 0) << "invalid size";
-    ignored.resize(len);
-    (data()->queues_[connection_id].*ReadDirection)
-        .Read(this, reinterpret_cast<char*>(ignored.data()), ignored.size());
-  }
+        .Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
+  } while (!packet.IsBegin());  // TODO(haining) check generation number
 }
 
 #ifdef CUTTLEFISH_HOST
@@ -231,8 +206,8 @@
   view_->IgnoreUntilBegin(connection_id_);
 }
 
-Message SocketForwardRegionView::Connection::Recv() {
-  return view_->Recv(connection_id_);
+void SocketForwardRegionView::Connection::Recv(Packet* packet) {
+  return view_->Recv(connection_id_, packet);
 }
 
 bool SocketForwardRegionView::Connection::closed() const {
@@ -247,10 +222,10 @@
   view_->SendBegin(connection_id_);
 }
 
-void SocketForwardRegionView::Connection::Send(const Message& message) {
+void SocketForwardRegionView::Connection::Send(const Packet& packet) {
   if (closed()) {
     LOG(INFO) << "connection closed, not sending\n";
     return;
   }
-  view_->Send(connection_id_, message);
+  view_->Send(connection_id_, packet);
 }
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index 27bea34..89da339 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -25,8 +25,6 @@
 namespace vsoc {
 namespace socket_forward {
 
-using Message = std::vector<std::uint8_t>;
-
 struct Header {
   std::uint32_t payload_length;
   std::uint32_t generation;
@@ -44,26 +42,30 @@
 struct Packet {
  private:
   Header header_;
-  char payload_data_[kMaxPayloadSize];
+  using Payload = char[kMaxPayloadSize];
+  Payload payload_data_;
 
-  static Packet MakePacket(Header::MessageType type, std::uint32_t generation) {
+  static Packet MakePacket(Header::MessageType type) {
     Packet packet{};
-    packet.set_generation(generation);
     packet.header_.message_type = type;
     return packet;
   }
 
  public:
-  static Packet MakeBegin(std::uint32_t generation) {
-    return MakePacket(Header::BEGIN, generation);
+  static Packet MakeBegin() {
+    return MakePacket(Header::BEGIN);
   }
 
-  static Packet MakeEnd(std::uint32_t generation) {
-    return MakePacket(Header::END, generation);
+  static Packet MakeEnd() {
+    return MakePacket(Header::END);
   }
 
-  static Packet MakeData(std::uint32_t generation) {
-    return MakePacket(Header::DATA, generation);
+  static Packet MakeData() {
+    return MakePacket(Header::DATA);
+  }
+
+  bool empty() const {
+    return header_.message_type == Header::DATA && header_.payload_length == 0;
   }
 
   void set_payload_length(std::uint32_t length) {
@@ -80,7 +82,11 @@
     header_.generation = generation;
   }
 
-  char* payload() {
+  Payload& payload() {
+    return payload_data_;
+  }
+
+  const Payload& payload() const {
     return payload_data_;
   }
 
@@ -88,10 +94,17 @@
     return header_.payload_length;
   }
 
-  Header::MessageType message_type() const {
-    return header_.message_type;
+  bool IsBegin() const {
+    return header_.message_type == Header::BEGIN;
   }
 
+  bool IsEnd() const {
+    return header_.message_type == Header::END;
+  }
+
+  bool IsData() const {
+    return header_.message_type == Header::DATA;
+  }
 
   char* raw_data() {
     return reinterpret_cast<char*>(this);
@@ -101,7 +114,7 @@
     return reinterpret_cast<const char*>(this);
   }
 
-  size_t raw_data_size() const {
+  size_t raw_data_length() const {
     return payload_length() + sizeof header_;
   }
 };
@@ -122,17 +135,17 @@
   void ReleaseConnectionID(int connection_id);
   std::pair<int, int> GetWaitingConnectionIDAndPort();
 
-  // Returns an empty Message if the other side is closed.
-  Message Recv(int connection_id);
-  // Does nothing if message is empty
-  void Send(int connection_id, const Message& message);
+  // Returns an empty data packet if the other side is closed.
+  void Recv(int connection_id, Packet* packet);
+  // Does nothing if packet is empty
+  void Send(int connection_id, const Packet& packet);
 
   void SendBegin(int connection_id);
   void SendEnd(int connection_id);
 
-  // skip everything in the connection queue until seeing the beginning of
-  // the next message
+  // skip everything in the connection queue until seeing a BEGIN
   void IgnoreUntilBegin(int connection_id);
+
   bool IsOtherSideClosed(int connection_id);
 
  public:
@@ -164,12 +177,12 @@
 
    private:
     // Sends should be done using a Sender.
-    void Send(const Message& message);
+    void Send(const Packet& packet);
     void SendBegin();
     void SendEnd();
 
     // Receives should be done using a Receiver.
-    Message Recv();
+    void Recv(Packet* packet);
     void IgnoreUntilBegin();
 
     struct Releaser {
@@ -201,8 +214,8 @@
       connection_->SendBegin();
     }
 
-    void Send(const Message& message) {
-      connection_->Send(message);
+    void Send(const Packet& packet) {
+      connection_->Send(packet);
     }
 
    private:
@@ -225,8 +238,8 @@
       connection_->IgnoreUntilBegin();
     }
 
-    Message Recv() {
-      return connection_->Recv();
+    void Recv(Packet* packet) {
+      return connection_->Recv(packet);
     }
 
    private: