blob: 40f13578a384189456934fa64fa2d2655562c16c [file] [log] [blame]
/*
* Copyright (C) 2020 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <host/libs/websocket/websocket_server.h>
#include <string>
#include <unordered_map>
#include <android-base/logging.h>
#include <libwebsockets.h>
#include <common/libs/utils/files.h>
#include <host/libs/websocket/websocket_handler.h>
namespace cuttlefish {
WebSocketServer::WebSocketServer(
const char* protocol_name,
const std::string &certs_dir,
const std::string &assets_dir,
int server_port) {
std::string cert_file = certs_dir + "/server.crt";
std::string key_file = certs_dir + "/server.key";
std::string ca_file = certs_dir + "/CA.crt";
retry_ = {
.secs_since_valid_ping = 3,
.secs_since_valid_hangup = 10,
};
struct lws_protocols protocols[] = {
{protocol_name, ServerCallback, 4096, 0, 0, nullptr, 0},
{nullptr, nullptr, 0, 0, 0, nullptr, 0}};
mount_ = {
.mount_next = nullptr,
.mountpoint = "/",
.mountpoint_len = 1,
.origin = assets_dir.c_str(),
.def = "index.html",
.protocol = nullptr,
.cgienv = nullptr,
.extra_mimetypes = nullptr,
.interpret = nullptr,
.cgi_timeout = 0,
.cache_max_age = 0,
.auth_mask = 0,
.cache_reusable = 0,
.cache_revalidate = 0,
.cache_intermediaries = 0,
.origin_protocol = LWSMPRO_FILE, // files in a dir
.basic_auth_login_file = nullptr,
};
struct lws_context_creation_info info;
headers_ = {NULL, NULL, "content-security-policy:",
"default-src 'self' https://ajax.googleapis.com; "
"style-src 'self' https://fonts.googleapis.com/; "
"font-src https://fonts.gstatic.com/; "};
memset(&info, 0, sizeof info);
info.port = server_port;
info.mounts = &mount_;
info.protocols = protocols;
info.vhost_name = "localhost";
info.ws_ping_pong_interval = 10;
info.headers = &headers_;
info.options |= LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
info.ssl_cert_filepath = cert_file.c_str();
info.ssl_private_key_filepath = key_file.c_str();
if (FileExists(ca_file)) {
info.ssl_ca_filepath = ca_file.c_str();
}
info.retry_and_idle_policy = &retry_;
context_ = lws_create_context(&info);
if (!context_) {
LOG(FATAL) << "Failed to create websocket context";
}
}
void WebSocketServer::RegisterHandlerFactory(
const std::string &path,
std::unique_ptr<WebSocketHandlerFactory> handler_factory_p) {
handler_factories_[path] = std::move(handler_factory_p);
}
void WebSocketServer::Serve() {
int n = 0;
while (n >= 0) {
n = lws_service(context_, 0);
}
lws_context_destroy(context_);
}
std::unordered_map<struct lws*, std::shared_ptr<WebSocketHandler>> WebSocketServer::handlers_ = {};
std::unordered_map<std::string, std::unique_ptr<WebSocketHandlerFactory>>
WebSocketServer::handler_factories_ = {};
std::string WebSocketServer::GetPath(struct lws* wsi) {
auto len = lws_hdr_total_length(wsi, WSI_TOKEN_GET_URI);
std::string path(len + 1, '\0');
auto ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_GET_URI);
if (ret <= 0) {
len = lws_hdr_total_length(wsi, WSI_TOKEN_HTTP_COLON_PATH);
path.resize(len + 1, '\0');
ret = lws_hdr_copy(wsi, path.data(), path.size(), WSI_TOKEN_HTTP_COLON_PATH);
}
if (ret < 0) {
LOG(FATAL) << "Something went wrong getting the path";
}
path.resize(len);
return path;
}
int WebSocketServer::ServerCallback(struct lws* wsi, enum lws_callback_reasons reason,
void* user, void* in, size_t len) {
switch (reason) {
case LWS_CALLBACK_ESTABLISHED: {
auto path = GetPath(wsi);
auto handler = InstantiateHandler(path, wsi);
if (!handler) {
// This message came on an unexpected uri, close the connection.
lws_close_reason(wsi, LWS_CLOSE_STATUS_NOSTATUS, (uint8_t*)"404", 3);
return -1;
}
handlers_[wsi] = handler;
handler->OnConnected();
break;
}
case LWS_CALLBACK_CLOSED: {
auto handler = handlers_[wsi];
if (handler) {
handler->OnClosed();
handlers_.erase(wsi);
}
break;
}
case LWS_CALLBACK_SERVER_WRITEABLE: {
auto handler = handlers_[wsi];
if (handler) {
auto should_close = handler->OnWritable();
if (should_close) {
lws_close_reason(wsi, LWS_CLOSE_STATUS_NORMAL, nullptr, 0);
return 1;
}
} else {
LOG(WARNING) << "Unknown wsi became writable";
return -1;
}
break;
}
case LWS_CALLBACK_RECEIVE: {
auto handler = handlers_[wsi];
if (handler) {
bool is_final = (lws_remaining_packet_payload(wsi) == 0) &&
lws_is_final_fragment(wsi);
handler->OnReceive(reinterpret_cast<const uint8_t*>(in), len,
lws_frame_is_binary(wsi), is_final);
} else {
LOG(WARNING) << "Unkwnown wsi sent data";
}
break;
}
default:
return lws_callback_http_dummy(wsi, reason, user, in, len);
}
return 0;
}
std::shared_ptr<WebSocketHandler> WebSocketServer::InstantiateHandler(
const std::string& uri_path, struct lws* wsi) {
auto it = handler_factories_.find(uri_path);
if (it == handler_factories_.end()) {
LOG(ERROR) << "Wrong path provided in URI: " << uri_path;
return nullptr;
} else {
LOG(INFO) << "Creating handler for " << uri_path;
return it->second->Build(wsi);
}
}
} // namespace cuttlefish