| // Copyright 2018 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "osp/public/message_demuxer.h" |
| |
| #include <memory> |
| #include <utility> |
| |
| #include "osp/impl/quic/quic_connection.h" |
| #include "platform/base/error.h" |
| #include "util/big_endian.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace osp { |
| |
| // static |
| // Decodes a varUint, expecting it to follow the encoding format described here: |
| // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 |
| ErrorOr<uint64_t> MessageTypeDecoder::DecodeVarUint( |
| const std::vector<uint8_t>& buffer, |
| size_t* num_bytes_decoded) { |
| if (buffer.size() == 0) { |
| return Error::Code::kCborIncompleteMessage; |
| } |
| |
| uint8_t num_type_bytes = static_cast<uint8_t>(buffer[0] >> 6 & 0x03); |
| *num_bytes_decoded = 0x1 << num_type_bytes; |
| |
| // Ensure that ReadBigEndian won't read beyond the end of the buffer. Also, |
| // since we expect the id to be followed by the message, equality is not valid |
| if (buffer.size() <= *num_bytes_decoded) { |
| return Error::Code::kCborIncompleteMessage; |
| } |
| |
| switch (num_type_bytes) { |
| case 0: |
| return ReadBigEndian<uint8_t>(&buffer[0]) & ~0xC0; |
| case 1: |
| return ReadBigEndian<uint16_t>(&buffer[0]) & ~(0xC0 << 8); |
| case 2: |
| return ReadBigEndian<uint32_t>(&buffer[0]) & ~(0xC0 << 24); |
| case 3: |
| return ReadBigEndian<uint64_t>(&buffer[0]) & ~(uint64_t{0xC0} << 56); |
| default: |
| OSP_NOTREACHED(); |
| } |
| } |
| |
| // static |
| // Decodes the Type of message, expecting it to follow the encoding format |
| // described here: |
| // https://tools.ietf.org/html/draft-ietf-quic-transport-16#section-16 |
| ErrorOr<msgs::Type> MessageTypeDecoder::DecodeType( |
| const std::vector<uint8_t>& buffer, |
| size_t* num_bytes_decoded) { |
| ErrorOr<uint64_t> message_type = |
| MessageTypeDecoder::DecodeVarUint(buffer, num_bytes_decoded); |
| if (message_type.is_error()) { |
| return message_type.error(); |
| } |
| |
| msgs::Type parsed_type = |
| msgs::TypeEnumValidator::SafeCast(message_type.value()); |
| if (parsed_type == msgs::Type::kUnknown) { |
| return Error::Code::kCborInvalidMessage; |
| } |
| |
| return parsed_type; |
| } |
| |
| // static |
| constexpr size_t MessageDemuxer::kDefaultBufferLimit; |
| |
| MessageDemuxer::MessageWatch::MessageWatch() = default; |
| |
| MessageDemuxer::MessageWatch::MessageWatch(MessageDemuxer* parent, |
| bool is_default, |
| uint64_t endpoint_id, |
| msgs::Type message_type) |
| : parent_(parent), |
| is_default_(is_default), |
| endpoint_id_(endpoint_id), |
| message_type_(message_type) {} |
| |
| MessageDemuxer::MessageWatch::MessageWatch( |
| MessageDemuxer::MessageWatch&& other) noexcept |
| : parent_(other.parent_), |
| is_default_(other.is_default_), |
| endpoint_id_(other.endpoint_id_), |
| message_type_(other.message_type_) { |
| other.parent_ = nullptr; |
| } |
| |
| MessageDemuxer::MessageWatch::~MessageWatch() { |
| if (parent_) { |
| if (is_default_) { |
| OSP_VLOG << "dropping default handler for type: " |
| << static_cast<int>(message_type_); |
| parent_->StopDefaultMessageTypeWatch(message_type_); |
| } else { |
| OSP_VLOG << "dropping handler for type: " |
| << static_cast<int>(message_type_); |
| parent_->StopWatchingMessageType(endpoint_id_, message_type_); |
| } |
| } |
| } |
| |
| MessageDemuxer::MessageWatch& MessageDemuxer::MessageWatch::operator=( |
| MessageWatch&& other) noexcept { |
| using std::swap; |
| swap(parent_, other.parent_); |
| swap(is_default_, other.is_default_); |
| swap(endpoint_id_, other.endpoint_id_); |
| swap(message_type_, other.message_type_); |
| return *this; |
| } |
| |
| MessageDemuxer::MessageDemuxer(ClockNowFunctionPtr now_function, |
| size_t buffer_limit = kDefaultBufferLimit) |
| : now_function_(now_function), buffer_limit_(buffer_limit) { |
| OSP_DCHECK(now_function_); |
| } |
| |
| MessageDemuxer::~MessageDemuxer() = default; |
| |
| MessageDemuxer::MessageWatch MessageDemuxer::WatchMessageType( |
| uint64_t endpoint_id, |
| msgs::Type message_type, |
| MessageCallback* callback) { |
| auto callbacks_entry = message_callbacks_.find(endpoint_id); |
| if (callbacks_entry == message_callbacks_.end()) { |
| callbacks_entry = |
| message_callbacks_ |
| .emplace(endpoint_id, std::map<msgs::Type, MessageCallback*>{}) |
| .first; |
| } |
| auto emplace_result = callbacks_entry->second.emplace(message_type, callback); |
| if (!emplace_result.second) |
| return MessageWatch(); |
| auto endpoint_entry = buffers_.find(endpoint_id); |
| if (endpoint_entry != buffers_.end()) { |
| for (auto& buffer : endpoint_entry->second) { |
| if (buffer.second.empty()) |
| continue; |
| auto buffered_type = static_cast<msgs::Type>(buffer.second[0]); |
| if (message_type == buffered_type) { |
| HandleStreamBufferLoop(endpoint_id, buffer.first, callbacks_entry, |
| &buffer.second); |
| } |
| } |
| } |
| return MessageWatch(this, false, endpoint_id, message_type); |
| } |
| |
| MessageDemuxer::MessageWatch MessageDemuxer::SetDefaultMessageTypeWatch( |
| msgs::Type message_type, |
| MessageCallback* callback) { |
| auto emplace_result = default_callbacks_.emplace(message_type, callback); |
| if (!emplace_result.second) |
| return MessageWatch(); |
| for (auto& endpoint_buffers : buffers_) { |
| auto endpoint_id = endpoint_buffers.first; |
| for (auto& stream_map : endpoint_buffers.second) { |
| if (stream_map.second.empty()) |
| continue; |
| auto buffered_type = static_cast<msgs::Type>(stream_map.second[0]); |
| if (message_type == buffered_type) { |
| auto connection_id = stream_map.first; |
| auto callbacks_entry = message_callbacks_.find(endpoint_id); |
| HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, |
| &stream_map.second); |
| } |
| } |
| } |
| return MessageWatch(this, true, 0, message_type); |
| } |
| |
| void MessageDemuxer::OnStreamData(uint64_t endpoint_id, |
| uint64_t connection_id, |
| const uint8_t* data, |
| size_t data_size) { |
| OSP_VLOG << __func__ << ": [" << endpoint_id << ", " << connection_id |
| << "] - (" << data_size << ")"; |
| auto& stream_map = buffers_[endpoint_id]; |
| if (!data_size) { |
| stream_map.erase(connection_id); |
| if (stream_map.empty()) |
| buffers_.erase(endpoint_id); |
| return; |
| } |
| std::vector<uint8_t>& buffer = stream_map[connection_id]; |
| buffer.insert(buffer.end(), data, data + data_size); |
| |
| auto callbacks_entry = message_callbacks_.find(endpoint_id); |
| HandleStreamBufferLoop(endpoint_id, connection_id, callbacks_entry, &buffer); |
| |
| if (buffer.size() > buffer_limit_) |
| stream_map.erase(connection_id); |
| } |
| |
| void MessageDemuxer::StopWatchingMessageType(uint64_t endpoint_id, |
| msgs::Type message_type) { |
| auto& message_map = message_callbacks_[endpoint_id]; |
| auto it = message_map.find(message_type); |
| message_map.erase(it); |
| } |
| |
| void MessageDemuxer::StopDefaultMessageTypeWatch(msgs::Type message_type) { |
| default_callbacks_.erase(message_type); |
| } |
| |
| MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBufferLoop( |
| uint64_t endpoint_id, |
| uint64_t connection_id, |
| std::map<uint64_t, std::map<msgs::Type, MessageCallback*>>::iterator |
| callbacks_entry, |
| std::vector<uint8_t>* buffer) { |
| HandleStreamBufferResult result; |
| do { |
| result = {false, 0}; |
| if (callbacks_entry != message_callbacks_.end()) { |
| OSP_VLOG << "attempting endpoint-specific handling"; |
| result = HandleStreamBuffer(endpoint_id, connection_id, |
| &callbacks_entry->second, buffer); |
| } |
| if (!result.handled) { |
| if (!default_callbacks_.empty()) { |
| OSP_VLOG << "attempting generic message handling"; |
| result = HandleStreamBuffer(endpoint_id, connection_id, |
| &default_callbacks_, buffer); |
| } |
| } |
| OSP_VLOG_IF(!result.handled) << "no message handler matched"; |
| } while (result.consumed && !buffer->empty()); |
| return result; |
| } |
| |
| // TODO(rwkeane) Use absl::Span for the buffer |
| MessageDemuxer::HandleStreamBufferResult MessageDemuxer::HandleStreamBuffer( |
| uint64_t endpoint_id, |
| uint64_t connection_id, |
| std::map<msgs::Type, MessageCallback*>* message_callbacks, |
| std::vector<uint8_t>* buffer) { |
| size_t consumed = 0; |
| size_t total_consumed = 0; |
| bool handled = false; |
| do { |
| consumed = 0; |
| size_t msg_type_byte_length; |
| ErrorOr<msgs::Type> message_type = |
| MessageTypeDecoder::DecodeType(*buffer, &msg_type_byte_length); |
| if (message_type.is_error()) { |
| buffer->clear(); |
| break; |
| } |
| auto callback_entry = message_callbacks->find(message_type.value()); |
| if (callback_entry == message_callbacks->end()) |
| break; |
| handled = true; |
| OSP_VLOG << "handling message type " |
| << static_cast<int>(message_type.value()); |
| auto consumed_or_error = callback_entry->second->OnStreamMessage( |
| endpoint_id, connection_id, message_type.value(), |
| buffer->data() + msg_type_byte_length, |
| buffer->size() - msg_type_byte_length, now_function_()); |
| if (!consumed_or_error) { |
| if (consumed_or_error.error().code() != |
| Error::Code::kCborIncompleteMessage) { |
| buffer->clear(); |
| break; |
| } |
| } else { |
| consumed = consumed_or_error.value(); |
| buffer->erase(buffer->begin(), |
| buffer->begin() + consumed + msg_type_byte_length); |
| } |
| total_consumed += consumed; |
| } while (consumed && !buffer->empty()); |
| return HandleStreamBufferResult{handled, total_consumed}; |
| } |
| |
| void StopWatching(MessageDemuxer::MessageWatch* watch) { |
| *watch = MessageDemuxer::MessageWatch(); |
| } |
| |
| } // namespace osp |
| } // namespace openscreen |