blob: 31cff0b56028323512556c94ca554ea6a57a2f0d [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <thread>
#include <glog/logging.h>
#include <gflags/gflags.h>
#include "common/libs/fs/shared_fd.h"
#ifdef CUTTLEFISH_HOST
#include "host/libs/config/cuttlefish_config.h"
#endif
struct Header {
std::uint32_t payload_length;
enum MessageType : std::uint32_t {
DATA = 0,
BEGIN,
END,
RECV_CLOSED, // indicate that this side's receive end is closed
RESTART,
};
MessageType message_type;
};
constexpr std::size_t kMaxPacketSize = 8192;
constexpr std::size_t kMaxPayloadSize = kMaxPacketSize - sizeof(Header);
struct Packet {
private:
Header header_;
using Payload = char[kMaxPayloadSize];
Payload payload_data_;
static constexpr Packet MakePacket(Header::MessageType type) {
Packet packet{};
packet.header_.message_type = type;
return packet;
}
public:
// port is only revelant on the host-side.
static Packet MakeBegin(std::uint16_t port);
static constexpr Packet MakeEnd() { return MakePacket(Header::END); }
static constexpr Packet MakeRecvClosed() {
return MakePacket(Header::RECV_CLOSED);
}
static constexpr Packet MakeRestart() { return MakePacket(Header::RESTART); }
// NOTE payload and payload_length must still be set.
static constexpr Packet MakeData() { return MakePacket(Header::DATA); }
bool empty() const { return IsData() && header_.payload_length == 0; }
void set_payload_length(std::uint32_t length) {
CHECK_LE(length, sizeof payload_data_);
header_.payload_length = length;
}
Payload& payload() { return payload_data_; }
const Payload& payload() const { return payload_data_; }
constexpr std::uint32_t payload_length() const {
return header_.payload_length;
}
constexpr bool IsBegin() const {
return header_.message_type == Header::BEGIN;
}
constexpr bool IsEnd() const { return header_.message_type == Header::END; }
constexpr bool IsData() const { return header_.message_type == Header::DATA; }
constexpr bool IsRecvClosed() const {
return header_.message_type == Header::RECV_CLOSED;
}
constexpr bool IsRestart() const {
return header_.message_type == Header::RESTART;
}
constexpr std::uint16_t port() const {
CHECK(IsBegin());
std::uint16_t port_number{};
CHECK_EQ(payload_length(), sizeof port_number);
std::memcpy(&port_number, payload(), sizeof port_number);
return port_number;
}
char* raw_data() { return reinterpret_cast<char*>(this); }
const char* raw_data() const { return reinterpret_cast<const char*>(this); }
constexpr size_t raw_data_length() const {
return payload_length() + sizeof header_;
}
};
static_assert(sizeof(Packet) == kMaxPacketSize, "");
static_assert(std::is_pod<Packet>{}, "");
DEFINE_uint32(tcp_port, 0, "TCP port (server on host, client on guest)");
DEFINE_uint32(vsock_port, 0, "vsock port (client on host, server on guest");
DEFINE_uint32(vsock_guest_cid, 0, "Guest identifier");
namespace {
// Sends packets, Shutdown(SHUT_WR) on destruction
class SocketSender {
public:
explicit SocketSender(cvd::SharedFD socket) : socket_{socket} {}
SocketSender(SocketSender&&) = default;
SocketSender& operator=(SocketSender&&) = default;
SocketSender(const SocketSender&&) = delete;
SocketSender& operator=(const SocketSender&) = delete;
~SocketSender() {
if (socket_.operator->()) { // check that socket_ was not moved-from
socket_->Shutdown(SHUT_WR);
}
}
ssize_t SendAll(const Packet& packet) {
ssize_t written{};
while (written < static_cast<ssize_t>(packet.payload_length())) {
if (!socket_->IsOpen()) {
return -1;
}
auto just_written =
socket_->Send(packet.payload() + written,
packet.payload_length() - written, MSG_NOSIGNAL);
if (just_written <= 0) {
LOG(INFO) << "Couldn't write to client: "
<< strerror(socket_->GetErrno());
return just_written;
}
written += just_written;
}
return written;
}
private:
cvd::SharedFD socket_;
};
class SocketReceiver {
public:
explicit SocketReceiver(cvd::SharedFD socket) : socket_{socket} {}
SocketReceiver(SocketReceiver&&) = default;
SocketReceiver& operator=(SocketReceiver&&) = default;
SocketReceiver(const SocketReceiver&&) = delete;
SocketReceiver& operator=(const SocketReceiver&) = delete;
// *packet will be empty if Read returns 0 or error
void Recv(Packet* packet) {
auto size = socket_->Read(packet->payload(), sizeof packet->payload());
if (size < 0) {
size = 0;
}
packet->set_payload_length(size);
}
private:
cvd::SharedFD socket_;
};
void SocketToVsock(SocketReceiver socket_receiver,
SocketSender vsock_sender) {
while (true) {
auto packet = Packet::MakeData();
socket_receiver.Recv(&packet);
if (packet.empty() || vsock_sender.SendAll(packet) < 0) {
break;
}
}
LOG(INFO) << "Socket to vsock exiting";
}
void VsockToSocket(SocketSender socket_sender,
SocketReceiver vsock_receiver) {
auto packet = Packet::MakeData();
while (true) {
vsock_receiver.Recv(&packet);
CHECK(packet.IsData());
if (packet.empty()) {
break;
}
if (socket_sender.SendAll(packet) < 0) {
break;
}
}
LOG(INFO) << "Vsock to socket exiting";
}
// One thread for reading from shm and writing into a socket.
// One thread for reading from a socket and writing into shm.
void HandleConnection(cvd::SharedFD vsock,
cvd::SharedFD socket) {
auto socket_to_vsock =
std::thread(SocketToVsock, SocketReceiver{socket}, SocketSender{vsock});
VsockToSocket(SocketSender{socket}, SocketReceiver{vsock});
socket_to_vsock.join();
}
#ifdef CUTTLEFISH_HOST
[[noreturn]] void host() {
LOG(INFO) << "starting server on " << FLAGS_tcp_port << " for vsock port "
<< FLAGS_vsock_port;
auto server = cvd::SharedFD::SocketLocalServer(FLAGS_tcp_port, SOCK_STREAM);
CHECK(server->IsOpen()) << "Could not start server on " << FLAGS_tcp_port;
LOG(INFO) << "Accepting client connections";
int last_failure_reason = 0;
while (true) {
auto client_socket = cvd::SharedFD::Accept(*server);
CHECK(client_socket->IsOpen()) << "error creating client socket";
cvd::SharedFD vsock_socket = cvd::SharedFD::VsockClient(
FLAGS_vsock_guest_cid, FLAGS_vsock_port, SOCK_STREAM);
if (vsock_socket->IsOpen()) {
last_failure_reason = 0;
LOG(INFO) << "Connected to vsock:" << FLAGS_vsock_guest_cid << ":"
<< FLAGS_vsock_port;
} else {
// Don't log if the previous connection failed with the same error
if (last_failure_reason != vsock_socket->GetErrno()) {
last_failure_reason = vsock_socket->GetErrno();
LOG(ERROR) << "Unable to connect to vsock server: "
<< vsock_socket->StrError();
}
continue;
}
auto thread = std::thread(HandleConnection, std::move(vsock_socket),
std::move(client_socket));
thread.detach();
}
}
#else
cvd::SharedFD OpenSocketConnection() {
while (true) {
auto sock = cvd::SharedFD::SocketLocalClient(FLAGS_tcp_port, SOCK_STREAM);
if (sock->IsOpen()) {
return sock;
}
LOG(WARNING) << "could not connect on port " << FLAGS_tcp_port
<< ". sleeping for 1 second";
sleep(1);
}
}
bool socketErrorIsRecoverable(int error) {
std::set<int> unrecoverable{EACCES, EAFNOSUPPORT, EINVAL, EPROTONOSUPPORT};
return unrecoverable.find(error) == unrecoverable.end();
}
[[noreturn]] static void SleepForever() {
while (true) {
sleep(std::numeric_limits<unsigned int>::max());
}
}
[[noreturn]] void guest() {
LOG(INFO) << "Starting guest mainloop";
LOG(INFO) << "starting server on " << FLAGS_vsock_port;
cvd::SharedFD vsock;
do {
vsock = cvd::SharedFD::VsockServer(FLAGS_vsock_port, SOCK_STREAM);
if (!vsock->IsOpen() && !socketErrorIsRecoverable(vsock->GetErrno())) {
LOG(ERROR) << "Could not open vsock socket: " << vsock->StrError();
SleepForever();
}
} while (!vsock->IsOpen());
CHECK(vsock->IsOpen()) << "Could not start server on " << FLAGS_vsock_port;
while (true) {
LOG(INFO) << "waiting for vsock connection";
auto vsock_client = cvd::SharedFD::Accept(*vsock);
CHECK(vsock_client->IsOpen()) << "error creating vsock socket";
LOG(INFO) << "vsock socket accepted";
auto client = OpenSocketConnection();
CHECK(client->IsOpen()) << "error connecting to guest client";
auto thread = std::thread(HandleConnection, std::move(vsock_client),
std::move(client));
thread.detach();
}
}
#endif
} // namespace
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
CHECK(FLAGS_tcp_port != 0) << "Must specify -tcp_port flag";
CHECK(FLAGS_vsock_port != 0) << "Must specify -vsock_port flag";
#ifdef CUTTLEFISH_HOST
CHECK(FLAGS_vsock_guest_cid != 0) << "Must specify -vsock_guest_cid flag";
host();
#else
guest();
#endif
}