blob: cec1b053f3a029ef20a506bc61a72c15acb60c75 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstdint>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <glog/logging.h>
#include <gflags/gflags.h>
#include <unistd.h>
#include "common/vsoc/lib/socket_forward_region_view.h"
#include "common/libs/tcp_socket/tcp_socket.h"
#ifdef CUTTLEFISH_HOST
#include "host/libs/config/host_config.h"
#endif
using vsoc::socket_forward::SocketForwardRegionView;
#ifdef CUTTLEFISH_HOST
DEFINE_uint32(port, 0, "Port from which to forward TCP connections.");
#endif
namespace {
class Worker {
public:
Worker(SocketForwardRegionView::Connection shm_connection,
cvd::ClientSocket socket)
: shm_connection_(std::move(shm_connection)),
socket_(std::move(socket)){}
[[nodiscard]] bool closed() {
{
std::lock_guard<std::mutex> guard(closed_lock_);
if (closed_) {
return true;
}
}
if (shm_connection_.closed() || socket_.closed()) {
std::lock_guard<std::mutex> guard(closed_lock_);
closed_ = true;
}
return closed_;
}
void close() {
std::lock_guard<std::mutex> guard(closed_lock_);
closed_ = true;
}
static void SocketToShm(std::shared_ptr<Worker> worker) {
worker->SocketToShmImpl();
}
static void ShmToSocket(std::shared_ptr<Worker> worker) {
worker->ShmToSocketImpl();
}
private:
void SocketToShmImpl() {
constexpr int kRecvSize = 8192;
auto sender = shm_connection_.MakeSender();
while (true) {
if (closed()) {
break;
}
auto msg = socket_.RecvAny(kRecvSize);
if (msg.empty()) {
break;
}
sender.Send(std::move(msg));
}
LOG(INFO) << "Socket to shm exiting";
close();
}
void ShmToSocketImpl() {
auto receiver = shm_connection_.MakeReceiver();
while (true) {
if (closed()) {
break;
}
auto msg = receiver.Recv();
if (msg.empty() || socket_.closed()) {
break;
}
if (socket_.Send(msg) < 0) {
break;
}
}
LOG(INFO) << "Shm to socket exiting";
close();
}
SocketForwardRegionView::Connection shm_connection_;
cvd::ClientSocket socket_;
bool closed_{};
std::mutex closed_lock_;
};
// One thread for reading from shm and writing into a socket.
// One thread for reading from a socket and writing into shm.
void LaunchWorkers(SocketForwardRegionView::Connection conn,
cvd::ClientSocket socket) {
auto worker = std::make_shared<Worker>(std::move(conn), std::move(socket));
std::thread threads[] = {std::thread(Worker::SocketToShm, worker),
std::thread(Worker::ShmToSocket, worker)};
for (auto&& t : threads) {
t.detach();
}
}
#ifdef CUTTLEFISH_HOST
[[noreturn]] void host(SocketForwardRegionView* shm, int port) {
LOG(INFO) << "starting server on " << port;
cvd::ServerSocket server(port);
while (true) {
auto client_socket = server.Accept();
LOG(INFO) << "client socket accepted";
auto conn = shm->OpenConnection(port);
LOG(INFO) << "shm connection opened";
LaunchWorkers(std::move(conn), std::move(client_socket));
}
}
#else
[[noreturn]] void guest(SocketForwardRegionView* shm) {
LOG(INFO) << "Starting guest mainloop";
while (true) {
auto conn = shm->AcceptConnection();
LOG(INFO) << "shm connection accepted";
auto sock = cvd::ClientSocket(conn.port());
LOG(INFO) << "socket opened to " << conn.port();
LaunchWorkers(std::move(conn), std::move(sock));
}
}
#endif
SocketForwardRegionView* GetShm() {
auto shm = SocketForwardRegionView::GetInstance(
#ifdef CUTTLEFISH_HOST
vsoc::GetDomain().c_str()
#endif
);
if (!shm) {
LOG(FATAL) << "Could not open SHM. Aborting.";
}
return shm;
}
// makes sure we're running as root on the guest, no-op on the host
void assert_correct_user() {
#ifndef CUTTLEFISH_HOST
CHECK_EQ(getuid(), 0u) << "must run as root!";
#endif
}
} // namespace
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
assert_correct_user();
auto shm = GetShm();
auto worker = shm->StartWorker();
#ifdef CUTTLEFISH_HOST
CHECK_NE(FLAGS_port, 0u) << "Must specify --port flag";
host(shm, FLAGS_port);
#else
guest(shm);
#endif
}