ANDROID: qtaguid: Fix the UAF probelm with tag_ref_tree

When multiple threads is trying to tag/delete the same socket at the
same time, there is a chance the tag_ref_entry of the target socket to
be null before the uid_tag_data entry is freed. It is caused by the
ctrl_cmd_tag function where it doesn't correctly grab the spinlocks
when tagging a socket.

Signed-off-by: Chenbo Feng <fengc@google.com>
Bug: 65853158
Change-Id: I5d89885918054cf835370a52bff2d693362ac5f0
diff --git a/net/netfilter/xt_qtaguid.c b/net/netfilter/xt_qtaguid.c
index 3bf0c59..7536cac 100644
--- a/net/netfilter/xt_qtaguid.c
+++ b/net/netfilter/xt_qtaguid.c
@@ -519,13 +519,11 @@
 
 	DR_DEBUG("qtaguid: get_tag_ref(0x%llx)\n",
 		 full_tag);
-	spin_lock_bh(&uid_tag_data_tree_lock);
 	tr_entry = lookup_tag_ref(full_tag, &utd_entry);
 	BUG_ON(IS_ERR_OR_NULL(utd_entry));
 	if (!tr_entry)
 		tr_entry = new_tag_ref(full_tag, utd_entry);
 
-	spin_unlock_bh(&uid_tag_data_tree_lock);
 	if (utd_res)
 		*utd_res = utd_entry;
 	DR_DEBUG("qtaguid: get_tag_ref(0x%llx) utd=%p tr=%p\n",
@@ -2038,6 +2036,7 @@
 
 	/* Delete socket tags */
 	spin_lock_bh(&sock_tag_list_lock);
+	spin_lock_bh(&uid_tag_data_tree_lock);
 	node = rb_first(&sock_tag_tree);
 	while (node) {
 		st_entry = rb_entry(node, struct sock_tag, sock_node);
@@ -2067,6 +2066,7 @@
 				list_del(&st_entry->list);
 		}
 	}
+	spin_unlock_bh(&uid_tag_data_tree_lock);
 	spin_unlock_bh(&sock_tag_list_lock);
 
 	sock_tag_tree_erase(&st_to_free_tree);
@@ -2276,10 +2276,12 @@
 	full_tag = combine_atag_with_uid(acct_tag, uid_int);
 
 	spin_lock_bh(&sock_tag_list_lock);
+	spin_lock_bh(&uid_tag_data_tree_lock);
 	sock_tag_entry = get_sock_stat_nl(el_socket->sk);
 	tag_ref_entry = get_tag_ref(full_tag, &uid_tag_data_entry);
 	if (IS_ERR(tag_ref_entry)) {
 		res = PTR_ERR(tag_ref_entry);
+		spin_unlock_bh(&uid_tag_data_tree_lock);
 		spin_unlock_bh(&sock_tag_list_lock);
 		goto err_put;
 	}
@@ -2313,15 +2315,19 @@
 			pr_err("qtaguid: ctrl_tag(%s): "
 			       "socket tag alloc failed\n",
 			       input);
+			BUG_ON(tag_ref_entry->num_sock_tags <= 0);
+			tag_ref_entry->num_sock_tags--;
+			free_tag_ref_from_utd_entry(tag_ref_entry,
+						    uid_tag_data_entry);
+			spin_unlock_bh(&uid_tag_data_tree_lock);
 			spin_unlock_bh(&sock_tag_list_lock);
 			res = -ENOMEM;
-			goto err_tag_unref_put;
+			goto err_put;
 		}
 		sock_tag_entry->sk = el_socket->sk;
 		sock_tag_entry->socket = el_socket;
 		sock_tag_entry->pid = current->tgid;
 		sock_tag_entry->tag = combine_atag_with_uid(acct_tag, uid_int);
-		spin_lock_bh(&uid_tag_data_tree_lock);
 		pqd_entry = proc_qtu_data_tree_search(
 			&proc_qtu_data_tree, current->tgid);
 		/*
@@ -2339,11 +2345,11 @@
 		else
 			list_add(&sock_tag_entry->list,
 				 &pqd_entry->sock_tag_list);
-		spin_unlock_bh(&uid_tag_data_tree_lock);
 
 		sock_tag_tree_insert(sock_tag_entry, &sock_tag_tree);
 		atomic64_inc(&qtu_events.sockets_tagged);
 	}
+	spin_unlock_bh(&uid_tag_data_tree_lock);
 	spin_unlock_bh(&sock_tag_list_lock);
 	/* We keep the ref to the socket (file) until it is untagged */
 	CT_DEBUG("qtaguid: ctrl_tag(%s): done st@%p ...->f_count=%ld\n",
@@ -2351,10 +2357,6 @@
 		 atomic_long_read(&el_socket->file->f_count));
 	return 0;
 
-err_tag_unref_put:
-	BUG_ON(tag_ref_entry->num_sock_tags <= 0);
-	tag_ref_entry->num_sock_tags--;
-	free_tag_ref_from_utd_entry(tag_ref_entry, uid_tag_data_entry);
 err_put:
 	CT_DEBUG("qtaguid: ctrl_tag(%s): done. ...->f_count=%ld\n",
 		 input, atomic_long_read(&el_socket->file->f_count) - 1);