blob: 380c5947f79466bcfb1b0f869fc1449627da1e6d [file] [log] [blame] [edit]
/*
* Copyright (C) 2023 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/traced_relay/socket_relay_handler.h"
#include <chrono>
#include <cstring>
#include <memory>
#include <random>
#include <string>
#include <thread>
#include <utility>
#include "perfetto/ext/base/threading/thread_pool.h"
#include "perfetto/ext/base/unix_socket.h"
#include "test/gtest_and_gmock.h"
using testing::Values;
namespace perfetto {
namespace {
using RawSocketPair = std::pair<base::UnixSocketRaw, base::UnixSocketRaw>;
using RngValueType = std::minstd_rand0::result_type;
struct TestClient {
RawSocketPair endpoint_sockets;
std::minstd_rand0 data_prng;
std::thread client_thread;
};
class SocketRelayHandlerTest : public ::testing::TestWithParam<uint32_t> {
protected:
void SetUp() override {
socket_relay_handler_ = std::make_unique<SocketRelayHandler>();
for (uint32_t i = 0; i < GetParam(); i++) {
TestClient client{SetUpEndToEndSockets(), std::minstd_rand0(i), {}};
test_clients_.push_back(std::move(client));
}
}
void TearDown() override { socket_relay_handler_ = nullptr; }
RawSocketPair SetUpEndToEndSockets() {
// Creates 2 SocketPairs:
// sock1 <-> sock2 <-> SocketRelayHandler <-> sock3 <-> sock4.
// sock2 and sock3 are transferred to the SocketRelayHandler.
// We test by reading and writing bidirectionally using sock1 and sock4.
auto [sock1, sock2] = base::UnixSocketRaw::CreatePairPosix(
base::SockFamily::kUnix, base::SockType::kStream);
sock2.SetBlocking(false);
auto [sock3, sock4] = base::UnixSocketRaw::CreatePairPosix(
base::SockFamily::kUnix, base::SockType::kStream);
sock3.SetBlocking(false);
auto socket_pair = std::make_unique<SocketPair>();
socket_pair->first.sock = std::move(sock2);
socket_pair->second.sock = std::move(sock3);
socket_relay_handler_->AddSocketPair(std::move(socket_pair));
RawSocketPair endpoint_sockets;
endpoint_sockets.first = std::move(sock1);
endpoint_sockets.second = std::move(sock4);
return endpoint_sockets;
}
std::unique_ptr<SocketRelayHandler> socket_relay_handler_;
std::vector<TestClient> test_clients_;
// Use fewer receiver threads than sender threads.
base::ThreadPool receiver_thread_pool_{1 + GetParam() / 10};
};
TEST(SocketWithBufferTest, EnqueueDequeue) {
SocketWithBuffer socket_with_buffer;
// No data initially.
EXPECT_EQ(0u, socket_with_buffer.data_size());
// Has room for writing some bytes into.
std::string data = "12345678901234567890";
EXPECT_GT(socket_with_buffer.available_bytes(), data.size());
memcpy(socket_with_buffer.buffer(), data.data(), data.size());
socket_with_buffer.EnqueueData(data.size());
EXPECT_EQ(data.size(), socket_with_buffer.data_size());
// Dequeue some bytes.
socket_with_buffer.DequeueData(5);
EXPECT_EQ(socket_with_buffer.data_size(), data.size() - 5);
std::string buffered_data(reinterpret_cast<char*>(socket_with_buffer.data()),
socket_with_buffer.data_size());
EXPECT_EQ(buffered_data, "678901234567890");
}
// Test the SocketRelayHander with randomized request and response data.
TEST_P(SocketRelayHandlerTest, RandomizedRequestResponse) {
#if defined(ADDRESS_SANITIZER) || defined(THREAD_SANITIZER) || \
defined(MEMORY_SANITIZER) || defined(LEAK_SANITIZER)
// Reduce the test strength for sanitizer builds.
constexpr size_t kMaxMsgSizeRng = 1 << 16;
constexpr size_t kMaxNumRequests = 10;
#else
// The max message size in the number of RNG calls.
constexpr size_t kMaxMsgSizeRng = 1 << 18;
// The max number of requests.
constexpr size_t kMaxNumRequests = 25;
#endif
// Create the threads for sending and receiving data through the
// SocketRelayHandler.
for (auto& client : test_clients_) {
auto* thread_pool = &receiver_thread_pool_;
auto thread_func = [&client, thread_pool]() {
auto& rng = client.data_prng;
// The max number of requests.
const size_t num_requests = rng() % kMaxNumRequests;
for (size_t j = 0; j < num_requests; j++) {
auto& send_endpoint = client.endpoint_sockets.first;
auto& receive_endpoint = client.endpoint_sockets.second;
auto req_size = rng() % kMaxMsgSizeRng;
// Generate the random request.
std::vector<RngValueType> request;
request.reserve(req_size);
for (size_t r = 0; r < req_size; r++) {
request.emplace_back(rng());
}
// Create a buffer for receiving the request.
std::vector<RngValueType> received_request(request.size());
std::mutex mutex;
std::condition_variable cv;
std::unique_lock<std::mutex> lock(mutex);
bool done = false;
// Blocking receive on the thread pool.
thread_pool->PostTask([&]() {
const size_t bytes_to_receive =
received_request.size() * sizeof(RngValueType);
uint8_t* receive_buffer =
reinterpret_cast<uint8_t*>(received_request.data());
size_t bytes_received = 0;
// Perform a blocking read until we received the expected bytes.
while (bytes_received < bytes_to_receive) {
ssize_t rsize = PERFETTO_EINTR(
receive_endpoint.Receive(receive_buffer + bytes_received,
bytes_to_receive - bytes_received));
if (rsize <= 0)
break;
bytes_received += static_cast<size_t>(rsize);
std::this_thread::yield(); // Adds some scheduling randomness.
}
std::lock_guard<std::mutex> inner_lock(mutex);
done = true;
cv.notify_one();
});
// Perform a blocking send of the request data.
PERFETTO_EINTR(send_endpoint.Send(
request.data(), request.size() * sizeof(RngValueType)));
// Wait until the request is fully received.
cv.wait(lock, [&done] { return done; });
// Check data integrity.
EXPECT_EQ(request, received_request);
// Add some randomness to timing.
std::this_thread::sleep_for(std::chrono::microseconds(rng() % 1000));
// Emulate the response by reversing the data flow direction.
std::swap(send_endpoint, receive_endpoint);
}
};
client.client_thread = std::thread(std::move(thread_func));
}
for (auto& client : test_clients_) {
client.client_thread.join();
}
}
#if PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
INSTANTIATE_TEST_SUITE_P(ByConnections, SocketRelayHandlerTest, Values(1, 5));
#else
INSTANTIATE_TEST_SUITE_P(ByConnections,
SocketRelayHandlerTest,
Values(1, 5, 25));
#endif
} // namespace
} // namespace perfetto