| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "cast/common/channel/connection_namespace_handler.h" |
| |
| #include <string> |
| #include <type_traits> |
| #include <utility> |
| |
| #include "absl/types/optional.h" |
| #include "cast/common/channel/message_util.h" |
| #include "cast/common/channel/proto/cast_channel.pb.h" |
| #include "cast/common/channel/virtual_connection.h" |
| #include "cast/common/channel/virtual_connection_manager.h" |
| #include "cast/common/channel/virtual_connection_router.h" |
| #include "cast/common/public/cast_socket.h" |
| #include "util/json/json_serialization.h" |
| #include "util/json/json_value.h" |
| #include "util/osp_logging.h" |
| |
| namespace openscreen { |
| namespace cast { |
| |
| using ::cast::channel::CastMessage; |
| using ::cast::channel::CastMessage_PayloadType; |
| |
| namespace { |
| |
| bool IsValidProtocolVersion(int version) { |
| return ::cast::channel::CastMessage_ProtocolVersion_IsValid(version); |
| } |
| |
| absl::optional<int> FindMaxProtocolVersion(const Json::Value* version, |
| const Json::Value* version_list) { |
| using ArrayIndex = Json::Value::ArrayIndex; |
| static_assert(std::is_integral<ArrayIndex>::value, |
| "Assuming ArrayIndex is integral"); |
| absl::optional<int> max_version; |
| if (version_list && version_list->isArray()) { |
| max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0; |
| for (auto it = version_list->begin(), end = version_list->end(); it != end; |
| ++it) { |
| if (it->isInt()) { |
| int version_int = it->asInt(); |
| if (IsValidProtocolVersion(version_int) && version_int > *max_version) { |
| max_version = version_int; |
| } |
| } |
| } |
| } |
| if (version && version->isInt()) { |
| int version_int = version->asInt(); |
| if (IsValidProtocolVersion(version_int)) { |
| if (!max_version) { |
| max_version = ::cast::channel::CastMessage_ProtocolVersion_CASTV2_1_0; |
| } |
| if (version_int > max_version) { |
| max_version = version_int; |
| } |
| } |
| } |
| return max_version; |
| } |
| |
| VirtualConnection::CloseReason GetCloseReason( |
| const Json::Value& parsed_message) { |
| VirtualConnection::CloseReason reason = |
| VirtualConnection::CloseReason::kClosedByPeer; |
| absl::optional<int> reason_code = MaybeGetInt( |
| parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyReasonCode)); |
| if (reason_code) { |
| int code = reason_code.value(); |
| if (code >= VirtualConnection::CloseReason::kFirstReason && |
| code <= VirtualConnection::CloseReason::kLastReason) { |
| reason = static_cast<VirtualConnection::CloseReason>(code); |
| } |
| } |
| return reason; |
| } |
| |
| } // namespace |
| |
| ConnectionNamespaceHandler::ConnectionNamespaceHandler( |
| VirtualConnectionManager* vc_manager, |
| VirtualConnectionPolicy* vc_policy) |
| : vc_manager_(vc_manager), vc_policy_(vc_policy) { |
| OSP_DCHECK(vc_manager); |
| OSP_DCHECK(vc_policy); |
| } |
| |
| ConnectionNamespaceHandler::~ConnectionNamespaceHandler() = default; |
| |
| void ConnectionNamespaceHandler::OnMessage(VirtualConnectionRouter* router, |
| CastSocket* socket, |
| CastMessage message) { |
| if (message.payload_type() != |
| CastMessage_PayloadType::CastMessage_PayloadType_STRING) { |
| return; |
| } |
| ErrorOr<Json::Value> result = json::Parse(message.payload_utf8()); |
| if (result.is_error()) { |
| return; |
| } |
| |
| Json::Value& value = result.value(); |
| if (!value.isObject()) { |
| return; |
| } |
| |
| absl::optional<absl::string_view> type = |
| MaybeGetString(value, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyType)); |
| if (!type) { |
| // TODO(btolsch): Some of these paths should have error reporting. One |
| // possibility is to pass errors back through |router| so higher-level code |
| // can decide whether to show an error to the user, stop talking to a |
| // particular device, etc. |
| return; |
| } |
| |
| absl::string_view type_str = type.value(); |
| if (type_str == kMessageTypeConnect) { |
| HandleConnect(router, socket, std::move(message), std::move(value)); |
| } else if (type_str == kMessageTypeClose) { |
| HandleClose(router, socket, std::move(message), std::move(value)); |
| } else { |
| // NOTE: Unknown message type so ignore it. |
| // TODO(btolsch): Should be included in future error reporting. |
| } |
| } |
| |
| void ConnectionNamespaceHandler::HandleConnect(VirtualConnectionRouter* router, |
| CastSocket* socket, |
| CastMessage message, |
| Json::Value parsed_message) { |
| if (message.destination_id() == kBroadcastId || |
| message.source_id() == kBroadcastId) { |
| return; |
| } |
| |
| VirtualConnection virtual_conn{std::move(message.destination_id()), |
| std::move(message.source_id()), |
| ToCastSocketId(socket)}; |
| if (!vc_policy_->IsConnectionAllowed(virtual_conn)) { |
| SendClose(router, std::move(virtual_conn)); |
| return; |
| } |
| |
| absl::optional<int> maybe_conn_type = MaybeGetInt( |
| parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyConnType)); |
| VirtualConnection::Type conn_type = VirtualConnection::Type::kStrong; |
| if (maybe_conn_type) { |
| int int_type = maybe_conn_type.value(); |
| if (int_type < static_cast<int>(VirtualConnection::Type::kMinValue) || |
| int_type > static_cast<int>(VirtualConnection::Type::kMaxValue)) { |
| SendClose(router, std::move(virtual_conn)); |
| return; |
| } |
| conn_type = static_cast<VirtualConnection::Type>(int_type); |
| } |
| |
| VirtualConnection::AssociatedData data; |
| |
| data.type = conn_type; |
| |
| absl::optional<absl::string_view> user_agent = MaybeGetString( |
| parsed_message, JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyUserAgent)); |
| if (user_agent) { |
| data.user_agent = std::string(user_agent.value()); |
| } |
| |
| const Json::Value* sender_info_value = parsed_message.find( |
| JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeySenderInfo)); |
| if (!sender_info_value || !sender_info_value->isObject()) { |
| // TODO(btolsch): Should this be guessed from user agent? |
| OSP_DVLOG << "No sender info from protocol."; |
| } |
| |
| const Json::Value* version_value = parsed_message.find( |
| JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersion)); |
| const Json::Value* version_list_value = parsed_message.find( |
| JSON_EXPAND_FIND_CONSTANT_ARGS(kMessageKeyProtocolVersionList)); |
| absl::optional<int> negotiated_version = |
| FindMaxProtocolVersion(version_value, version_list_value); |
| if (negotiated_version) { |
| data.max_protocol_version = static_cast<VirtualConnection::ProtocolVersion>( |
| negotiated_version.value()); |
| } else { |
| data.max_protocol_version = VirtualConnection::ProtocolVersion::kV2_1_0; |
| } |
| |
| if (socket) { |
| data.ip_fragment = socket->GetSanitizedIpAddress(); |
| } else { |
| data.ip_fragment = {}; |
| } |
| |
| OSP_DVLOG << "Connection opened: " << virtual_conn.local_id << ", " |
| << virtual_conn.peer_id << ", " << virtual_conn.socket_id; |
| |
| // NOTE: Only send a response for senders that actually sent a version. This |
| // maintains compatibility with older senders that don't send a version and |
| // don't expect a response. |
| if (negotiated_version) { |
| SendConnectedResponse(router, virtual_conn, negotiated_version.value()); |
| } |
| |
| vc_manager_->AddConnection(std::move(virtual_conn), std::move(data)); |
| } |
| |
| void ConnectionNamespaceHandler::HandleClose(VirtualConnectionRouter* router, |
| CastSocket* socket, |
| CastMessage message, |
| Json::Value parsed_message) { |
| VirtualConnection virtual_conn{std::move(message.destination_id()), |
| std::move(message.source_id()), |
| ToCastSocketId(socket)}; |
| if (!vc_manager_->GetConnectionData(virtual_conn)) { |
| return; |
| } |
| |
| VirtualConnection::CloseReason reason = GetCloseReason(parsed_message); |
| |
| OSP_DVLOG << "Connection closed (reason: " << reason |
| << "): " << virtual_conn.local_id << ", " << virtual_conn.peer_id |
| << ", " << virtual_conn.socket_id; |
| vc_manager_->RemoveConnection(virtual_conn, reason); |
| } |
| |
| void ConnectionNamespaceHandler::SendClose(VirtualConnectionRouter* router, |
| VirtualConnection virtual_conn) { |
| Json::Value close_message(Json::ValueType::objectValue); |
| close_message[kMessageKeyType] = kMessageTypeClose; |
| |
| ErrorOr<std::string> result = json::Stringify(close_message); |
| if (result.is_error()) { |
| return; |
| } |
| |
| router->Send( |
| std::move(virtual_conn), |
| MakeSimpleUTF8Message(kConnectionNamespace, std::move(result.value()))); |
| } |
| |
| void ConnectionNamespaceHandler::SendConnectedResponse( |
| VirtualConnectionRouter* router, |
| const VirtualConnection& virtual_conn, |
| int max_protocol_version) { |
| Json::Value connected_message(Json::ValueType::objectValue); |
| connected_message[kMessageKeyType] = kMessageTypeConnected; |
| connected_message[kMessageKeyProtocolVersion] = |
| static_cast<int>(max_protocol_version); |
| |
| ErrorOr<std::string> result = json::Stringify(connected_message); |
| if (result.is_error()) { |
| return; |
| } |
| |
| router->Send(virtual_conn, MakeSimpleUTF8Message(kConnectionNamespace, |
| std::move(result.value()))); |
| } |
| |
| } // namespace cast |
| } // namespace openscreen |