priority LB: avoid possibility of rescheduling a timer before it fires (#29188) (#29241)

* priority LB: avoid possibility of rescheduling a timer before it fires

* clang-format

* fix memory leak

* small change, just to be paranoid

* inline StartFailoverTimerLocked()

* initialize timer_pending_ to true

* don't check shutting_down_ in timer callbacks

Co-authored-by: Mark D. Roth <roth@google.com>
diff --git a/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc
index 257c36b..d681673 100644
--- a/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc
+++ b/src/core/ext/filters/client_channel/lb_policy/priority/priority.cc
@@ -106,7 +106,6 @@
     void ResetBackoffLocked();
     void DeactivateLocked();
     void MaybeReactivateLocked();
-    void MaybeCancelFailoverTimerLocked();
 
     void Orphan() override;
 
@@ -122,9 +121,7 @@
       return connectivity_status_;
     }
 
-    bool failover_timer_callback_pending() const {
-      return failover_timer_callback_pending_;
-    }
+    bool FailoverTimerPending() const { return failover_timer_ != nullptr; }
 
    private:
     // A simple wrapper for ref-counting a picker from the child policy.
@@ -170,6 +167,38 @@
       RefCountedPtr<ChildPriority> priority_;
     };
 
+    class DeactivationTimer : public InternallyRefCounted<DeactivationTimer> {
+     public:
+      explicit DeactivationTimer(RefCountedPtr<ChildPriority> child_priority);
+
+      void Orphan() override;
+
+     private:
+      static void OnTimer(void* arg, grpc_error_handle error);
+      void OnTimerLocked(grpc_error_handle);
+
+      RefCountedPtr<ChildPriority> child_priority_;
+      grpc_timer timer_;
+      grpc_closure on_timer_;
+      bool timer_pending_ = true;
+    };
+
+    class FailoverTimer : public InternallyRefCounted<FailoverTimer> {
+     public:
+      explicit FailoverTimer(RefCountedPtr<ChildPriority> child_priority);
+
+      void Orphan() override;
+
+     private:
+      static void OnTimer(void* arg, grpc_error_handle error);
+      void OnTimerLocked(grpc_error_handle);
+
+      RefCountedPtr<ChildPriority> child_priority_;
+      grpc_timer timer_;
+      grpc_closure on_timer_;
+      bool timer_pending_ = true;
+    };
+
     // Methods for dealing with the child policy.
     OrphanablePtr<LoadBalancingPolicy> CreateChildPolicyLocked(
         const grpc_channel_args* args);
@@ -178,13 +207,6 @@
         grpc_connectivity_state state, const absl::Status& status,
         std::unique_ptr<SubchannelPicker> picker);
 
-    void StartFailoverTimerLocked();
-
-    static void OnFailoverTimer(void* arg, grpc_error_handle error);
-    void OnFailoverTimerLocked(grpc_error_handle error);
-    static void OnDeactivationTimer(void* arg, grpc_error_handle error);
-    void OnDeactivationTimerLocked(grpc_error_handle error);
-
     RefCountedPtr<PriorityLb> priority_policy_;
     const std::string name_;
     bool ignore_reresolution_requests_ = false;
@@ -195,15 +217,8 @@
     absl::Status connectivity_status_;
     RefCountedPtr<RefCountedPicker> picker_wrapper_;
 
-    // States for delayed removal.
-    grpc_timer deactivation_timer_;
-    grpc_closure on_deactivation_timer_;
-    bool deactivation_timer_callback_pending_ = false;
-
-    // States of failover.
-    grpc_timer failover_timer_;
-    grpc_closure on_failover_timer_;
-    bool failover_timer_callback_pending_ = false;
+    OrphanablePtr<DeactivationTimer> deactivation_timer_;
+    OrphanablePtr<FailoverTimer> failover_timer_;
   };
 
   ~PriorityLb() override;
@@ -451,7 +466,7 @@
     }
     // Child is not READY or IDLE.
     // If its failover timer is still pending, give it time to fire.
-    if (child->failover_timer_callback_pending()) {
+    if (child->FailoverTimerPending()) {
       if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
         gpr_log(GPR_INFO,
                 "[priority_lb %p] priority %u, child %s: child still "
@@ -502,6 +517,132 @@
 }
 
 //
+// PriorityLb::ChildPriority::DeactivationTimer
+//
+
+PriorityLb::ChildPriority::DeactivationTimer::DeactivationTimer(
+    RefCountedPtr<PriorityLb::ChildPriority> child_priority)
+    : child_priority_(std::move(child_priority)) {
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+    gpr_log(GPR_INFO,
+            "[priority_lb %p] child %s (%p): deactivating -- will remove in "
+            "%" PRId64 "ms",
+            child_priority_->priority_policy_.get(),
+            child_priority_->name_.c_str(), child_priority_.get(),
+            kChildRetentionInterval.millis());
+  }
+  GRPC_CLOSURE_INIT(&on_timer_, OnTimer, this, nullptr);
+  Ref(DEBUG_LOCATION, "Timer").release();
+  grpc_timer_init(&timer_, ExecCtx::Get()->Now() + kChildRetentionInterval,
+                  &on_timer_);
+}
+
+void PriorityLb::ChildPriority::DeactivationTimer::Orphan() {
+  if (timer_pending_) {
+    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+      gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): reactivating",
+              child_priority_->priority_policy_.get(),
+              child_priority_->name_.c_str(), child_priority_.get());
+    }
+    timer_pending_ = false;
+    grpc_timer_cancel(&timer_);
+  }
+  Unref();
+}
+
+void PriorityLb::ChildPriority::DeactivationTimer::OnTimer(
+    void* arg, grpc_error_handle error) {
+  auto* self = static_cast<DeactivationTimer*>(arg);
+  (void)GRPC_ERROR_REF(error);  // ref owned by lambda
+  self->child_priority_->priority_policy_->work_serializer()->Run(
+      [self, error]() { self->OnTimerLocked(error); }, DEBUG_LOCATION);
+}
+
+void PriorityLb::ChildPriority::DeactivationTimer::OnTimerLocked(
+    grpc_error_handle error) {
+  if (error == GRPC_ERROR_NONE && timer_pending_) {
+    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+      gpr_log(GPR_INFO,
+              "[priority_lb %p] child %s (%p): deactivation timer fired, "
+              "deleting child",
+              child_priority_->priority_policy_.get(),
+              child_priority_->name_.c_str(), child_priority_.get());
+    }
+    timer_pending_ = false;
+    child_priority_->priority_policy_->DeleteChild(child_priority_.get());
+  }
+  Unref(DEBUG_LOCATION, "Timer");
+  GRPC_ERROR_UNREF(error);
+}
+
+//
+// PriorityLb::ChildPriority::FailoverTimer
+//
+
+PriorityLb::ChildPriority::FailoverTimer::FailoverTimer(
+    RefCountedPtr<PriorityLb::ChildPriority> child_priority)
+    : child_priority_(std::move(child_priority)) {
+  if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+    gpr_log(
+        GPR_INFO,
+        "[priority_lb %p] child %s (%p): starting failover timer for %" PRId64
+        "ms",
+        child_priority_->priority_policy_.get(), child_priority_->name_.c_str(),
+        child_priority_.get(),
+        child_priority_->priority_policy_->child_failover_timeout_.millis());
+  }
+  GRPC_CLOSURE_INIT(&on_timer_, OnTimer, this, nullptr);
+  Ref(DEBUG_LOCATION, "Timer").release();
+  grpc_timer_init(
+      &timer_,
+      ExecCtx::Get()->Now() +
+          child_priority_->priority_policy_->child_failover_timeout_,
+      &on_timer_);
+}
+
+void PriorityLb::ChildPriority::FailoverTimer::Orphan() {
+  if (timer_pending_) {
+    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+      gpr_log(GPR_INFO,
+              "[priority_lb %p] child %s (%p): cancelling failover timer",
+              child_priority_->priority_policy_.get(),
+              child_priority_->name_.c_str(), child_priority_.get());
+    }
+    timer_pending_ = false;
+    grpc_timer_cancel(&timer_);
+  }
+  Unref();
+}
+
+void PriorityLb::ChildPriority::FailoverTimer::OnTimer(
+    void* arg, grpc_error_handle error) {
+  auto* self = static_cast<FailoverTimer*>(arg);
+  (void)GRPC_ERROR_REF(error);  // ref owned by lambda
+  self->child_priority_->priority_policy_->work_serializer()->Run(
+      [self, error]() { self->OnTimerLocked(error); }, DEBUG_LOCATION);
+}
+
+void PriorityLb::ChildPriority::FailoverTimer::OnTimerLocked(
+    grpc_error_handle error) {
+  if (error == GRPC_ERROR_NONE && timer_pending_) {
+    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
+      gpr_log(GPR_INFO,
+              "[priority_lb %p] child %s (%p): failover timer fired, "
+              "reporting TRANSIENT_FAILURE",
+              child_priority_->priority_policy_.get(),
+              child_priority_->name_.c_str(), child_priority_.get());
+    }
+    timer_pending_ = false;
+    child_priority_->OnConnectivityStateUpdateLocked(
+        GRPC_CHANNEL_TRANSIENT_FAILURE,
+        absl::Status(absl::StatusCode::kUnavailable, "failover timer fired"),
+        nullptr);
+  }
+  Unref(DEBUG_LOCATION, "Timer");
+  GRPC_ERROR_UNREF(error);
+}
+
+//
 // PriorityLb::ChildPriority
 //
 
@@ -512,12 +653,8 @@
     gpr_log(GPR_INFO, "[priority_lb %p] creating child %s (%p)",
             priority_policy_.get(), name_.c_str(), this);
   }
-  GRPC_CLOSURE_INIT(&on_failover_timer_, OnFailoverTimer, this,
-                    grpc_schedule_on_exec_ctx);
-  GRPC_CLOSURE_INIT(&on_deactivation_timer_, OnDeactivationTimer, this,
-                    grpc_schedule_on_exec_ctx);
   // Start the failover timer.
-  StartFailoverTimerLocked();
+  failover_timer_ = MakeOrphanable<FailoverTimer>(Ref());
 }
 
 void PriorityLb::ChildPriority::Orphan() {
@@ -525,10 +662,8 @@
     gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): orphaned",
             priority_policy_.get(), name_.c_str(), this);
   }
-  MaybeCancelFailoverTimerLocked();
-  if (deactivation_timer_callback_pending_) {
-    grpc_timer_cancel(&deactivation_timer_);
-  }
+  failover_timer_.reset();
+  deactivation_timer_.reset();
   // Remove the child policy's interested_parties pollset_set from the
   // xDS policy.
   grpc_pollset_set_del_pollset_set(child_policy_->interested_parties(),
@@ -537,9 +672,6 @@
   // Drop our ref to the child's picker, in case it's holding a ref to
   // the child.
   picker_wrapper_.reset();
-  if (deactivation_timer_callback_pending_) {
-    grpc_timer_cancel(&deactivation_timer_);
-  }
   Unref(DEBUG_LOCATION, "ChildPriority+Orphan");
 }
 
@@ -600,9 +732,8 @@
 }
 
 void PriorityLb::ChildPriority::ExitIdleLocked() {
-  if (connectivity_state_ == GRPC_CHANNEL_IDLE &&
-      !failover_timer_callback_pending_) {
-    StartFailoverTimerLocked();
+  if (connectivity_state_ == GRPC_CHANNEL_IDLE && failover_timer_ == nullptr) {
+    failover_timer_ = MakeOrphanable<FailoverTimer>(Ref());
   }
   child_policy_->ExitIdleLocked();
 }
@@ -628,122 +759,21 @@
   // If READY or IDLE or TRANSIENT_FAILURE, cancel failover timer.
   if (state == GRPC_CHANNEL_READY || state == GRPC_CHANNEL_IDLE ||
       state == GRPC_CHANNEL_TRANSIENT_FAILURE) {
-    MaybeCancelFailoverTimerLocked();
+    failover_timer_.reset();
   }
   // Notify the parent policy.
   priority_policy_->HandleChildConnectivityStateChangeLocked(this);
 }
 
-void PriorityLb::ChildPriority::StartFailoverTimerLocked() {
-  if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-    gpr_log(
-        GPR_INFO,
-        "[priority_lb %p] child %s (%p): starting failover timer for %" PRId64
-        "ms",
-        priority_policy_.get(), name_.c_str(), this,
-        priority_policy_->child_failover_timeout_.millis());
-  }
-  Ref(DEBUG_LOCATION, "ChildPriority+OnFailoverTimerLocked").release();
-  grpc_timer_init(
-      &failover_timer_,
-      ExecCtx::Get()->Now() + priority_policy_->child_failover_timeout_,
-      &on_failover_timer_);
-  failover_timer_callback_pending_ = true;
-}
-
-void PriorityLb::ChildPriority::MaybeCancelFailoverTimerLocked() {
-  if (failover_timer_callback_pending_) {
-    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-      gpr_log(GPR_INFO,
-              "[priority_lb %p] child %s (%p): cancelling failover timer",
-              priority_policy_.get(), name_.c_str(), this);
-    }
-    grpc_timer_cancel(&failover_timer_);
-    failover_timer_callback_pending_ = false;
-  }
-}
-
-void PriorityLb::ChildPriority::OnFailoverTimer(void* arg,
-                                                grpc_error_handle error) {
-  ChildPriority* self = static_cast<ChildPriority*>(arg);
-  (void)GRPC_ERROR_REF(error);  // ref owned by lambda
-  self->priority_policy_->work_serializer()->Run(
-      [self, error]() { self->OnFailoverTimerLocked(error); }, DEBUG_LOCATION);
-}
-
-void PriorityLb::ChildPriority::OnFailoverTimerLocked(grpc_error_handle error) {
-  if (error == GRPC_ERROR_NONE && failover_timer_callback_pending_ &&
-      !priority_policy_->shutting_down_) {
-    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-      gpr_log(GPR_INFO,
-              "[priority_lb %p] child %s (%p): failover timer fired, "
-              "reporting TRANSIENT_FAILURE",
-              priority_policy_.get(), name_.c_str(), this);
-    }
-    failover_timer_callback_pending_ = false;
-    OnConnectivityStateUpdateLocked(
-        GRPC_CHANNEL_TRANSIENT_FAILURE,
-        absl::Status(absl::StatusCode::kUnavailable, "failover timer fired"),
-        nullptr);
-  }
-  Unref(DEBUG_LOCATION, "ChildPriority+OnFailoverTimerLocked");
-  GRPC_ERROR_UNREF(error);
-}
-
 void PriorityLb::ChildPriority::DeactivateLocked() {
   // If already deactivated, don't do it again.
-  if (deactivation_timer_callback_pending_) return;
-  if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-    gpr_log(GPR_INFO,
-            "[priority_lb %p] child %s (%p): deactivating -- will remove in "
-            "%" PRId64 "ms.",
-            priority_policy_.get(), name_.c_str(), this,
-            kChildRetentionInterval.millis());
-  }
-  MaybeCancelFailoverTimerLocked();
-  // Start a timer to delete the child.
-  Ref(DEBUG_LOCATION, "ChildPriority+timer").release();
-  grpc_timer_init(&deactivation_timer_,
-                  ExecCtx::Get()->Now() + kChildRetentionInterval,
-                  &on_deactivation_timer_);
-  deactivation_timer_callback_pending_ = true;
+  if (deactivation_timer_ != nullptr) return;
+  failover_timer_.reset();
+  deactivation_timer_ = MakeOrphanable<DeactivationTimer>(Ref());
 }
 
 void PriorityLb::ChildPriority::MaybeReactivateLocked() {
-  if (deactivation_timer_callback_pending_) {
-    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-      gpr_log(GPR_INFO, "[priority_lb %p] child %s (%p): reactivating",
-              priority_policy_.get(), name_.c_str(), this);
-    }
-    deactivation_timer_callback_pending_ = false;
-    grpc_timer_cancel(&deactivation_timer_);
-  }
-}
-
-void PriorityLb::ChildPriority::OnDeactivationTimer(void* arg,
-                                                    grpc_error_handle error) {
-  ChildPriority* self = static_cast<ChildPriority*>(arg);
-  (void)GRPC_ERROR_REF(error);  // ref owned by lambda
-  self->priority_policy_->work_serializer()->Run(
-      [self, error]() { self->OnDeactivationTimerLocked(error); },
-      DEBUG_LOCATION);
-}
-
-void PriorityLb::ChildPriority::OnDeactivationTimerLocked(
-    grpc_error_handle error) {
-  if (error == GRPC_ERROR_NONE && deactivation_timer_callback_pending_ &&
-      !priority_policy_->shutting_down_) {
-    if (GRPC_TRACE_FLAG_ENABLED(grpc_lb_priority_trace)) {
-      gpr_log(GPR_INFO,
-              "[priority_lb %p] child %s (%p): deactivation timer fired, "
-              "deleting child",
-              priority_policy_.get(), name_.c_str(), this);
-    }
-    deactivation_timer_callback_pending_ = false;
-    priority_policy_->DeleteChild(this);
-  }
-  Unref(DEBUG_LOCATION, "ChildPriority+timer");
-  GRPC_ERROR_UNREF(error);
+  deactivation_timer_.reset();
 }
 
 //