Establish webRTC media transport over TCP

Add "?use_tcp=true" to the URL options to use it.

Bug: 151278089
Change-Id: If813ba419795a4767e0571866b140f0812f3b5c5
Merged-In: If813ba419795a4767e0571866b140f0812f3b5c5
diff --git a/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp b/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
index eabe180..60fcfb1 100644
--- a/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
+++ b/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
@@ -180,7 +180,9 @@
 "a=rtcp-fb:96 nack pli\r\n";
 
         ss <<
-"m=video 9 UDP/TLS/RTP/SAVPF 96 97\r\n"
+"m=video 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/TLS/RTP/SAVPF 96 97\r\n"
 "c=IN IP4 0.0.0.0\r\n"
 "a=rtcp:9 IN IP4 0.0.0.0\r\n";
 
@@ -211,7 +213,9 @@
 
         if (!(mOptions & OptionBits::disableAudio)) {
             ss <<
-"m=audio 9 UDP/TLS/RTP/SAVPF 98\r\n"
+"m=audio 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/TLS/RTP/SAVPF 98\r\n"
 "c=IN IP4 0.0.0.0\r\n"
 "a=rtcp:9 IN IP4 0.0.0.0\r\n";
 
@@ -237,7 +241,9 @@
 
         if (mOptions & OptionBits::enableData) {
             ss <<
-"m=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n"
+"m=application 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/DTLS/SCTP webrtc-datachannel\r\n"
 "c=IN IP4 0.0.0.0\r\n"
 "a=sctp-port:5000\r\n";
 
@@ -411,6 +417,9 @@
         auto rtp = std::make_shared<RTPSocketHandler>(
                 mRunLoop,
                 mServerState,
+                (mOptions & OptionBits::useTCP)
+                    ? RTPSocketHandler::TransportType::TCP
+                    : RTPSocketHandler::TransportType::UDP,
                 PF_INET,
                 trackMask,
                 session);
@@ -427,15 +436,25 @@
 
     auto localIPString = rtp->getLocalIPString();
 
-    // see rfc8445, 5.1.2.1. for the derivation of "2122121471" below.
-    reply["candidate"] =
-                "candidate:0 1 UDP 2122121471 "
-                + localIPString
-                + " "
-                + std::to_string(rtp->getLocalPort())
-                + " typ host generation 0 ufrag "
-                + rtp->getLocalUFrag();
+    std::stringstream ss;
+    ss << "candidate:0 1 ";
 
+    if (mOptions & OptionBits::useTCP) {
+        ss << "tcp";
+    } else {
+        ss << "UDP";
+    }
+
+    // see rfc8445, 5.1.2.1. for the derivation of "2122121471" below.
+    ss << " 2122121471 " << localIPString << " " << rtp->getLocalPort() << " typ host ";
+
+    if (mOptions & OptionBits::useTCP) {
+        ss << "tcptype passive ";
+    }
+
+    ss << "generation 0 ufrag " << rtp->getLocalUFrag();
+
+    reply["candidate"] = ss.str();
     reply["mlineIndex"] = static_cast<Json::UInt64>(mlineIndex);
 
     Json::FastWriter json_writer;
@@ -609,6 +628,9 @@
         } else if (name == "enable_data" && boolValue) {
             auto mask = OptionBits::enableData;
             mOptions = (mOptions & ~mask) | (boolValue ? mask : 0);
+        } else if (name == "use_tcp" && boolValue) {
+            auto mask = OptionBits::useTCP;
+            mOptions = (mOptions & ~mask) | (boolValue ? mask : 0);
         }
     }
 }
diff --git a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
index ae18a2f..3885039 100644
--- a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
+++ b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
@@ -39,6 +39,7 @@
 // These are the ports we currently open in the firewall (15550..15557)
 static constexpr int kPortRangeBegin = 15550;
 static constexpr int kPortRangeEnd = 15558;
+static constexpr int kPortRangeEndTcp = 15551;
 
 static socklen_t getSockAddrLen(const sockaddr_storage &addr) {
     switch (addr.ss_family) {
@@ -52,7 +53,7 @@
     }
 }
 
-static int acquirePort(int sockfd, int domain) {
+static int acquirePort(int sockfd, int domain, bool tcp) {
     sockaddr_storage addr;
     uint16_t* port_ptr;
 
@@ -85,6 +86,10 @@
         if (errno != EADDRINUSE) {
             return -1;
         }
+        // for now, limit to one client / one tcp port to minimize
+        // complexity for using WebRTC over TCP over ssh tunnels
+        if (tcp && port == kPortRangeEndTcp)
+            break;
         // else try the next port
     }
 
@@ -94,24 +99,46 @@
 RTPSocketHandler::RTPSocketHandler(
         std::shared_ptr<RunLoop> runLoop,
         std::shared_ptr<ServerState> serverState,
+        TransportType transportType,
         int domain,
         uint32_t trackMask,
         std::shared_ptr<RTPSession> session)
     : mRunLoop(runLoop),
       mServerState(serverState),
+      mTransportType(transportType),
       mTrackMask(trackMask),
       mSession(session),
       mSendPending(false),
-      mDTLSConnected(false) {
-    int sock = socket(domain, SOCK_DGRAM, 0);
+      mDTLSConnected(false),
+      mInBufferLength(0) {
+    bool tcp = mTransportType == TransportType::TCP;
+
+    int sock = socket(domain, tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
+
+    if (tcp) {
+        static constexpr int yes = 1;
+        auto res = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
+        CHECK(!res);
+    }
 
     makeFdNonblocking(sock);
-    mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);
 
-    mLocalPort = acquirePort(sock, domain);
+    mLocalPort = acquirePort(sock, domain, tcp);
 
     CHECK(mLocalPort > 0);
 
+    if (tcp) {
+        auto res = listen(sock, 4);
+        CHECK(!res);
+    }
+
+    auto tmp = std::make_shared<PlainSocket>(mRunLoop, sock);
+    if (tcp) {
+        mServerSocket = tmp;
+    } else {
+        mSocket = tmp;
+    }
+
     auto videoPacketizer =
         (trackMask & TRACK_VIDEO)
             ? mServerState->getVideoPacketizer() : nullptr;
@@ -151,7 +178,76 @@
 }
 
 void RTPSocketHandler::run() {
-    mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+    if (mTransportType == TransportType::TCP) {
+        mServerSocket->postRecv(
+                makeSafeCallback(this, &RTPSocketHandler::onTCPConnect));
+    } else {
+        mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+    }
+}
+
+void RTPSocketHandler::onTCPConnect() {
+    int sock = accept(mServerSocket->fd(), nullptr, 0);
+
+    if (sock < 0) {
+        LOG(ERROR) << "RTPSocketHandler: Failed to accept client";
+        mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPConnect));
+        return;
+    }
+
+    LOG(INFO) << "RTPSocketHandler: Accepted client";
+
+    makeFdNonblocking(sock);
+
+    mClientAddrLen = sizeof(mClientAddr);
+
+    int res = getpeername(
+            sock, reinterpret_cast<sockaddr *>(&mClientAddr), &mClientAddrLen);
+
+    CHECK(!res);
+
+    mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);
+
+    mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPReceive));
+}
+
+void RTPSocketHandler::onTCPReceive() {
+    mInBuffer.resize(mInBuffer.size() + 8192);
+
+    auto n = mSocket->recv(
+            mInBuffer.data() + mInBufferLength, mInBuffer.size() - mInBufferLength);
+
+    if (n == 0) {
+        LOG(INFO) << "Client disconnected.";
+        return;
+    }
+
+    mInBufferLength += n;
+
+    size_t offset = 0;
+    while (offset + 1 < mInBufferLength) {
+        auto packetLength = U16_AT(mInBuffer.data() + offset);
+        offset += 2;
+
+        if (offset + packetLength > mInBufferLength) {
+            break;
+        }
+
+        onPacketReceived(
+                mClientAddr,
+                mClientAddrLen,
+                mInBuffer.data() + offset,
+                packetLength);
+
+        offset += packetLength;
+    }
+
+    if (offset > 0) {
+        mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + offset);
+        mInBufferLength -= offset;
+    }
+
+    mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPReceive));
 }
 
 void RTPSocketHandler::onReceive() {
@@ -165,6 +261,22 @@
     auto n = mSocket->recvfrom(
             data, buffer.size(), reinterpret_cast<sockaddr *>(&addr), &addrLen);
 
+    onPacketReceived(addr, addrLen, data, n);
+
+    mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+}
+
+void RTPSocketHandler::onPacketReceived(
+        const sockaddr_storage &addr,
+        socklen_t addrLen,
+        uint8_t *data,
+        size_t n) {
+#if 0
+    std::cout << "========================================" << std::endl;
+
+    hexdump(data, n);
+#endif
+
     STUNMessage msg(data, n);
     if (!msg.isValid()) {
         if (mDTLSConnected) {
@@ -203,7 +315,6 @@
             onDTLSReceive(data, static_cast<size_t>(n));
         }
 
-        run();
         return;
     }
 
@@ -212,7 +323,6 @@
 
         if (!matchesSession(msg)) {
             LOG(WARNING) << "Unknown session or no USERNAME.";
-            run();
             return;
         }
 
@@ -299,15 +409,7 @@
 
         // response.dump(answerPassword);
 
-        auto res =
-            mSocket->sendto(
-                    response.data(),
-                    response.size(),
-                    reinterpret_cast<const sockaddr *>(&addr),
-                    addrLen);
-
-        CHECK_GT(res, 0);
-        CHECK_EQ(static_cast<size_t>(res), response.size());
+        queueDatagram(addr, response.data(), response.size());
 
         if (!mSession->isActive()) {
             mSession->setRemoteAddress(addr);
@@ -336,8 +438,6 @@
             mDTLS->connect(mSession->remoteAddress());
         }
     }
-
-    run();
 }
 
 bool RTPSocketHandler::matchesSession(const STUNMessage &msg) const {
@@ -434,6 +534,22 @@
 
 void RTPSocketHandler::queueDatagram(
         const sockaddr_storage &addr, const void *data, size_t size) {
+    if (mTransportType == TransportType::TCP) {
+        std::vector copy(
+                static_cast<const uint8_t *>(data),
+                static_cast<const uint8_t *>(data) + size);
+
+        mRunLoop->post(
+                makeSafeCallback<RTPSocketHandler>(
+                    this,
+                    [copy](RTPSocketHandler *me) {
+                        // addr is ignored and assumed to be the connected endpoint's.
+                        me->queueTCPOutputPacket(copy.data(), copy.size());
+                    }));
+
+        return;
+    }
+
     auto datagram = std::make_shared<Datagram>(addr, data, size);
 
     CHECK_LE(size, RTPSocketHandler::kMaxUDPPayloadSize);
@@ -450,6 +566,58 @@
                 }));
 }
 
+void RTPSocketHandler::queueTCPOutputPacket(const uint8_t *data, size_t size) {
+    uint8_t framing[2];
+    framing[0] = size >> 8;
+    framing[1] = size & 0xff;
+
+    std::copy(framing, framing + sizeof(framing), std::back_inserter(mOutBuffer));
+    std::copy(data, data + size, std::back_inserter(mOutBuffer));
+
+    if (!mSendPending) {
+        mSendPending = true;
+
+        mSocket->postSend(
+                makeSafeCallback(this, &RTPSocketHandler::sendTCPOutputData));
+    }
+}
+
+void RTPSocketHandler::sendTCPOutputData() {
+    mSendPending = false;
+
+    const size_t size = mOutBuffer.size();
+    size_t offset = 0;
+
+    bool disconnected = false;
+
+    while (offset < size) {
+        auto n = mSocket->send(mOutBuffer.data() + offset, size - offset);
+
+        if (n < 0) {
+            if (errno == EINTR) {
+                continue;
+            }
+
+            LOG(FATAL) << "Should not be here.";
+        } else if (n == 0) {
+            offset = size;
+            disconnected = true;
+            break;
+        }
+
+        offset += static_cast<size_t>(n);
+    }
+
+    mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
+
+    if (!mOutBuffer.empty() && !disconnected) {
+        mSendPending = true;
+
+        mSocket->postSend(
+                makeSafeCallback(this, &RTPSocketHandler::sendTCPOutputData));
+    }
+}
+
 void RTPSocketHandler::scheduleDrainOutQueue() {
     CHECK(!mSendPending);
 
diff --git a/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h b/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
index fef87cb..f22de55 100644
--- a/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
+++ b/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
@@ -52,6 +52,7 @@
         bundleTracks                        = 2,
         enableData                          = 4,
         useSingleCertificateForAllTracks    = 8,
+        useTCP                              = 16,
     };
 
     using TouchSink = android::TouchSink;
diff --git a/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h b/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
index b0f8ff8..f1dc4a2 100644
--- a/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
+++ b/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
@@ -16,7 +16,7 @@
 
 #pragma once
 
-#include <https/BufferedSocket.h>
+#include <https/PlainSocket.h>
 #include <https/RunLoop.h>
 #include <webrtc/DTLS.h>
 #include <webrtc/RTPSender.h>
@@ -40,9 +40,15 @@
     static constexpr uint32_t TRACK_AUDIO = 2;
     static constexpr uint32_t TRACK_DATA  = 4;
 
+    enum class TransportType {
+        UDP,
+        TCP,
+    };
+
     explicit RTPSocketHandler(
             std::shared_ptr<RunLoop> runLoop,
             std::shared_ptr<ServerState> serverState,
+            TransportType type,
             int domain,
             uint32_t trackMask,
             std::shared_ptr<RTPSession> session);
@@ -78,6 +84,7 @@
 
     std::shared_ptr<RunLoop> mRunLoop;
     std::shared_ptr<ServerState> mServerState;
+    TransportType mTransportType;
     uint16_t mLocalPort;
     uint32_t mTrackMask;
     std::shared_ptr<RTPSession> mSession;
@@ -92,6 +99,16 @@
 
     std::shared_ptr<RTPSender> mRTPSender;
 
+    // for TransportType TCP:
+    std::shared_ptr<PlainSocket> mServerSocket;
+    sockaddr_storage mClientAddr;
+    socklen_t mClientAddrLen;
+
+    std::vector<uint8_t> mInBuffer;
+    size_t mInBufferLength;
+
+    std::vector<uint8_t> mOutBuffer;
+
     void onReceive();
     void onDTLSReceive(const uint8_t *data, size_t size);
 
@@ -103,6 +120,18 @@
     void drainOutQueue();
 
     int onSRTPReceive(uint8_t *data, size_t size);
+
+    void onTCPConnect();
+    void onTCPReceive();
+
+    void onPacketReceived(
+            const sockaddr_storage &addr,
+            socklen_t addrLen,
+            uint8_t *data,
+            size_t size);
+
+    void queueTCPOutputPacket(const uint8_t *data, size_t size);
+    void sendTCPOutputData();
 };