Resolve public IP automatically with STUN server

Optionally, if FLAGS_public_ip is empty or "0.0.0.0", contact an
external STUN server to resolve our public IP.

This defaults to stun.l.google.com:19302 but can be pointed at other
servers via FLAGS_stun_server.

Bug: 147509789
Change-Id: Ib7d884ea7c51d5f9401727361b64bc84f8e5d33a
diff --git a/host/frontend/gcastv2/webrtc/Android.bp b/host/frontend/gcastv2/webrtc/Android.bp
index 85e0644..e860461 100644
--- a/host/frontend/gcastv2/webrtc/Android.bp
+++ b/host/frontend/gcastv2/webrtc/Android.bp
@@ -29,6 +29,7 @@
         "SCTPHandler.cpp",
         "SDP.cpp",
         "ServerState.cpp",
+        "STUNClient.cpp",
         "STUNMessage.cpp",
         "Utils.cpp",
         "VP8Packetizer.cpp",
diff --git a/host/frontend/gcastv2/webrtc/STUNClient.cpp b/host/frontend/gcastv2/webrtc/STUNClient.cpp
new file mode 100644
index 0000000..d29c715
--- /dev/null
+++ b/host/frontend/gcastv2/webrtc/STUNClient.cpp
@@ -0,0 +1,175 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Utils.h"
+
+#include <webrtc/STUNClient.h>
+#include <webrtc/STUNMessage.h>
+
+#include <https/SafeCallbackable.h>
+#include <https/Support.h>
+
+#include <android-base/logging.h>
+
+STUNClient::STUNClient(
+        std::shared_ptr<RunLoop> runLoop,
+        const sockaddr_in &addr,
+        Callback cb)
+    : mRunLoop(runLoop),
+      mRemoteAddr(addr),
+      mCallback(cb),
+      mTimeoutToken(0),
+      mNumRetriesLeft(kMaxNumRetries) {
+
+    int sock = socket(PF_INET, SOCK_DGRAM, 0);
+    makeFdNonblocking(sock);
+
+    sockaddr_in addrV4;
+    memset(addrV4.sin_zero, 0, sizeof(addrV4.sin_zero));
+    addrV4.sin_family = AF_INET;
+    addrV4.sin_port = 0;
+    addrV4.sin_addr.s_addr = INADDR_ANY;
+
+    int res = bind(
+            sock,
+            reinterpret_cast<const sockaddr *>(&addrV4),
+            sizeof(addrV4));
+
+    CHECK(!res);
+
+    sockaddr_in tmp;
+    socklen_t tmpLen = sizeof(tmp);
+
+    res = getsockname(sock, reinterpret_cast<sockaddr *>(&tmp), &tmpLen);
+    CHECK(!res);
+
+    LOG(VERBOSE) << "local port: " << ntohs(tmp.sin_port);
+
+    mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);
+}
+
+void STUNClient::run() {
+    LOG(VERBOSE) << "STUNClient::run()";
+
+    scheduleRequest();
+}
+
+void STUNClient::onSendRequest() {
+    LOG(VERBOSE) << "STUNClient::onSendRequest";
+
+    std::vector<uint8_t> transactionID { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 };
+
+    STUNMessage msg(0x0001 /* Binding Request */, transactionID.data());
+    msg.addFingerprint();
+
+    ssize_t n;
+
+    do {
+        n = sendto(
+            mSocket->fd(),
+            msg.data(),
+            msg.size(),
+            0 /* flags */,
+            reinterpret_cast<const sockaddr *>(&mRemoteAddr),
+            sizeof(mRemoteAddr));
+
+    } while (n < 0 && errno == EINTR);
+
+    CHECK_GT(n, 0);
+
+    LOG(VERBOSE) << "Sent BIND request, awaiting response";
+
+    mSocket->postRecv(
+            makeSafeCallback(this, &STUNClient::onReceiveResponse));
+}
+
+void STUNClient::onReceiveResponse() {
+    LOG(VERBOSE) << "Received STUN response";
+
+    std::vector<uint8_t> buffer(kMaxUDPPayloadSize);
+
+    uint8_t *data = buffer.data();
+
+    sockaddr_storage addr;
+    socklen_t addrLen = sizeof(addr);
+
+    auto n = mSocket->recvfrom(
+            data, buffer.size(), reinterpret_cast<sockaddr *>(&addr), &addrLen);
+
+    CHECK_GT(n, 0);
+
+    STUNMessage msg(data, n);
+    CHECK(msg.isValid());
+
+    // msg.dump();
+
+    if (msg.type() == 0x0101 /* Binding Response */) {
+        const uint8_t *data;
+        size_t size;
+        if (msg.findAttribute(
+                    0x0020 /* XOR-MAPPED-ADDRESS */,
+                    reinterpret_cast<const void **>(&data),
+                    &size)) {
+
+            CHECK_EQ(size, 8u);
+            CHECK_EQ(data[1], 0x01u);  // We only deal with IPv4 for now.
+
+            static constexpr uint32_t kMagicCookie = 0x2112a442;
+
+            uint16_t port = U16_AT(&data[2]) ^ (kMagicCookie >> 16);
+            uint32_t ip = U32_AT(&data[4]) ^ kMagicCookie;
+
+            LOG(VERBOSE) << "translated port: " << port;
+
+            mCallback(
+                    0 /* result */,
+                    StringPrintf(
+                        "%u.%u.%u.%u",
+                        ip >> 24,
+                        (ip >> 16) & 0xff,
+                        (ip >> 8) & 0xff,
+                        ip & 0xff));
+
+            mRunLoop->cancelToken(mTimeoutToken);
+            mTimeoutToken = 0;
+        }
+    }
+}
+
+void STUNClient::scheduleRequest() {
+    CHECK_EQ(mTimeoutToken, 0);
+
+    mSocket->postSend(
+            makeSafeCallback(this, &STUNClient::onSendRequest));
+
+    mTimeoutToken = mRunLoop->postWithDelay(
+            kTimeoutDelay,
+            makeSafeCallback(this, &STUNClient::onTimeout));
+
+}
+
+void STUNClient::onTimeout() {
+    mTimeoutToken = 0;
+
+    if (mNumRetriesLeft == 0) {
+        mCallback(-ETIMEDOUT, "");
+        return;
+    }
+
+    --mNumRetriesLeft;
+    scheduleRequest();
+}
+
diff --git a/host/frontend/gcastv2/webrtc/include/webrtc/STUNClient.h b/host/frontend/gcastv2/webrtc/include/webrtc/STUNClient.h
new file mode 100644
index 0000000..87bce72
--- /dev/null
+++ b/host/frontend/gcastv2/webrtc/include/webrtc/STUNClient.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <https/PlainSocket.h>
+#include <https/RunLoop.h>
+#include <memory>
+
+#include <arpa/inet.h>
+
+struct STUNClient : public std::enable_shared_from_this<STUNClient> {
+    using Callback = std::function<void(int, const std::string &)>;
+
+    explicit STUNClient(
+            std::shared_ptr<RunLoop> runLoop,
+            const sockaddr_in &addr,
+            Callback cb);
+
+    void run();
+
+private:
+    static constexpr size_t kMaxUDPPayloadSize = 1536;
+    static constexpr size_t kMaxNumRetries = 5;
+
+    static constexpr std::chrono::duration kTimeoutDelay =
+        std::chrono::seconds(1);
+
+    std::shared_ptr<RunLoop> mRunLoop;
+    sockaddr_in mRemoteAddr;
+    Callback mCallback;
+
+    std::shared_ptr<PlainSocket> mSocket;
+
+    RunLoop::Token mTimeoutToken;
+    size_t mNumRetriesLeft;
+
+    void onSendRequest();
+    void onReceiveResponse();
+
+    void scheduleRequest();
+    void onTimeout();
+};
diff --git a/host/frontend/gcastv2/webrtc/webRTC.cpp b/host/frontend/gcastv2/webrtc/webRTC.cpp
index 8aaa30e..dfa621b 100644
--- a/host/frontend/gcastv2/webrtc/webRTC.cpp
+++ b/host/frontend/gcastv2/webrtc/webRTC.cpp
@@ -14,11 +14,14 @@
  * limitations under the License.
  */
 
+#include "Utils.h"
+
 #include <webrtc/AdbWebSocketHandler.h>
 #include <webrtc/DTLS.h>
 #include <webrtc/MyWebSocketHandler.h>
 #include <webrtc/RTPSocketHandler.h>
 #include <webrtc/ServerState.h>
+#include <webrtc/STUNClient.h>
 #include <webrtc/STUNMessage.h>
 
 #include <https/HTTPServer.h>
@@ -31,6 +34,8 @@
 #include <iostream>
 #include <unordered_map>
 
+#include <netdb.h>
+
 #include <gflags/gflags.h>
 
 DEFINE_int32(http_server_port, 8443, "The port for the http server.");
@@ -55,12 +60,80 @@
 
 DEFINE_string(adb, "", "Interface:port of local adb service.");
 
+DEFINE_string(
+        stun_server,
+        "stun.l.google.com:19302",
+        "host:port of STUN server to use for public address resolution");
+
 int main(int argc, char **argv) {
     ::gflags::ParseCommandLineFlags(&argc, &argv, true);
 
     SSLSocket::Init();
     DTLS::Init();
 
+    if (FLAGS_public_ip.empty() || FLAGS_public_ip == "0.0.0.0") {
+        // NOTE: We only contact the external STUN server once upon startup
+        // to determine our own public IP.
+        // This only works if NAT does not remap ports, i.e. a local port 15550
+        // is visible to the outside world on port 15550 as well.
+        // If this condition is not met, this code will have to be modified
+        // and a STUN request made for each locally bound socket before
+        // fulfilling a "MyWebSocketHandler::getCandidate" ICE request.
+
+        const addrinfo kHints = {
+            AI_ADDRCONFIG,
+            PF_INET,
+            SOCK_DGRAM,
+            IPPROTO_UDP,
+            0,  // ai_addrlen
+            nullptr,  // ai_addr
+            nullptr,  // ai_canonname
+            nullptr  // ai_next
+        };
+
+        auto pieces = SplitString(FLAGS_stun_server, ':');
+        CHECK_EQ(pieces.size(), 2u);
+
+        addrinfo *infos;
+        CHECK(!getaddrinfo(pieces[0].c_str(), pieces[1].c_str(), &kHints, &infos));
+
+        sockaddr_storage stunAddr;
+        memcpy(&stunAddr, infos->ai_addr, infos->ai_addrlen);
+
+        freeaddrinfo(infos);
+        infos = nullptr;
+
+        CHECK_EQ(stunAddr.ss_family, AF_INET);
+
+        std::mutex lock;
+        std::condition_variable cond;
+        bool done = false;
+
+        auto runLoop = std::make_shared<RunLoop>("STUN");
+
+        auto stunClient = std::make_shared<STUNClient>(
+                runLoop,
+                reinterpret_cast<const sockaddr_in &>(stunAddr),
+                [&lock, &cond, &done](int result, const std::string &myPublicIp) {
+                    CHECK(!result);
+                    LOG(INFO)
+                        << "STUN-discovered public IP: " << myPublicIp;
+
+                    FLAGS_public_ip = myPublicIp;
+
+                    std::lock_guard autoLock(lock);
+                    done = true;
+                    cond.notify_all();
+                });
+
+        stunClient->run();
+
+        std::unique_lock autoLock(lock);
+        while (!done) {
+            cond.wait(autoLock);
+        }
+    }
+
     auto runLoop = RunLoop::main();
 
     auto state = std::make_shared<ServerState>(