Change UDP socket ReceiveMessage to perform peek first and then create a packet of the correct size.
Change-Id: Ia16538141d821744afcf616f615b43768d6af645
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1702868
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
diff --git a/osp/impl/discovery/mdns/embedder_demo.cc b/osp/impl/discovery/mdns/embedder_demo.cc
index 3cab141..2603dd5 100644
--- a/osp/impl/discovery/mdns/embedder_demo.cc
+++ b/osp/impl/discovery/mdns/embedder_demo.cc
@@ -312,8 +312,9 @@
mdns_adapter->RunTasks();
auto data = platform::OnePlatformLoopIteration(waiter);
for (auto& packet : data) {
- mdns_adapter->OnDataReceived(packet.source, packet.original_destination,
- packet.data(), packet.length, packet.socket);
+ mdns_adapter->OnDataReceived(packet.source(), packet.destination(),
+ packet.data(), packet.size(),
+ packet.socket());
}
}
OSP_LOG << "num services: " << g_services->size();
diff --git a/osp/impl/mdns_responder_service.cc b/osp/impl/mdns_responder_service.cc
index b688577..15836fb 100644
--- a/osp/impl/mdns_responder_service.cc
+++ b/osp/impl/mdns_responder_service.cc
@@ -56,13 +56,13 @@
}
void MdnsResponderService::HandleNewEvents(
- const std::vector<platform::UdpReadCallback::Packet>& data) {
+ const std::vector<platform::UdpPacket>& packets) {
if (!mdns_responder_)
return;
- for (auto& packet : data) {
- mdns_responder_->OnDataReceived(packet.source, packet.original_destination,
- packet.data(), packet.length,
- packet.socket);
+ for (auto& packet : packets) {
+ mdns_responder_->OnDataReceived(packet.source(), packet.destination(),
+ packet.data(), packet.size(),
+ packet.socket());
}
mdns_responder_->RunTasks();
diff --git a/osp/impl/mdns_responder_service.h b/osp/impl/mdns_responder_service.h
index 996607b..d9d8a62 100644
--- a/osp/impl/mdns_responder_service.h
+++ b/osp/impl/mdns_responder_service.h
@@ -46,8 +46,7 @@
const std::vector<platform::NetworkInterfaceIndex> whitelist,
const std::map<std::string, std::string>& txt_data);
- void HandleNewEvents(
- const std::vector<platform::UdpReadCallback::Packet>& data);
+ void HandleNewEvents(const std::vector<platform::UdpPacket>& packets);
// ServiceListenerImpl::Delegate overrides.
void StartListener() override;
diff --git a/osp/impl/quic/quic_connection.h b/osp/impl/quic/quic_connection.h
index 8f3c82a..069b06a 100644
--- a/osp/impl/quic/quic_connection.h
+++ b/osp/impl/quic/quic_connection.h
@@ -71,8 +71,7 @@
// Passes a received UDP packet to the QUIC implementation. If this contains
// any stream data, it will be passed automatically to the relevant
// QuicStream::Delegate objects.
- virtual void OnDataReceived(
- const platform::UdpReadCallback::Packet& data) = 0;
+ virtual void OnDataReceived(const platform::UdpPacket& packet) = 0;
virtual std::unique_ptr<QuicStream> MakeOutgoingStream(
QuicStream::Delegate* delegate) = 0;
diff --git a/osp/impl/quic/quic_connection_factory_impl.cc b/osp/impl/quic/quic_connection_factory_impl.cc
index e3edf25..b176d24 100644
--- a/osp/impl/quic/quic_connection_factory_impl.cc
+++ b/osp/impl/quic/quic_connection_factory_impl.cc
@@ -129,34 +129,35 @@
// QuicConnectionFactoryImpl.
OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(),
[&packet](const platform::UdpSocketUniquePtr& s) {
- return s.get() == packet.socket;
+ return s.get() == packet.socket();
}) != sockets_.end());
// TODO(btolsch): We will need to rethink this both for ICE and connection
// migration support.
- auto conn_it = connections_.find(packet.source);
+ auto conn_it = connections_.find(packet.source());
if (conn_it == connections_.end()) {
if (server_delegate_) {
- OSP_VLOG << __func__ << ": spawning connection from " << packet.source;
+ OSP_VLOG << __func__ << ": spawning connection from "
+ << packet.source();
auto transport =
- std::make_unique<UdpTransport>(packet.socket, packet.source);
+ std::make_unique<UdpTransport>(packet.socket(), packet.source());
::quic::QuartcSessionConfig session_config;
session_config.perspective = ::quic::Perspective::IS_SERVER;
session_config.packet_transport = transport.get();
auto result = std::make_unique<QuicConnectionImpl>(
- this, server_delegate_->NextConnectionDelegate(packet.source),
+ this, server_delegate_->NextConnectionDelegate(packet.source()),
std::move(transport),
quartc_factory_->CreateQuartcSession(session_config));
auto* result_ptr = result.get();
- connections_.emplace(packet.source,
- OpenConnection{result_ptr, packet.socket});
+ connections_.emplace(packet.source(),
+ OpenConnection{result_ptr, packet.socket()});
server_delegate_->OnIncomingConnection(std::move(result));
result_ptr->OnDataReceived(packet);
}
} else {
OSP_VLOG << __func__ << ": data for existing connection from "
- << packet.source;
+ << packet.source();
conn_it->second.connection->OnDataReceived(packet);
}
}
diff --git a/osp/impl/quic/quic_connection_impl.cc b/osp/impl/quic/quic_connection_impl.cc
index d56f91c..f07c5bf 100644
--- a/osp/impl/quic/quic_connection_impl.cc
+++ b/osp/impl/quic/quic_connection_impl.cc
@@ -89,10 +89,9 @@
QuicConnectionImpl::~QuicConnectionImpl() = default;
-void QuicConnectionImpl::OnDataReceived(
- const platform::UdpReadCallback::Packet& data) {
- session_->OnTransportReceived(reinterpret_cast<const char*>(data.data()),
- data.length);
+void QuicConnectionImpl::OnDataReceived(const platform::UdpPacket& packet) {
+ session_->OnTransportReceived(reinterpret_cast<const char*>(packet.data()),
+ packet.size());
}
std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream(
diff --git a/osp/impl/quic/quic_connection_impl.h b/osp/impl/quic/quic_connection_impl.h
index 8e39d63..2288faf 100644
--- a/osp/impl/quic/quic_connection_impl.h
+++ b/osp/impl/quic/quic_connection_impl.h
@@ -75,7 +75,7 @@
~QuicConnectionImpl() override;
// QuicConnection overrides.
- void OnDataReceived(const platform::UdpReadCallback::Packet& data) override;
+ void OnDataReceived(const platform::UdpPacket& packet) override;
std::unique_ptr<QuicStream> MakeOutgoingStream(
QuicStream::Delegate* delegate) override;
void Close() override;
diff --git a/osp/impl/quic/testing/fake_quic_connection.cc b/osp/impl/quic/testing/fake_quic_connection.cc
index c57b5df..06b8279 100644
--- a/osp/impl/quic/testing/fake_quic_connection.cc
+++ b/osp/impl/quic/testing/fake_quic_connection.cc
@@ -60,8 +60,7 @@
return result;
}
-void FakeQuicConnection::OnDataReceived(
- const platform::UdpReadCallback::Packet& data) {
+void FakeQuicConnection::OnDataReceived(const platform::UdpPacket& packet) {
OSP_DCHECK(false) << "data should go directly to fake streams";
}
diff --git a/osp/impl/quic/testing/fake_quic_connection.h b/osp/impl/quic/testing/fake_quic_connection.h
index 9186a8e..90833c4 100644
--- a/osp/impl/quic/testing/fake_quic_connection.h
+++ b/osp/impl/quic/testing/fake_quic_connection.h
@@ -57,7 +57,7 @@
std::unique_ptr<FakeQuicStream> MakeIncomingStream();
// QuicConnection overrides.
- void OnDataReceived(const platform::UdpReadCallback::Packet& data) override;
+ void OnDataReceived(const platform::UdpPacket& packet) override;
std::unique_ptr<QuicStream> MakeOutgoingStream(
QuicStream::Delegate* delegate) override;
void Close() override;
diff --git a/platform/BUILD.gn b/platform/BUILD.gn
index e53e7a0..e891b1b 100644
--- a/platform/BUILD.gn
+++ b/platform/BUILD.gn
@@ -28,9 +28,10 @@
"api/trace_logging_platform.cc",
"api/trace_logging_platform.h",
"api/trace_logging_types.h",
+ "api/udp_packet.h",
+ "api/udp_read_callback.h",
"api/udp_socket.cc",
"api/udp_socket.h",
- "api/upd_read_callback.h",
"base/error.cc",
"base/error.h",
"base/ip_address.cc",
diff --git a/platform/api/udp_packet.h b/platform/api/udp_packet.h
new file mode 100644
index 0000000..a05bdb7
--- /dev/null
+++ b/platform/api/udp_packet.h
@@ -0,0 +1,47 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file
+
+#ifndef PLATFORM_API_UDP_PACKET_H_
+#define PLATFORM_API_UDP_PACKET_H_
+
+#include <vector>
+
+#include "platform/api/logging.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+namespace platform {
+
+class UdpSocket;
+
+static constexpr size_t kUdpMaxPacketSize = 1 << 16;
+
+class UdpPacket : public std::vector<uint8_t> {
+ public:
+ explicit UdpPacket(size_t size) : std::vector<uint8_t>(size) {
+ OSP_DCHECK(size <= kUdpMaxPacketSize);
+ };
+ UdpPacket() : UdpPacket(0){};
+
+ const IPEndpoint& source() const { return source_; }
+ void set_source(IPEndpoint endpoint) { source_ = std::move(endpoint); }
+
+ const IPEndpoint& destination() const { return destination_; }
+ void set_destination(IPEndpoint endpoint) {
+ destination_ = std::move(endpoint);
+ }
+
+ UdpSocket* socket() const { return socket_; }
+ void set_socket(UdpSocket* socket) { socket_ = socket; }
+
+ private:
+ IPEndpoint source_ = {};
+ IPEndpoint destination_ = {};
+ UdpSocket* socket_ = nullptr;
+};
+
+} // namespace platform
+} // namespace openscreen
+
+#endif // PLATFORM_API_UDP_PACKET_H_
\ No newline at end of file
diff --git a/platform/api/udp_read_callback.h b/platform/api/udp_read_callback.h
index f69efd4..4a68575 100644
--- a/platform/api/udp_read_callback.h
+++ b/platform/api/udp_read_callback.h
@@ -5,39 +5,20 @@
#ifndef PLATFORM_API_UDP_READ_CALLBACK_H_
#define PLATFORM_API_UDP_READ_CALLBACK_H_
-#include <array>
-#include <cstdint>
-#include <memory>
-
-#include "platform/base/ip_address.h"
+#include "platform/api/udp_packet.h"
namespace openscreen {
namespace platform {
class NetworkRunner;
-class UdpSocket;
-
-static constexpr int kUdpMaxPacketSize = 1 << 16;
class UdpReadCallback {
public:
- struct Packet : std::array<uint8_t, kUdpMaxPacketSize> {
- Packet() = default;
- ~Packet() = default;
-
- IPEndpoint source;
- IPEndpoint original_destination;
- ssize_t length;
- // TODO(btolsch): When this gets to implementation, make sure the callback
- // is never called with a |socket| that could have been destroyed (e.g.
- // between queueing the read data and running the task).
- UdpSocket* socket;
- };
-
virtual ~UdpReadCallback() = default;
-
- virtual void OnRead(std::unique_ptr<Packet> data,
- NetworkRunner* network_runner) = 0;
+ // TODO(btolsch): When this gets to implementation, make sure the callback
+ // is never called with a |packet| from a socket that could have been
+ // destroyed (e.g. between queueing the read data and running the task).
+ virtual void OnRead(UdpPacket packet, NetworkRunner* network_runner) = 0;
};
} // namespace platform
diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h
index 7ac9035..e060d41 100644
--- a/platform/api/udp_socket.h
+++ b/platform/api/udp_socket.h
@@ -10,6 +10,7 @@
#include <memory>
#include "platform/api/network_interface.h"
+#include "platform/api/udp_read_callback.h"
#include "platform/base/error.h"
#include "platform/base/ip_address.h"
#include "platform/base/macros.h"
@@ -79,13 +80,8 @@
// received. Note that a non-Error return value of 0 is a valid result,
// indicating an empty message has been received. Also note that
// Error::Code::kAgain might be returned if there is no message currently
- // ready for receive, which can be expected during normal operation. |src| and
- // |original_destination| are optional output arguments that provide the
- // source of the message and its intended destination, respectively.
- virtual ErrorOr<size_t> ReceiveMessage(void* data,
- size_t length,
- IPEndpoint* src,
- IPEndpoint* original_destination) = 0;
+ // ready for receive, which can be expected during normal operation.
+ virtual ErrorOr<UdpPacket> ReceiveMessage() = 0;
// Sends a message and returns the number of bytes sent, on success.
// Error::Code::kAgain might be returned to indicate the operation would
diff --git a/platform/impl/event_loop.cc b/platform/impl/event_loop.cc
index 6899ce4..05454c9 100644
--- a/platform/impl/event_loop.cc
+++ b/platform/impl/event_loop.cc
@@ -12,35 +12,21 @@
namespace openscreen {
namespace platform {
-Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
- UdpReadCallback::Packet* data) {
- OSP_DCHECK(data);
- ErrorOr<size_t> len = read_event.socket->ReceiveMessage(
- &data[0], data->size(), &data->source, &data->original_destination);
- if (!len) {
- OSP_LOG_ERROR << "ReceiveMessage() on socket failed: "
- << len.error().message();
- return len.error();
- }
- OSP_DCHECK_LE(len.value(), static_cast<size_t>(kUdpMaxPacketSize));
- data->length = len.value();
- data->socket = read_event.socket;
- return Error::None();
-}
-
-std::vector<UdpReadCallback::Packet> HandleUdpSocketReadEvents(
- const Events& events) {
- std::vector<UdpReadCallback::Packet> data;
+std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events) {
+ std::vector<UdpPacket> packets(events.udp_readable_events.size());
for (const auto& read_event : events.udp_readable_events) {
- UdpReadCallback::Packet next_data;
- if (ReceiveDataFromEvent(read_event, &next_data).ok())
- data.emplace_back(std::move(next_data));
+ ErrorOr<UdpPacket> result = read_event.socket->ReceiveMessage();
+ if (result) {
+ packets.emplace_back(result.MoveValue());
+ } else {
+ OSP_LOG_ERROR << "ReceiveMessage() on socket failed: "
+ << result.error().message();
+ }
}
- return data;
+ return packets;
}
-std::vector<UdpReadCallback::Packet> OnePlatformLoopIteration(
- EventWaiterPtr waiter) {
+std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter) {
ErrorOr<Events> events = WaitForEvents(waiter);
if (!events)
return {};
diff --git a/platform/impl/event_loop.h b/platform/impl/event_loop.h
index 2988b57..05bf458 100644
--- a/platform/impl/event_loop.h
+++ b/platform/impl/event_loop.h
@@ -16,12 +16,8 @@
namespace openscreen {
namespace platform {
-Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
- UdpReadCallback::Packet* data);
-std::vector<UdpReadCallback::Packet> HandleUdpSocketReadEvents(
- const Events& events);
-std::vector<UdpReadCallback::Packet> OnePlatformLoopIteration(
- EventWaiterPtr waiter);
+std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events);
+std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter);
} // namespace platform
} // namespace openscreen
diff --git a/platform/impl/network_reader.cc b/platform/impl/network_reader.cc
index 9eccc92..4a94705 100644
--- a/platform/impl/network_reader.cc
+++ b/platform/impl/network_reader.cc
@@ -11,24 +11,6 @@
namespace openscreen {
namespace platform {
-namespace {
-
-class ReadCallbackExecutor {
- public:
- ReadCallbackExecutor(std::unique_ptr<UdpReadCallback::Packet> data,
- NetworkReader::Callback function)
- : function_(function) {
- data_ = std::move(data);
- }
-
- void operator()() { function_(std::move(data_)); }
-
- private:
- std::unique_ptr<UdpReadCallback::Packet> data_;
- NetworkReader::Callback function_;
-};
-
-} // namespace
NetworkReader::NetworkReader(TaskRunner* task_runner)
: NetworkReader(task_runner, NetworkWaiter::Create()) {}
@@ -85,39 +67,28 @@
continue;
}
- ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> read_packet =
- ReadFromSocket(mapped_socket->first);
+ ErrorOr<UdpPacket> read_packet = mapped_socket->first->ReceiveMessage();
if (read_packet.is_error()) {
error = read_packet.error();
continue;
}
- // FIXME: Investigate removing ReadCallbackExecutor.
- auto task =
- ReadCallbackExecutor(read_packet.MoveValue(), mapped_socket->second);
- task_runner_->PostTask(std::move(task));
+ // Capture the UdpPacket by move into |arg| here to transfer the ownership
+ // and avoid copying the UdpPacket. This move constructs the UdpPacket
+ // inside of the lambda. Then the UdpPacket |arg| is passed by move to the
+ // callback function |func|.
+ auto executor = [arg = read_packet.MoveValue(),
+ func = mapped_socket->second]() mutable {
+ func(std::move(arg));
+ };
+
+ task_runner_->PostTask(std::move(executor));
}
}
return error;
}
-ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> NetworkReader::ReadFromSocket(
- UdpSocket* socket) {
- // TODO(rwkeane): Use circular buffer in Socket instead of new packet.
- auto data = std::make_unique<UdpReadCallback::Packet>();
- ErrorOr<size_t> read_bytes = socket->ReceiveMessage(
- &(*data)[0], data->size(), &data->source, &data->original_destination);
- if (read_bytes.is_error()) {
- return read_bytes.error();
- }
-
- data->socket = socket;
- data->length = read_bytes.value();
-
- return data;
-}
-
void NetworkReader::RunUntilStopped() {
const bool was_running = is_running_.exchange(true);
OSP_CHECK(!was_running);
diff --git a/platform/impl/network_reader.h b/platform/impl/network_reader.h
index 21e04ae..4d714b3 100644
--- a/platform/impl/network_reader.h
+++ b/platform/impl/network_reader.h
@@ -23,8 +23,7 @@
class NetworkReader {
public:
// Create a type for readability
- using Callback =
- std::function<void(std::unique_ptr<UdpReadCallback::Packet>)>;
+ using Callback = std::function<void(UdpPacket)>;
// Creates a new instance of this object.
// NOTE: The provided TaskRunner must be running and must live for the
@@ -61,12 +60,6 @@
// duration of this instance's life.
NetworkReader(TaskRunner* task_runner, std::unique_ptr<NetworkWaiter> waiter);
- // Method to read data from a socket. This method will not block, but is only
- // expected to be called by WaitAndRead when it detects that a socket has
- // data waiting to be read.
- virtual ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> ReadFromSocket(
- UdpSocket* socket);
-
// Waits for any writes to occur or for timeout to pass, whichever is sooner.
// If an error occurs when calling WaitAndRead, then no callbacks will have
// been called during the method's execution, but it is still safe to
diff --git a/platform/impl/network_reader_unittest.cc b/platform/impl/network_reader_unittest.cc
index 303c2cc..e1dcf8e 100644
--- a/platform/impl/network_reader_unittest.cc
+++ b/platform/impl/network_reader_unittest.cc
@@ -60,28 +60,19 @@
// Public method to call wait, since usually this method is internally
// callable only.
Error WaitTesting(Clock::duration timeout) { return WaitAndRead(timeout); }
-
- MOCK_METHOD1(
- ReadFromSocket,
- ErrorOr<std::unique_ptr<UdpReadCallback::Packet>>(UdpSocket* socket));
};
class MockCallbacks {
public:
- std::function<void(std::unique_ptr<UdpReadCallback::Packet>)>
- GetReadCallback() {
- return [this](std::unique_ptr<UdpReadCallback::Packet> packet) {
- this->ReadCallback(std::move(packet));
- };
+ std::function<void(UdpPacket)> GetReadCallback() {
+ return [this](UdpPacket packet) { this->ReadCallback(std::move(packet)); };
}
std::function<void()> GetWriteCallback() {
return [this]() { this->WriteCallback(); };
}
- void ReadCallback(std::unique_ptr<UdpReadCallback::Packet> packet) {
- ReadCallbackInternal();
- }
+ void ReadCallback(UdpPacket packet) { ReadCallbackInternal(); }
MOCK_METHOD0(ReadCallbackInternal, void());
MOCK_METHOD0(WriteCallback, void());
@@ -194,7 +185,7 @@
TestingNetworkWaiter network_waiter(std::move(mock_waiter),
task_runner.get());
auto timeout = Clock::duration(0);
- auto packet = std::make_unique<UdpReadCallback::Packet>();
+ UdpPacket packet;
MockCallbacks callbacks;
network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
@@ -217,27 +208,26 @@
std::unique_ptr<NetworkWaiter>(mock_waiter_ptr);
std::unique_ptr<TaskRunner> task_runner =
std::unique_ptr<TaskRunner>(task_runner_ptr);
- std::unique_ptr<MockUdpSocket> socket =
- std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4);
+ MockUdpSocket socket(UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter),
task_runner.get());
auto timeout = Clock::duration(0);
- auto packet = std::make_unique<UdpReadCallback::Packet>();
+ UdpPacket packet;
MockCallbacks callbacks;
- network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
+ network_waiter.ReadRepeatedly(&socket, callbacks.GetReadCallback());
EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
- .WillOnce(Return(ByMove(std::vector<UdpSocket*>{socket.get()})));
+ .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket})));
EXPECT_CALL(callbacks, ReadCallbackInternal()).Times(1);
- EXPECT_CALL(network_waiter, ReadFromSocket(socket.get()))
+ EXPECT_CALL(socket, ReceiveMessage())
.WillOnce(Return(ByMove(std::move(packet))));
EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kNone);
EXPECT_EQ(task_runner_ptr->tasks_posted, uint32_t{1});
// Set deletion callback because otherwise the destructor tries to call a
// callback on the deleted object when it goes out of scope.
- socket->SetDeletionCallback([](UdpSocket* socket) {});
+ socket.SetDeletionCallback([](UdpSocket* socket) {});
}
TEST(NetworkReaderTest, WaitFailsIfReadingSocketFails) {
@@ -246,27 +236,25 @@
std::unique_ptr<NetworkWaiter>(mock_waiter_ptr);
std::unique_ptr<TaskRunner> task_runner =
std::unique_ptr<TaskRunner>(new MockTaskRunner());
- std::unique_ptr<MockUdpSocket> socket =
- std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4);
+ MockUdpSocket socket(UdpSocket::Version::kV4);
TestingNetworkWaiter network_waiter(std::move(mock_waiter),
task_runner.get());
auto timeout = Clock::duration(0);
- auto packet = std::make_unique<UdpReadCallback::Packet>();
MockCallbacks callbacks;
- network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
+ network_waiter.ReadRepeatedly(&socket, callbacks.GetReadCallback());
EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
- .WillOnce(Return(ByMove(std::vector<UdpSocket*>{socket.get()})));
+ .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket})));
EXPECT_CALL(callbacks, ReadCallbackInternal()).Times(0);
- EXPECT_CALL(network_waiter, ReadFromSocket(socket.get()))
+ EXPECT_CALL(socket, ReceiveMessage())
.WillOnce(Return(ByMove(Error::Code::kGenericPlatformError)));
EXPECT_EQ(network_waiter.WaitTesting(timeout),
Error::Code::kGenericPlatformError);
// Set deletion callback because otherwise the destructor tries to call a
// callback on the deleted object when it goes out of scope.
- socket->SetDeletionCallback([](UdpSocket* socket) {});
+ socket.SetDeletionCallback([](UdpSocket* socket) {});
}
} // namespace platform
diff --git a/platform/impl/network_runner.cc b/platform/impl/network_runner.cc
index ca260ba..c090c6b 100644
--- a/platform/impl/network_runner.cc
+++ b/platform/impl/network_runner.cc
@@ -24,10 +24,9 @@
Error NetworkRunnerImpl::ReadRepeatedly(UdpSocket* socket,
UdpReadCallback* callback) {
- NetworkReader::Callback func =
- [callback, this](std::unique_ptr<UdpReadCallback::Packet> packet) {
- callback->OnRead(std::move(packet), this);
- };
+ NetworkReader::Callback func = [callback, this](UdpPacket packet) {
+ callback->OnRead(std::move(packet), this);
+ };
return network_loop_->ReadRepeatedly(socket, func);
}
diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc
index 6134b48..f4263e7 100644
--- a/platform/impl/udp_socket_posix.cc
+++ b/platform/impl/udp_socket_posix.cc
@@ -8,6 +8,7 @@
#include <fcntl.h>
#include <netinet/in.h>
#include <netinet/ip.h>
+#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
@@ -226,143 +227,126 @@
return Error(hard_error_code, strerror(errno));
}
-} // namespace
-
-ErrorOr<size_t> UdpSocketPosix::ReceiveMessage(
- void* data,
- size_t length,
- IPEndpoint* src,
- IPEndpoint* original_destination) {
- struct iovec iov = {data, length};
- char control_buf[1024];
- size_t cmsg_size = sizeof(control_buf) - sizeof(struct cmsghdr) + 1;
- void* cmsg_buf = control_buf;
- std::align(alignof(struct cmsghdr), sizeof(cmsg_buf), cmsg_buf, cmsg_size);
- switch (version_) {
- case UdpSocket::Version::kV4: {
- struct sockaddr_in sa;
- struct msghdr msg;
- msg.msg_name = &sa;
- msg.msg_namelen = sizeof(sa);
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsg_buf;
- msg.msg_controllen = cmsg_size;
- msg.msg_flags = 0;
-
- ssize_t num_bytes_received = recvmsg(fd_, &msg, 0);
- if (num_bytes_received == -1) {
- return ChooseError(errno, Error::Code::kSocketReadFailure);
- }
- OSP_DCHECK_GE(num_bytes_received, 0);
-
- if (src) {
- src->address =
- IPAddress(IPAddress::Version::kV4,
- reinterpret_cast<const uint8_t*>(&sa.sin_addr.s_addr));
- src->port = ntohs(sa.sin_port);
- }
-
- // For multicast sockets, the packet's original destination address may be
- // the host address (since we called bind()) but it may also be a
- // multicast address. This may be relevant for handling multicast data;
- // specifically, mDNSResponder requires this information to work properly.
- if (original_destination) {
- *original_destination = IPEndpoint{{}, 0};
- if ((msg.msg_flags & MSG_CTRUNC) == 0) {
- for (struct cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh;
- cmh = CMSG_NXTHDR(&msg, cmh)) {
- if (cmh->cmsg_level != IPPROTO_IP || cmh->cmsg_type != IP_PKTINFO)
- continue;
-
- struct sockaddr_in addr;
- socklen_t addr_len = sizeof(addr);
- if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&addr),
- &addr_len) == -1) {
- break;
- }
- // |original_destination->port| will be 0 if this line isn't
- // reached.
- original_destination->port = ntohs(addr.sin_port);
-
- struct in_pktinfo* pktinfo =
- reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(cmh));
- original_destination->address =
- IPAddress(IPAddress::Version::kV4,
- reinterpret_cast<const uint8_t*>(&pktinfo->ipi_addr));
- break;
- }
- }
- }
-
- return num_bytes_received;
- }
-
- case UdpSocket::Version::kV6: {
- struct sockaddr_in6 sa;
- struct msghdr msg;
- msg.msg_name = &sa;
- msg.msg_namelen = sizeof(sa);
- msg.msg_iov = &iov;
- msg.msg_iovlen = 1;
- msg.msg_control = cmsg_buf;
- msg.msg_controllen = cmsg_size;
- msg.msg_flags = 0;
-
- ssize_t num_bytes_received = recvmsg(fd_, &msg, 0);
- if (num_bytes_received == -1) {
- return ChooseError(errno, Error::Code::kSocketReadFailure);
- }
- OSP_DCHECK_GE(num_bytes_received, 0);
-
- if (src) {
- src->address =
- IPAddress(IPAddress::Version::kV6,
- reinterpret_cast<const uint8_t*>(&sa.sin6_addr.s6_addr));
- src->port = ntohs(sa.sin6_port);
- }
-
- // For multicast sockets, the packet's original destination address may be
- // the host address (since we called bind()) but it may also be a
- // multicast address. This may be relevant for handling multicast data;
- // specifically, mDNSResponder requires this information to work properly.
- if (original_destination) {
- *original_destination = IPEndpoint{{}, 0};
- if ((msg.msg_flags & MSG_CTRUNC) == 0) {
- for (struct cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh;
- cmh = CMSG_NXTHDR(&msg, cmh)) {
- if (cmh->cmsg_level != IPPROTO_IPV6 ||
- cmh->cmsg_type != IPV6_PKTINFO) {
- continue;
- }
- struct sockaddr_in6 addr;
- socklen_t addr_len = sizeof(addr);
- if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&addr),
- &addr_len) == -1) {
- break;
- }
- // |original_destination->port| will be 0 if this line isn't
- // reached.
- original_destination->port = ntohs(addr.sin6_port);
-
- struct in6_pktinfo* pktinfo =
- reinterpret_cast<struct in6_pktinfo*>(CMSG_DATA(cmh));
- original_destination->address = IPAddress(
- IPAddress::Version::kV6,
- reinterpret_cast<const uint8_t*>(&pktinfo->ipi6_addr));
- break;
- }
- }
- }
-
- return num_bytes_received;
- }
- }
-
- OSP_NOTREACHED();
- return Error::Code::kGenericPlatformError;
+IPAddress GetIPAddressFromSockAddr(const sockaddr_in& sa) {
+ static_assert(IPAddress::kV4Size == sizeof(sa.sin_addr.s_addr),
+ "IPv4 address size mismatch.");
+ return IPAddress(IPAddress::Version::kV4,
+ reinterpret_cast<const uint8_t*>(&sa.sin_addr.s_addr));
}
+IPAddress GetIPAddressFromPktInfo(const in_pktinfo& pktinfo) {
+ static_assert(IPAddress::kV4Size == sizeof(pktinfo.ipi_addr),
+ "IPv4 address size mismatch.");
+ return IPAddress(IPAddress::Version::kV4,
+ reinterpret_cast<const uint8_t*>(&pktinfo.ipi_addr));
+}
+
+uint16_t GetPortFromFromSockAddr(const sockaddr_in& sa) {
+ return ntohs(sa.sin_port);
+}
+
+IPAddress GetIPAddressFromSockAddr(const sockaddr_in6& sa) {
+ return IPAddress(sa.sin6_addr.s6_addr);
+}
+
+IPAddress GetIPAddressFromPktInfo(const in6_pktinfo& pktinfo) {
+ return IPAddress(pktinfo.ipi6_addr.s6_addr);
+}
+
+uint16_t GetPortFromFromSockAddr(const sockaddr_in6& sa) {
+ return ntohs(sa.sin6_port);
+}
+
+template <class PktInfoType>
+bool IsPacketInfo(cmsghdr* cmh);
+
+template <>
+bool IsPacketInfo<in_pktinfo>(cmsghdr* cmh) {
+ return cmh->cmsg_level == IPPROTO_IP && cmh->cmsg_type == IP_PKTINFO;
+}
+
+template <>
+bool IsPacketInfo<in6_pktinfo>(cmsghdr* cmh) {
+ return cmh->cmsg_level == IPPROTO_IPV6 && cmh->cmsg_type == IPV6_PKTINFO;
+}
+
+template <class SockAddrType, class PktInfoType>
+Error ReceiveMessageInternal(int fd, UdpPacket* packet) {
+ SockAddrType sa;
+ iovec iov = {packet->data(), packet->size()};
+ alignas(alignof(cmsghdr)) uint8_t control_buffer[1024];
+ msghdr msg;
+ msg.msg_name = &sa;
+ msg.msg_namelen = sizeof(sa);
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_control = control_buffer;
+ msg.msg_controllen = sizeof(control_buffer);
+ msg.msg_flags = 0;
+
+ ssize_t bytes_received = recvmsg(fd, &msg, 0);
+ if (bytes_received == -1) {
+ return ChooseError(errno, Error::Code::kSocketReadFailure);
+ }
+
+ OSP_DCHECK_EQ(static_cast<size_t>(bytes_received), packet->size());
+
+ IPEndpoint source_endpoint = {.address = GetIPAddressFromSockAddr(sa),
+ .port = GetPortFromFromSockAddr(sa)};
+ packet->set_source(std::move(source_endpoint));
+
+ // For multicast sockets, the packet's original destination address may be
+ // the host address (since we called bind()) but it may also be a
+ // multicast address. This may be relevant for handling multicast data;
+ // specifically, mDNSResponder requires this information to work properly.
+
+ socklen_t sa_len = sizeof(sa);
+ if (((msg.msg_flags & MSG_CTRUNC) != 0) ||
+ (getsockname(fd, reinterpret_cast<sockaddr*>(&sa), &sa_len) == -1)) {
+ return Error::Code::kNone;
+ }
+ for (cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh; cmh = CMSG_NXTHDR(&msg, cmh)) {
+ if (IsPacketInfo<PktInfoType>(cmh)) {
+ PktInfoType* pktinfo = reinterpret_cast<PktInfoType*>(CMSG_DATA(cmh));
+ IPEndpoint destination_endpoint = {
+ .address = GetIPAddressFromPktInfo(*pktinfo),
+ .port = GetPortFromFromSockAddr(sa)};
+ packet->set_destination(std::move(destination_endpoint));
+ break;
+ }
+ }
+ return Error::Code::kNone;
+}
+
+} // namespace
+
+ErrorOr<UdpPacket> UdpSocketPosix::ReceiveMessage() {
+ ssize_t bytes_available = recv(fd_, nullptr, 0, MSG_PEEK | MSG_TRUNC);
+ if (bytes_available == -1) {
+ return ChooseError(errno, Error::Code::kSocketReadFailure);
+ }
+ UdpPacket packet(bytes_available);
+ packet.set_socket(this);
+ Error result = Error::Code::kGenericPlatformError;
+ switch (version_) {
+ case UdpSocket::Version::kV4: {
+ result = ReceiveMessageInternal<sockaddr_in, in_pktinfo>(fd_, &packet);
+ break;
+ }
+ case UdpSocket::Version::kV6: {
+ result = ReceiveMessageInternal<sockaddr_in6, in6_pktinfo>(fd_, &packet);
+ break;
+ }
+ default: {
+ OSP_NOTREACHED();
+ }
+ }
+ return result.ok() ? ErrorOr<UdpPacket>(std::move(packet))
+ : ErrorOr<UdpPacket>(std::move(result));
+}
+
+// TODO(yakimakha): Consider changing the interface to accept UdpPacket as
+// an input parameter
Error UdpSocketPosix::SendMessage(const void* data,
size_t length,
const IPEndpoint& dest) {
diff --git a/platform/impl/udp_socket_posix.h b/platform/impl/udp_socket_posix.h
index f5bc63d..17983b1 100644
--- a/platform/impl/udp_socket_posix.h
+++ b/platform/impl/udp_socket_posix.h
@@ -22,10 +22,7 @@
Error SetMulticastOutboundInterface(NetworkInterfaceIndex ifindex) final;
Error JoinMulticastGroup(const IPAddress& address,
NetworkInterfaceIndex ifindex) final;
- ErrorOr<size_t> ReceiveMessage(void* data,
- size_t length,
- IPEndpoint* src,
- IPEndpoint* original_destination) final;
+ ErrorOr<UdpPacket> ReceiveMessage() final;
Error SendMessage(const void* data,
size_t length,
const IPEndpoint& dest) final;
diff --git a/platform/test/mock_udp_socket.h b/platform/test/mock_udp_socket.h
index 90ed352..8fbdb87 100644
--- a/platform/test/mock_udp_socket.h
+++ b/platform/test/mock_udp_socket.h
@@ -27,8 +27,7 @@
MOCK_METHOD1(SetMulticastOutboundInterface, Error(NetworkInterfaceIndex));
MOCK_METHOD2(JoinMulticastGroup,
Error(const IPAddress&, NetworkInterfaceIndex));
- MOCK_METHOD4(ReceiveMessage,
- ErrorOr<size_t>(void*, size_t, IPEndpoint*, IPEndpoint*));
+ MOCK_METHOD0(ReceiveMessage, ErrorOr<UdpPacket>());
MOCK_METHOD3(SendMessage, Error(const void*, size_t, const IPEndpoint&));
MOCK_METHOD1(SetDscp, Error(DscpMode));