| #include <uds/client_channel.h> |
| |
| #include <sys/socket.h> |
| |
| #include <algorithm> |
| #include <limits> |
| #include <random> |
| #include <thread> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| |
| #include <pdx/client.h> |
| #include <pdx/rpc/remote_method.h> |
| #include <pdx/service.h> |
| #include <pdx/service_dispatcher.h> |
| |
| #include <uds/client_channel_factory.h> |
| #include <uds/service_endpoint.h> |
| |
| using testing::Return; |
| using testing::_; |
| |
| using android::pdx::ClientBase; |
| using android::pdx::LocalChannelHandle; |
| using android::pdx::LocalHandle; |
| using android::pdx::Message; |
| using android::pdx::ServiceBase; |
| using android::pdx::ServiceDispatcher; |
| using android::pdx::Status; |
| using android::pdx::rpc::DispatchRemoteMethod; |
| using android::pdx::uds::ClientChannel; |
| using android::pdx::uds::ClientChannelFactory; |
| using android::pdx::uds::Endpoint; |
| |
| namespace { |
| |
| struct TestProtocol { |
| using DataType = int8_t; |
| enum { |
| kOpSum = 0, |
| }; |
| PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&)); |
| }; |
| |
| class TestService : public ServiceBase<TestService> { |
| public: |
| TestService(std::unique_ptr<Endpoint> endpoint) |
| : ServiceBase{"TestService", std::move(endpoint)} {} |
| |
| Status<void> HandleMessage(Message& message) override { |
| switch (message.GetOp()) { |
| case TestProtocol::kOpSum: |
| DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum, |
| message); |
| return {}; |
| |
| default: |
| return Service::HandleMessage(message); |
| } |
| } |
| |
| int64_t OnSum(Message& /*message*/, |
| const std::vector<TestProtocol::DataType>& data) { |
| return std::accumulate(data.begin(), data.end(), int64_t{0}); |
| } |
| }; |
| |
| class TestClient : public ClientBase<TestClient> { |
| public: |
| using ClientBase::ClientBase; |
| |
| int64_t Sum(const std::vector<TestProtocol::DataType>& data) { |
| auto status = InvokeRemoteMethod<TestProtocol::Sum>(data); |
| return status ? status.get() : -1; |
| } |
| }; |
| |
| class TestServiceRunner { |
| public: |
| TestServiceRunner(LocalHandle channel_socket) { |
| auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{}); |
| endpoint->RegisterNewChannelForTests(std::move(channel_socket)); |
| service_ = TestService::Create(std::move(endpoint)); |
| dispatcher_ = ServiceDispatcher::Create(); |
| dispatcher_->AddService(service_); |
| dispatch_thread_ = std::thread( |
| std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get())); |
| } |
| |
| ~TestServiceRunner() { |
| dispatcher_->SetCanceled(true); |
| dispatch_thread_.join(); |
| dispatcher_->RemoveService(service_); |
| } |
| |
| private: |
| std::shared_ptr<TestService> service_; |
| std::unique_ptr<ServiceDispatcher> dispatcher_; |
| std::thread dispatch_thread_; |
| }; |
| |
| class ClientChannelTest : public testing::Test { |
| public: |
| void SetUp() override { |
| int channel_sockets[2] = {}; |
| ASSERT_EQ( |
| 0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets)); |
| LocalHandle service_channel{channel_sockets[0]}; |
| LocalHandle client_channel{channel_sockets[1]}; |
| |
| service_runner_.reset(new TestServiceRunner{std::move(service_channel)}); |
| auto factory = ClientChannelFactory::Create(std::move(client_channel)); |
| auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout); |
| ASSERT_TRUE(status); |
| client_ = TestClient::Create(status.take()); |
| } |
| |
| void TearDown() override { |
| service_runner_.reset(); |
| client_.reset(); |
| } |
| |
| protected: |
| std::unique_ptr<TestServiceRunner> service_runner_; |
| std::shared_ptr<TestClient> client_; |
| }; |
| |
| TEST_F(ClientChannelTest, MultithreadedClient) { |
| constexpr int kNumTestThreads = 8; |
| constexpr size_t kDataSize = 1000; // Try to keep RPC buffer size below 4K. |
| |
| std::random_device rd; |
| std::mt19937 gen{rd()}; |
| std::uniform_int_distribution<TestProtocol::DataType> dist{ |
| std::numeric_limits<TestProtocol::DataType>::min(), |
| std::numeric_limits<TestProtocol::DataType>::max()}; |
| |
| auto worker = [](std::shared_ptr<TestClient> client, |
| std::vector<TestProtocol::DataType> data) { |
| constexpr int kMaxIterations = 500; |
| int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0}); |
| for (int i = 0; i < kMaxIterations; i++) { |
| ASSERT_EQ(expected, client->Sum(data)); |
| } |
| }; |
| |
| // Start client threads. |
| std::vector<TestProtocol::DataType> data; |
| data.resize(kDataSize); |
| std::vector<std::thread> threads; |
| for (int i = 0; i < kNumTestThreads; i++) { |
| std::generate(data.begin(), data.end(), |
| [&dist, &gen]() { return dist(gen); }); |
| threads.emplace_back(worker, client_, data); |
| } |
| |
| // Wait for threads to finish. |
| for (auto& thread : threads) |
| thread.join(); |
| } |
| |
| } // namespace |