Establish webRTC media transport over TCP
Add "?use_tcp=true" to the URL options to use it.
Bug: 151278089
Change-Id: If813ba419795a4767e0571866b140f0812f3b5c5
Merged-In: If813ba419795a4767e0571866b140f0812f3b5c5
diff --git a/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp b/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
index eabe180..60fcfb1 100644
--- a/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
+++ b/host/frontend/gcastv2/webrtc/MyWebSocketHandler.cpp
@@ -180,7 +180,9 @@
"a=rtcp-fb:96 nack pli\r\n";
ss <<
-"m=video 9 UDP/TLS/RTP/SAVPF 96 97\r\n"
+"m=video 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/TLS/RTP/SAVPF 96 97\r\n"
"c=IN IP4 0.0.0.0\r\n"
"a=rtcp:9 IN IP4 0.0.0.0\r\n";
@@ -211,7 +213,9 @@
if (!(mOptions & OptionBits::disableAudio)) {
ss <<
-"m=audio 9 UDP/TLS/RTP/SAVPF 98\r\n"
+"m=audio 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/TLS/RTP/SAVPF 98\r\n"
"c=IN IP4 0.0.0.0\r\n"
"a=rtcp:9 IN IP4 0.0.0.0\r\n";
@@ -237,7 +241,9 @@
if (mOptions & OptionBits::enableData) {
ss <<
-"m=application 9 UDP/DTLS/SCTP webrtc-datachannel\r\n"
+"m=application 9 "
+<< ((mOptions & OptionBits::useTCP) ? "TCP" : "UDP")
+<< "/DTLS/SCTP webrtc-datachannel\r\n"
"c=IN IP4 0.0.0.0\r\n"
"a=sctp-port:5000\r\n";
@@ -411,6 +417,9 @@
auto rtp = std::make_shared<RTPSocketHandler>(
mRunLoop,
mServerState,
+ (mOptions & OptionBits::useTCP)
+ ? RTPSocketHandler::TransportType::TCP
+ : RTPSocketHandler::TransportType::UDP,
PF_INET,
trackMask,
session);
@@ -427,15 +436,25 @@
auto localIPString = rtp->getLocalIPString();
- // see rfc8445, 5.1.2.1. for the derivation of "2122121471" below.
- reply["candidate"] =
- "candidate:0 1 UDP 2122121471 "
- + localIPString
- + " "
- + std::to_string(rtp->getLocalPort())
- + " typ host generation 0 ufrag "
- + rtp->getLocalUFrag();
+ std::stringstream ss;
+ ss << "candidate:0 1 ";
+ if (mOptions & OptionBits::useTCP) {
+ ss << "tcp";
+ } else {
+ ss << "UDP";
+ }
+
+ // see rfc8445, 5.1.2.1. for the derivation of "2122121471" below.
+ ss << " 2122121471 " << localIPString << " " << rtp->getLocalPort() << " typ host ";
+
+ if (mOptions & OptionBits::useTCP) {
+ ss << "tcptype passive ";
+ }
+
+ ss << "generation 0 ufrag " << rtp->getLocalUFrag();
+
+ reply["candidate"] = ss.str();
reply["mlineIndex"] = static_cast<Json::UInt64>(mlineIndex);
Json::FastWriter json_writer;
@@ -609,6 +628,9 @@
} else if (name == "enable_data" && boolValue) {
auto mask = OptionBits::enableData;
mOptions = (mOptions & ~mask) | (boolValue ? mask : 0);
+ } else if (name == "use_tcp" && boolValue) {
+ auto mask = OptionBits::useTCP;
+ mOptions = (mOptions & ~mask) | (boolValue ? mask : 0);
}
}
}
diff --git a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
index ae18a2f..3885039 100644
--- a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
+++ b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
@@ -39,6 +39,7 @@
// These are the ports we currently open in the firewall (15550..15557)
static constexpr int kPortRangeBegin = 15550;
static constexpr int kPortRangeEnd = 15558;
+static constexpr int kPortRangeEndTcp = 15551;
static socklen_t getSockAddrLen(const sockaddr_storage &addr) {
switch (addr.ss_family) {
@@ -52,7 +53,7 @@
}
}
-static int acquirePort(int sockfd, int domain) {
+static int acquirePort(int sockfd, int domain, bool tcp) {
sockaddr_storage addr;
uint16_t* port_ptr;
@@ -85,6 +86,10 @@
if (errno != EADDRINUSE) {
return -1;
}
+ // for now, limit to one client / one tcp port to minimize
+ // complexity for using WebRTC over TCP over ssh tunnels
+ if (tcp && port == kPortRangeEndTcp)
+ break;
// else try the next port
}
@@ -94,24 +99,46 @@
RTPSocketHandler::RTPSocketHandler(
std::shared_ptr<RunLoop> runLoop,
std::shared_ptr<ServerState> serverState,
+ TransportType transportType,
int domain,
uint32_t trackMask,
std::shared_ptr<RTPSession> session)
: mRunLoop(runLoop),
mServerState(serverState),
+ mTransportType(transportType),
mTrackMask(trackMask),
mSession(session),
mSendPending(false),
- mDTLSConnected(false) {
- int sock = socket(domain, SOCK_DGRAM, 0);
+ mDTLSConnected(false),
+ mInBufferLength(0) {
+ bool tcp = mTransportType == TransportType::TCP;
+
+ int sock = socket(domain, tcp ? SOCK_STREAM : SOCK_DGRAM, 0);
+
+ if (tcp) {
+ static constexpr int yes = 1;
+ auto res = setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &yes, sizeof(yes));
+ CHECK(!res);
+ }
makeFdNonblocking(sock);
- mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);
- mLocalPort = acquirePort(sock, domain);
+ mLocalPort = acquirePort(sock, domain, tcp);
CHECK(mLocalPort > 0);
+ if (tcp) {
+ auto res = listen(sock, 4);
+ CHECK(!res);
+ }
+
+ auto tmp = std::make_shared<PlainSocket>(mRunLoop, sock);
+ if (tcp) {
+ mServerSocket = tmp;
+ } else {
+ mSocket = tmp;
+ }
+
auto videoPacketizer =
(trackMask & TRACK_VIDEO)
? mServerState->getVideoPacketizer() : nullptr;
@@ -151,7 +178,76 @@
}
void RTPSocketHandler::run() {
- mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+ if (mTransportType == TransportType::TCP) {
+ mServerSocket->postRecv(
+ makeSafeCallback(this, &RTPSocketHandler::onTCPConnect));
+ } else {
+ mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+ }
+}
+
+void RTPSocketHandler::onTCPConnect() {
+ int sock = accept(mServerSocket->fd(), nullptr, 0);
+
+ if (sock < 0) {
+ LOG(ERROR) << "RTPSocketHandler: Failed to accept client";
+ mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPConnect));
+ return;
+ }
+
+ LOG(INFO) << "RTPSocketHandler: Accepted client";
+
+ makeFdNonblocking(sock);
+
+ mClientAddrLen = sizeof(mClientAddr);
+
+ int res = getpeername(
+ sock, reinterpret_cast<sockaddr *>(&mClientAddr), &mClientAddrLen);
+
+ CHECK(!res);
+
+ mSocket = std::make_shared<PlainSocket>(mRunLoop, sock);
+
+ mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPReceive));
+}
+
+void RTPSocketHandler::onTCPReceive() {
+ mInBuffer.resize(mInBuffer.size() + 8192);
+
+ auto n = mSocket->recv(
+ mInBuffer.data() + mInBufferLength, mInBuffer.size() - mInBufferLength);
+
+ if (n == 0) {
+ LOG(INFO) << "Client disconnected.";
+ return;
+ }
+
+ mInBufferLength += n;
+
+ size_t offset = 0;
+ while (offset + 1 < mInBufferLength) {
+ auto packetLength = U16_AT(mInBuffer.data() + offset);
+ offset += 2;
+
+ if (offset + packetLength > mInBufferLength) {
+ break;
+ }
+
+ onPacketReceived(
+ mClientAddr,
+ mClientAddrLen,
+ mInBuffer.data() + offset,
+ packetLength);
+
+ offset += packetLength;
+ }
+
+ if (offset > 0) {
+ mInBuffer.erase(mInBuffer.begin(), mInBuffer.begin() + offset);
+ mInBufferLength -= offset;
+ }
+
+ mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onTCPReceive));
}
void RTPSocketHandler::onReceive() {
@@ -165,6 +261,22 @@
auto n = mSocket->recvfrom(
data, buffer.size(), reinterpret_cast<sockaddr *>(&addr), &addrLen);
+ onPacketReceived(addr, addrLen, data, n);
+
+ mSocket->postRecv(makeSafeCallback(this, &RTPSocketHandler::onReceive));
+}
+
+void RTPSocketHandler::onPacketReceived(
+ const sockaddr_storage &addr,
+ socklen_t addrLen,
+ uint8_t *data,
+ size_t n) {
+#if 0
+ std::cout << "========================================" << std::endl;
+
+ hexdump(data, n);
+#endif
+
STUNMessage msg(data, n);
if (!msg.isValid()) {
if (mDTLSConnected) {
@@ -203,7 +315,6 @@
onDTLSReceive(data, static_cast<size_t>(n));
}
- run();
return;
}
@@ -212,7 +323,6 @@
if (!matchesSession(msg)) {
LOG(WARNING) << "Unknown session or no USERNAME.";
- run();
return;
}
@@ -299,15 +409,7 @@
// response.dump(answerPassword);
- auto res =
- mSocket->sendto(
- response.data(),
- response.size(),
- reinterpret_cast<const sockaddr *>(&addr),
- addrLen);
-
- CHECK_GT(res, 0);
- CHECK_EQ(static_cast<size_t>(res), response.size());
+ queueDatagram(addr, response.data(), response.size());
if (!mSession->isActive()) {
mSession->setRemoteAddress(addr);
@@ -336,8 +438,6 @@
mDTLS->connect(mSession->remoteAddress());
}
}
-
- run();
}
bool RTPSocketHandler::matchesSession(const STUNMessage &msg) const {
@@ -434,6 +534,22 @@
void RTPSocketHandler::queueDatagram(
const sockaddr_storage &addr, const void *data, size_t size) {
+ if (mTransportType == TransportType::TCP) {
+ std::vector copy(
+ static_cast<const uint8_t *>(data),
+ static_cast<const uint8_t *>(data) + size);
+
+ mRunLoop->post(
+ makeSafeCallback<RTPSocketHandler>(
+ this,
+ [copy](RTPSocketHandler *me) {
+ // addr is ignored and assumed to be the connected endpoint's.
+ me->queueTCPOutputPacket(copy.data(), copy.size());
+ }));
+
+ return;
+ }
+
auto datagram = std::make_shared<Datagram>(addr, data, size);
CHECK_LE(size, RTPSocketHandler::kMaxUDPPayloadSize);
@@ -450,6 +566,58 @@
}));
}
+void RTPSocketHandler::queueTCPOutputPacket(const uint8_t *data, size_t size) {
+ uint8_t framing[2];
+ framing[0] = size >> 8;
+ framing[1] = size & 0xff;
+
+ std::copy(framing, framing + sizeof(framing), std::back_inserter(mOutBuffer));
+ std::copy(data, data + size, std::back_inserter(mOutBuffer));
+
+ if (!mSendPending) {
+ mSendPending = true;
+
+ mSocket->postSend(
+ makeSafeCallback(this, &RTPSocketHandler::sendTCPOutputData));
+ }
+}
+
+void RTPSocketHandler::sendTCPOutputData() {
+ mSendPending = false;
+
+ const size_t size = mOutBuffer.size();
+ size_t offset = 0;
+
+ bool disconnected = false;
+
+ while (offset < size) {
+ auto n = mSocket->send(mOutBuffer.data() + offset, size - offset);
+
+ if (n < 0) {
+ if (errno == EINTR) {
+ continue;
+ }
+
+ LOG(FATAL) << "Should not be here.";
+ } else if (n == 0) {
+ offset = size;
+ disconnected = true;
+ break;
+ }
+
+ offset += static_cast<size_t>(n);
+ }
+
+ mOutBuffer.erase(mOutBuffer.begin(), mOutBuffer.begin() + offset);
+
+ if (!mOutBuffer.empty() && !disconnected) {
+ mSendPending = true;
+
+ mSocket->postSend(
+ makeSafeCallback(this, &RTPSocketHandler::sendTCPOutputData));
+ }
+}
+
void RTPSocketHandler::scheduleDrainOutQueue() {
CHECK(!mSendPending);
diff --git a/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h b/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
index fef87cb..f22de55 100644
--- a/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
+++ b/host/frontend/gcastv2/webrtc/include/webrtc/MyWebSocketHandler.h
@@ -52,6 +52,7 @@
bundleTracks = 2,
enableData = 4,
useSingleCertificateForAllTracks = 8,
+ useTCP = 16,
};
using TouchSink = android::TouchSink;
diff --git a/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h b/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
index b0f8ff8..f1dc4a2 100644
--- a/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
+++ b/host/frontend/gcastv2/webrtc/include/webrtc/RTPSocketHandler.h
@@ -16,7 +16,7 @@
#pragma once
-#include <https/BufferedSocket.h>
+#include <https/PlainSocket.h>
#include <https/RunLoop.h>
#include <webrtc/DTLS.h>
#include <webrtc/RTPSender.h>
@@ -40,9 +40,15 @@
static constexpr uint32_t TRACK_AUDIO = 2;
static constexpr uint32_t TRACK_DATA = 4;
+ enum class TransportType {
+ UDP,
+ TCP,
+ };
+
explicit RTPSocketHandler(
std::shared_ptr<RunLoop> runLoop,
std::shared_ptr<ServerState> serverState,
+ TransportType type,
int domain,
uint32_t trackMask,
std::shared_ptr<RTPSession> session);
@@ -78,6 +84,7 @@
std::shared_ptr<RunLoop> mRunLoop;
std::shared_ptr<ServerState> mServerState;
+ TransportType mTransportType;
uint16_t mLocalPort;
uint32_t mTrackMask;
std::shared_ptr<RTPSession> mSession;
@@ -92,6 +99,16 @@
std::shared_ptr<RTPSender> mRTPSender;
+ // for TransportType TCP:
+ std::shared_ptr<PlainSocket> mServerSocket;
+ sockaddr_storage mClientAddr;
+ socklen_t mClientAddrLen;
+
+ std::vector<uint8_t> mInBuffer;
+ size_t mInBufferLength;
+
+ std::vector<uint8_t> mOutBuffer;
+
void onReceive();
void onDTLSReceive(const uint8_t *data, size_t size);
@@ -103,6 +120,18 @@
void drainOutQueue();
int onSRTPReceive(uint8_t *data, size_t size);
+
+ void onTCPConnect();
+ void onTCPReceive();
+
+ void onPacketReceived(
+ const sockaddr_storage &addr,
+ socklen_t addrLen,
+ uint8_t *data,
+ size_t size);
+
+ void queueTCPOutputPacket(const uint8_t *data, size_t size);
+ void sendTCPOutputData();
};