ANDROID: usb: f_accessory: Add refcounting to global 'acc_dev'

Add refcounting to track the lifetime of the global 'acc_dev' structure,
as the underlying function directories can be removed while references
still exist to the dev node.

Bug: 173789633
Signed-off-by: Will Deacon <willdeacon@google.com>
Change-Id: I248408e890d01167706c329146d63b64a6456df6
Signed-off-by: Giuliano Procida <gprocida@google.com>
diff --git a/drivers/usb/gadget/function/f_accessory.c b/drivers/usb/gadget/function/f_accessory.c
index b91cd38..97a6bd8 100644
--- a/drivers/usb/gadget/function/f_accessory.c
+++ b/drivers/usb/gadget/function/f_accessory.c
@@ -27,6 +27,7 @@
 #include <linux/interrupt.h>
 #include <linux/kthread.h>
 #include <linux/freezer.h>
+#include <linux/kref.h>
 
 #include <linux/types.h>
 #include <linux/file.h>
@@ -73,6 +74,7 @@
 	struct usb_function function;
 	struct usb_composite_dev *cdev;
 	spinlock_t lock;
+	struct acc_dev_ref *ref;
 
 	struct usb_ep *ep_in;
 	struct usb_ep *ep_out;
@@ -199,7 +201,14 @@
 	NULL,
 };
 
-static struct acc_dev *_acc_dev;
+struct acc_dev_ref {
+	struct kref	kref;
+	struct acc_dev	*acc_dev;
+};
+
+static struct acc_dev_ref _acc_dev_ref = {
+	.kref = KREF_INIT(0),
+};
 
 struct acc_instance {
 	struct usb_function_instance func_inst;
@@ -208,11 +217,26 @@
 
 static struct acc_dev *get_acc_dev(void)
 {
-	return _acc_dev;
+	struct acc_dev_ref *ref = &_acc_dev_ref;
+
+	return kref_get_unless_zero(&ref->kref) ? ref->acc_dev : NULL;
+}
+
+static void __put_acc_dev(struct kref *kref)
+{
+	struct acc_dev_ref *ref = container_of(kref, struct acc_dev_ref, kref);
+	struct acc_dev *dev = ref->acc_dev;
+
+	ref->acc_dev = NULL;
+	kfree(dev);
 }
 
 static void put_acc_dev(struct acc_dev *dev)
 {
+	struct acc_dev_ref *ref = dev->ref;
+
+	WARN_ON(ref->acc_dev != dev);
+	kref_put(&ref->kref, __put_acc_dev);
 }
 
 static inline struct acc_dev *func_to_dev(struct usb_function *f)
@@ -1230,6 +1254,7 @@
 
 static int acc_setup(void)
 {
+	struct acc_dev_ref *ref = &_acc_dev_ref;
 	struct acc_dev *dev;
 	int ret;
 
@@ -1248,7 +1273,9 @@
 	INIT_DELAYED_WORK(&dev->start_work, acc_start_work);
 	INIT_WORK(&dev->hid_work, acc_hid_work);
 
-	_acc_dev = dev;
+	dev->ref = ref;
+	kref_init(&ref->kref);
+	ref->acc_dev = dev;
 
 	ret = misc_register(&acc_device);
 	if (ret)
@@ -1276,12 +1303,11 @@
 
 static void acc_cleanup(void)
 {
-	struct acc_dev *dev = _acc_dev;
+	struct acc_dev *dev = get_acc_dev();
 
 	misc_deregister(&acc_device);
 	put_acc_dev(dev);
-	kfree(dev);
-	_acc_dev = NULL;
+	put_acc_dev(dev); /* Pairs with kref_init() in acc_setup() */
 }
 static struct acc_instance *to_acc_instance(struct config_item *item)
 {