[webrtc] Always deliver the last frame to new clients.

Bug: 142904735
Test: locally
Change-Id: Ie03728faa10f9f0dddca6049a392b001fc521644
diff --git a/host/frontend/gcastv2/libsource/FrameBufferSource.cpp b/host/frontend/gcastv2/libsource/FrameBufferSource.cpp
index f473c19..9b742cc 100644
--- a/host/frontend/gcastv2/libsource/FrameBufferSource.cpp
+++ b/host/frontend/gcastv2/libsource/FrameBufferSource.cpp
@@ -46,7 +46,9 @@
 
     virtual void forceIDRFrame() = 0;
     virtual bool isForcingIDRFrame() const = 0;
-    virtual std::shared_ptr<SBuffer> encode(const void *frame, int64_t timeUs) = 0;
+
+    virtual void storeFrame(const void* frame) = 0;
+    virtual std::shared_ptr<SBuffer> encodeStoredFrame(int64_t timeUs) = 0;
 };
 
 ////////////////////////////////////////////////////////////////////////////////
@@ -58,7 +60,8 @@
     void forceIDRFrame() override;
     bool isForcingIDRFrame() const override;
 
-    std::shared_ptr<SBuffer> encode(const void *frame, int64_t timeUs) override;
+    void storeFrame(const void* frame) override;
+    std::shared_ptr<SBuffer> encodeStoredFrame(int64_t timeUs) override;
 
 private:
     int mWidth, mHeight, mRefreshRateHz;
@@ -74,6 +77,7 @@
 
     std::atomic<bool> mForceIDRFrame;
     bool mFirstFrame;
+    bool mStoredFrame;
     int64_t mLastTimeUs;
 };
 
@@ -98,6 +102,7 @@
       mCodecContext(nullptr, vpx_codec_destroy),
       mForceIDRFrame(false),
       mFirstFrame(true),
+      mStoredFrame(false),
       mLastTimeUs(0) {
 
     CHECK((width & 1) == 0);
@@ -155,8 +160,7 @@
     return mForceIDRFrame;
 }
 
-std::shared_ptr<SBuffer> FrameBufferSource::VPXEncoder::encode(
-        const void *frame, int64_t timeUs) {
+void FrameBufferSource::VPXEncoder::storeFrame(const void *frame) {
     uint8_t *yPlane = static_cast<uint8_t *>(mI420Data);
     uint8_t *uPlane = yPlane + mSizeY;
     uint8_t *vPlane = uPlane + mSizeUV;
@@ -172,15 +176,18 @@
             mWidth / 2,
             mWidth,
             mHeight);
+    mStoredFrame = true;
+}
 
+std::shared_ptr<SBuffer> FrameBufferSource::VPXEncoder::encodeStoredFrame(
+        int64_t timeUs) {
+    if (!mStoredFrame) {
+        return nullptr;
+    }
     vpx_image_t raw_frame;
-    vpx_img_wrap(
-            &raw_frame,
-            VPX_IMG_FMT_I420,
-            mWidth,
-            mHeight,
-            2 /* stride_align */,
-            yPlane);
+    vpx_img_wrap(&raw_frame, VPX_IMG_FMT_I420, mWidth, mHeight,
+                 2 /* stride_align */,
+                 reinterpret_cast<unsigned char *>(mI420Data));
 
     vpx_enc_frame_flags_t flags = 0;
 
@@ -217,7 +224,8 @@
 
     std::shared_ptr<SBuffer> accessUnit;
 
-    while ((packet = vpx_codec_get_cx_data(mCodecContext.get(), &iter)) != nullptr) {
+    while ((packet = vpx_codec_get_cx_data(mCodecContext.get(), &iter)) !=
+            nullptr) {
         if (packet->kind == VPX_CODEC_CX_FRAME_PKT) {
             LOG(VERBOSE)
                 << "vpx_codec_encode returned packet of size "
@@ -258,6 +266,7 @@
       mScreenHeight(0),
       mScreenDpi(0),
       mScreenRate(0),
+      mNumConsumers(0),
       mOnFrameFn(nullptr) {
     mInitCheck = 0;
 }
@@ -369,13 +378,38 @@
     (void)size;
 
     std::lock_guard<std::mutex> autoLock(mLock);
-    if (/* noone is listening || */ mState != State::RUNNING) {
+    if (mState != State::RUNNING) {
         return;
     }
 
-    std::shared_ptr<SBuffer> accessUnit = mEncoder->encode(data, GetNowUs());
+    mEncoder->storeFrame(data);
+    // Only encode and forward the frame when there are consumers connected
+    if (mNumConsumers) {
+        auto accessUnit = mEncoder->encodeStoredFrame(GetNowUs());
+        StreamingSource::onAccessUnit(accessUnit);
+    }
+}
+
+void FrameBufferSource::notifyNewStreamConsumer() {
+    std::lock_guard<std::mutex> autoLock(mLock);
+    ++mNumConsumers;
+    if (mState != State::RUNNING) {
+        return;
+    }
+
+    mEncoder->forceIDRFrame();
+    auto accessUnit = mEncoder->encodeStoredFrame(GetNowUs());
+    if (!accessUnit) {
+        // nullptr means there isn't a stored frame to encode.
+        return;
+    }
 
     StreamingSource::onAccessUnit(accessUnit);
 }
 
+void FrameBufferSource::notifyStreamConsumerDisconnected() {
+    std::lock_guard<std::mutex> autoLock(mLock);
+    --mNumConsumers;
+}
+
 }  // namespace android
diff --git a/host/frontend/gcastv2/libsource/include/source/AudioSource.h b/host/frontend/gcastv2/libsource/include/source/AudioSource.h
index 0db8ac3..7c7771e 100644
--- a/host/frontend/gcastv2/libsource/include/source/AudioSource.h
+++ b/host/frontend/gcastv2/libsource/include/source/AudioSource.h
@@ -54,6 +54,8 @@
     int32_t stop() override;
 
     int32_t requestIDRFrame() override;
+    void notifyNewStreamConsumer() override {}
+    void notifyStreamConsumerDisconnected() override {}
 
     void inject(const void *data, size_t size);
 
diff --git a/host/frontend/gcastv2/libsource/include/source/FrameBufferSource.h b/host/frontend/gcastv2/libsource/include/source/FrameBufferSource.h
index f929b14..7990be4 100644
--- a/host/frontend/gcastv2/libsource/include/source/FrameBufferSource.h
+++ b/host/frontend/gcastv2/libsource/include/source/FrameBufferSource.h
@@ -47,6 +47,8 @@
     bool paused() const override;
 
     int32_t requestIDRFrame() override;
+    void notifyNewStreamConsumer() override;
+    void notifyStreamConsumerDisconnected() override;
 
     void setScreenParams(const int32_t screenParams[4]);
     void injectFrame(const void *data, size_t size);
@@ -69,7 +71,7 @@
 
     std::mutex mLock;
 
-    int32_t mScreenWidth, mScreenHeight, mScreenDpi, mScreenRate;
+    int32_t mScreenWidth, mScreenHeight, mScreenDpi, mScreenRate, mNumConsumers;
 
     std::function<void(const std::shared_ptr<SBuffer> &)> mOnFrameFn;
 };
diff --git a/host/frontend/gcastv2/libsource/include/source/StreamingSource.h b/host/frontend/gcastv2/libsource/include/source/StreamingSource.h
index 5e33ea2..27a30a4 100644
--- a/host/frontend/gcastv2/libsource/include/source/StreamingSource.h
+++ b/host/frontend/gcastv2/libsource/include/source/StreamingSource.h
@@ -90,6 +90,8 @@
     virtual bool paused() const { return false; }
 
     virtual int32_t requestIDRFrame() = 0;
+    virtual void notifyNewStreamConsumer() = 0;
+    virtual void notifyStreamConsumerDisconnected() = 0;
 
 protected:
     void onAccessUnit(const std::shared_ptr<SBuffer> &accessUnit);
diff --git a/host/frontend/gcastv2/webrtc/Packetizer.cpp b/host/frontend/gcastv2/webrtc/Packetizer.cpp
index 92f25dc..bbe6102 100644
--- a/host/frontend/gcastv2/webrtc/Packetizer.cpp
+++ b/host/frontend/gcastv2/webrtc/Packetizer.cpp
@@ -38,6 +38,7 @@
         auto sender = it->lock();
         if (!sender) {
             it = mSenders.erase(it);
+            mStreamingSource->notifyStreamConsumerDisconnected();
             continue;
         }
 
@@ -48,6 +49,12 @@
 
 void Packetizer::addSender(std::shared_ptr<RTPSender> sender) {
     mSenders.push_back(sender);
+    auto weak_source = std::weak_ptr<StreamingSource>(mStreamingSource);
+    mRunLoop->post([weak_source](){
+        auto source = weak_source.lock();
+        if (!source) return;
+        source->notifyNewStreamConsumer();
+    });
 }
 
 int32_t Packetizer::requestIDRFrame() {
@@ -71,6 +78,10 @@
 }
 
 void Packetizer::onFrame(const std::shared_ptr<android::SBuffer>& accessUnit) {
+    if (!accessUnit) {
+        LOG(WARNING) << "Received invalid buffer in " << __FUNCTION__;
+        return;
+    }
     int64_t timeUs = accessUnit->time_us();
     CHECK(timeUs);
 
diff --git a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
index 51f052f..ae18a2f 100644
--- a/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
+++ b/host/frontend/gcastv2/webrtc/RTPSocketHandler.cpp
@@ -131,14 +131,10 @@
         mRTPSender->addSource(0xcafeb0b0);
 
         mRTPSender->addRetransInfo(0xdeadbeef, 96, 0xcafeb0b0, 97);
-
-        videoPacketizer->addSender(mRTPSender);
     }
 
     if (trackMask & TRACK_AUDIO) {
         mRTPSender->addSource(0x8badf00d);
-
-        audioPacketizer->addSender(mRTPSender);
     }
 }
 
@@ -510,6 +506,14 @@
 
     mDTLSConnected = true;
 
+    if (mTrackMask & TRACK_VIDEO) {
+        mServerState->getVideoPacketizer()->addSender(mRTPSender);
+    }
+
+    if (mTrackMask & TRACK_AUDIO) {
+        mServerState->getAudioPacketizer()->addSender(mRTPSender);
+    }
+
     if (mTrackMask & TRACK_DATA) {
         mSCTPHandler = std::make_shared<SCTPHandler>(mRunLoop, mDTLS);
         mSCTPHandler->run();