| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ |
| #define OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ |
| |
| #include <cstddef> |
| #include <cstdint> |
| #include <type_traits> |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/types/optional.h" |
| #include "osp/public/message_demuxer.h" |
| #include "osp/public/network_service_manager.h" |
| #include "osp/public/protocol_connection.h" |
| #include "platform/base/error.h" |
| #include "platform/base/macros.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace osp { |
| |
| template <typename T> |
| using MessageDecodingFunction = ssize_t (*)(const uint8_t*, size_t, T*); |
| |
| // Provides a uniform way of accessing import properties of a request/response |
| // message pair from a template: request encode function, response decode |
| // function, request serializable data member. |
| template <typename T> |
| struct DefaultRequestCoderTraits { |
| public: |
| using RequestMsgType = typename T::RequestMsgType; |
| static constexpr MessageEncodingFunction<RequestMsgType> kEncoder = |
| T::kEncoder; |
| static constexpr MessageDecodingFunction<typename T::ResponseMsgType> |
| kDecoder = T::kDecoder; |
| |
| static const RequestMsgType* serial_request(const T& data) { |
| return &data.request; |
| } |
| static RequestMsgType* serial_request(T& data) { return &data.request; } |
| }; |
| |
| // Provides a wrapper for the common pattern of sending a request message and |
| // waiting for a response message with a matching |request_id| field. It also |
| // handles the business of queueing messages to be sent until a protocol |
| // connection is available. |
| // |
| // Messages are written using WriteMessage. This will queue messages if there |
| // is no protocol connection or write them immediately if there is. When a |
| // matching response is received via the MessageDemuxer (taken from the global |
| // ProtocolConnectionClient), OnMatchedResponse is called on the provided |
| // Delegate object along with the original request that it matches. |
| template <typename RequestT, |
| typename RequestCoderTraits = DefaultRequestCoderTraits<RequestT>> |
| class RequestResponseHandler : public MessageDemuxer::MessageCallback { |
| public: |
| class Delegate { |
| public: |
| virtual ~Delegate() = default; |
| |
| virtual void OnMatchedResponse(RequestT* request, |
| typename RequestT::ResponseMsgType* response, |
| uint64_t endpoint_id) = 0; |
| virtual void OnError(RequestT* request, Error error) = 0; |
| }; |
| |
| explicit RequestResponseHandler(Delegate* delegate) : delegate_(delegate) {} |
| ~RequestResponseHandler() { Reset(); } |
| |
| void Reset() { |
| connection_ = nullptr; |
| for (auto& message : to_send_) { |
| delegate_->OnError(&message.request, Error::Code::kRequestCancelled); |
| } |
| to_send_.clear(); |
| for (auto& message : sent_) { |
| delegate_->OnError(&message.request, Error::Code::kRequestCancelled); |
| } |
| sent_.clear(); |
| response_watch_ = MessageDemuxer::MessageWatch(); |
| } |
| |
| // Write a message to the underlying protocol connection, or queue it until |
| // one is provided via SetConnection. If |id| is provided, it can be used to |
| // cancel the message via CancelMessage. |
| template <typename RequestTRval> |
| typename std::enable_if< |
| !std::is_lvalue_reference<RequestTRval>::value && |
| std::is_same<typename std::decay<RequestTRval>::type, |
| RequestT>::value, |
| Error>::type |
| WriteMessage(absl::optional<uint64_t> id, RequestTRval&& message) { |
| auto* request_msg = RequestCoderTraits::serial_request(message); |
| if (connection_) { |
| request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); |
| Error result = |
| connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); |
| if (!result.ok()) { |
| return result; |
| } |
| sent_.emplace_back(RequestWithId{id, std::move(message)}); |
| EnsureResponseWatch(); |
| } else { |
| to_send_.emplace_back(RequestWithId{id, std::move(message)}); |
| } |
| return Error::None(); |
| } |
| |
| template <typename RequestTRval> |
| typename std::enable_if< |
| !std::is_lvalue_reference<RequestTRval>::value && |
| std::is_same<typename std::decay<RequestTRval>::type, |
| RequestT>::value, |
| Error>::type |
| WriteMessage(RequestTRval&& message) { |
| return WriteMessage(absl::nullopt, std::move(message)); |
| } |
| |
| // Remove the message that was originally written with |id| from the send and |
| // sent queues so that we are no longer looking for a response. |
| void CancelMessage(uint64_t id) { |
| to_send_.erase(std::remove_if(to_send_.begin(), to_send_.end(), |
| [&id](const RequestWithId& msg) { |
| return id == msg.id; |
| }), |
| to_send_.end()); |
| sent_.erase(std::remove_if( |
| sent_.begin(), sent_.end(), |
| [&id](const RequestWithId& msg) { return id == msg.id; }), |
| sent_.end()); |
| if (sent_.empty()) { |
| response_watch_ = MessageDemuxer::MessageWatch(); |
| } |
| } |
| |
| // Assign a ProtocolConnection to this handler for writing messages. |
| void SetConnection(ProtocolConnection* connection) { |
| connection_ = connection; |
| for (auto& message : to_send_) { |
| auto* request_msg = RequestCoderTraits::serial_request(message.request); |
| request_msg->request_id = GetNextRequestId(connection_->endpoint_id()); |
| Error result = |
| connection_->WriteMessage(*request_msg, RequestCoderTraits::kEncoder); |
| if (result.ok()) { |
| sent_.emplace_back(std::move(message)); |
| } else { |
| delegate_->OnError(&message.request, result); |
| } |
| } |
| if (!to_send_.empty()) { |
| EnsureResponseWatch(); |
| } |
| to_send_.clear(); |
| } |
| |
| // MessageDemuxer::MessageCallback overrides. |
| ErrorOr<size_t> OnStreamMessage(uint64_t endpoint_id, |
| uint64_t connection_id, |
| msgs::Type message_type, |
| const uint8_t* buffer, |
| size_t buffer_size, |
| Clock::time_point now) override { |
| if (message_type != RequestT::kResponseType) { |
| return 0; |
| } |
| typename RequestT::ResponseMsgType response; |
| ssize_t result = |
| RequestCoderTraits::kDecoder(buffer, buffer_size, &response); |
| if (result < 0) { |
| return 0; |
| } |
| auto it = std::find_if( |
| sent_.begin(), sent_.end(), [&response](const RequestWithId& msg) { |
| return RequestCoderTraits::serial_request(msg.request)->request_id == |
| response.request_id; |
| }); |
| if (it != sent_.end()) { |
| delegate_->OnMatchedResponse(&it->request, &response, |
| connection_->endpoint_id()); |
| sent_.erase(it); |
| if (sent_.empty()) { |
| response_watch_ = MessageDemuxer::MessageWatch(); |
| } |
| } else { |
| OSP_LOG_WARN << "got response for unknown request id: " |
| << response.request_id; |
| } |
| return result; |
| } |
| |
| private: |
| struct RequestWithId { |
| absl::optional<uint64_t> id; |
| RequestT request; |
| }; |
| |
| void EnsureResponseWatch() { |
| if (!response_watch_) { |
| response_watch_ = NetworkServiceManager::Get() |
| ->GetProtocolConnectionClient() |
| ->message_demuxer() |
| ->WatchMessageType(connection_->endpoint_id(), |
| RequestT::kResponseType, this); |
| } |
| } |
| |
| uint64_t GetNextRequestId(uint64_t endpoint_id) { |
| return NetworkServiceManager::Get() |
| ->GetProtocolConnectionClient() |
| ->endpoint_request_ids() |
| ->GetNextRequestId(endpoint_id); |
| } |
| |
| ProtocolConnection* connection_ = nullptr; |
| Delegate* const delegate_; |
| std::vector<RequestWithId> to_send_; |
| std::vector<RequestWithId> sent_; |
| MessageDemuxer::MessageWatch response_watch_; |
| |
| OSP_DISALLOW_COPY_AND_ASSIGN(RequestResponseHandler); |
| }; |
| |
| } // namespace osp |
| } // namespace openscreen |
| |
| #endif // OSP_PUBLIC_REQUEST_RESPONSE_HANDLER_H_ |