Handle Turn error response to RefreshRequest, CreatePermissionRequest, and ChanelBindRequest

BUG=webrtc:5116
R=pthatcher@webrtc.org

Review URL: https://codereview.webrtc.org/1453823004 .

Cr-Commit-Position: refs/heads/master@{#10994}
diff --git a/webrtc/p2p/base/turnport.cc b/webrtc/p2p/base/turnport.cc
index 4279b0f..55b3b94 100644
--- a/webrtc/p2p/base/turnport.cc
+++ b/webrtc/p2p/base/turnport.cc
@@ -38,6 +38,8 @@
 // STUN_ERROR_ALLOCATION_MISMATCH error per rfc5766.
 static const size_t MAX_ALLOCATE_MISMATCH_RETRIES = 2;
 
+static const int TURN_SUCCESS_RESULT_CODE = 0;
+
 inline bool IsTurnChannelData(uint16_t msg_type) {
   return ((msg_type & 0xC000) == 0x4000);  // MSB are 0b01
 }
@@ -137,6 +139,9 @@
   TurnPort* port() { return port_; }
 
   int channel_id() const { return channel_id_; }
+  // For testing only.
+  void set_channel_id(int channel_id) { channel_id_ = channel_id; }
+
   const rtc::SocketAddress& address() const { return ext_addr_; }
   BindState state() const { return state_; }
 
@@ -155,8 +160,10 @@
 
   void OnCreatePermissionSuccess();
   void OnCreatePermissionError(StunMessage* response, int code);
+  void OnCreatePermissionTimeout();
   void OnChannelBindSuccess();
   void OnChannelBindError(StunMessage* response, int code);
+  void OnChannelBindTimeout();
   // Signal sent when TurnEntry is destroyed.
   sigslot::signal1<TurnEntry*> SignalDestroyed;
 
@@ -464,6 +471,15 @@
   return NULL;
 }
 
+bool TurnPort::DestroyConnection(const rtc::SocketAddress& address) {
+  Connection* conn = GetConnection(address);
+  if (conn != nullptr) {
+    conn->Destroy();
+    return true;
+  }
+  return false;
+}
+
 int TurnPort::SetOption(rtc::Socket::Option opt, int value) {
   if (!socket_) {
     // If socket is not created yet, these options will be applied during socket
@@ -698,34 +714,45 @@
   // We will send SignalPortError asynchronously as this can be sent during
   // port initialization. This way it will not be blocking other port
   // creation.
-  thread()->Post(this, MSG_ERROR);
+  thread()->Post(this, MSG_ALLOCATE_ERROR);
+}
+
+void TurnPort::Close() {
+  // Stop the port from creating new connections.
+  state_ = STATE_DISCONNECTED;
+  // Delete all existing connections; stop sending data.
+  for (auto kv : connections()) {
+    kv.second->Destroy();
+  }
 }
 
 void TurnPort::OnMessage(rtc::Message* message) {
-  if (message->message_id == MSG_ERROR) {
-    SignalPortError(this);
-    return;
-  } else if (message->message_id == MSG_ALLOCATE_MISMATCH) {
-    OnAllocateMismatch();
-    return;
-  } else if (message->message_id == MSG_TRY_ALTERNATE_SERVER) {
-    if (server_address().proto == PROTO_UDP) {
-      // Send another allocate request to alternate server, with the received
-      // realm and nonce values.
-      SendRequest(new TurnAllocateRequest(this), 0);
-    } else {
-      // Since it's TCP, we have to delete the connected socket and reconnect
-      // with the alternate server. PrepareAddress will send stun binding once
-      // the new socket is connected.
-      ASSERT(server_address().proto == PROTO_TCP);
-      ASSERT(!SharedSocket());
-      delete socket_;
-      socket_ = NULL;
-      PrepareAddress();
-    }
-    return;
+  switch (message->message_id) {
+    case MSG_ALLOCATE_ERROR:
+      SignalPortError(this);
+      break;
+    case MSG_ALLOCATE_MISMATCH:
+      OnAllocateMismatch();
+      break;
+    case MSG_TRY_ALTERNATE_SERVER:
+      if (server_address().proto == PROTO_UDP) {
+        // Send another allocate request to alternate server, with the received
+        // realm and nonce values.
+        SendRequest(new TurnAllocateRequest(this), 0);
+      } else {
+        // Since it's TCP, we have to delete the connected socket and reconnect
+        // with the alternate server. PrepareAddress will send stun binding once
+        // the new socket is connected.
+        ASSERT(server_address().proto == PROTO_TCP);
+        ASSERT(!SharedSocket());
+        delete socket_;
+        socket_ = NULL;
+        PrepareAddress();
+      }
+      break;
+    default:
+      Port::OnMessage(message);
   }
-  Port::OnMessage(message);
 }
 
 void TurnPort::OnAllocateRequestTimeout() {
@@ -968,6 +995,16 @@
   entry->set_destruction_timestamp(0);
 }
 
+bool TurnPort::SetEntryChannelId(const rtc::SocketAddress& address,
+                                 int channel_id) {
+  TurnEntry* entry = FindEntry(address);
+  if (!entry) {
+    return false;
+  }
+  entry->set_channel_id(channel_id);
+  return true;
+}
+
 TurnAllocateRequest::TurnAllocateRequest(TurnPort* port)
     : StunRequest(new TurnMessage()),
       port_(port) {
@@ -1181,16 +1218,12 @@
 
   // Schedule a refresh based on the returned lifetime value.
   port_->ScheduleRefresh(lifetime_attr->value());
+  port_->SignalTurnRefreshResult(port_, TURN_SUCCESS_RESULT_CODE);
 }
 
 void TurnRefreshRequest::OnErrorResponse(StunMessage* response) {
   const StunErrorCodeAttribute* error_code = response->GetErrorCode();
 
-  LOG_J(LS_INFO, port_) << "Received TURN refresh error response"
-                        << ", id=" << rtc::hex_encode(id())
-                        << ", code=" << error_code->code()
-                        << ", rtt=" << Elapsed();
-
   if (error_code->code() == STUN_ERROR_STALE_NONCE) {
     if (port_->UpdateNonce(response)) {
       // Send RefreshRequest immediately.
@@ -1201,11 +1234,14 @@
                              << ", id=" << rtc::hex_encode(id())
                              << ", code=" << error_code->code()
                              << ", rtt=" << Elapsed();
+    port_->OnTurnRefreshError();
+    port_->SignalTurnRefreshResult(port_, error_code->code());
   }
 }
 
 void TurnRefreshRequest::OnTimeout() {
   LOG_J(LS_WARNING, port_) << "TURN refresh timeout " << rtc::hex_encode(id());
+  port_->OnTurnRefreshError();
 }
 
 TurnCreatePermissionRequest::TurnCreatePermissionRequest(
@@ -1258,6 +1294,9 @@
 void TurnCreatePermissionRequest::OnTimeout() {
   LOG_J(LS_WARNING, port_) << "TURN create permission timeout "
                            << rtc::hex_encode(id());
+  if (entry_) {
+    entry_->OnCreatePermissionTimeout();
+  }
 }
 
 void TurnCreatePermissionRequest::OnEntryDestroyed(TurnEntry* entry) {
@@ -1325,6 +1364,9 @@
 void TurnChannelBindRequest::OnTimeout() {
   LOG_J(LS_WARNING, port_) << "TURN channel bind timeout "
                            << rtc::hex_encode(id());
+  if (entry_) {
+    entry_->OnChannelBindTimeout();
+  }
 }
 
 void TurnChannelBindRequest::OnEntryDestroyed(TurnEntry* entry) {
@@ -1385,8 +1427,8 @@
   LOG_J(LS_INFO, port_) << "Create permission for "
                         << ext_addr_.ToSensitiveString()
                         << " succeeded";
-  // For success result code will be 0.
-  port_->SignalCreatePermissionResult(port_, ext_addr_, 0);
+  port_->SignalCreatePermissionResult(port_, ext_addr_,
+                                      TURN_SUCCESS_RESULT_CODE);
 
   // If |state_| is STATE_BOUND, the permission will be refreshed
   // by ChannelBindRequest.
@@ -1406,6 +1448,7 @@
       SendCreatePermissionRequest(0);
     }
   } else {
+    port_->DestroyConnection(ext_addr_);
     // Send signal with error code.
     port_->SignalCreatePermissionResult(port_, ext_addr_, code);
     Connection* c = port_->GetConnection(ext_addr_);
@@ -1417,6 +1460,10 @@
   }
 }
 
+void TurnEntry::OnCreatePermissionTimeout() {
+  port_->DestroyConnection(ext_addr_);
+}
+
 void TurnEntry::OnChannelBindSuccess() {
   LOG_J(LS_INFO, port_) << "Channel bind for " << ext_addr_.ToSensitiveString()
                         << " succeeded";
@@ -1425,14 +1472,21 @@
 }
 
 void TurnEntry::OnChannelBindError(StunMessage* response, int code) {
-  // TODO(mallinath) - Implement handling of error response for channel
-  // bind request as per http://tools.ietf.org/html/rfc5766#section-11.3
+  // If the channel bind fails due to errors other than STATE_NONCE,
+  // we just destroy the connection and rely on ICE restart to re-establish
+  // the connection.
   if (code == STUN_ERROR_STALE_NONCE) {
     if (port_->UpdateNonce(response)) {
       // Send channel bind request with fresh nonce.
       SendChannelBindRequest(0);
     }
+  } else {
+    state_ = STATE_UNBOUND;
+    port_->DestroyConnection(ext_addr_);
   }
 }
-
+void TurnEntry::OnChannelBindTimeout() {
+  state_ = STATE_UNBOUND;
+  port_->DestroyConnection(ext_addr_);
+}
 }  // namespace cricket
diff --git a/webrtc/p2p/base/turnport.h b/webrtc/p2p/base/turnport.h
index 62e3c41..7b88364 100644
--- a/webrtc/p2p/base/turnport.h
+++ b/webrtc/p2p/base/turnport.h
@@ -133,10 +133,17 @@
                    const rtc::SocketAddress&,
                    const rtc::SocketAddress&> SignalResolvedServerAddress;
 
+  // All public methods/signals below are for testing only.
+  sigslot::signal2<TurnPort*, int> SignalTurnRefreshResult;
   sigslot::signal3<TurnPort*, const rtc::SocketAddress&, int>
       SignalCreatePermissionResult;
-  // For testing only.
   void FlushRequests() { request_manager_.Flush(); }
+  void set_credentials(RelayCredentials& credentials) {
+    credentials_ = credentials;
+  }
+  // Finds the turn entry with |address| and sets its channel id.
+  // Returns true if the entry is found.
+  bool SetEntryChannelId(const rtc::SocketAddress& address, int channel_id);
 
  protected:
   TurnPort(rtc::Thread* thread,
@@ -165,7 +172,7 @@
 
  private:
   enum {
-    MSG_ERROR = MSG_FIRST_AVAILABLE,
+    MSG_ALLOCATE_ERROR = MSG_FIRST_AVAILABLE,
     MSG_ALLOCATE_MISMATCH,
     MSG_TRY_ALTERNATE_SERVER
   };
@@ -186,6 +193,9 @@
     }
   }
 
+  // Shuts down the turn port, usually because of some fatal errors.
+  void Close();
+  void OnTurnRefreshError() { Close(); }
   bool SetAlternateServer(const rtc::SocketAddress& address);
   void ResolveTurnAddress(const rtc::SocketAddress& address);
   void OnResolveResult(rtc::AsyncResolverInterface* resolver);
@@ -228,6 +238,10 @@
   void CancelEntryDestruction(TurnEntry* entry);
   void OnConnectionDestroyed(Connection* conn);
 
+  // Destroys the connection with remote address |address|. Returns true if
+  // a connection is found and destroyed.
+  bool DestroyConnection(const rtc::SocketAddress& address);
+
   ProtocolAddress server_address_;
   RelayCredentials credentials_;
   AttemptedServerSet attempted_server_addresses_;
diff --git a/webrtc/p2p/base/turnport_unittest.cc b/webrtc/p2p/base/turnport_unittest.cc
index d9353cc..5720748 100644
--- a/webrtc/p2p/base/turnport_unittest.cc
+++ b/webrtc/p2p/base/turnport_unittest.cc
@@ -13,6 +13,7 @@
 
 #include "webrtc/p2p/base/basicpacketsocketfactory.h"
 #include "webrtc/p2p/base/constants.h"
+#include "webrtc/p2p/base/portallocator.h"
 #include "webrtc/p2p/base/tcpport.h"
 #include "webrtc/p2p/base/testturnserver.h"
 #include "webrtc/p2p/base/turnport.h"
@@ -172,12 +173,15 @@
                             bool /*port_muxed*/) {
     turn_unknown_address_ = true;
   }
-  void OnTurnCreatePermissionResult(TurnPort* port, const SocketAddress& addr,
-                                     int code) {
+  void OnTurnCreatePermissionResult(TurnPort* port,
+                                    const SocketAddress& addr,
+                                    int code) {
     // Ignoring the address.
-    if (code == 0) {
-      turn_create_permission_success_ = true;
-    }
+    turn_create_permission_success_ = (code == 0);
+  }
+
+  void OnTurnRefreshResult(TurnPort* port, int code) {
+    turn_refresh_success_ = (code == 0);
   }
   void OnTurnReadPacket(Connection* conn, const char* data, size_t size,
                         const rtc::PacketTime& packet_time) {
@@ -190,6 +194,7 @@
                        const rtc::PacketTime& packet_time) {
     udp_packets_.push_back(rtc::Buffer(data, size));
   }
+  void OnConnectionDestroyed(Connection* conn) { connection_destroyed_ = true; }
   void OnSocketReadPacket(rtc::AsyncPacketSocket* socket,
                           const char* data, size_t size,
                           const rtc::SocketAddress& remote_addr,
@@ -273,6 +278,11 @@
         &TurnPortTest::OnTurnUnknownAddress);
     turn_port_->SignalCreatePermissionResult.connect(this,
         &TurnPortTest::OnTurnCreatePermissionResult);
+    turn_port_->SignalTurnRefreshResult.connect(
+        this, &TurnPortTest::OnTurnRefreshResult);
+  }
+  void ConnectConnectionDestroyedSignal(Connection* conn) {
+    conn->SignalDestroyed.connect(this, &TurnPortTest::OnConnectionDestroyed);
   }
 
   void CreateUdpPort() { CreateUdpPort(kLocalAddr2); }
@@ -287,6 +297,23 @@
         this, &TurnPortTest::OnUdpPortComplete);
   }
 
+  void PrepareTurnAndUdpPorts() {
+    // turn_port_ should have been created.
+    ASSERT_TRUE(turn_port_ != nullptr);
+    turn_port_->PrepareAddress();
+    ASSERT_TRUE_WAIT(turn_ready_, kTimeout);
+
+    CreateUdpPort();
+    udp_port_->PrepareAddress();
+    ASSERT_TRUE_WAIT(udp_ready_, kTimeout);
+  }
+
+  bool CheckConnectionDestroyed() {
+    turn_port_->FlushRequests();
+    rtc::Thread::Current()->ProcessMessages(50);
+    return connection_destroyed_;
+  }
+
   void TestTurnAlternateServer(cricket::ProtocolType protocol_type) {
     std::vector<rtc::SocketAddress> redirect_addresses;
     redirect_addresses.push_back(kTurnAlternateIntAddr);
@@ -370,12 +397,7 @@
 
   void TestTurnConnection() {
     // Create ports and prepare addresses.
-    ASSERT_TRUE(turn_port_ != NULL);
-    turn_port_->PrepareAddress();
-    ASSERT_TRUE_WAIT(turn_ready_, kTimeout);
-    CreateUdpPort();
-    udp_port_->PrepareAddress();
-    ASSERT_TRUE_WAIT(udp_ready_, kTimeout);
+    PrepareTurnAndUdpPorts();
 
     // Send ping from UDP to TURN.
     Connection* conn1 = udp_port_->CreateConnection(
@@ -406,12 +428,7 @@
   }
 
   void TestDestroyTurnConnection() {
-    turn_port_->PrepareAddress();
-    ASSERT_TRUE_WAIT(turn_ready_, kTimeout);
-    // Create a remote UDP port
-    CreateUdpPort();
-    udp_port_->PrepareAddress();
-    ASSERT_TRUE_WAIT(udp_ready_, kTimeout);
+    PrepareTurnAndUdpPorts();
 
     // Create connections on both ends.
     Connection* conn1 = udp_port_->CreateConnection(turn_port_->Candidates()[0],
@@ -448,11 +465,8 @@
   }
 
   void TestTurnSendData() {
-    turn_port_->PrepareAddress();
-    EXPECT_TRUE_WAIT(turn_ready_, kTimeout);
-    CreateUdpPort();
-    udp_port_->PrepareAddress();
-    EXPECT_TRUE_WAIT(udp_ready_, kTimeout);
+    PrepareTurnAndUdpPorts();
+
     // Create connections and send pings.
     Connection* conn1 = turn_port_->CreateConnection(
         udp_port_->Candidates()[0], Port::ORIGIN_MESSAGE);
@@ -508,6 +522,8 @@
   bool turn_create_permission_success_;
   bool udp_ready_;
   bool test_finish_;
+  bool turn_refresh_success_ = false;
+  bool connection_destroyed_ = false;
   std::vector<rtc::Buffer> turn_packets_;
   std::vector<rtc::Buffer> udp_packets_;
   rtc::PacketOptions options;
@@ -675,16 +691,31 @@
   EXPECT_NE(first_addr, turn_port_->socket()->GetLocalAddress());
 }
 
+TEST_F(TurnPortTest, TestRefreshRequestGetsErrorResponse) {
+  CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr);
+  turn_port_->PrepareAddress();
+  EXPECT_TRUE_WAIT(turn_ready_, kTimeout);
+  // Set bad credentials.
+  cricket::RelayCredentials bad_credentials("bad_user", "bad_pwd");
+  turn_port_->set_credentials(bad_credentials);
+  turn_refresh_success_ = false;
+  // This sends out the first RefreshRequest with correct credentials.
+  // When this succeeds, it will schedule a new RefreshRequest with the bad
+  // credential.
+  turn_port_->FlushRequests();
+  EXPECT_TRUE_WAIT(turn_refresh_success_, kTimeout);
+  // Flush it again, it will receive a bad response.
+  turn_port_->FlushRequests();
+  EXPECT_TRUE_WAIT(!turn_refresh_success_, kTimeout);
+  EXPECT_TRUE(turn_port_->connections().empty());
+  EXPECT_FALSE(turn_port_->connected());
+}
+
 // Test that CreateConnection will return null if port becomes disconnected.
 TEST_F(TurnPortTest, TestCreateConnectionWhenSocketClosed) {
   turn_server_.AddInternalSocket(kTurnTcpIntAddr, cricket::PROTO_TCP);
   CreateTurnPort(kTurnUsername, kTurnPassword, kTurnTcpProtoAddr);
-  turn_port_->PrepareAddress();
-  ASSERT_TRUE_WAIT(turn_ready_, kTimeout);
-
-  CreateUdpPort();
-  udp_port_->PrepareAddress();
-  ASSERT_TRUE_WAIT(udp_ready_, kTimeout);
+  PrepareTurnAndUdpPorts();
   // Create a connection.
   Connection* conn1 = turn_port_->CreateConnection(udp_port_->Candidates()[0],
                                                    Port::ORIGIN_MESSAGE);
@@ -792,25 +823,51 @@
 }
 
 // Test that CreatePermissionRequest will be scheduled after the success
-// of the first create permission request.
+// of the first create permission request and the request will get an
+// ErrorResponse if the ufrag and pwd are incorrect.
 TEST_F(TurnPortTest, TestRefreshCreatePermissionRequest) {
   CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr);
-
-  ASSERT_TRUE(turn_port_ != NULL);
-  turn_port_->PrepareAddress();
-  ASSERT_TRUE_WAIT(turn_ready_, kTimeout);
-  CreateUdpPort();
-  udp_port_->PrepareAddress();
-  ASSERT_TRUE_WAIT(udp_ready_, kTimeout);
+  PrepareTurnAndUdpPorts();
 
   Connection* conn = turn_port_->CreateConnection(udp_port_->Candidates()[0],
                                                   Port::ORIGIN_MESSAGE);
+  ConnectConnectionDestroyedSignal(conn);
   ASSERT_TRUE(conn != NULL);
   ASSERT_TRUE_WAIT(turn_create_permission_success_, kTimeout);
   turn_create_permission_success_ = false;
   // A create-permission-request should be pending.
+  // After the next create-permission-response is received, it will schedule
+  // another request with bad_ufrag and bad_pwd.
+  cricket::RelayCredentials bad_credentials("bad_user", "bad_pwd");
+  turn_port_->set_credentials(bad_credentials);
   turn_port_->FlushRequests();
   ASSERT_TRUE_WAIT(turn_create_permission_success_, kTimeout);
+  // Flush the requests again; the create-permission-request will fail.
+  turn_port_->FlushRequests();
+  EXPECT_TRUE_WAIT(!turn_create_permission_success_, kTimeout);
+  EXPECT_TRUE_WAIT(connection_destroyed_, kTimeout);
+}
+
+TEST_F(TurnPortTest, TestChannelBindGetErrorResponse) {
+  CreateTurnPort(kTurnUsername, kTurnPassword, kTurnUdpProtoAddr);
+  PrepareTurnAndUdpPorts();
+  Connection* conn1 = turn_port_->CreateConnection(udp_port_->Candidates()[0],
+                                                   Port::ORIGIN_MESSAGE);
+  ASSERT_TRUE(conn1 != nullptr);
+  Connection* conn2 = udp_port_->CreateConnection(turn_port_->Candidates()[0],
+                                                  Port::ORIGIN_MESSAGE);
+  ASSERT_TRUE(conn2 != nullptr);
+  ConnectConnectionDestroyedSignal(conn1);
+  conn1->Ping(0);
+  ASSERT_TRUE_WAIT(conn1->writable(), kTimeout);
+
+  std::string data = "ABC";
+  conn1->Send(data.data(), data.length(), options);
+  bool success =
+      turn_port_->SetEntryChannelId(udp_port_->Candidates()[0].address(), -1);
+  ASSERT_TRUE(success);
+  // Next time when the binding request is sent, it will get an ErrorResponse.
+  EXPECT_TRUE_WAIT(CheckConnectionDestroyed(), kTimeout);
 }
 
 // Do a TURN allocation, establish a UDP connection, and send some data.