Uses new Packet struct instead of Message for shm
This reduces copying overhead and groups the data more explicitly.
Before starting the worker's loops, they each create a Packet which is
then used to hold the data transmitted. This requires a more low-level
approach than is provided by tcp_socket.
BUG: 72654144
Change-Id: Id23d3d548f15b62cc7b5f61c570f24d6ccd92aee
diff --git a/common/commands/socket_forward_proxy/Android.bp b/common/commands/socket_forward_proxy/Android.bp
index 07d02ce..9819e9d 100644
--- a/common/commands/socket_forward_proxy/Android.bp
+++ b/common/commands/socket_forward_proxy/Android.bp
@@ -5,6 +5,7 @@
],
shared_libs: [
"libbase",
+ "libcuttlefish_fs",
"vsoc_lib",
"liblog",
"cuttlefish_tcp_socket",
diff --git a/common/commands/socket_forward_proxy/main.cpp b/common/commands/socket_forward_proxy/main.cpp
index cec1b05..cb1fadb 100644
--- a/common/commands/socket_forward_proxy/main.cpp
+++ b/common/commands/socket_forward_proxy/main.cpp
@@ -26,13 +26,14 @@
#include <unistd.h>
+#include "common/libs/fs/shared_fd.h"
#include "common/vsoc/lib/socket_forward_region_view.h"
-#include "common/libs/tcp_socket/tcp_socket.h"
#ifdef CUTTLEFISH_HOST
#include "host/libs/config/host_config.h"
#endif
+using vsoc::socket_forward::Packet;
using vsoc::socket_forward::SocketForwardRegionView;
#ifdef CUTTLEFISH_HOST
@@ -43,7 +44,7 @@
class Worker {
public:
Worker(SocketForwardRegionView::Connection shm_connection,
- cvd::ClientSocket socket)
+ cvd::SharedFD socket)
: shm_connection_(std::move(shm_connection)),
socket_(std::move(socket)){}
@@ -54,7 +55,7 @@
return true;
}
}
- if (shm_connection_.closed() || socket_.closed()) {
+ if (shm_connection_.closed() || !socket_->IsOpen()) {
std::lock_guard<std::mutex> guard(closed_lock_);
closed_ = true;
}
@@ -76,35 +77,54 @@
private:
void SocketToShmImpl() {
- constexpr int kRecvSize = 8192;
+ auto shm_sender = shm_connection_.MakeSender();
- auto sender = shm_connection_.MakeSender();
-
+ auto packet = Packet::MakeData();
while (true) {
if (closed()) {
break;
}
- auto msg = socket_.RecvAny(kRecvSize);
- if (msg.empty()) {
+ auto size = socket_->Recv(packet.payload(), sizeof packet.payload(), 0);
+ if (size <= 0) {
break;
}
- sender.Send(std::move(msg));
+ packet.set_payload_length(size);
+ shm_sender.Send(packet);
}
LOG(INFO) << "Socket to shm exiting";
close();
}
+ ssize_t SocketSendAll(const Packet& packet) {
+ ssize_t written{};
+ while (written < static_cast<ssize_t>(packet.payload_length())) {
+ if (!socket_->IsOpen()) {
+ return -1;
+ }
+ auto just_written = socket_->Write(packet.payload() + written,
+ packet.payload_length() - written);
+ if (just_written <= 0) {
+ LOG(INFO) << "Couldn't write to client: "
+ << strerror(socket_->GetErrno());
+ return just_written;
+ }
+ written += just_written;
+ }
+ return written;
+ }
+
void ShmToSocketImpl() {
- auto receiver = shm_connection_.MakeReceiver();
+ auto shm_receiver = shm_connection_.MakeReceiver();
+ Packet packet{};
while (true) {
if (closed()) {
break;
}
- auto msg = receiver.Recv();
- if (msg.empty() || socket_.closed()) {
+ shm_receiver.Recv(&packet);
+ if (packet.IsEnd()) {
break;
}
- if (socket_.Send(msg) < 0) {
+ if (SocketSendAll(packet) < 0) {
break;
}
}
@@ -113,7 +133,7 @@
}
SocketForwardRegionView::Connection shm_connection_;
- cvd::ClientSocket socket_;
+ cvd::SharedFD socket_;
bool closed_{};
std::mutex closed_lock_;
};
@@ -121,7 +141,7 @@
// One thread for reading from shm and writing into a socket.
// One thread for reading from a socket and writing into shm.
void LaunchWorkers(SocketForwardRegionView::Connection conn,
- cvd::ClientSocket socket) {
+ cvd::SharedFD socket) {
auto worker = std::make_shared<Worker>(std::move(conn), std::move(socket));
std::thread threads[] = {std::thread(Worker::SocketToShm, worker),
std::thread(Worker::ShmToSocket, worker)};
@@ -133,9 +153,11 @@
#ifdef CUTTLEFISH_HOST
[[noreturn]] void host(SocketForwardRegionView* shm, int port) {
LOG(INFO) << "starting server on " << port;
- cvd::ServerSocket server(port);
+ auto server = cvd::SharedFD::SocketLocalServer(port, SOCK_STREAM);
+ CHECK(server->IsOpen()) << "Could not start server on port " << port;
while (true) {
- auto client_socket = server.Accept();
+ auto client_socket = cvd::SharedFD::Accept(*server);
+ CHECK(client_socket->IsOpen()) << "error creating client socket";
LOG(INFO) << "client socket accepted";
auto conn = shm->OpenConnection(port);
LOG(INFO) << "shm connection opened";
@@ -148,7 +170,8 @@
while (true) {
auto conn = shm->AcceptConnection();
LOG(INFO) << "shm connection accepted";
- auto sock = cvd::ClientSocket(conn.port());
+ auto sock = cvd::SharedFD::SocketLocalClient(conn.port(), SOCK_STREAM);
+ CHECK(sock->IsOpen()) << "Could not open socket to port " << conn.port();
LOG(INFO) << "socket opened to " << conn.port();
LaunchWorkers(std::move(conn), std::move(sock));
}
@@ -170,7 +193,7 @@
// makes sure we're running as root on the guest, no-op on the host
void assert_correct_user() {
#ifndef CUTTLEFISH_HOST
- CHECK_EQ(getuid(), 0u) << "must run as root!";
+ CHECK_EQ(getuid(), 0u) << "must run as root!";
#endif
}
diff --git a/common/vsoc/lib/socket_forward_region_view.cpp b/common/vsoc/lib/socket_forward_region_view.cpp
index b56dde4..102a827 100644
--- a/common/vsoc/lib/socket_forward_region_view.cpp
+++ b/common/vsoc/lib/socket_forward_region_view.cpp
@@ -39,69 +39,44 @@
guest_to_host;
#endif
-using vsoc::socket_forward::Message;
using vsoc::socket_forward::SocketForwardRegionView;
-constexpr std::int32_t kConnectionBegin = -1;
-constexpr std::int32_t kConnectionEnd = -2;
-
-Message SocketForwardRegionView::Recv(int connection_id) {
- std::int32_t len{};
- (data()->queues_[connection_id].*ReadDirection)
- .Read(this, reinterpret_cast<char*>(&len), sizeof len);
- if (len == kConnectionEnd) {
- return {};
- }
- CHECK_NE(len, 0) << "zero-size message received";
- CHECK_GT(len, 0) << "invalid size";
- Message message(len);
- (data()->queues_[connection_id].*ReadDirection)
- .Read(this, reinterpret_cast<char*>(message.data()), message.size());
- return message;
+void SocketForwardRegionView::Recv(int connection_id, Packet* packet) {
+ CHECK(packet != nullptr);
+ do {
+ (data()->queues_[connection_id].*ReadDirection)
+ .Read(this, reinterpret_cast<char*>(packet), sizeof *packet);
+ } while (packet->IsBegin());
+ // TODO(haining) check packet generation number
+ CHECK(!packet->empty()) << "zero-size data message received";
+ CHECK_LE(packet->payload_length(), kMaxPayloadSize) << "invalid size";
}
-void SocketForwardRegionView::Send(int connection_id, const Message& message) {
- if (message.empty()) {
+void SocketForwardRegionView::Send(int connection_id, const Packet& packet) {
+ if (packet.empty()) {
+ LOG(WARNING) << "ignoring empty packet (not sending)";
return;
}
- std::int32_t len = message.size();
+ // TODO(haining) set packet generation number
+ CHECK_LE(packet.payload_length(), kMaxPayloadSize);
(data()->queues_[connection_id].*WriteDirection)
- .Write(this, reinterpret_cast<const char*>(&len), sizeof len);
- (data()->queues_[connection_id].*WriteDirection)
- .Write(this, reinterpret_cast<const char*>(message.data()),
- message.size());
+ .Write(this, packet.raw_data(), packet.raw_data_length());
}
void SocketForwardRegionView::SendBegin(int connection_id) {
- (data()->queues_[connection_id].*WriteDirection)
- .Write(this, reinterpret_cast<const char*>(&kConnectionBegin),
- sizeof kConnectionBegin);
+ Send(connection_id, Packet::MakeBegin());
}
void SocketForwardRegionView::SendEnd(int connection_id) {
- (data()->queues_[connection_id].*WriteDirection)
- .Write(this, reinterpret_cast<const char*>(&kConnectionEnd),
- sizeof kConnectionEnd);
+ Send(connection_id, Packet::MakeEnd());
}
void SocketForwardRegionView::IgnoreUntilBegin(int connection_id) {
- Message ignored(128);
- while (true) {
- std::int32_t len{};
+ Packet packet{};
+ do {
(data()->queues_[connection_id].*ReadDirection)
- .Read(this, reinterpret_cast<char*>(&len), sizeof len);
- if (len == kConnectionBegin) {
- break;
- } else if (len == kConnectionEnd) {
- continue;
- }
-
- CHECK_NE(len, 0) << "zero-size message received";
- CHECK_GT(len, 0) << "invalid size";
- ignored.resize(len);
- (data()->queues_[connection_id].*ReadDirection)
- .Read(this, reinterpret_cast<char*>(ignored.data()), ignored.size());
- }
+ .Read(this, reinterpret_cast<char*>(&packet), sizeof packet);
+ } while (!packet.IsBegin()); // TODO(haining) check generation number
}
#ifdef CUTTLEFISH_HOST
@@ -231,8 +206,8 @@
view_->IgnoreUntilBegin(connection_id_);
}
-Message SocketForwardRegionView::Connection::Recv() {
- return view_->Recv(connection_id_);
+void SocketForwardRegionView::Connection::Recv(Packet* packet) {
+ return view_->Recv(connection_id_, packet);
}
bool SocketForwardRegionView::Connection::closed() const {
@@ -247,10 +222,10 @@
view_->SendBegin(connection_id_);
}
-void SocketForwardRegionView::Connection::Send(const Message& message) {
+void SocketForwardRegionView::Connection::Send(const Packet& packet) {
if (closed()) {
LOG(INFO) << "connection closed, not sending\n";
return;
}
- view_->Send(connection_id_, message);
+ view_->Send(connection_id_, packet);
}
diff --git a/common/vsoc/lib/socket_forward_region_view.h b/common/vsoc/lib/socket_forward_region_view.h
index 27bea34..89da339 100644
--- a/common/vsoc/lib/socket_forward_region_view.h
+++ b/common/vsoc/lib/socket_forward_region_view.h
@@ -25,8 +25,6 @@
namespace vsoc {
namespace socket_forward {
-using Message = std::vector<std::uint8_t>;
-
struct Header {
std::uint32_t payload_length;
std::uint32_t generation;
@@ -44,26 +42,30 @@
struct Packet {
private:
Header header_;
- char payload_data_[kMaxPayloadSize];
+ using Payload = char[kMaxPayloadSize];
+ Payload payload_data_;
- static Packet MakePacket(Header::MessageType type, std::uint32_t generation) {
+ static Packet MakePacket(Header::MessageType type) {
Packet packet{};
- packet.set_generation(generation);
packet.header_.message_type = type;
return packet;
}
public:
- static Packet MakeBegin(std::uint32_t generation) {
- return MakePacket(Header::BEGIN, generation);
+ static Packet MakeBegin() {
+ return MakePacket(Header::BEGIN);
}
- static Packet MakeEnd(std::uint32_t generation) {
- return MakePacket(Header::END, generation);
+ static Packet MakeEnd() {
+ return MakePacket(Header::END);
}
- static Packet MakeData(std::uint32_t generation) {
- return MakePacket(Header::DATA, generation);
+ static Packet MakeData() {
+ return MakePacket(Header::DATA);
+ }
+
+ bool empty() const {
+ return header_.message_type == Header::DATA && header_.payload_length == 0;
}
void set_payload_length(std::uint32_t length) {
@@ -80,7 +82,11 @@
header_.generation = generation;
}
- char* payload() {
+ Payload& payload() {
+ return payload_data_;
+ }
+
+ const Payload& payload() const {
return payload_data_;
}
@@ -88,10 +94,17 @@
return header_.payload_length;
}
- Header::MessageType message_type() const {
- return header_.message_type;
+ bool IsBegin() const {
+ return header_.message_type == Header::BEGIN;
}
+ bool IsEnd() const {
+ return header_.message_type == Header::END;
+ }
+
+ bool IsData() const {
+ return header_.message_type == Header::DATA;
+ }
char* raw_data() {
return reinterpret_cast<char*>(this);
@@ -101,7 +114,7 @@
return reinterpret_cast<const char*>(this);
}
- size_t raw_data_size() const {
+ size_t raw_data_length() const {
return payload_length() + sizeof header_;
}
};
@@ -122,17 +135,17 @@
void ReleaseConnectionID(int connection_id);
std::pair<int, int> GetWaitingConnectionIDAndPort();
- // Returns an empty Message if the other side is closed.
- Message Recv(int connection_id);
- // Does nothing if message is empty
- void Send(int connection_id, const Message& message);
+ // Returns an empty data packet if the other side is closed.
+ void Recv(int connection_id, Packet* packet);
+ // Does nothing if packet is empty
+ void Send(int connection_id, const Packet& packet);
void SendBegin(int connection_id);
void SendEnd(int connection_id);
- // skip everything in the connection queue until seeing the beginning of
- // the next message
+ // skip everything in the connection queue until seeing a BEGIN
void IgnoreUntilBegin(int connection_id);
+
bool IsOtherSideClosed(int connection_id);
public:
@@ -164,12 +177,12 @@
private:
// Sends should be done using a Sender.
- void Send(const Message& message);
+ void Send(const Packet& packet);
void SendBegin();
void SendEnd();
// Receives should be done using a Receiver.
- Message Recv();
+ void Recv(Packet* packet);
void IgnoreUntilBegin();
struct Releaser {
@@ -201,8 +214,8 @@
connection_->SendBegin();
}
- void Send(const Message& message) {
- connection_->Send(message);
+ void Send(const Packet& packet) {
+ connection_->Send(packet);
}
private:
@@ -225,8 +238,8 @@
connection_->IgnoreUntilBegin();
}
- Message Recv() {
- return connection_->Recv();
+ void Recv(Packet* packet) {
+ return connection_->Recv(packet);
}
private: