Change UDP socket ReceiveMessage to perform peek first and then create a packet of the correct size.

Change-Id: Ia16538141d821744afcf616f615b43768d6af645
Reviewed-on: https://chromium-review.googlesource.com/c/openscreen/+/1702868
Commit-Queue: Max Yakimakha <yakimakha@chromium.org>
Reviewed-by: mark a. foltz <mfoltz@chromium.org>
Reviewed-by: Ryan Keane <rwkeane@google.com>
diff --git a/osp/impl/discovery/mdns/embedder_demo.cc b/osp/impl/discovery/mdns/embedder_demo.cc
index 3cab141..2603dd5 100644
--- a/osp/impl/discovery/mdns/embedder_demo.cc
+++ b/osp/impl/discovery/mdns/embedder_demo.cc
@@ -312,8 +312,9 @@
     mdns_adapter->RunTasks();
     auto data = platform::OnePlatformLoopIteration(waiter);
     for (auto& packet : data) {
-      mdns_adapter->OnDataReceived(packet.source, packet.original_destination,
-                                   packet.data(), packet.length, packet.socket);
+      mdns_adapter->OnDataReceived(packet.source(), packet.destination(),
+                                   packet.data(), packet.size(),
+                                   packet.socket());
     }
   }
   OSP_LOG << "num services: " << g_services->size();
diff --git a/osp/impl/mdns_responder_service.cc b/osp/impl/mdns_responder_service.cc
index b688577..15836fb 100644
--- a/osp/impl/mdns_responder_service.cc
+++ b/osp/impl/mdns_responder_service.cc
@@ -56,13 +56,13 @@
 }
 
 void MdnsResponderService::HandleNewEvents(
-    const std::vector<platform::UdpReadCallback::Packet>& data) {
+    const std::vector<platform::UdpPacket>& packets) {
   if (!mdns_responder_)
     return;
-  for (auto& packet : data) {
-    mdns_responder_->OnDataReceived(packet.source, packet.original_destination,
-                                    packet.data(), packet.length,
-                                    packet.socket);
+  for (auto& packet : packets) {
+    mdns_responder_->OnDataReceived(packet.source(), packet.destination(),
+                                    packet.data(), packet.size(),
+                                    packet.socket());
   }
   mdns_responder_->RunTasks();
 
diff --git a/osp/impl/mdns_responder_service.h b/osp/impl/mdns_responder_service.h
index 996607b..d9d8a62 100644
--- a/osp/impl/mdns_responder_service.h
+++ b/osp/impl/mdns_responder_service.h
@@ -46,8 +46,7 @@
       const std::vector<platform::NetworkInterfaceIndex> whitelist,
       const std::map<std::string, std::string>& txt_data);
 
-  void HandleNewEvents(
-      const std::vector<platform::UdpReadCallback::Packet>& data);
+  void HandleNewEvents(const std::vector<platform::UdpPacket>& packets);
 
   // ServiceListenerImpl::Delegate overrides.
   void StartListener() override;
diff --git a/osp/impl/quic/quic_connection.h b/osp/impl/quic/quic_connection.h
index 8f3c82a..069b06a 100644
--- a/osp/impl/quic/quic_connection.h
+++ b/osp/impl/quic/quic_connection.h
@@ -71,8 +71,7 @@
   // Passes a received UDP packet to the QUIC implementation.  If this contains
   // any stream data, it will be passed automatically to the relevant
   // QuicStream::Delegate objects.
-  virtual void OnDataReceived(
-      const platform::UdpReadCallback::Packet& data) = 0;
+  virtual void OnDataReceived(const platform::UdpPacket& packet) = 0;
 
   virtual std::unique_ptr<QuicStream> MakeOutgoingStream(
       QuicStream::Delegate* delegate) = 0;
diff --git a/osp/impl/quic/quic_connection_factory_impl.cc b/osp/impl/quic/quic_connection_factory_impl.cc
index e3edf25..b176d24 100644
--- a/osp/impl/quic/quic_connection_factory_impl.cc
+++ b/osp/impl/quic/quic_connection_factory_impl.cc
@@ -129,34 +129,35 @@
     // QuicConnectionFactoryImpl.
     OSP_DCHECK(std::find_if(sockets_.begin(), sockets_.end(),
                             [&packet](const platform::UdpSocketUniquePtr& s) {
-                              return s.get() == packet.socket;
+                              return s.get() == packet.socket();
                             }) != sockets_.end());
 
     // TODO(btolsch): We will need to rethink this both for ICE and connection
     // migration support.
-    auto conn_it = connections_.find(packet.source);
+    auto conn_it = connections_.find(packet.source());
     if (conn_it == connections_.end()) {
       if (server_delegate_) {
-        OSP_VLOG << __func__ << ": spawning connection from " << packet.source;
+        OSP_VLOG << __func__ << ": spawning connection from "
+                 << packet.source();
         auto transport =
-            std::make_unique<UdpTransport>(packet.socket, packet.source);
+            std::make_unique<UdpTransport>(packet.socket(), packet.source());
         ::quic::QuartcSessionConfig session_config;
         session_config.perspective = ::quic::Perspective::IS_SERVER;
         session_config.packet_transport = transport.get();
 
         auto result = std::make_unique<QuicConnectionImpl>(
-            this, server_delegate_->NextConnectionDelegate(packet.source),
+            this, server_delegate_->NextConnectionDelegate(packet.source()),
             std::move(transport),
             quartc_factory_->CreateQuartcSession(session_config));
         auto* result_ptr = result.get();
-        connections_.emplace(packet.source,
-                             OpenConnection{result_ptr, packet.socket});
+        connections_.emplace(packet.source(),
+                             OpenConnection{result_ptr, packet.socket()});
         server_delegate_->OnIncomingConnection(std::move(result));
         result_ptr->OnDataReceived(packet);
       }
     } else {
       OSP_VLOG << __func__ << ": data for existing connection from "
-               << packet.source;
+               << packet.source();
       conn_it->second.connection->OnDataReceived(packet);
     }
   }
diff --git a/osp/impl/quic/quic_connection_impl.cc b/osp/impl/quic/quic_connection_impl.cc
index d56f91c..f07c5bf 100644
--- a/osp/impl/quic/quic_connection_impl.cc
+++ b/osp/impl/quic/quic_connection_impl.cc
@@ -89,10 +89,9 @@
 
 QuicConnectionImpl::~QuicConnectionImpl() = default;
 
-void QuicConnectionImpl::OnDataReceived(
-    const platform::UdpReadCallback::Packet& data) {
-  session_->OnTransportReceived(reinterpret_cast<const char*>(data.data()),
-                                data.length);
+void QuicConnectionImpl::OnDataReceived(const platform::UdpPacket& packet) {
+  session_->OnTransportReceived(reinterpret_cast<const char*>(packet.data()),
+                                packet.size());
 }
 
 std::unique_ptr<QuicStream> QuicConnectionImpl::MakeOutgoingStream(
diff --git a/osp/impl/quic/quic_connection_impl.h b/osp/impl/quic/quic_connection_impl.h
index 8e39d63..2288faf 100644
--- a/osp/impl/quic/quic_connection_impl.h
+++ b/osp/impl/quic/quic_connection_impl.h
@@ -75,7 +75,7 @@
   ~QuicConnectionImpl() override;
 
   // QuicConnection overrides.
-  void OnDataReceived(const platform::UdpReadCallback::Packet& data) override;
+  void OnDataReceived(const platform::UdpPacket& packet) override;
   std::unique_ptr<QuicStream> MakeOutgoingStream(
       QuicStream::Delegate* delegate) override;
   void Close() override;
diff --git a/osp/impl/quic/testing/fake_quic_connection.cc b/osp/impl/quic/testing/fake_quic_connection.cc
index c57b5df..06b8279 100644
--- a/osp/impl/quic/testing/fake_quic_connection.cc
+++ b/osp/impl/quic/testing/fake_quic_connection.cc
@@ -60,8 +60,7 @@
   return result;
 }
 
-void FakeQuicConnection::OnDataReceived(
-    const platform::UdpReadCallback::Packet& data) {
+void FakeQuicConnection::OnDataReceived(const platform::UdpPacket& packet) {
   OSP_DCHECK(false) << "data should go directly to fake streams";
 }
 
diff --git a/osp/impl/quic/testing/fake_quic_connection.h b/osp/impl/quic/testing/fake_quic_connection.h
index 9186a8e..90833c4 100644
--- a/osp/impl/quic/testing/fake_quic_connection.h
+++ b/osp/impl/quic/testing/fake_quic_connection.h
@@ -57,7 +57,7 @@
   std::unique_ptr<FakeQuicStream> MakeIncomingStream();
 
   // QuicConnection overrides.
-  void OnDataReceived(const platform::UdpReadCallback::Packet& data) override;
+  void OnDataReceived(const platform::UdpPacket& packet) override;
   std::unique_ptr<QuicStream> MakeOutgoingStream(
       QuicStream::Delegate* delegate) override;
   void Close() override;
diff --git a/platform/BUILD.gn b/platform/BUILD.gn
index e53e7a0..e891b1b 100644
--- a/platform/BUILD.gn
+++ b/platform/BUILD.gn
@@ -28,9 +28,10 @@
     "api/trace_logging_platform.cc",
     "api/trace_logging_platform.h",
     "api/trace_logging_types.h",
+    "api/udp_packet.h",
+    "api/udp_read_callback.h",
     "api/udp_socket.cc",
     "api/udp_socket.h",
-    "api/upd_read_callback.h",
     "base/error.cc",
     "base/error.h",
     "base/ip_address.cc",
diff --git a/platform/api/udp_packet.h b/platform/api/udp_packet.h
new file mode 100644
index 0000000..a05bdb7
--- /dev/null
+++ b/platform/api/udp_packet.h
@@ -0,0 +1,47 @@
+// Copyright 2019 The Chromium Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style license that can be
+// found in the LICENSE file
+
+#ifndef PLATFORM_API_UDP_PACKET_H_
+#define PLATFORM_API_UDP_PACKET_H_
+
+#include <vector>
+
+#include "platform/api/logging.h"
+#include "platform/base/ip_address.h"
+
+namespace openscreen {
+namespace platform {
+
+class UdpSocket;
+
+static constexpr size_t kUdpMaxPacketSize = 1 << 16;
+
+class UdpPacket : public std::vector<uint8_t> {
+ public:
+  explicit UdpPacket(size_t size) : std::vector<uint8_t>(size) {
+    OSP_DCHECK(size <= kUdpMaxPacketSize);
+  };
+  UdpPacket() : UdpPacket(0){};
+
+  const IPEndpoint& source() const { return source_; }
+  void set_source(IPEndpoint endpoint) { source_ = std::move(endpoint); }
+
+  const IPEndpoint& destination() const { return destination_; }
+  void set_destination(IPEndpoint endpoint) {
+    destination_ = std::move(endpoint);
+  }
+
+  UdpSocket* socket() const { return socket_; }
+  void set_socket(UdpSocket* socket) { socket_ = socket; }
+
+ private:
+  IPEndpoint source_ = {};
+  IPEndpoint destination_ = {};
+  UdpSocket* socket_ = nullptr;
+};
+
+}  // namespace platform
+}  // namespace openscreen
+
+#endif  // PLATFORM_API_UDP_PACKET_H_
\ No newline at end of file
diff --git a/platform/api/udp_read_callback.h b/platform/api/udp_read_callback.h
index f69efd4..4a68575 100644
--- a/platform/api/udp_read_callback.h
+++ b/platform/api/udp_read_callback.h
@@ -5,39 +5,20 @@
 #ifndef PLATFORM_API_UDP_READ_CALLBACK_H_
 #define PLATFORM_API_UDP_READ_CALLBACK_H_
 
-#include <array>
-#include <cstdint>
-#include <memory>
-
-#include "platform/base/ip_address.h"
+#include "platform/api/udp_packet.h"
 
 namespace openscreen {
 namespace platform {
 
 class NetworkRunner;
-class UdpSocket;
-
-static constexpr int kUdpMaxPacketSize = 1 << 16;
 
 class UdpReadCallback {
  public:
-  struct Packet : std::array<uint8_t, kUdpMaxPacketSize> {
-    Packet() = default;
-    ~Packet() = default;
-
-    IPEndpoint source;
-    IPEndpoint original_destination;
-    ssize_t length;
-    // TODO(btolsch): When this gets to implementation, make sure the callback
-    // is never called with a |socket| that could have been destroyed (e.g.
-    // between queueing the read data and running the task).
-    UdpSocket* socket;
-  };
-
   virtual ~UdpReadCallback() = default;
-
-  virtual void OnRead(std::unique_ptr<Packet> data,
-                      NetworkRunner* network_runner) = 0;
+  // TODO(btolsch): When this gets to implementation, make sure the callback
+  // is never called with a |packet| from a socket that could have been
+  // destroyed (e.g. between queueing the read data and running the task).
+  virtual void OnRead(UdpPacket packet, NetworkRunner* network_runner) = 0;
 };
 
 }  // namespace platform
diff --git a/platform/api/udp_socket.h b/platform/api/udp_socket.h
index 7ac9035..e060d41 100644
--- a/platform/api/udp_socket.h
+++ b/platform/api/udp_socket.h
@@ -10,6 +10,7 @@
 #include <memory>
 
 #include "platform/api/network_interface.h"
+#include "platform/api/udp_read_callback.h"
 #include "platform/base/error.h"
 #include "platform/base/ip_address.h"
 #include "platform/base/macros.h"
@@ -79,13 +80,8 @@
   // received. Note that a non-Error return value of 0 is a valid result,
   // indicating an empty message has been received. Also note that
   // Error::Code::kAgain might be returned if there is no message currently
-  // ready for receive, which can be expected during normal operation. |src| and
-  // |original_destination| are optional output arguments that provide the
-  // source of the message and its intended destination, respectively.
-  virtual ErrorOr<size_t> ReceiveMessage(void* data,
-                                         size_t length,
-                                         IPEndpoint* src,
-                                         IPEndpoint* original_destination) = 0;
+  // ready for receive, which can be expected during normal operation.
+  virtual ErrorOr<UdpPacket> ReceiveMessage() = 0;
 
   // Sends a message and returns the number of bytes sent, on success.
   // Error::Code::kAgain might be returned to indicate the operation would
diff --git a/platform/impl/event_loop.cc b/platform/impl/event_loop.cc
index 6899ce4..05454c9 100644
--- a/platform/impl/event_loop.cc
+++ b/platform/impl/event_loop.cc
@@ -12,35 +12,21 @@
 namespace openscreen {
 namespace platform {
 
-Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
-                           UdpReadCallback::Packet* data) {
-  OSP_DCHECK(data);
-  ErrorOr<size_t> len = read_event.socket->ReceiveMessage(
-      &data[0], data->size(), &data->source, &data->original_destination);
-  if (!len) {
-    OSP_LOG_ERROR << "ReceiveMessage() on socket failed: "
-                  << len.error().message();
-    return len.error();
-  }
-  OSP_DCHECK_LE(len.value(), static_cast<size_t>(kUdpMaxPacketSize));
-  data->length = len.value();
-  data->socket = read_event.socket;
-  return Error::None();
-}
-
-std::vector<UdpReadCallback::Packet> HandleUdpSocketReadEvents(
-    const Events& events) {
-  std::vector<UdpReadCallback::Packet> data;
+std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events) {
+  std::vector<UdpPacket> packets(events.udp_readable_events.size());
   for (const auto& read_event : events.udp_readable_events) {
-    UdpReadCallback::Packet next_data;
-    if (ReceiveDataFromEvent(read_event, &next_data).ok())
-      data.emplace_back(std::move(next_data));
+    ErrorOr<UdpPacket> result = read_event.socket->ReceiveMessage();
+    if (result) {
+      packets.emplace_back(result.MoveValue());
+    } else {
+      OSP_LOG_ERROR << "ReceiveMessage() on socket failed: "
+                    << result.error().message();
+    }
   }
-  return data;
+  return packets;
 }
 
-std::vector<UdpReadCallback::Packet> OnePlatformLoopIteration(
-    EventWaiterPtr waiter) {
+std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter) {
   ErrorOr<Events> events = WaitForEvents(waiter);
   if (!events)
     return {};
diff --git a/platform/impl/event_loop.h b/platform/impl/event_loop.h
index 2988b57..05bf458 100644
--- a/platform/impl/event_loop.h
+++ b/platform/impl/event_loop.h
@@ -16,12 +16,8 @@
 namespace openscreen {
 namespace platform {
 
-Error ReceiveDataFromEvent(const UdpSocketReadableEvent& read_event,
-                           UdpReadCallback::Packet* data);
-std::vector<UdpReadCallback::Packet> HandleUdpSocketReadEvents(
-    const Events& events);
-std::vector<UdpReadCallback::Packet> OnePlatformLoopIteration(
-    EventWaiterPtr waiter);
+std::vector<UdpPacket> HandleUdpSocketReadEvents(const Events& events);
+std::vector<UdpPacket> OnePlatformLoopIteration(EventWaiterPtr waiter);
 
 }  // namespace platform
 }  // namespace openscreen
diff --git a/platform/impl/network_reader.cc b/platform/impl/network_reader.cc
index 9eccc92..4a94705 100644
--- a/platform/impl/network_reader.cc
+++ b/platform/impl/network_reader.cc
@@ -11,24 +11,6 @@
 
 namespace openscreen {
 namespace platform {
-namespace {
-
-class ReadCallbackExecutor {
- public:
-  ReadCallbackExecutor(std::unique_ptr<UdpReadCallback::Packet> data,
-                       NetworkReader::Callback function)
-      : function_(function) {
-    data_ = std::move(data);
-  }
-
-  void operator()() { function_(std::move(data_)); }
-
- private:
-  std::unique_ptr<UdpReadCallback::Packet> data_;
-  NetworkReader::Callback function_;
-};
-
-}  // namespace
 
 NetworkReader::NetworkReader(TaskRunner* task_runner)
     : NetworkReader(task_runner, NetworkWaiter::Create()) {}
@@ -85,39 +67,28 @@
         continue;
       }
 
-      ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> read_packet =
-          ReadFromSocket(mapped_socket->first);
+      ErrorOr<UdpPacket> read_packet = mapped_socket->first->ReceiveMessage();
       if (read_packet.is_error()) {
         error = read_packet.error();
         continue;
       }
 
-      // FIXME: Investigate removing ReadCallbackExecutor.
-      auto task =
-          ReadCallbackExecutor(read_packet.MoveValue(), mapped_socket->second);
-      task_runner_->PostTask(std::move(task));
+      // Capture the UdpPacket by move into |arg| here to transfer the ownership
+      // and avoid copying the UdpPacket. This move constructs the UdpPacket
+      // inside of the lambda. Then the UdpPacket |arg| is passed by move to the
+      // callback function |func|.
+      auto executor = [arg = read_packet.MoveValue(),
+                       func = mapped_socket->second]() mutable {
+        func(std::move(arg));
+      };
+
+      task_runner_->PostTask(std::move(executor));
     }
   }
 
   return error;
 }
 
-ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> NetworkReader::ReadFromSocket(
-    UdpSocket* socket) {
-  // TODO(rwkeane): Use circular buffer in Socket instead of new packet.
-  auto data = std::make_unique<UdpReadCallback::Packet>();
-  ErrorOr<size_t> read_bytes = socket->ReceiveMessage(
-      &(*data)[0], data->size(), &data->source, &data->original_destination);
-  if (read_bytes.is_error()) {
-    return read_bytes.error();
-  }
-
-  data->socket = socket;
-  data->length = read_bytes.value();
-
-  return data;
-}
-
 void NetworkReader::RunUntilStopped() {
   const bool was_running = is_running_.exchange(true);
   OSP_CHECK(!was_running);
diff --git a/platform/impl/network_reader.h b/platform/impl/network_reader.h
index 21e04ae..4d714b3 100644
--- a/platform/impl/network_reader.h
+++ b/platform/impl/network_reader.h
@@ -23,8 +23,7 @@
 class NetworkReader {
  public:
   // Create a type for readability
-  using Callback =
-      std::function<void(std::unique_ptr<UdpReadCallback::Packet>)>;
+  using Callback = std::function<void(UdpPacket)>;
 
   // Creates a new instance of this object.
   // NOTE: The provided TaskRunner must be running and must live for the
@@ -61,12 +60,6 @@
   // duration of this instance's life.
   NetworkReader(TaskRunner* task_runner, std::unique_ptr<NetworkWaiter> waiter);
 
-  // Method to read data from a socket. This method will not block, but is only
-  // expected to be called by WaitAndRead when it detects that a socket has
-  // data waiting to be read.
-  virtual ErrorOr<std::unique_ptr<UdpReadCallback::Packet>> ReadFromSocket(
-      UdpSocket* socket);
-
   // Waits for any writes to occur or for timeout to pass, whichever is sooner.
   // If an error occurs when calling WaitAndRead, then no callbacks will have
   // been called during the method's execution, but it is still safe to
diff --git a/platform/impl/network_reader_unittest.cc b/platform/impl/network_reader_unittest.cc
index 303c2cc..e1dcf8e 100644
--- a/platform/impl/network_reader_unittest.cc
+++ b/platform/impl/network_reader_unittest.cc
@@ -60,28 +60,19 @@
   // Public method to call wait, since usually this method is internally
   // callable only.
   Error WaitTesting(Clock::duration timeout) { return WaitAndRead(timeout); }
-
-  MOCK_METHOD1(
-      ReadFromSocket,
-      ErrorOr<std::unique_ptr<UdpReadCallback::Packet>>(UdpSocket* socket));
 };
 
 class MockCallbacks {
  public:
-  std::function<void(std::unique_ptr<UdpReadCallback::Packet>)>
-  GetReadCallback() {
-    return [this](std::unique_ptr<UdpReadCallback::Packet> packet) {
-      this->ReadCallback(std::move(packet));
-    };
+  std::function<void(UdpPacket)> GetReadCallback() {
+    return [this](UdpPacket packet) { this->ReadCallback(std::move(packet)); };
   }
 
   std::function<void()> GetWriteCallback() {
     return [this]() { this->WriteCallback(); };
   }
 
-  void ReadCallback(std::unique_ptr<UdpReadCallback::Packet> packet) {
-    ReadCallbackInternal();
-  }
+  void ReadCallback(UdpPacket packet) { ReadCallbackInternal(); }
 
   MOCK_METHOD0(ReadCallbackInternal, void());
   MOCK_METHOD0(WriteCallback, void());
@@ -194,7 +185,7 @@
   TestingNetworkWaiter network_waiter(std::move(mock_waiter),
                                       task_runner.get());
   auto timeout = Clock::duration(0);
-  auto packet = std::make_unique<UdpReadCallback::Packet>();
+  UdpPacket packet;
   MockCallbacks callbacks;
 
   network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
@@ -217,27 +208,26 @@
       std::unique_ptr<NetworkWaiter>(mock_waiter_ptr);
   std::unique_ptr<TaskRunner> task_runner =
       std::unique_ptr<TaskRunner>(task_runner_ptr);
-  std::unique_ptr<MockUdpSocket> socket =
-      std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4);
+  MockUdpSocket socket(UdpSocket::Version::kV4);
   TestingNetworkWaiter network_waiter(std::move(mock_waiter),
                                       task_runner.get());
   auto timeout = Clock::duration(0);
-  auto packet = std::make_unique<UdpReadCallback::Packet>();
+  UdpPacket packet;
   MockCallbacks callbacks;
 
-  network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
+  network_waiter.ReadRepeatedly(&socket, callbacks.GetReadCallback());
 
   EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
-      .WillOnce(Return(ByMove(std::vector<UdpSocket*>{socket.get()})));
+      .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket})));
   EXPECT_CALL(callbacks, ReadCallbackInternal()).Times(1);
-  EXPECT_CALL(network_waiter, ReadFromSocket(socket.get()))
+  EXPECT_CALL(socket, ReceiveMessage())
       .WillOnce(Return(ByMove(std::move(packet))));
   EXPECT_EQ(network_waiter.WaitTesting(timeout), Error::Code::kNone);
   EXPECT_EQ(task_runner_ptr->tasks_posted, uint32_t{1});
 
   // Set deletion callback because otherwise the destructor tries to call a
   // callback on the deleted object when it goes out of scope.
-  socket->SetDeletionCallback([](UdpSocket* socket) {});
+  socket.SetDeletionCallback([](UdpSocket* socket) {});
 }
 
 TEST(NetworkReaderTest, WaitFailsIfReadingSocketFails) {
@@ -246,27 +236,25 @@
       std::unique_ptr<NetworkWaiter>(mock_waiter_ptr);
   std::unique_ptr<TaskRunner> task_runner =
       std::unique_ptr<TaskRunner>(new MockTaskRunner());
-  std::unique_ptr<MockUdpSocket> socket =
-      std::make_unique<MockUdpSocket>(UdpSocket::Version::kV4);
+  MockUdpSocket socket(UdpSocket::Version::kV4);
   TestingNetworkWaiter network_waiter(std::move(mock_waiter),
                                       task_runner.get());
   auto timeout = Clock::duration(0);
-  auto packet = std::make_unique<UdpReadCallback::Packet>();
   MockCallbacks callbacks;
 
-  network_waiter.ReadRepeatedly(socket.get(), callbacks.GetReadCallback());
+  network_waiter.ReadRepeatedly(&socket, callbacks.GetReadCallback());
 
   EXPECT_CALL(*mock_waiter_ptr, AwaitSocketsReadable(_, timeout))
-      .WillOnce(Return(ByMove(std::vector<UdpSocket*>{socket.get()})));
+      .WillOnce(Return(ByMove(std::vector<UdpSocket*>{&socket})));
   EXPECT_CALL(callbacks, ReadCallbackInternal()).Times(0);
-  EXPECT_CALL(network_waiter, ReadFromSocket(socket.get()))
+  EXPECT_CALL(socket, ReceiveMessage())
       .WillOnce(Return(ByMove(Error::Code::kGenericPlatformError)));
   EXPECT_EQ(network_waiter.WaitTesting(timeout),
             Error::Code::kGenericPlatformError);
 
   // Set deletion callback because otherwise the destructor tries to call a
   // callback on the deleted object when it goes out of scope.
-  socket->SetDeletionCallback([](UdpSocket* socket) {});
+  socket.SetDeletionCallback([](UdpSocket* socket) {});
 }
 
 }  // namespace platform
diff --git a/platform/impl/network_runner.cc b/platform/impl/network_runner.cc
index ca260ba..c090c6b 100644
--- a/platform/impl/network_runner.cc
+++ b/platform/impl/network_runner.cc
@@ -24,10 +24,9 @@
 
 Error NetworkRunnerImpl::ReadRepeatedly(UdpSocket* socket,
                                         UdpReadCallback* callback) {
-  NetworkReader::Callback func =
-      [callback, this](std::unique_ptr<UdpReadCallback::Packet> packet) {
-        callback->OnRead(std::move(packet), this);
-      };
+  NetworkReader::Callback func = [callback, this](UdpPacket packet) {
+    callback->OnRead(std::move(packet), this);
+  };
   return network_loop_->ReadRepeatedly(socket, func);
 }
 
diff --git a/platform/impl/udp_socket_posix.cc b/platform/impl/udp_socket_posix.cc
index 6134b48..f4263e7 100644
--- a/platform/impl/udp_socket_posix.cc
+++ b/platform/impl/udp_socket_posix.cc
@@ -8,6 +8,7 @@
 #include <fcntl.h>
 #include <netinet/in.h>
 #include <netinet/ip.h>
+#include <sys/ioctl.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <unistd.h>
@@ -226,143 +227,126 @@
   return Error(hard_error_code, strerror(errno));
 }
 
-}  // namespace
-
-ErrorOr<size_t> UdpSocketPosix::ReceiveMessage(
-    void* data,
-    size_t length,
-    IPEndpoint* src,
-    IPEndpoint* original_destination) {
-  struct iovec iov = {data, length};
-  char control_buf[1024];
-  size_t cmsg_size = sizeof(control_buf) - sizeof(struct cmsghdr) + 1;
-  void* cmsg_buf = control_buf;
-  std::align(alignof(struct cmsghdr), sizeof(cmsg_buf), cmsg_buf, cmsg_size);
-  switch (version_) {
-    case UdpSocket::Version::kV4: {
-      struct sockaddr_in sa;
-      struct msghdr msg;
-      msg.msg_name = &sa;
-      msg.msg_namelen = sizeof(sa);
-      msg.msg_iov = &iov;
-      msg.msg_iovlen = 1;
-      msg.msg_control = cmsg_buf;
-      msg.msg_controllen = cmsg_size;
-      msg.msg_flags = 0;
-
-      ssize_t num_bytes_received = recvmsg(fd_, &msg, 0);
-      if (num_bytes_received == -1) {
-        return ChooseError(errno, Error::Code::kSocketReadFailure);
-      }
-      OSP_DCHECK_GE(num_bytes_received, 0);
-
-      if (src) {
-        src->address =
-            IPAddress(IPAddress::Version::kV4,
-                      reinterpret_cast<const uint8_t*>(&sa.sin_addr.s_addr));
-        src->port = ntohs(sa.sin_port);
-      }
-
-      // For multicast sockets, the packet's original destination address may be
-      // the host address (since we called bind()) but it may also be a
-      // multicast address.  This may be relevant for handling multicast data;
-      // specifically, mDNSResponder requires this information to work properly.
-      if (original_destination) {
-        *original_destination = IPEndpoint{{}, 0};
-        if ((msg.msg_flags & MSG_CTRUNC) == 0) {
-          for (struct cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh;
-               cmh = CMSG_NXTHDR(&msg, cmh)) {
-            if (cmh->cmsg_level != IPPROTO_IP || cmh->cmsg_type != IP_PKTINFO)
-              continue;
-
-            struct sockaddr_in addr;
-            socklen_t addr_len = sizeof(addr);
-            if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&addr),
-                            &addr_len) == -1) {
-              break;
-            }
-            // |original_destination->port| will be 0 if this line isn't
-            // reached.
-            original_destination->port = ntohs(addr.sin_port);
-
-            struct in_pktinfo* pktinfo =
-                reinterpret_cast<struct in_pktinfo*>(CMSG_DATA(cmh));
-            original_destination->address =
-                IPAddress(IPAddress::Version::kV4,
-                          reinterpret_cast<const uint8_t*>(&pktinfo->ipi_addr));
-            break;
-          }
-        }
-      }
-
-      return num_bytes_received;
-    }
-
-    case UdpSocket::Version::kV6: {
-      struct sockaddr_in6 sa;
-      struct msghdr msg;
-      msg.msg_name = &sa;
-      msg.msg_namelen = sizeof(sa);
-      msg.msg_iov = &iov;
-      msg.msg_iovlen = 1;
-      msg.msg_control = cmsg_buf;
-      msg.msg_controllen = cmsg_size;
-      msg.msg_flags = 0;
-
-      ssize_t num_bytes_received = recvmsg(fd_, &msg, 0);
-      if (num_bytes_received == -1) {
-        return ChooseError(errno, Error::Code::kSocketReadFailure);
-      }
-      OSP_DCHECK_GE(num_bytes_received, 0);
-
-      if (src) {
-        src->address =
-            IPAddress(IPAddress::Version::kV6,
-                      reinterpret_cast<const uint8_t*>(&sa.sin6_addr.s6_addr));
-        src->port = ntohs(sa.sin6_port);
-      }
-
-      // For multicast sockets, the packet's original destination address may be
-      // the host address (since we called bind()) but it may also be a
-      // multicast address.  This may be relevant for handling multicast data;
-      // specifically, mDNSResponder requires this information to work properly.
-      if (original_destination) {
-        *original_destination = IPEndpoint{{}, 0};
-        if ((msg.msg_flags & MSG_CTRUNC) == 0) {
-          for (struct cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh;
-               cmh = CMSG_NXTHDR(&msg, cmh)) {
-            if (cmh->cmsg_level != IPPROTO_IPV6 ||
-                cmh->cmsg_type != IPV6_PKTINFO) {
-              continue;
-            }
-            struct sockaddr_in6 addr;
-            socklen_t addr_len = sizeof(addr);
-            if (getsockname(fd_, reinterpret_cast<struct sockaddr*>(&addr),
-                            &addr_len) == -1) {
-              break;
-            }
-            // |original_destination->port| will be 0 if this line isn't
-            // reached.
-            original_destination->port = ntohs(addr.sin6_port);
-
-            struct in6_pktinfo* pktinfo =
-                reinterpret_cast<struct in6_pktinfo*>(CMSG_DATA(cmh));
-            original_destination->address = IPAddress(
-                IPAddress::Version::kV6,
-                reinterpret_cast<const uint8_t*>(&pktinfo->ipi6_addr));
-            break;
-          }
-        }
-      }
-
-      return num_bytes_received;
-    }
-  }
-
-  OSP_NOTREACHED();
-  return Error::Code::kGenericPlatformError;
+IPAddress GetIPAddressFromSockAddr(const sockaddr_in& sa) {
+  static_assert(IPAddress::kV4Size == sizeof(sa.sin_addr.s_addr),
+                "IPv4 address size mismatch.");
+  return IPAddress(IPAddress::Version::kV4,
+                   reinterpret_cast<const uint8_t*>(&sa.sin_addr.s_addr));
 }
 
+IPAddress GetIPAddressFromPktInfo(const in_pktinfo& pktinfo) {
+  static_assert(IPAddress::kV4Size == sizeof(pktinfo.ipi_addr),
+                "IPv4 address size mismatch.");
+  return IPAddress(IPAddress::Version::kV4,
+                   reinterpret_cast<const uint8_t*>(&pktinfo.ipi_addr));
+}
+
+uint16_t GetPortFromFromSockAddr(const sockaddr_in& sa) {
+  return ntohs(sa.sin_port);
+}
+
+IPAddress GetIPAddressFromSockAddr(const sockaddr_in6& sa) {
+  return IPAddress(sa.sin6_addr.s6_addr);
+}
+
+IPAddress GetIPAddressFromPktInfo(const in6_pktinfo& pktinfo) {
+  return IPAddress(pktinfo.ipi6_addr.s6_addr);
+}
+
+uint16_t GetPortFromFromSockAddr(const sockaddr_in6& sa) {
+  return ntohs(sa.sin6_port);
+}
+
+template <class PktInfoType>
+bool IsPacketInfo(cmsghdr* cmh);
+
+template <>
+bool IsPacketInfo<in_pktinfo>(cmsghdr* cmh) {
+  return cmh->cmsg_level == IPPROTO_IP && cmh->cmsg_type == IP_PKTINFO;
+}
+
+template <>
+bool IsPacketInfo<in6_pktinfo>(cmsghdr* cmh) {
+  return cmh->cmsg_level == IPPROTO_IPV6 && cmh->cmsg_type == IPV6_PKTINFO;
+}
+
+template <class SockAddrType, class PktInfoType>
+Error ReceiveMessageInternal(int fd, UdpPacket* packet) {
+  SockAddrType sa;
+  iovec iov = {packet->data(), packet->size()};
+  alignas(alignof(cmsghdr)) uint8_t control_buffer[1024];
+  msghdr msg;
+  msg.msg_name = &sa;
+  msg.msg_namelen = sizeof(sa);
+  msg.msg_iov = &iov;
+  msg.msg_iovlen = 1;
+  msg.msg_control = control_buffer;
+  msg.msg_controllen = sizeof(control_buffer);
+  msg.msg_flags = 0;
+
+  ssize_t bytes_received = recvmsg(fd, &msg, 0);
+  if (bytes_received == -1) {
+    return ChooseError(errno, Error::Code::kSocketReadFailure);
+  }
+
+  OSP_DCHECK_EQ(static_cast<size_t>(bytes_received), packet->size());
+
+  IPEndpoint source_endpoint = {.address = GetIPAddressFromSockAddr(sa),
+                                .port = GetPortFromFromSockAddr(sa)};
+  packet->set_source(std::move(source_endpoint));
+
+  // For multicast sockets, the packet's original destination address may be
+  // the host address (since we called bind()) but it may also be a
+  // multicast address.  This may be relevant for handling multicast data;
+  // specifically, mDNSResponder requires this information to work properly.
+
+  socklen_t sa_len = sizeof(sa);
+  if (((msg.msg_flags & MSG_CTRUNC) != 0) ||
+      (getsockname(fd, reinterpret_cast<sockaddr*>(&sa), &sa_len) == -1)) {
+    return Error::Code::kNone;
+  }
+  for (cmsghdr* cmh = CMSG_FIRSTHDR(&msg); cmh; cmh = CMSG_NXTHDR(&msg, cmh)) {
+    if (IsPacketInfo<PktInfoType>(cmh)) {
+      PktInfoType* pktinfo = reinterpret_cast<PktInfoType*>(CMSG_DATA(cmh));
+      IPEndpoint destination_endpoint = {
+          .address = GetIPAddressFromPktInfo(*pktinfo),
+          .port = GetPortFromFromSockAddr(sa)};
+      packet->set_destination(std::move(destination_endpoint));
+      break;
+    }
+  }
+  return Error::Code::kNone;
+}
+
+}  // namespace
+
+ErrorOr<UdpPacket> UdpSocketPosix::ReceiveMessage() {
+  ssize_t bytes_available = recv(fd_, nullptr, 0, MSG_PEEK | MSG_TRUNC);
+  if (bytes_available == -1) {
+    return ChooseError(errno, Error::Code::kSocketReadFailure);
+  }
+  UdpPacket packet(bytes_available);
+  packet.set_socket(this);
+  Error result = Error::Code::kGenericPlatformError;
+  switch (version_) {
+    case UdpSocket::Version::kV4: {
+      result = ReceiveMessageInternal<sockaddr_in, in_pktinfo>(fd_, &packet);
+      break;
+    }
+    case UdpSocket::Version::kV6: {
+      result = ReceiveMessageInternal<sockaddr_in6, in6_pktinfo>(fd_, &packet);
+      break;
+    }
+    default: {
+      OSP_NOTREACHED();
+    }
+  }
+  return result.ok() ? ErrorOr<UdpPacket>(std::move(packet))
+                     : ErrorOr<UdpPacket>(std::move(result));
+}
+
+// TODO(yakimakha): Consider changing the interface to accept UdpPacket as
+// an input parameter
 Error UdpSocketPosix::SendMessage(const void* data,
                                   size_t length,
                                   const IPEndpoint& dest) {
diff --git a/platform/impl/udp_socket_posix.h b/platform/impl/udp_socket_posix.h
index f5bc63d..17983b1 100644
--- a/platform/impl/udp_socket_posix.h
+++ b/platform/impl/udp_socket_posix.h
@@ -22,10 +22,7 @@
   Error SetMulticastOutboundInterface(NetworkInterfaceIndex ifindex) final;
   Error JoinMulticastGroup(const IPAddress& address,
                            NetworkInterfaceIndex ifindex) final;
-  ErrorOr<size_t> ReceiveMessage(void* data,
-                                 size_t length,
-                                 IPEndpoint* src,
-                                 IPEndpoint* original_destination) final;
+  ErrorOr<UdpPacket> ReceiveMessage() final;
   Error SendMessage(const void* data,
                     size_t length,
                     const IPEndpoint& dest) final;
diff --git a/platform/test/mock_udp_socket.h b/platform/test/mock_udp_socket.h
index 90ed352..8fbdb87 100644
--- a/platform/test/mock_udp_socket.h
+++ b/platform/test/mock_udp_socket.h
@@ -27,8 +27,7 @@
   MOCK_METHOD1(SetMulticastOutboundInterface, Error(NetworkInterfaceIndex));
   MOCK_METHOD2(JoinMulticastGroup,
                Error(const IPAddress&, NetworkInterfaceIndex));
-  MOCK_METHOD4(ReceiveMessage,
-               ErrorOr<size_t>(void*, size_t, IPEndpoint*, IPEndpoint*));
+  MOCK_METHOD0(ReceiveMessage, ErrorOr<UdpPacket>());
   MOCK_METHOD3(SendMessage, Error(const void*, size_t, const IPEndpoint&));
   MOCK_METHOD1(SetDscp, Error(DscpMode));