Merge "L2CAP ERTM: Drop invalid packet on reassembly" am: 16210d5a08
am: 7eac0c8aa3

Change-Id: Ia3a85b7d943eca9b155ebc5f795d1e086c625ad5
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
index 2a1b930..97f9702 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
@@ -168,20 +168,20 @@
     }
   }
 
-  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar,
+  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar, uint16_t sdu_size,
                     const packet::PacketView<true>& payload) {
     if (rx_state_ == RxState::RECV) {
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f) &&
           !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f) && !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -216,14 +216,14 @@
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
         rx_state_ = RxState::RECV;
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -682,8 +682,8 @@
     recv_f_bit(f);
   }
 
-  void data_indication(SegmentationAndReassembly sar, const packet::PacketView<true>& segment) {
-    controller_->stage_for_reassembly(sar, segment);
+  void data_indication(SegmentationAndReassembly sar, uint16_t sdu_size, const packet::PacketView<true>& segment) {
+    controller_->stage_for_reassembly(sar, sdu_size, segment);
     buffer_seq_ = (buffer_seq_ + 1) % kMaxTxWin;
   }
 
@@ -825,8 +825,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -872,8 +885,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameWithFcsView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameWithFcsView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -908,10 +934,16 @@
   return next;
 }
 
-void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar,
+void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar, uint16_t sdu_size,
                                           const packet::PacketView<kLittleEndian>& payload) {
   switch (sar) {
     case SegmentationAndReassembly::UNSEGMENTED:
+      if (sar_state_ != SegmentationAndReassembly::END) {
+        LOG_WARN("Received invalid SAR");
+        close_channel();
+        return;
+      }
+      // TODO: Enforce MTU
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(payload), handler_);
       break;
     case SegmentationAndReassembly::START:
@@ -920,8 +952,10 @@
         close_channel();
         return;
       }
+      // TODO: Enforce MTU
       sar_state_ = SegmentationAndReassembly::START;
       reassembly_stage_ = payload;
+      remaining_sdu_continuation_packet_size_ = sdu_size - payload.size();
       break;
     case SegmentationAndReassembly::CONTINUATION:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -930,6 +964,7 @@
         return;
       }
       reassembly_stage_.AppendPacketView(payload);
+      remaining_sdu_continuation_packet_size_ -= payload.size();
       break;
     case SegmentationAndReassembly::END:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -937,9 +972,17 @@
         close_channel();
         return;
       }
+      sar_state_ = SegmentationAndReassembly::END;
+      remaining_sdu_continuation_packet_size_ -= payload.size();
+      if (remaining_sdu_continuation_packet_size_ != 0) {
+        LOG_WARN("Received invalid END I-Frame");
+        reassembly_stage_ = PacketViewForReassembly(std::make_shared<std::vector<uint8_t>>());
+        remaining_sdu_continuation_packet_size_ = 0;
+        close_channel();
+        return;
+      }
       reassembly_stage_.AppendPacketView(payload);
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(reassembly_stage_), handler_);
-      sar_state_ = SegmentationAndReassembly::END;
       break;
   }
 }
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
index 14deea3..dd1ca64 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
@@ -59,7 +59,19 @@
   os::Handler* handler_;
   std::queue<std::unique_ptr<packet::BasePacketBuilder>> pdu_queue_;
   Scheduler* scheduler_;
+
+  // Configuration options
   bool fcs_enabled_ = false;
+  uint16_t local_tx_window_ = 10;
+  uint16_t local_max_transmit_ = 20;
+  uint16_t local_retransmit_timeout_ms_ = 2000;
+  uint16_t local_monitor_timeout_ms_ = 12000;
+
+  uint16_t remote_tx_window_ = 10;
+  uint16_t remote_mps_ = 1010;
+
+  uint16_t size_each_packet_ =
+      (remote_mps_ - 4 /* basic L2CAP header */ - 2 /* SDU length */ - 2 /* Extended control */ - 2 /* FCS */);
 
   class PacketViewForReassembly : public packet::PacketView<kLittleEndian> {
    public:
@@ -83,8 +95,10 @@
 
   PacketViewForReassembly reassembly_stage_{std::make_shared<std::vector<uint8_t>>()};
   SegmentationAndReassembly sar_state_ = SegmentationAndReassembly::END;
+  uint16_t remaining_sdu_continuation_packet_size_ = 0;
 
-  void stage_for_reassembly(SegmentationAndReassembly sar, const packet::PacketView<kLittleEndian>& payload);
+  void stage_for_reassembly(SegmentationAndReassembly sar, uint16_t sdu_size,
+                            const packet::PacketView<kLittleEndian>& payload);
   void send_pdu(std::unique_ptr<packet::BasePacketBuilder> pdu);
 
   void close_channel();
@@ -92,18 +106,6 @@
   void on_pdu_no_fcs(const packet::PacketView<true>& pdu);
   void on_pdu_fcs(const packet::PacketView<true>& pdu);
 
-  // Configuration options
-  uint16_t local_tx_window_ = 10;
-  uint16_t local_max_transmit_ = 20;
-  uint16_t local_retransmit_timeout_ms_ = 2000;
-  uint16_t local_monitor_timeout_ms_ = 12000;
-
-  uint16_t remote_tx_window_ = 10;
-  uint16_t remote_mps_ = 1010;
-
-  uint16_t size_each_packet_ =
-      (remote_mps_ - 4 /* basic L2CAP header */ - 2 /* SDU length */ - 2 /* Extended control */ - 2 /* FCS */);
-
   struct impl;
   std::unique_ptr<impl> pimpl_;
 };
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
index 04ef82b..e1b5818 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
@@ -108,6 +108,54 @@
   EXPECT_EQ(data, "abcd");
 }
 
+TEST_F(ErtmDataControllerTest, reassemble_valid_sdu) {
+  common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
+  testing::MockScheduler scheduler;
+  ErtmController controller{1, 1, channel_queue.GetDownEnd(), queue_handler_, &scheduler};
+  auto segment1 = CreateSdu({'a'});
+  auto segment2 = CreateSdu({'b', 'c'});
+  auto segment3 = CreateSdu({'d', 'e', 'f'});
+  auto builder1 = EnhancedInformationStartFrameBuilder::Create(1, 0, Final::NOT_SET, 0, 6, std::move(segment1));
+  auto base_view = GetPacketView(std::move(builder1));
+  controller.OnPdu(base_view);
+  auto builder2 = EnhancedInformationFrameBuilder::Create(1, 1, Final::NOT_SET, 0,
+                                                          SegmentationAndReassembly::CONTINUATION, std::move(segment2));
+  base_view = GetPacketView(std::move(builder2));
+  controller.OnPdu(base_view);
+  auto builder3 = EnhancedInformationFrameBuilder::Create(1, 2, Final::NOT_SET, 0, SegmentationAndReassembly::END,
+                                                          std::move(segment3));
+  base_view = GetPacketView(std::move(builder3));
+  controller.OnPdu(base_view);
+  sync_handler(queue_handler_);
+  auto payload = channel_queue.GetUpEnd()->TryDequeue();
+  EXPECT_NE(payload, nullptr);
+  std::string data = std::string(payload->begin(), payload->end());
+  EXPECT_EQ(data, "abcdef");
+}
+
+TEST_F(ErtmDataControllerTest, reassemble_invalid_sdu_size_in_start_frame) {
+  common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
+  testing::MockScheduler scheduler;
+  ErtmController controller{1, 1, channel_queue.GetDownEnd(), queue_handler_, &scheduler};
+  auto segment1 = CreateSdu({'a'});
+  auto segment2 = CreateSdu({'b', 'c'});
+  auto segment3 = CreateSdu({'d', 'e', 'f'});
+  auto builder1 = EnhancedInformationStartFrameBuilder::Create(1, 0, Final::NOT_SET, 0, 10, std::move(segment1));
+  auto base_view = GetPacketView(std::move(builder1));
+  controller.OnPdu(base_view);
+  auto builder2 = EnhancedInformationFrameBuilder::Create(1, 1, Final::NOT_SET, 0,
+                                                          SegmentationAndReassembly::CONTINUATION, std::move(segment2));
+  base_view = GetPacketView(std::move(builder2));
+  controller.OnPdu(base_view);
+  auto builder3 = EnhancedInformationFrameBuilder::Create(1, 2, Final::NOT_SET, 0, SegmentationAndReassembly::END,
+                                                          std::move(segment3));
+  base_view = GetPacketView(std::move(builder3));
+  controller.OnPdu(base_view);
+  sync_handler(queue_handler_);
+  auto payload = channel_queue.GetUpEnd()->TryDequeue();
+  EXPECT_EQ(payload, nullptr);
+}
+
 TEST_F(ErtmDataControllerTest, transmit_with_fcs) {
   common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
   testing::MockScheduler scheduler;