blob: e604d4e0c056090238f90d4a48539a80d696022e [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "cast/common/channel/virtual_connection_router.h"
#include <utility>
#include "cast/common/channel/cast_message_handler.h"
#include "cast/common/channel/connection_namespace_handler.h"
#include "cast/common/channel/message_util.h"
#include "cast/common/channel/proto/cast_channel.pb.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace cast {
using ::cast::channel::CastMessage;
VirtualConnectionRouter::VirtualConnectionRouter() = default;
VirtualConnectionRouter::~VirtualConnectionRouter() = default;
void VirtualConnectionRouter::AddConnection(
VirtualConnection virtual_connection,
VirtualConnection::AssociatedData associated_data) {
auto& socket_map = connections_[virtual_connection.socket_id];
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
auto it = std::find_if(
local_entries.first, local_entries.second,
[&virtual_connection](const std::pair<std::string, VCTail>& entry) {
return entry.second.peer_id == virtual_connection.peer_id;
});
if (it == socket_map.end()) {
socket_map.emplace(std::move(virtual_connection.local_id),
VCTail{std::move(virtual_connection.peer_id),
std::move(associated_data)});
}
}
bool VirtualConnectionRouter::RemoveConnection(
const VirtualConnection& virtual_connection,
VirtualConnection::CloseReason reason) {
auto socket_entry = connections_.find(virtual_connection.socket_id);
if (socket_entry == connections_.end()) {
return false;
}
auto& socket_map = socket_entry->second;
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
if (local_entries.first == socket_map.end()) {
return false;
}
for (auto it = local_entries.first; it != local_entries.second; ++it) {
if (it->second.peer_id == virtual_connection.peer_id) {
socket_map.erase(it);
if (socket_map.empty()) {
connections_.erase(socket_entry);
}
return true;
}
}
return false;
}
void VirtualConnectionRouter::RemoveConnectionsByLocalId(
const std::string& local_id) {
for (auto socket_entry = connections_.begin();
socket_entry != connections_.end();) {
auto& socket_map = socket_entry->second;
auto local_entries = socket_map.equal_range(local_id);
if (local_entries.first != socket_map.end()) {
socket_map.erase(local_entries.first, local_entries.second);
if (socket_map.empty()) {
socket_entry = connections_.erase(socket_entry);
continue;
}
}
++socket_entry;
}
}
void VirtualConnectionRouter::RemoveConnectionsBySocketId(int socket_id) {
auto entry = connections_.find(socket_id);
if (entry != connections_.end()) {
connections_.erase(entry);
}
}
absl::optional<const VirtualConnection::AssociatedData*>
VirtualConnectionRouter::GetConnectionData(
const VirtualConnection& virtual_connection) const {
auto socket_entry = connections_.find(virtual_connection.socket_id);
if (socket_entry == connections_.end()) {
return absl::nullopt;
}
auto& socket_map = socket_entry->second;
auto local_entries = socket_map.equal_range(virtual_connection.local_id);
if (local_entries.first == socket_map.end()) {
return absl::nullopt;
}
for (auto it = local_entries.first; it != local_entries.second; ++it) {
if (it->second.peer_id == virtual_connection.peer_id) {
return &it->second.data;
}
}
return absl::nullopt;
}
bool VirtualConnectionRouter::AddHandlerForLocalId(
std::string local_id,
CastMessageHandler* endpoint) {
return endpoints_.emplace(std::move(local_id), endpoint).second;
}
bool VirtualConnectionRouter::RemoveHandlerForLocalId(
const std::string& local_id) {
return endpoints_.erase(local_id) == 1u;
}
void VirtualConnectionRouter::TakeSocket(SocketErrorHandler* error_handler,
std::unique_ptr<CastSocket> socket) {
int id = socket->socket_id();
socket->SetClient(this);
sockets_.emplace(id, SocketWithHandler{std::move(socket), error_handler});
}
void VirtualConnectionRouter::CloseSocket(int id) {
auto it = sockets_.find(id);
if (it != sockets_.end()) {
RemoveConnectionsBySocketId(id);
std::unique_ptr<CastSocket> socket = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
error_handler->OnClose(socket.get());
}
}
Error VirtualConnectionRouter::Send(VirtualConnection virtual_conn,
CastMessage message) {
if (virtual_conn.peer_id == kBroadcastId) {
return BroadcastFromLocalPeer(std::move(virtual_conn.local_id),
std::move(message));
}
if (!IsTransportNamespace(message.namespace_()) &&
!GetConnectionData(virtual_conn)) {
return Error::Code::kNoActiveConnection;
}
auto it = sockets_.find(virtual_conn.socket_id);
if (it == sockets_.end()) {
return Error::Code::kItemNotFound;
}
message.set_source_id(std::move(virtual_conn.local_id));
message.set_destination_id(std::move(virtual_conn.peer_id));
return it->second.socket->Send(message);
}
Error VirtualConnectionRouter::BroadcastFromLocalPeer(
std::string local_id,
::cast::channel::CastMessage message) {
message.set_source_id(std::move(local_id));
message.set_destination_id(kBroadcastId);
// Broadcast to local endpoints.
for (const auto& entry : endpoints_) {
if (entry.first != message.source_id()) {
entry.second->OnMessage(this, nullptr, message);
}
}
// Broadcast to remote endpoints. If an Error occurs, continue broadcasting,
// and later return the first Error that occurred.
Error error;
for (const auto& entry : sockets_) {
auto result = entry.second.socket->Send(message);
if (!result.ok() && error.ok()) {
error = std::move(result);
}
}
return error;
}
void VirtualConnectionRouter::OnError(CastSocket* socket, Error error) {
const int id = socket->socket_id();
auto it = sockets_.find(id);
if (it != sockets_.end()) {
RemoveConnectionsBySocketId(id);
std::unique_ptr<CastSocket> socket_owned = std::move(it->second.socket);
SocketErrorHandler* error_handler = it->second.error_handler;
sockets_.erase(it);
error_handler->OnError(socket, error);
}
}
void VirtualConnectionRouter::OnMessage(CastSocket* socket,
CastMessage message) {
OSP_DCHECK(socket);
const std::string& local_id = message.destination_id();
if (local_id == kBroadcastId) {
for (const auto& entry : endpoints_) {
entry.second->OnMessage(this, socket, message);
}
} else {
// Connection namespace messages are weird: The message.source_id() and
// message.destination_id() are NOT treated as "envelope routing
// information," like for all other namespaces. Instead, they are considered
// part of the payload data for CONNECT/CLOSE requests. Thus, they require
// special-case handling here.
if (message.namespace_() == kConnectionNamespace) {
if (connection_handler_) {
connection_handler_->OnMessage(this, socket, std::move(message));
}
return;
}
// Drop all messages for virtual connections that do not yet exist.
// Exception: All transport namespace messages (e.g., device auth,
// heartbeats, etc.); because these are always assumed to have a route.
if (!IsTransportNamespace(message.namespace_()) &&
!GetConnectionData(VirtualConnection{local_id, message.source_id(),
socket->socket_id()})) {
return;
}
auto it = endpoints_.find(local_id);
if (it != endpoints_.end()) {
it->second->OnMessage(this, socket, std::move(message));
}
}
}
} // namespace cast
} // namespace openscreen