blob: 986bac9b022e3c51f7e2f61ae1a0cd1352094d9d [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "osp/public/message_demuxer.h"
#include <memory>
#include <utility>
#include "osp/impl/quic/quic_connection.h"
#include "platform/base/error.h"
#include "util/big_endian.h"
#include "util/osp_logging.h"
namespace openscreen {
namespace osp {
// static
// Decodes a varUint, expecting it to follow the encoding format described here:
// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint(
const std::vector<uint8_t>& buffer,
size_t* num_bytes_decoded) {
if (buffer.size() == 0) {
return Error::Code::kCborIncompleteMessage;
}
uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03);
*num_bytes_decoded = 0x1 << num_type_bytes;
// Ensure that ReadBigEndian won't read beyond the end of the buffer. Also,
// since we expect the id to be followed by the message, equality is not valid
if (buffer.size() <= *num_bytes_decoded) {
return Error::Code::kCborIncompleteMessage;
}
switch (num_type_bytes) {
case 0:
return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0;
case 1:
return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8);
case 2:
return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24);
case 3:
return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56);
default:
OSP_NOTREACHED();
}
}
// static
// Decodes the Type of message, expecting it to follow the encoding format
// described here:
// https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16
ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType(
const std::vector<uint8_t>& buffer,
size_t* num_bytes_decoded) {
ErrorOr<uint64_t> message_type =
MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded);
if (message_type.is_error()) {
return message_type.error();
}
msgs::Type parsed_type =
msgs::TypeEnumValidator::SafeCast(message_type.value());
if (parsed_type == msgs::Type::kUnknown) {
return Error::Code::kCborInvalidMessage;
}
return parsed_type;
}
// static
constexpr size_t MessageDemuxer::kDefaultBufferLimit;
MessageDemuxer::MessageWatch::MessageWatch() = default;
MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent,
bool is_default,
uint64_t endpoint_id,
msgs::Type message_type)
: parent_(parent),
is_default_(is_default),
endpoint_id_(endpoint_id),
message_type_(message_type) {}
MessageDemuxer::MessageWatch::MessageWatch(
MessageDemuxer::MessageWatch&& other) noexcept
: parent_(other.parent_),
is_default_(other.is_default_),
endpoint_id_(other.endpoint_id_),
message_type_(other.message_type_) {
other.parent_ = nullptr;
}
MessageDemuxer::MessageWatch::~MessageWatch() {
if (parent_) {
if (is_default_) {
OSP_VLOG << "dropping default handler for type: "
<< static_cast<int>(message_type_);
parent_->StopDefaultMessageTypeWatch(message_type_);
} else {
OSP_VLOG << "dropping handler for type: "
<< static_cast<int>(message_type_);
parent_->StopWatchingMessageType(endpoint_id_, message_type_);
}
}
}
MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=(
MessageWatch&& other) noexcept {
using std::swap;
swap(parent_, other.parent_);
swap(is_default_, other.is_default_);
swap(endpoint_id_, other.endpoint_id_);
swap(message_type_, other.message_type_);
return *this;
}
MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function,
size_t buffer_limit = kDefaultBufferLimit)
: now_function_(now_function), buffer_limit_(buffer_limit) {
OSP_DCHECK(now_function_);
}
MessageDemuxer::~MessageDemuxer() = default;
MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType(
uint64_t endpoint_id,
msgs::Type message_type,
MessageCallback* callback) {
auto callbacks_entry = message_callbacks_.find(endpoint_id);
if (callbacks_entry == message_callbacks_.end()) {
callbacks_entry =
message_callbacks_
.emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{})
.first;
}
auto emplace_result = callbacks_entry->second.emplace(message_type, callback);
if (!emplace_result.second)
return MessageWatch();
auto endpoint_entry = buffers_.find(endpoint_id);
if (endpoint_entry != buffers_.end()) {
for (auto& buffer : endpoint_entry->second) {
if (buffer.second.empty())
continue;
auto buffered_type = static_cast<msgs::Type>(buffer.second[0]);
if (message_type == buffered_type) {
HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry,
&buffer.second);
}
}
}
return MessageWatch(this, false, endpoint_id, message_type);
}
MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch(
msgs::Type message_type,
MessageCallback* callback) {
auto emplace_result = default_callbacks_.emplace(message_type, callback);
if (!emplace_result.second)
return MessageWatch();
for (auto& endpoint_buffers : buffers_) {
auto endpoint_id = endpoint_buffers.first;
for (auto& stream_map : endpoint_buffers.second) {
if (stream_map.second.empty())
continue;
auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]);
if (message_type == buffered_type) {
auto connection_id = stream_map.first;
auto callbacks_entry = message_callbacks_.find(endpoint_id);
HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry,
&stream_map.second);
}
}
}
return MessageWatch(this, true, 0, message_type);
}
void MessageDemuxer::OnStreamData(uint64_t endpoint_id,
uint64_t connection_id,
const uint8_t* data,
size_t data_size) {
OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id
<< "] - (" << data_size << ")";
auto& stream_map = buffers_[endpoint_id];
if (!data_size) {
stream_map.erase(connection_id);
if (stream_map.empty())
buffers_.erase(endpoint_id);
return;
}
std::vector<uint8_t>& buffer = stream_map[connection_id];
buffer.insert(buffer.end(), data, data + data_size);
auto callbacks_entry = message_callbacks_.find(endpoint_id);
HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer);
if (buffer.size() > buffer_limit_)
stream_map.erase(connection_id);
}
void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id,
msgs::Type message_type) {
auto& message_map = message_callbacks_[endpoint_id];
auto it = message_map.find(message_type);
message_map.erase(it);
}
void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) {
default_callbacks_.erase(message_type);
}
MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop(
uint64_t endpoint_id,
uint64_t connection_id,
std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator
callbacks_entry,
std::vector<uint8_t>* buffer) {
HandleStreamBufferResult result;
do {
result = {false, 0};
if (callbacks_entry != message_callbacks_.end()) {
OSP_VLOG << "attempting endpoint-specific handling";
result = HandleStreamBuffer(endpoint_id, connection_id,
&callbacks_entry->second, buffer);
}
if (!result.handled) {
if (!default_callbacks_.empty()) {
OSP_VLOG << "attempting generic message handling";
result = HandleStreamBuffer(endpoint_id, connection_id,
&default_callbacks_, buffer);
}
}
OSP_VLOG_IF(!result.handled) << "no message handler matched";
} while (result.consumed && !buffer->empty());
return result;
}
// TODO(rwkeane) Use absl::Span for the buffer
MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer(
uint64_t endpoint_id,
uint64_t connection_id,
std::map<msgs::Type, MessageCallback*>* message_callbacks,
std::vector<uint8_t>* buffer) {
size_t consumed = 0;
size_t total_consumed = 0;
bool handled = false;
do {
consumed = 0;
size_t msg_type_byte_length;
ErrorOr<msgs::Type> message_type =
MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length);
if (message_type.is_error()) {
buffer->clear();
break;
}
auto callback_entry = message_callbacks->find(message_type.value());
if (callback_entry == message_callbacks->end())
break;
handled = true;
OSP_VLOG << "handling message type "
<< static_cast<int>(message_type.value());
auto consumed_or_error = callback_entry->second->OnStreamMessage(
endpoint_id, connection_id, message_type.value(),
buffer->data() + msg_type_byte_length,
buffer->size() - msg_type_byte_length, now_function_());
if (!consumed_or_error) {
if (consumed_or_error.error().code() !=
Error::Code::kCborIncompleteMessage) {
buffer->clear();
break;
}
} else {
consumed = consumed_or_error.value();
buffer->erase(buffer->begin(),
buffer->begin() + consumed + msg_type_byte_length);
}
total_consumed += consumed;
} while (consumed && !buffer->empty());
return HandleStreamBufferResult{handled, total_consumed};
}
void StopWatching(MessageDemuxer::MessageWatch* watch) {
*watch = MessageDemuxer::MessageWatch();
}
} // namespace osp
} // namespace openscreen