| /* |
| * Copyright (C) 2025 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include <gtest/gtest.h> |
| |
| #include <stddef.h> |
| #include <stdint.h> |
| #include <memory> |
| #include <thread> |
| #include <string.h> |
| #include <unistd.h> |
| |
| #include <sys/socket.h> |
| #include <sys/un.h> |
| #include <android-base/unique_fd.h> |
| |
| #include "adbd_auth.h" |
| #include "adbd_auth_internal.h" |
| |
| void Log(const std::string& msg) { |
| LOG(INFO) << "(" << gettid() << "): " << msg; |
| } |
| |
| using namespace std::string_view_literals; |
| using namespace std::string_literals; |
| using namespace std::chrono_literals; |
| |
| constexpr std::string_view kUdsName = "\0adb_auth_test_uds"sv; |
| static_assert(kUdsName.size() <= sizeof(reinterpret_cast<sockaddr_un*>(0)->sun_path)); |
| |
| |
| // A convenient struct that will stop the context and wait for the context runner thread to |
| // return when it is destroyed. |
| struct ContextRunner { |
| public: |
| explicit ContextRunner(std::unique_ptr<AdbdAuthContext> context) : context_(std::move(context)) { |
| thread_ = std::thread([raw_context = context_.get()]() { |
| raw_context->Run(); |
| }); |
| |
| // Wait until the context has started running. |
| while (!context_->IsRunning()) { |
| std::this_thread::sleep_for(10ms); |
| } |
| } |
| |
| ~ContextRunner() { |
| context_->Stop(); |
| thread_.join(); |
| } |
| |
| AdbdAuthContext* Context() { |
| return context_.get(); |
| } |
| private: |
| std::thread thread_; |
| std::unique_ptr<AdbdAuthContext> context_; |
| }; |
| |
| // Emulate android_get_control_socket which adbauth uses to get the Unix Domain Socket |
| // to communicate with Framework. |
| std::optional<int> CreateServerSocket() { |
| int sockfd = socket(AF_UNIX, SOCK_SEQPACKET, 0); |
| if (sockfd == -1) { |
| Log("Failed to create server socket"); |
| return {}; |
| } |
| |
| struct sockaddr_un addr{}; |
| addr.sun_family = AF_UNIX; |
| memset(addr.sun_path, 0, sizeof(addr.sun_path)); |
| strncpy(addr.sun_path, kUdsName.data(), kUdsName.size()); |
| |
| if (bind(sockfd, (struct sockaddr*) &addr, sizeof(sockaddr_un)) == -1) { |
| Log("Failed to bind socket server to abstract namespace"); |
| return {}; |
| } |
| |
| if (listen(sockfd, 1) == -1) { |
| Log("Failed to listen on socket server"); |
| return {}; |
| } |
| |
| return sockfd; |
| } |
| |
| // Workaround to fail from anywhere. FAIL macro can only be called from a function |
| // returning void. |
| void fail(const std::string& msg) { |
| FAIL() << msg; |
| } |
| |
| // This class behaves like AdbDebuggingManager. |
| // - Open the UDS created by our adb_auth emulation layer |
| // - Allow to send messages |
| // - Allow to recv messages |
| class Framework { |
| public: |
| Framework() { |
| socket_ = Connect(); |
| } |
| |
| ~Framework() { |
| close(socket_); |
| } |
| |
| std::string Recv() const { |
| char msg[256]; |
| auto num_bytes_read = read(socket_, msg, sizeof(msg)); |
| if (num_bytes_read < 0) { |
| Log("Framework could not read: "s + strerror(errno)); |
| return ""; |
| } |
| |
| Log("Framework read "s + std::to_string(num_bytes_read) + " bytes"); |
| return std::string(msg, num_bytes_read); |
| } |
| |
| int Send(const std::string& msg) const { |
| return write(socket_, msg.data(), msg.size()); |
| } |
| |
| void SendAndWaitContext(const std::string& msg, ContextRunner* runner) { |
| auto packet_id = runner->Context()->ReceivedPackets(); |
| Send(msg); |
| |
| while(runner->Context()->ReceivedPackets() == packet_id) { |
| std::this_thread::sleep_for(10ms); |
| } |
| } |
| |
| private: |
| int socket_; |
| |
| int Connect() { |
| int fd = socket(AF_UNIX, SOCK_SEQPACKET, 0); |
| if (fd == -1) { |
| fail("Cannot create client socket"); |
| } |
| |
| |
| sockaddr_un addr = { .sun_family = AF_UNIX }; |
| strncpy(addr.sun_path, kUdsName.data(), kUdsName.size()); |
| |
| auto res = connect(fd, (struct sockaddr*) &addr, sizeof(addr)); |
| if (res == -1) { |
| fail("Cannot connect client socket"); |
| } |
| |
| return fd; |
| } |
| }; |
| |
| std::unique_ptr<ContextRunner> CreateContextRunner(const AdbdAuthCallbacks& cb) { |
| auto server_socket = CreateServerSocket(); |
| if (!server_socket.has_value()) { |
| Log("Cannot create context"); |
| return {}; |
| } |
| |
| std::unique_ptr<AdbdAuthContext> context; |
| switch (cb.version) { |
| case 1: { |
| context = std::make_unique<AdbdAuthContext>(&reinterpret_cast<const AdbdAuthCallbacksV1&>(cb), |
| server_socket.value()); |
| break; |
| } |
| case 2: { |
| context = std::make_unique<AdbdAuthContextV2>(&reinterpret_cast<const AdbdAuthCallbacksV2&>(cb), |
| server_socket.value()); |
| break; |
| } |
| case 3: { |
| context = std::make_unique<AdbdAuthContextV3>(&reinterpret_cast<const AdbdAuthCallbacksV3&>(cb), |
| server_socket.value()); |
| break; |
| } |
| default: { |
| fail("Unable to create AuthContext for version="s + std::to_string(cb.version)); |
| } |
| } |
| context->InitFrameworkHandlers(); |
| |
| return std::make_unique<ContextRunner>(std::move(context)); |
| } |
| |
| TEST(AdbAuthTest, SendTcpPort) { |
| AdbdAuthCallbacksV1 callbacks{}; |
| callbacks.version = 1; |
| auto runner = CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| // Send TLS to framework |
| const uint8_t port = 19; |
| adbd_auth_send_tls_server_port(runner->Context(), port); |
| |
| // Check that Framework received it. |
| std::string msg = framework.Recv(); |
| ASSERT_EQ(4, msg.size()); |
| ASSERT_EQ(msg[0], 'T'); |
| ASSERT_EQ(msg[1], 'P'); |
| ASSERT_EQ(msg[2], port); |
| ASSERT_EQ(msg[3], 0); |
| } |
| |
| // If user forget to set callbacks, adbauth should not crash. Instead, it should |
| // discard messages and issue a warning. |
| TEST(AdbAuthTest, UnsetCallbacks) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| // We did not set the callback "start ADB Wifi". This should not crash if |
| // the message is properly dispatched. |
| framework.Send("W1"); |
| // We did not set the callback "stop ADB Wifi". This should not crash if. |
| // the message is properly dispatched. |
| framework.Send("W0"); |
| } |
| |
| // Test Wifi lifecycle callbacks |
| TEST(AdbAuthTest, WifiLifeCycle) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| |
| static bool start_message_received = false; |
| callbacks.start_adbd_wifi = [] { |
| start_message_received= true; |
| }; |
| |
| static bool stop_message_received = false; |
| callbacks.stop_adbd_wifi = [] { |
| stop_message_received= true; |
| }; |
| |
| |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| framework.SendAndWaitContext("W1", runner.get()); |
| ASSERT_EQ(start_message_received,true); |
| |
| framework.SendAndWaitContext("W0", runner.get()); |
| ASSERT_EQ(stop_message_received,true); |
| } |
| |
| |
| TEST(AdbAuthTest, UnhandledPacket) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| uint16_t port = 19; |
| adbd_auth_send_tls_server_port(runner->Context(), port); |
| |
| // Send an unhandled packet. This should not reset the stack. |
| framework.SendAndWaitContext("XX", runner.get()); |
| |
| // Check that libauth did not reset the socket. |
| auto msg = framework.Recv(); |
| ASSERT_EQ(4, msg.size()); |
| ASSERT_EQ(msg[0], 'T'); |
| ASSERT_EQ(msg[1], 'P'); |
| ASSERT_EQ(msg[2], port); |
| ASSERT_EQ(msg[3], 0); |
| } |
| |
| TEST(AdbAuthTest, RegisterService) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| std::string instance_name = "in"; |
| std::string service_type= "t_p"; |
| uint16_t port = 65019; |
| auto result = adbd_auth_register_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str(), port); |
| ASSERT_EQ(ADBD_AUTH_REGISTER_OK, result); |
| |
| |
| auto msg = framework.Recv(); |
| ASSERT_EQ(2 + 1 + instance_name.size() + 1 + service_type.size() + sizeof(port), msg.size()); |
| ASSERT_EQ(msg[0], 'R'); |
| ASSERT_EQ(msg[1], 'S'); |
| |
| ASSERT_EQ(msg[2], instance_name.size()); |
| ASSERT_EQ(msg.substr(3, 2), instance_name); |
| |
| ASSERT_EQ(msg[5], service_type.size()); |
| ASSERT_EQ(msg.substr(6,3), service_type); |
| |
| uint8_t lsb = port & 0xFF; |
| uint8_t msb = (port >> 8) & 0xFF; |
| ASSERT_EQ((uint8_t)msg[9], lsb); |
| ASSERT_EQ((uint8_t)msg[10], msb); |
| } |
| |
| TEST(AdbAuthTest, RegisterServiceBadInstance) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| std::string instance_name(256, 'I'); |
| std::string service_type= "t_p"; |
| uint16_t port = 65019; |
| auto result = adbd_auth_register_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str(), port); |
| ASSERT_EQ(ADBD_AUTH_REGISTER_BAD_NAME, result); |
| } |
| |
| TEST(AdbAuthTest, RegisterServiceBadService) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| std::string instance_name = "in"; |
| std::string service_type(256, 'S'); |
| uint16_t port = 65019; |
| auto result = adbd_auth_register_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str(), port); |
| ASSERT_EQ(ADBD_AUTH_REGISTER_BAD_NAME, result); |
| } |
| |
| TEST(AdbAuthTest, UnregisterService) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| const std::string instance_name = "in"; |
| const std::string service_type= "t_p"; |
| auto result = adbd_auth_unregister_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str()); |
| ASSERT_EQ(ADBD_AUTH_UNREGISTER_OK, result); |
| |
| |
| auto msg = framework.Recv(); |
| ASSERT_EQ(2 + 1 + instance_name.size() + 1 + service_type.size(), msg.size()); |
| ASSERT_EQ(msg[0], 'U'); |
| ASSERT_EQ(msg[1], 'S'); |
| |
| ASSERT_EQ(msg[2], instance_name.size()); |
| ASSERT_EQ(msg.substr(3, 2), instance_name); |
| |
| ASSERT_EQ(msg[5], service_type.size()); |
| ASSERT_EQ(msg.substr(6,3), service_type); |
| } |
| |
| TEST(AdbAuthTest, UnregisterServiceBadInstance) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| std::string instance_name(256, 'I'); |
| std::string service_type= "t_p"; |
| auto result = adbd_auth_unregister_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str()); |
| ASSERT_EQ(ADBD_AUTH_UNREGISTER_BAD_NAME, result); |
| } |
| |
| TEST(AdbAuthTest, UnregisterServiceBadService) { |
| AdbdAuthCallbacksV2 callbacks{}; |
| callbacks.version = 2; |
| auto runner= CreateContextRunner(callbacks); |
| Framework framework{}; |
| |
| std::string instance_name = "in"; |
| std::string service_type(256, 'S'); |
| auto result = adbd_auth_unregister_service(runner->Context(), |
| instance_name.c_str(), service_type.c_str()); |
| ASSERT_EQ(ADBD_AUTH_UNREGISTER_BAD_NAME, result); |
| } |
| |
| TEST(AdbAuthTest, FrameworkConnectedCallback) { |
| AdbdAuthCallbacksV3 callbacks{}; |
| |
| static bool framework_detected = false; |
| callbacks.on_framework_connected = [] { |
| framework_detected = true; |
| }; |
| callbacks.version = 3; |
| auto runner= CreateContextRunner(callbacks); |
| |
| // Check that the framework connection has not been detected. |
| ASSERT_FALSE(framework_detected); |
| |
| // Connect the framework |
| Framework framework{}; |
| framework.SendAndWaitContext("XX", runner.get()); |
| |
| // Check that the framework connection has been detected. |
| ASSERT_TRUE(framework_detected); |
| } |