| /* |
| * Copyright (C) 2019 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "adbwifi/pairing/pairing_server.h" |
| |
| #include <sys/epoll.h> |
| #include <sys/eventfd.h> |
| |
| #include <atomic> |
| #include <deque> |
| #include <iomanip> |
| #include <mutex> |
| #include <sstream> |
| #include <thread> |
| #include <tuple> |
| #include <unordered_map> |
| #include <variant> |
| #include <vector> |
| |
| #include <adbwifi/pairing/pairing_connection.h> |
| #include <android-base/logging.h> |
| #include <android-base/parsenetaddress.h> |
| #include <android-base/thread_annotations.h> |
| #include <android-base/unique_fd.h> |
| #include <cutils/sockets.h> |
| |
| namespace adbwifi { |
| namespace pairing { |
| |
| using android::base::ScopedLockAssertion; |
| using android::base::unique_fd; |
| |
| namespace { |
| |
| // The implimentation has two background threads running: one to handle and |
| // accept any new pairing connection requests (socket accept), and the other to |
| // handle connection events (connection started, connection finished). |
| class PairingServerImpl : public PairingServer { |
| public: |
| virtual ~PairingServerImpl(); |
| |
| // All parameters must be non-empty. |
| explicit PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, |
| const Data& priv_key, int port); |
| |
| // Starts the pairing server. This call is non-blocking. Upon completion, |
| // if the pairing was successful, then |cb| will be called with the PublicKeyHeader |
| // containing the info of the trusted peer. Otherwise, |cb| will be |
| // called with an empty value. Start can only be called once in the lifetime |
| // of this object. |
| // |
| // Returns true if PairingServer was successfully started. Otherwise, |
| // returns false. |
| virtual bool start(PairingConnection::ResultCallback cb, void* opaque) override; |
| |
| private: |
| // Setup the server socket to accept incoming connections |
| bool setupServer(); |
| // Force stop the server thread. |
| void stopServer(); |
| |
| // handles a new pairing client connection |
| bool handleNewClientConnection(int fd) EXCLUDES(conn_mutex_); |
| |
| // ======== connection events thread ============= |
| std::mutex conn_mutex_; |
| std::condition_variable conn_cv_; |
| |
| using FdVal = int; |
| using ConnectionPtr = std::unique_ptr<PairingConnection>; |
| using NewConnectionEvent = std::tuple<unique_fd, ConnectionPtr>; |
| // <fd, PeerInfo.name, PeerInfo.guid, certificate> |
| using ConnectionFinishedEvent = std::tuple<FdVal, std::optional<std::string>, |
| std::optional<std::string>, std::optional<Data>>; |
| using ConnectionEvent = std::variant<NewConnectionEvent, ConnectionFinishedEvent>; |
| // Queue for connections to write into. We have a separate queue to read |
| // from, in order to minimize the time the server thread is blocked. |
| std::deque<ConnectionEvent> conn_write_queue_ GUARDED_BY(conn_mutex_); |
| std::deque<ConnectionEvent> conn_read_queue_; |
| // Map of fds to their PairingConnections currently running. |
| std::unordered_map<FdVal, ConnectionPtr> connections_; |
| |
| // Two threads launched when starting the pairing server: |
| // 1) A server thread that waits for incoming client connections, and |
| // 2) A connection events thread that synchonizes events from all of the |
| // clients, since each PairingConnection is running in it's own thread. |
| void startConnectionEventsThread(); |
| void startServerThread(); |
| |
| std::thread conn_events_thread_; |
| void connectionEventsWorker(); |
| std::thread server_thread_; |
| void serverWorker(); |
| bool is_terminate_ GUARDED_BY(conn_mutex_) = false; |
| |
| enum class State { |
| Ready, |
| Running, |
| Stopped, |
| }; |
| State state_ = State::Ready; |
| Data pswd_; |
| PeerInfo peer_info_; |
| Data cert_; |
| Data priv_key_; |
| int port_ = -1; |
| |
| PairingConnection::ResultCallback cb_; |
| void* opaque_ = nullptr; |
| bool got_valid_pairing_ = false; |
| |
| static const int kEpollConstSocket = 0; |
| // Used to break the server thread from epoll_wait |
| static const int kEpollConstEventFd = 1; |
| unique_fd epoll_fd_; |
| unique_fd server_fd_; |
| unique_fd event_fd_; |
| }; // PairingServerImpl |
| |
| PairingServerImpl::PairingServerImpl(const Data& pswd, const PeerInfo& peer_info, const Data& cert, |
| const Data& priv_key, int port) |
| : pswd_(pswd), peer_info_(peer_info), cert_(cert), priv_key_(priv_key), port_(port) { |
| CHECK(!pswd_.empty() && !cert_.empty() && !priv_key_.empty() && port_ > 0); |
| CHECK('\0' == peer_info.name[kPeerNameLength - 1] && |
| '\0' == peer_info.guid[kPeerGuidLength - 1] && strlen(peer_info.name) > 0 && |
| strlen(peer_info.guid) > 0); |
| } |
| |
| PairingServerImpl::~PairingServerImpl() { |
| // Since these connections have references to us, let's make sure they |
| // destruct before us. |
| if (server_thread_.joinable()) { |
| stopServer(); |
| server_thread_.join(); |
| } |
| |
| { |
| std::lock_guard<std::mutex> lock(conn_mutex_); |
| is_terminate_ = true; |
| } |
| conn_cv_.notify_one(); |
| if (conn_events_thread_.joinable()) { |
| conn_events_thread_.join(); |
| } |
| |
| // Notify the cb_ if it hasn't already. |
| if (!got_valid_pairing_ && cb_ != nullptr) { |
| cb_(nullptr, nullptr, opaque_); |
| } |
| } |
| |
| bool PairingServerImpl::start(PairingConnection::ResultCallback cb, void* opaque) { |
| cb_ = cb; |
| opaque_ = opaque; |
| |
| if (state_ != State::Ready) { |
| LOG(ERROR) << "PairingServer already running or stopped"; |
| return false; |
| } |
| |
| if (!setupServer()) { |
| LOG(ERROR) << "Unable to start PairingServer"; |
| state_ = State::Stopped; |
| return false; |
| } |
| |
| state_ = State::Running; |
| return true; |
| } |
| |
| void PairingServerImpl::stopServer() { |
| if (event_fd_.get() == -1) { |
| return; |
| } |
| uint64_t value = 1; |
| ssize_t rc = write(event_fd_.get(), &value, sizeof(value)); |
| if (rc == -1) { |
| // This can happen if the server didn't start. |
| PLOG(ERROR) << "write to eventfd failed"; |
| } else if (rc != sizeof(value)) { |
| LOG(FATAL) << "write to event returned short (" << rc << ")"; |
| } |
| } |
| |
| bool PairingServerImpl::setupServer() { |
| epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC)); |
| if (epoll_fd_ == -1) { |
| PLOG(ERROR) << "failed to create epoll fd"; |
| return false; |
| } |
| |
| event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)); |
| if (event_fd_ == -1) { |
| PLOG(ERROR) << "failed to create eventfd"; |
| return false; |
| } |
| |
| server_fd_.reset(socket_inaddr_any_server(port_, SOCK_STREAM)); |
| if (server_fd_.get() == -1) { |
| PLOG(ERROR) << "Failed to start pairing connection server"; |
| return false; |
| } else if (fcntl(server_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) { |
| PLOG(ERROR) << "Failed to make server socket cloexec"; |
| return false; |
| } else if (fcntl(server_fd_.get(), F_SETFD, O_NONBLOCK) != 0) { |
| PLOG(ERROR) << "Failed to make server socket nonblocking"; |
| return false; |
| } |
| |
| startConnectionEventsThread(); |
| startServerThread(); |
| return true; |
| } |
| |
| void PairingServerImpl::startServerThread() { |
| server_thread_ = std::thread([this]() { serverWorker(); }); |
| } |
| |
| void PairingServerImpl::startConnectionEventsThread() { |
| conn_events_thread_ = std::thread([this]() { connectionEventsWorker(); }); |
| } |
| |
| void PairingServerImpl::serverWorker() { |
| { |
| struct epoll_event event; |
| event.events = EPOLLIN; |
| event.data.u64 = kEpollConstSocket; |
| CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, server_fd_.get(), &event)); |
| } |
| |
| { |
| struct epoll_event event; |
| event.events = EPOLLIN; |
| event.data.u64 = kEpollConstEventFd; |
| CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event)); |
| } |
| |
| while (true) { |
| struct epoll_event events[2]; |
| int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 2, -1)); |
| if (rc == -1) { |
| PLOG(ERROR) << "epoll_wait failed"; |
| return; |
| } else if (rc == 0) { |
| LOG(ERROR) << "epoll_wait returned 0"; |
| return; |
| } |
| |
| for (int i = 0; i < rc; ++i) { |
| struct epoll_event& event = events[i]; |
| switch (event.data.u64) { |
| case kEpollConstSocket: |
| handleNewClientConnection(server_fd_.get()); |
| break; |
| case kEpollConstEventFd: |
| uint64_t dummy; |
| int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy))); |
| if (rc != sizeof(dummy)) { |
| PLOG(FATAL) << "failed to read from eventfd (rc=" << rc << ")"; |
| } |
| return; |
| } |
| } |
| } |
| } |
| |
| void PairingServerImpl::connectionEventsWorker() { |
| for (;;) { |
| // Transfer the write queue to the read queue. |
| { |
| std::unique_lock<std::mutex> lock(conn_mutex_); |
| ScopedLockAssertion assume_locked(conn_mutex_); |
| |
| if (is_terminate_) { |
| // We check |is_terminate_| twice because condition_variable's |
| // notify() only wakes up a thread if it is in the wait state |
| // prior to notify(). Furthermore, we aren't holding the mutex |
| // when processing the events in |conn_read_queue_|. |
| return; |
| } |
| if (conn_write_queue_.empty()) { |
| // We need to wait for new events, or the termination signal. |
| conn_cv_.wait(lock, [this]() REQUIRES(conn_mutex_) { |
| return (is_terminate_ || !conn_write_queue_.empty()); |
| }); |
| } |
| if (is_terminate_) { |
| // We're done. |
| return; |
| } |
| // Move all events into the read queue. |
| conn_read_queue_ = std::move(conn_write_queue_); |
| conn_write_queue_.clear(); |
| } |
| |
| // Process all events in the read queue. |
| while (conn_read_queue_.size() > 0) { |
| auto& event = conn_read_queue_.front(); |
| if (auto* p = std::get_if<NewConnectionEvent>(&event)) { |
| // Ignore if we are already at the max number of connections |
| if (connections_.size() >= internal::kMaxConnections) { |
| conn_read_queue_.pop_front(); |
| continue; |
| } |
| auto [ufd, connection] = std::move(*p); |
| int fd = ufd.release(); |
| bool started = connection->start( |
| fd, |
| [fd](const PeerInfo* peer_info, const Data* cert, void* opaque) { |
| auto* p = reinterpret_cast<PairingServerImpl*>(opaque); |
| |
| ConnectionFinishedEvent event; |
| if (peer_info != nullptr && cert != nullptr) { |
| event = std::make_tuple(fd, std::string(peer_info->name), |
| std::string(peer_info->guid), Data(*cert)); |
| } else { |
| event = std::make_tuple(fd, std::nullopt, std::nullopt, |
| std::nullopt); |
| } |
| { |
| std::lock_guard<std::mutex> lock(p->conn_mutex_); |
| p->conn_write_queue_.push_back(std::move(event)); |
| } |
| p->conn_cv_.notify_one(); |
| }, |
| this); |
| if (!started) { |
| LOG(ERROR) << "PairingServer unable to start a PairingConnection fd=" << fd; |
| ufd.reset(fd); |
| } else { |
| connections_[fd] = std::move(connection); |
| } |
| } else if (auto* p = std::get_if<ConnectionFinishedEvent>(&event)) { |
| auto [fd, name, guid, cert] = std::move(*p); |
| if (name.has_value() && guid.has_value() && cert.has_value() && !name->empty() && |
| !guid->empty() && !cert->empty()) { |
| // Valid pairing. Let's shutdown the server and close any |
| // pairing connections in progress. |
| stopServer(); |
| connections_.clear(); |
| |
| CHECK_LE(name->size(), kPeerNameLength); |
| CHECK_LE(guid->size(), kPeerGuidLength); |
| PeerInfo info = {}; |
| strncpy(info.name, name->data(), name->size()); |
| strncpy(info.guid, guid->data(), guid->size()); |
| |
| cb_(&info, &*cert, opaque_); |
| |
| got_valid_pairing_ = true; |
| return; |
| } |
| // Invalid pairing. Close the invalid connection. |
| if (connections_.find(fd) != connections_.end()) { |
| connections_.erase(fd); |
| } |
| } |
| conn_read_queue_.pop_front(); |
| } |
| } |
| } |
| |
| bool PairingServerImpl::handleNewClientConnection(int fd) { |
| unique_fd ufd(TEMP_FAILURE_RETRY(accept4(fd, nullptr, nullptr, SOCK_CLOEXEC))); |
| if (ufd == -1) { |
| PLOG(WARNING) << "adb_socket_accept failed fd=" << fd; |
| return false; |
| } |
| auto connection = PairingConnection::create(PairingConnection::Role::Server, pswd_, peer_info_, |
| cert_, priv_key_); |
| if (connection == nullptr) { |
| LOG(ERROR) << "PairingServer unable to create a PairingConnection fd=" << fd; |
| return false; |
| } |
| // send the new connection to the connection thread for further processing |
| NewConnectionEvent event = std::make_tuple(std::move(ufd), std::move(connection)); |
| { |
| std::lock_guard<std::mutex> lock(conn_mutex_); |
| conn_write_queue_.push_back(std::move(event)); |
| } |
| conn_cv_.notify_one(); |
| |
| return true; |
| } |
| |
| } // namespace |
| |
| // static |
| std::unique_ptr<PairingServer> PairingServer::create(const Data& pswd, const PeerInfo& peer_info, |
| const Data& cert, const Data& priv_key, |
| int port) { |
| if (pswd.empty() || cert.empty() || priv_key.empty() || port <= 0) { |
| return nullptr; |
| } |
| // Make sure peer_info has a non-empty, null-terminated string for guid and |
| // name. |
| if ('\0' != peer_info.name[kPeerNameLength - 1] || |
| '\0' != peer_info.guid[kPeerGuidLength - 1] || strlen(peer_info.name) == 0 || |
| strlen(peer_info.guid) == 0) { |
| LOG(ERROR) << "The GUID/short name fields are empty or not null-terminated"; |
| return nullptr; |
| } |
| |
| if (port != kDefaultPairingPort) { |
| LOG(WARNING) << "Starting server with non-default pairing port=" << port; |
| } |
| |
| return std::unique_ptr<PairingServer>( |
| new PairingServerImpl(pswd, peer_info, cert, priv_key, port)); |
| } |
| |
| } // namespace pairing |
| } // namespace adbwifi |