blob: b2c16f6e1af3c269caaab2234a78523479e52ca1 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <array>
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <limits>
#include <memory>
#include <mutex>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
#include <glog/logging.h>
#include <gflags/gflags.h>
#include <unistd.h>
#include "common/libs/fs/shared_fd.h"
#include "common/libs/strings/str_split.h"
#include "common/vsoc/lib/socket_forward_region_view.h"
#ifdef CUTTLEFISH_HOST
#include "host/libs/config/cuttlefish_config.h"
#include "host/libs/adb_connection_maintainer/adb_connection_maintainer.h"
#endif
using vsoc::socket_forward::Packet;
using vsoc::socket_forward::SocketForwardRegionView;
#ifdef CUTTLEFISH_HOST
DEFINE_string(guest_ports, "",
"Comma-separated list of ports on which to forward TCP "
"connections to the guest.");
DEFINE_string(host_ports, "",
"Comma-separated list of ports on which to run TCP servers on "
"the host.");
#endif
namespace {
// Sends packets, Shutdown(SHUT_WR) on destruction
class SocketSender {
public:
explicit SocketSender(cvd::SharedFD socket) : socket_{std::move(socket)} {}
SocketSender(SocketSender&&) = default;
SocketSender& operator=(SocketSender&&) = default;
SocketSender(const SocketSender&&) = delete;
SocketSender& operator=(const SocketSender&) = delete;
~SocketSender() {
if (socket_.operator->()) { // check that socket_ was not moved-from
socket_->Shutdown(SHUT_WR);
}
}
ssize_t SendAll(const Packet& packet) {
ssize_t written{};
while (written < static_cast<ssize_t>(packet.payload_length())) {
if (!socket_->IsOpen()) {
return -1;
}
auto just_written =
socket_->Send(packet.payload() + written,
packet.payload_length() - written, MSG_NOSIGNAL);
if (just_written <= 0) {
LOG(INFO) << "Couldn't write to client: "
<< strerror(socket_->GetErrno());
return just_written;
}
written += just_written;
}
return written;
}
private:
cvd::SharedFD socket_;
};
class SocketReceiver {
public:
explicit SocketReceiver(cvd::SharedFD socket) : socket_{std::move(socket)} {}
SocketReceiver(SocketReceiver&&) = default;
SocketReceiver& operator=(SocketReceiver&&) = default;
SocketReceiver(const SocketReceiver&&) = delete;
SocketReceiver& operator=(const SocketReceiver&) = delete;
// *packet will be empty if Read returns 0 or error
void Recv(Packet* packet) {
auto size = socket_->Read(packet->payload(), sizeof packet->payload());
if (size < 0) {
size = 0;
}
packet->set_payload_length(size);
}
private:
cvd::SharedFD socket_;
};
void SocketToShm(SocketReceiver socket_receiver,
SocketForwardRegionView::ShmSender shm_sender) {
while (true) {
auto packet = Packet::MakeData();
socket_receiver.Recv(&packet);
if (packet.empty() || !shm_sender.Send(packet)) {
break;
}
}
LOG(INFO) << "Socket to shm exiting";
}
void ShmToSocket(SocketSender socket_sender,
SocketForwardRegionView::ShmReceiver shm_receiver) {
auto packet = Packet{};
while (true) {
shm_receiver.Recv(&packet);
CHECK(packet.IsData());
if (packet.empty()) {
break;
}
if (socket_sender.SendAll(packet) < 0) {
break;
}
}
LOG(INFO) << "Shm to socket exiting";
}
// One thread for reading from shm and writing into a socket.
// One thread for reading from a socket and writing into shm.
void HandleConnection(SocketForwardRegionView::ShmSenderReceiverPair shm_sender_and_receiver,
cvd::SharedFD socket) {
auto socket_to_shm =
std::thread(SocketToShm, SocketReceiver{socket}, std::move(shm_sender_and_receiver.first));
ShmToSocket(SocketSender{socket}, std::move(shm_sender_and_receiver.second));
socket_to_shm.join();
}
#ifdef CUTTLEFISH_HOST
struct PortPair {
int guest_port;
int host_port;
};
enum class QueueState {
kFree,
kUsed,
};
struct SocketConnectionInfo {
std::mutex lock{};
std::condition_variable cv{};
cvd::SharedFD socket{};
int guest_port{};
QueueState state = QueueState::kFree;
};
static constexpr auto kNumHostThreads =
vsoc::layout::socket_forward::kNumQueues;
using SocketConnectionInfoCollection =
std::array<SocketConnectionInfo, kNumHostThreads>;
void LaunchConnectionMaintainer(int port) {
std::thread(cvd::EstablishAndMaintainConnection, port).detach();
}
void MarkAsFree(SocketConnectionInfo* conn) {
std::lock_guard<std::mutex> guard{conn->lock};
conn->socket = cvd::SharedFD{};
conn->guest_port = 0;
conn->state = QueueState::kFree;
}
std::pair<int, cvd::SharedFD> WaitForConnection(SocketConnectionInfo* conn) {
std::unique_lock<std::mutex> guard{conn->lock};
while (conn->state != QueueState::kUsed) {
conn->cv.wait(guard);
}
return {conn->guest_port, conn->socket};
}
[[noreturn]] void host_thread(SocketForwardRegionView::ShmConnectionView view,
SocketConnectionInfo* conn) {
while (true) {
int guest_port{};
cvd::SharedFD socket{};
// TODO structured binding in C++17
std::tie(guest_port, socket) = WaitForConnection(conn);
LOG(INFO) << "Establishing connection to guest port " << guest_port
<< " with connection_id: " << view.connection_id();
HandleConnection(view.EstablishConnection(guest_port), std::move(socket));
LOG(INFO) << "Connection to guest port " << guest_port
<< " closed. Marking queue " << view.connection_id()
<< " as free.";
MarkAsFree(conn);
}
}
bool TryAllocateConnection(SocketConnectionInfo* conn, int guest_port,
cvd::SharedFD socket) {
bool success = false;
{
std::lock_guard<std::mutex> guard{conn->lock};
if (conn->state == QueueState::kFree) {
conn->socket = std::move(socket);
conn->guest_port = guest_port;
conn->state = QueueState::kUsed;
success = true;
}
}
if (success) {
conn->cv.notify_one();
}
return success;
}
void AllocateWorkers(cvd::SharedFD socket,
SocketConnectionInfoCollection* socket_connection_info,
int guest_port) {
while (true) {
for (auto& conn : *socket_connection_info) {
if (TryAllocateConnection(&conn, guest_port, socket)) {
return;
}
}
LOG(INFO) << "no queues available. sleeping and retrying";
sleep(5);
}
}
[[noreturn]] void host_impl(
SocketForwardRegionView* shm,
SocketConnectionInfoCollection* socket_connection_info,
std::vector<PortPair> ports, std::size_t index) {
// launch a worker for the following port before handling the current port.
// recursion (instead of a loop) removes the need fore any join() or having
// the main thread do no work.
if (index + 1 < ports.size()) {
std::thread(host_impl, shm, socket_connection_info, ports, index + 1)
.detach();
}
auto guest_port = ports[index].guest_port;
auto host_port = ports[index].host_port;
LOG(INFO) << "starting server on " << host_port << " for guest port "
<< guest_port;
auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM);
CHECK(server->IsOpen()) << "Could not start server on port " << host_port;
// Note: If generically forwarding ports, the adb connection maintainer should
// be disabled
LaunchConnectionMaintainer(host_port);
while (true) {
auto client_socket = cvd::SharedFD::Accept(*server);
CHECK(client_socket->IsOpen()) << "error creating client socket";
LOG(INFO) << "client socket accepted";
AllocateWorkers(std::move(client_socket), socket_connection_info,
guest_port);
}
}
[[noreturn]] void host(SocketForwardRegionView* shm,
std::vector<PortPair> ports) {
CHECK(!ports.empty());
SocketConnectionInfoCollection socket_connection_info{};
auto conn_info_iter = std::begin(socket_connection_info);
for (auto& shm_connection_view : shm->AllConnections()) {
CHECK_NE(conn_info_iter, std::end(socket_connection_info));
std::thread(host_thread, std::move(shm_connection_view), &*conn_info_iter)
.detach();
++conn_info_iter;
}
CHECK_EQ(conn_info_iter, std::end(socket_connection_info));
host_impl(shm, &socket_connection_info, ports, 0);
}
std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str,
const std::string& host_ports_str) {
std::vector<PortPair> ports{};
auto guest_ports = cvd::StrSplit(guest_ports_str, ',');
auto host_ports = cvd::StrSplit(host_ports_str, ',');
CHECK(guest_ports.size() == host_ports.size());
for (std::size_t i = 0; i < guest_ports.size(); ++i) {
ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])});
}
return ports;
}
#else
cvd::SharedFD OpenSocketConnection(int port) {
while (true) {
auto sock = cvd::SharedFD::SocketLocalClient(port, SOCK_STREAM);
if (sock->IsOpen()) {
return sock;
}
LOG(WARNING) << "could not connect on port " << port
<< ". sleeping for 1 second";
sleep(1);
}
}
[[noreturn]] void guest_thread(
SocketForwardRegionView::ShmConnectionView view) {
while (true) {
LOG(INFO) << "waiting for new connection";
auto shm_sender_and_receiver = view.WaitForNewConnection();
LOG(INFO) << "new connection for port " << view.port();
HandleConnection(std::move(shm_sender_and_receiver), OpenSocketConnection(view.port()));
LOG(INFO) << "connection closed on port " << view.port();
}
}
[[noreturn]] void guest(SocketForwardRegionView* shm) {
LOG(INFO) << "Starting guest mainloop";
auto connection_views = shm->AllConnections();
for (auto&& shm_connection_view : connection_views) {
std::thread(guest_thread, std::move(shm_connection_view)).detach();
}
while (true) {
sleep(std::numeric_limits<unsigned int>::max());
}
}
#endif
SocketForwardRegionView* GetShm() {
auto shm = SocketForwardRegionView::GetInstance(
#ifdef CUTTLEFISH_HOST
vsoc::GetDomain().c_str()
#endif
);
if (!shm) {
LOG(FATAL) << "Could not open SHM. Aborting.";
}
shm->CleanUpPreviousConnections();
return shm;
}
// makes sure we're running as root on the guest, no-op on the host
void assert_correct_user() {
#ifndef CUTTLEFISH_HOST
CHECK_EQ(getuid(), 0u) << "must run as root!";
#endif
}
} // namespace
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
assert_correct_user();
auto shm = GetShm();
auto worker = shm->StartWorker();
#ifdef CUTTLEFISH_HOST
CHECK(!FLAGS_guest_ports.empty()) << "Must specify --guest_ports flag";
CHECK(!FLAGS_host_ports.empty()) << "Must specify --host_ports flag";
host(shm, ParsePortsList(FLAGS_guest_ports, FLAGS_host_ports));
#else
guest(shm);
#endif
}