Fix callout stop deadlock
If a callout stop is called from the executing callout's func on the
same thread, the callout will deadlock itself waiting on wait_cond.
This adds a check to prevent that.
diff --git a/usrsctplib/netinet/sctp_callout.c b/usrsctplib/netinet/sctp_callout.c
index 849516e..5f04511 100755
--- a/usrsctplib/netinet/sctp_callout.c
+++ b/usrsctplib/netinet/sctp_callout.c
@@ -78,6 +78,7 @@
* - SCTP_BASE_INFO(callqueue)
* - sctp_os_timer_next: next timer to check
* - sctp_os_timer_current: current callout callback in progress
+ * - sctp_os_timer_current_tid: current callout thread id in progress
* - sctp_os_timer_waiting: some thread is waiting for callout to complete
* - sctp_os_timer_wait_ctr: incremented every time a thread wants to wait
* for a callout to complete.
@@ -86,6 +87,7 @@
static sctp_os_timer_t *sctp_os_timer_current = NULL;
static int sctp_os_timer_waiting = 0;
static int sctp_os_timer_wait_ctr = 0;
+static userland_thread_id_t sctp_os_timer_current_tid;
/*
* SCTP_TIMERWAIT_LOCK (sctp_os_timerwait_mtx) protects:
@@ -171,6 +173,18 @@
SCTP_TIMERQ_UNLOCK();
return (0);
} else {
+ /*
+ * Deleting the callout from the currently running
+ * callout from the same thread, so just return
+ */
+ userland_thread_id_t tid;
+ sctp_userspace_thread_id(&tid);
+ if (sctp_userspace_thread_equal(tid,
+ sctp_os_timer_current_tid)) {
+ SCTP_TIMERQ_UNLOCK();
+ return (0);
+ }
+
/* need to wait until the callout is finished */
sctp_os_timer_waiting = 1;
wakeup_cookie = ++sctp_os_timer_wait_ctr;
@@ -223,6 +237,7 @@
c_arg = c->c_arg;
c->c_flags &= ~SCTP_CALLOUT_PENDING;
sctp_os_timer_current = c;
+ sctp_userspace_thread_id(&sctp_os_timer_current_tid);
SCTP_TIMERQ_UNLOCK();
c_func(c_arg);
SCTP_TIMERQ_LOCK();
diff --git a/usrsctplib/netinet/sctp_os_userspace.h b/usrsctplib/netinet/sctp_os_userspace.h
index 9b1e567..a97b515 100755
--- a/usrsctplib/netinet/sctp_os_userspace.h
+++ b/usrsctplib/netinet/sctp_os_userspace.h
@@ -76,6 +76,7 @@
typedef CONDITION_VARIABLE userland_cond_t;
#endif
typedef HANDLE userland_thread_t;
+typedef DWORD userland_thread_id_t;
#define ADDRESS_FAMILY unsigned __int8
#define IPVERSION 4
#define MAXTTL 255
@@ -282,6 +283,7 @@
typedef pthread_mutex_t userland_mutex_t;
typedef pthread_cond_t userland_cond_t;
typedef pthread_t userland_thread_t;
+typedef pthread_t userland_thread_id_t;
#endif
#if defined(__Userspace_os_Windows) || defined(__Userspace_os_NaCl)
@@ -1039,6 +1041,9 @@
void
sctp_userspace_set_threadname(const char *name);
+int sctp_userspace_thread_id(userland_thread_id_t *thread);
+int sctp_userspace_thread_equal(userland_thread_id_t t1, userland_thread_id_t t2);
+
/*
* SCTP protocol specific mbuf flags.
*/
diff --git a/usrsctplib/netinet/sctp_userspace.c b/usrsctplib/netinet/sctp_userspace.c
index 14c04c6..be54783 100755
--- a/usrsctplib/netinet/sctp_userspace.c
+++ b/usrsctplib/netinet/sctp_userspace.c
@@ -61,12 +61,39 @@
return GetLastError();
return 0;
}
+
+int
+sctp_userspace_thread_id(userland_thread_id_t *thread)
+{
+ *thread = GetCurrentThreadId();
+ return 0;
+}
+
+int
+sctp_userspace_thread_equal(userland_thread_id_t t1, userland_thread_id_t t2)
+{
+ return (t1 == t2);
+}
+
#else
int
sctp_userspace_thread_create(userland_thread_t *thread, start_routine_t start_routine)
{
return pthread_create(thread, NULL, start_routine, NULL);
}
+
+int
+sctp_userspace_thread_id(userland_thread_id_t *thread)
+{
+ *thread = pthread_self();
+ return 0;
+}
+
+int
+sctp_userspace_thread_equal(userland_thread_id_t t1, userland_thread_id_t t2)
+{
+ return pthread_equal(t1, t2);
+}
#endif
void