firewire: cdev: reference-count client instances
[cascardo/linux.git] / drivers / firewire / fw-cdev.c
index 40cc973..81362c1 100644 (file)
 
 #include <linux/module.h>
 #include <linux/kernel.h>
+#include <linux/kref.h>
 #include <linux/wait.h>
 #include <linux/errno.h>
 #include <linux/device.h>
 #include <linux/vmalloc.h>
+#include <linux/mutex.h>
 #include <linux/poll.h>
 #include <linux/preempt.h>
 #include <linux/time.h>
 #include "fw-device.h"
 
 struct client;
+struct client_resource;
+typedef void (*client_resource_release_fn_t)(struct client *,
+                                            struct client_resource *);
 struct client_resource {
-       struct list_head link;
-       void (*release)(struct client *client, struct client_resource *r);
-       u32 handle;
+       client_resource_release_fn_t release;
+       int handle;
 };
 
 /*
@@ -77,9 +81,10 @@ struct iso_interrupt {
 struct client {
        u32 version;
        struct fw_device *device;
+
        spinlock_t lock;
-       u32 resource_handle;
-       struct list_head resource_list;
+       bool in_shutdown;
+       struct idr resource_idr;
        struct list_head event_list;
        wait_queue_head_t wait;
        u64 bus_reset_closure;
@@ -90,16 +95,33 @@ struct client {
        unsigned long vm_start;
 
        struct list_head link;
+       struct kref kref;
 };
 
-static inline void __user *
-u64_to_uptr(__u64 value)
+static inline void client_get(struct client *client)
+{
+       kref_get(&client->kref);
+}
+
+static void client_release(struct kref *kref)
+{
+       struct client *client = container_of(kref, struct client, kref);
+
+       fw_device_put(client->device);
+       kfree(client);
+}
+
+static void client_put(struct client *client)
+{
+       kref_put(&client->kref, client_release);
+}
+
+static inline void __user *u64_to_uptr(__u64 value)
 {
        return (void __user *)(unsigned long)value;
 }
 
-static inline __u64
-uptr_to_u64(void __user *ptr)
+static inline __u64 uptr_to_u64(void __user *ptr)
 {
        return (__u64)(unsigned long)ptr;
 }
@@ -108,7 +130,6 @@ static int fw_device_op_open(struct inode *inode, struct file *file)
 {
        struct fw_device *device;
        struct client *client;
-       unsigned long flags;
 
        device = fw_device_get_by_devt(inode->i_rdev);
        if (device == NULL)
@@ -126,16 +147,17 @@ static int fw_device_op_open(struct inode *inode, struct file *file)
        }
 
        client->device = device;
-       INIT_LIST_HEAD(&client->event_list);
-       INIT_LIST_HEAD(&client->resource_list);
        spin_lock_init(&client->lock);
+       idr_init(&client->resource_idr);
+       INIT_LIST_HEAD(&client->event_list);
        init_waitqueue_head(&client->wait);
+       kref_init(&client->kref);
 
        file->private_data = client;
 
-       spin_lock_irqsave(&device->client_list_lock, flags);
+       mutex_lock(&device->client_list_mutex);
        list_add_tail(&client->link, &device->client_list);
-       spin_unlock_irqrestore(&device->client_list_lock, flags);
+       mutex_unlock(&device->client_list_mutex);
 
        return 0;
 }
@@ -151,32 +173,35 @@ static void queue_event(struct client *client, struct event *event,
        event->v[1].size = size1;
 
        spin_lock_irqsave(&client->lock, flags);
-       list_add_tail(&event->link, &client->event_list);
+       if (client->in_shutdown)
+               kfree(event);
+       else
+               list_add_tail(&event->link, &client->event_list);
        spin_unlock_irqrestore(&client->lock, flags);
 
        wake_up_interruptible(&client->wait);
 }
 
-static int
-dequeue_event(struct client *client, char __user *buffer, size_t count)
+static int dequeue_event(struct client *client,
+                        char __user *buffer, size_t count)
 {
        unsigned long flags;
        struct event *event;
        size_t size, total;
-       int i, retval;
+       int i, ret;
 
-       retval = wait_event_interruptible(client->wait,
-                                         !list_empty(&client->event_list) ||
-                                         fw_device_is_shutdown(client->device));
-       if (retval < 0)
-               return retval;
+       ret = wait_event_interruptible(client->wait,
+                       !list_empty(&client->event_list) ||
+                       fw_device_is_shutdown(client->device));
+       if (ret < 0)
+               return ret;
 
        if (list_empty(&client->event_list) &&
                       fw_device_is_shutdown(client->device))
                return -ENODEV;
 
        spin_lock_irqsave(&client->lock, flags);
-       event = container_of(client->event_list.next, struct event, link);
+       event = list_first_entry(&client->event_list, struct event, link);
        list_del(&event->link);
        spin_unlock_irqrestore(&client->lock, flags);
 
@@ -184,31 +209,29 @@ dequeue_event(struct client *client, char __user *buffer, size_t count)
        for (i = 0; i < ARRAY_SIZE(event->v) && total < count; i++) {
                size = min(event->v[i].size, count - total);
                if (copy_to_user(buffer + total, event->v[i].data, size)) {
-                       retval = -EFAULT;
+                       ret = -EFAULT;
                        goto out;
                }
                total += size;
        }
-       retval = total;
+       ret = total;
 
  out:
        kfree(event);
 
-       return retval;
+       return ret;
 }
 
-static ssize_t
-fw_device_op_read(struct file *file,
-                 char __user *buffer, size_t count, loff_t *offset)
+static ssize_t fw_device_op_read(struct file *file, char __user *buffer,
+                                size_t count, loff_t *offset)
 {
        struct client *client = file->private_data;
 
        return dequeue_event(client, buffer, count);
 }
 
-static void
-fill_bus_reset_event(struct fw_cdev_event_bus_reset *event,
-                    struct client *client)
+static void fill_bus_reset_event(struct fw_cdev_event_bus_reset *event,
+                                struct client *client)
 {
        struct fw_card *card = client->device->card;
        unsigned long flags;
@@ -227,27 +250,22 @@ fill_bus_reset_event(struct fw_cdev_event_bus_reset *event,
        spin_unlock_irqrestore(&card->lock, flags);
 }
 
-static void
-for_each_client(struct fw_device *device,
-               void (*callback)(struct client *client))
+static void for_each_client(struct fw_device *device,
+                           void (*callback)(struct client *client))
 {
        struct client *c;
-       unsigned long flags;
-
-       spin_lock_irqsave(&device->client_list_lock, flags);
 
+       mutex_lock(&device->client_list_mutex);
        list_for_each_entry(c, &device->client_list, link)
                callback(c);
-
-       spin_unlock_irqrestore(&device->client_list_lock, flags);
+       mutex_unlock(&device->client_list_mutex);
 }
 
-static void
-queue_bus_reset_event(struct client *client)
+static void queue_bus_reset_event(struct client *client)
 {
        struct bus_reset *bus_reset;
 
-       bus_reset = kzalloc(sizeof(*bus_reset), GFP_ATOMIC);
+       bus_reset = kzalloc(sizeof(*bus_reset), GFP_KERNEL);
        if (bus_reset == NULL) {
                fw_notify("Out of memory when allocating bus reset event\n");
                return;
@@ -313,34 +331,49 @@ static int ioctl_get_info(struct client *client, void *buffer)
        return 0;
 }
 
-static void
-add_client_resource(struct client *client, struct client_resource *resource)
+static int add_client_resource(struct client *client,
+                              struct client_resource *resource, gfp_t gfp_mask)
 {
        unsigned long flags;
+       int ret;
+
+ retry:
+       if (idr_pre_get(&client->resource_idr, gfp_mask) == 0)
+               return -ENOMEM;
 
        spin_lock_irqsave(&client->lock, flags);
-       list_add_tail(&resource->link, &client->resource_list);
-       resource->handle = client->resource_handle++;
+       if (client->in_shutdown)
+               ret = -ECANCELED;
+       else
+               ret = idr_get_new(&client->resource_idr, resource,
+                                 &resource->handle);
+       if (ret >= 0)
+               client_get(client);
        spin_unlock_irqrestore(&client->lock, flags);
+
+       if (ret == -EAGAIN)
+               goto retry;
+
+       return ret < 0 ? ret : 0;
 }
 
-static int
-release_client_resource(struct client *client, u32 handle,
-                       struct client_resource **resource)
+static int release_client_resource(struct client *client, u32 handle,
+                                  client_resource_release_fn_t release,
+                                  struct client_resource **resource)
 {
        struct client_resource *r;
        unsigned long flags;
 
        spin_lock_irqsave(&client->lock, flags);
-       list_for_each_entry(r, &client->resource_list, link) {
-               if (r->handle == handle) {
-                       list_del(&r->link);
-                       break;
-               }
-       }
+       if (client->in_shutdown)
+               r = NULL;
+       else
+               r = idr_find(&client->resource_idr, handle);
+       if (r && r->release == release)
+               idr_remove(&client->resource_idr, handle);
        spin_unlock_irqrestore(&client->lock, flags);
 
-       if (&r->link == &client->resource_list)
+       if (!(r && r->release == release))
                return -EINVAL;
 
        if (resource)
@@ -348,11 +381,13 @@ release_client_resource(struct client *client, u32 handle,
        else
                r->release(client, r);
 
+       client_put(client);
+
        return 0;
 }
 
-static void
-release_transaction(struct client *client, struct client_resource *resource)
+static void release_transaction(struct client *client,
+                               struct client_resource *resource)
 {
        struct response *response =
                container_of(resource, struct response, resource);
@@ -360,9 +395,8 @@ release_transaction(struct client *client, struct client_resource *resource)
        fw_cancel_transaction(client->device->card, &response->transaction);
 }
 
-static void
-complete_transaction(struct fw_card *card, int rcode,
-                    void *payload, size_t length, void *data)
+static void complete_transaction(struct fw_card *card, int rcode,
+                                void *payload, size_t length, void *data)
 {
        struct response *response = data;
        struct client *client = response->client;
@@ -375,7 +409,22 @@ complete_transaction(struct fw_card *card, int rcode,
                memcpy(r->data, payload, r->length);
 
        spin_lock_irqsave(&client->lock, flags);
-       list_del(&response->resource.link);
+       /*
+        * 1. If called while in shutdown, the idr tree must be left untouched.
+        *    The idr handle will be removed and the client reference will be
+        *    dropped later.
+        * 2. If the call chain was release_client_resource ->
+        *    release_transaction -> complete_transaction (instead of a normal
+        *    conclusion of the transaction), i.e. if this resource was already
+        *    unregistered from the idr, the client reference will be dropped
+        *    by release_client_resource and we must not drop it here.
+        */
+       if (!client->in_shutdown &&
+           idr_find(&client->resource_idr, response->resource.handle)) {
+               idr_remove(&client->resource_idr, response->resource.handle);
+               /* Drop the idr's reference */
+               client_put(client);
+       }
        spin_unlock_irqrestore(&client->lock, flags);
 
        r->type   = FW_CDEV_EVENT_RESPONSE;
@@ -394,6 +443,9 @@ complete_transaction(struct fw_card *card, int rcode,
        else
                queue_event(client, &response->event, r, sizeof(*r) + r->length,
                            NULL, 0);
+
+       /* Drop the transaction callback's reference */
+       client_put(client);
 }
 
 static int ioctl_send_request(struct client *client, void *buffer)
@@ -401,6 +453,7 @@ static int ioctl_send_request(struct client *client, void *buffer)
        struct fw_device *device = client->device;
        struct fw_cdev_send_request *request = buffer;
        struct response *response;
+       int ret;
 
        /* What is the biggest size we'll accept, really? */
        if (request->length > 4096)
@@ -417,12 +470,35 @@ static int ioctl_send_request(struct client *client, void *buffer)
        if (request->data &&
            copy_from_user(response->response.data,
                           u64_to_uptr(request->data), request->length)) {
-               kfree(response);
-               return -EFAULT;
+               ret = -EFAULT;
+               goto failed;
+       }
+
+       switch (request->tcode) {
+       case TCODE_WRITE_QUADLET_REQUEST:
+       case TCODE_WRITE_BLOCK_REQUEST:
+       case TCODE_READ_QUADLET_REQUEST:
+       case TCODE_READ_BLOCK_REQUEST:
+       case TCODE_LOCK_MASK_SWAP:
+       case TCODE_LOCK_COMPARE_SWAP:
+       case TCODE_LOCK_FETCH_ADD:
+       case TCODE_LOCK_LITTLE_ADD:
+       case TCODE_LOCK_BOUNDED_ADD:
+       case TCODE_LOCK_WRAP_ADD:
+       case TCODE_LOCK_VENDOR_DEPENDENT:
+               break;
+       default:
+               ret = -EINVAL;
+               goto failed;
        }
 
        response->resource.release = release_transaction;
-       add_client_resource(client, &response->resource);
+       ret = add_client_resource(client, &response->resource, GFP_KERNEL);
+       if (ret < 0)
+               goto failed;
+
+       /* Get a reference for the transaction callback */
+       client_get(client);
 
        fw_send_request(device->card, &response->transaction,
                        request->tcode & 0x1f,
@@ -437,6 +513,10 @@ static int ioctl_send_request(struct client *client, void *buffer)
                return sizeof(request) + request->length;
        else
                return sizeof(request);
+ failed:
+       kfree(response);
+
+       return ret;
 }
 
 struct address_handler {
@@ -458,8 +538,8 @@ struct request_event {
        struct fw_cdev_event_request request;
 };
 
-static void
-release_request(struct client *client, struct client_resource *resource)
+static void release_request(struct client *client,
+                           struct client_resource *resource)
 {
        struct request *request =
                container_of(resource, struct request, resource);
@@ -469,33 +549,31 @@ release_request(struct client *client, struct client_resource *resource)
        kfree(request);
 }
 
-static void
-handle_request(struct fw_card *card, struct fw_request *r,
-              int tcode, int destination, int source,
-              int generation, int speed,
-              unsigned long long offset,
-              void *payload, size_t length, void *callback_data)
+static void handle_request(struct fw_card *card, struct fw_request *r,
+                          int tcode, int destination, int source,
+                          int generation, int speed,
+                          unsigned long long offset,
+                          void *payload, size_t length, void *callback_data)
 {
        struct address_handler *handler = callback_data;
        struct request *request;
        struct request_event *e;
        struct client *client = handler->client;
+       int ret;
 
        request = kmalloc(sizeof(*request), GFP_ATOMIC);
        e = kmalloc(sizeof(*e), GFP_ATOMIC);
-       if (request == NULL || e == NULL) {
-               kfree(request);
-               kfree(e);
-               fw_send_response(card, r, RCODE_CONFLICT_ERROR);
-               return;
-       }
+       if (request == NULL || e == NULL)
+               goto failed;
 
        request->request = r;
        request->data    = payload;
        request->length  = length;
 
        request->resource.release = release_request;
-       add_client_resource(client, &request->resource);
+       ret = add_client_resource(client, &request->resource, GFP_ATOMIC);
+       if (ret < 0)
+               goto failed;
 
        e->request.type    = FW_CDEV_EVENT_REQUEST;
        e->request.tcode   = tcode;
@@ -506,11 +584,16 @@ handle_request(struct fw_card *card, struct fw_request *r,
 
        queue_event(client, &e->event,
                    &e->request, sizeof(e->request), payload, length);
+       return;
+
+ failed:
+       kfree(request);
+       kfree(e);
+       fw_send_response(card, r, RCODE_CONFLICT_ERROR);
 }
 
-static void
-release_address_handler(struct client *client,
-                       struct client_resource *resource)
+static void release_address_handler(struct client *client,
+                                   struct client_resource *resource)
 {
        struct address_handler *handler =
                container_of(resource, struct address_handler, resource);
@@ -524,6 +607,7 @@ static int ioctl_allocate(struct client *client, void *buffer)
        struct fw_cdev_allocate *request = buffer;
        struct address_handler *handler;
        struct fw_address_region region;
+       int ret;
 
        handler = kmalloc(sizeof(*handler), GFP_KERNEL);
        if (handler == NULL)
@@ -537,13 +621,18 @@ static int ioctl_allocate(struct client *client, void *buffer)
        handler->closure = request->closure;
        handler->client = client;
 
-       if (fw_core_add_address_handler(&handler->handler, &region) < 0) {
+       ret = fw_core_add_address_handler(&handler->handler, &region);
+       if (ret < 0) {
                kfree(handler);
-               return -EBUSY;
+               return ret;
        }
 
        handler->resource.release = release_address_handler;
-       add_client_resource(client, &handler->resource);
+       ret = add_client_resource(client, &handler->resource, GFP_KERNEL);
+       if (ret < 0) {
+               release_address_handler(client, &handler->resource);
+               return ret;
+       }
        request->handle = handler->resource.handle;
 
        return 0;
@@ -553,7 +642,8 @@ static int ioctl_deallocate(struct client *client, void *buffer)
 {
        struct fw_cdev_deallocate *request = buffer;
 
-       return release_client_resource(client, request->handle, NULL);
+       return release_client_resource(client, request->handle,
+                                      release_address_handler, NULL);
 }
 
 static int ioctl_send_response(struct client *client, void *buffer)
@@ -562,8 +652,10 @@ static int ioctl_send_response(struct client *client, void *buffer)
        struct client_resource *resource;
        struct request *r;
 
-       if (release_client_resource(client, request->handle, &resource) < 0)
+       if (release_client_resource(client, request->handle,
+                                   release_request, &resource) < 0)
                return -EINVAL;
+
        r = container_of(resource, struct request, resource);
        if (request->length < r->length)
                r->length = request->length;
@@ -606,7 +698,7 @@ static int ioctl_add_descriptor(struct client *client, void *buffer)
 {
        struct fw_cdev_add_descriptor *request = buffer;
        struct descriptor *descriptor;
-       int retval;
+       int ret;
 
        if (request->length > 256)
                return -EINVAL;
@@ -618,8 +710,8 @@ static int ioctl_add_descriptor(struct client *client, void *buffer)
 
        if (copy_from_user(descriptor->data,
                           u64_to_uptr(request->data), request->length * 4)) {
-               kfree(descriptor);
-               return -EFAULT;
+               ret = -EFAULT;
+               goto failed;
        }
 
        descriptor->d.length = request->length;
@@ -627,29 +719,35 @@ static int ioctl_add_descriptor(struct client *client, void *buffer)
        descriptor->d.key = request->key;
        descriptor->d.data = descriptor->data;
 
-       retval = fw_core_add_descriptor(&descriptor->d);
-       if (retval < 0) {
-               kfree(descriptor);
-               return retval;
-       }
+       ret = fw_core_add_descriptor(&descriptor->d);
+       if (ret < 0)
+               goto failed;
 
        descriptor->resource.release = release_descriptor;
-       add_client_resource(client, &descriptor->resource);
+       ret = add_client_resource(client, &descriptor->resource, GFP_KERNEL);
+       if (ret < 0) {
+               fw_core_remove_descriptor(&descriptor->d);
+               goto failed;
+       }
        request->handle = descriptor->resource.handle;
 
        return 0;
+ failed:
+       kfree(descriptor);
+
+       return ret;
 }
 
 static int ioctl_remove_descriptor(struct client *client, void *buffer)
 {
        struct fw_cdev_remove_descriptor *request = buffer;
 
-       return release_client_resource(client, request->handle, NULL);
+       return release_client_resource(client, request->handle,
+                                      release_descriptor, NULL);
 }
 
-static void
-iso_callback(struct fw_iso_context *context, u32 cycle,
-            size_t header_length, void *header, void *data)
+static void iso_callback(struct fw_iso_context *context, u32 cycle,
+                        size_t header_length, void *header, void *data)
 {
        struct client *client = data;
        struct iso_interrupt *irq;
@@ -885,11 +983,11 @@ static int (* const ioctl_handlers[])(struct client *client, void *buffer) = {
        ioctl_get_cycle_timer,
 };
 
-static int
-dispatch_ioctl(struct client *client, unsigned int cmd, void __user *arg)
+static int dispatch_ioctl(struct client *client,
+                         unsigned int cmd, void __user *arg)
 {
        char buffer[256];
-       int retval;
+       int ret;
 
        if (_IOC_TYPE(cmd) != '#' ||
            _IOC_NR(cmd) >= ARRAY_SIZE(ioctl_handlers))
@@ -901,9 +999,9 @@ dispatch_ioctl(struct client *client, unsigned int cmd, void __user *arg)
                        return -EFAULT;
        }
 
-       retval = ioctl_handlers[_IOC_NR(cmd)](client, buffer);
-       if (retval < 0)
-               return retval;
+       ret = ioctl_handlers[_IOC_NR(cmd)](client, buffer);
+       if (ret < 0)
+               return ret;
 
        if (_IOC_DIR(cmd) & _IOC_READ) {
                if (_IOC_SIZE(cmd) > sizeof(buffer) ||
@@ -911,12 +1009,11 @@ dispatch_ioctl(struct client *client, unsigned int cmd, void __user *arg)
                        return -EFAULT;
        }
 
-       return retval;
+       return ret;
 }
 
-static long
-fw_device_op_ioctl(struct file *file,
-                  unsigned int cmd, unsigned long arg)
+static long fw_device_op_ioctl(struct file *file,
+                              unsigned int cmd, unsigned long arg)
 {
        struct client *client = file->private_data;
 
@@ -927,9 +1024,8 @@ fw_device_op_ioctl(struct file *file,
 }
 
 #ifdef CONFIG_COMPAT
-static long
-fw_device_op_compat_ioctl(struct file *file,
-                         unsigned int cmd, unsigned long arg)
+static long fw_device_op_compat_ioctl(struct file *file,
+                                     unsigned int cmd, unsigned long arg)
 {
        struct client *client = file->private_data;
 
@@ -945,7 +1041,7 @@ static int fw_device_op_mmap(struct file *file, struct vm_area_struct *vma)
        struct client *client = file->private_data;
        enum dma_data_direction direction;
        unsigned long size;
-       int page_count, retval;
+       int page_count, ret;
 
        if (fw_device_is_shutdown(client->device))
                return -ENODEV;
@@ -971,48 +1067,58 @@ static int fw_device_op_mmap(struct file *file, struct vm_area_struct *vma)
        else
                direction = DMA_FROM_DEVICE;
 
-       retval = fw_iso_buffer_init(&client->buffer, client->device->card,
-                                   page_count, direction);
-       if (retval < 0)
-               return retval;
+       ret = fw_iso_buffer_init(&client->buffer, client->device->card,
+                                page_count, direction);
+       if (ret < 0)
+               return ret;
 
-       retval = fw_iso_buffer_map(&client->buffer, vma);
-       if (retval < 0)
+       ret = fw_iso_buffer_map(&client->buffer, vma);
+       if (ret < 0)
                fw_iso_buffer_destroy(&client->buffer, client->device->card);
 
-       return retval;
+       return ret;
+}
+
+static int shutdown_resource(int id, void *p, void *data)
+{
+       struct client_resource *r = p;
+       struct client *client = data;
+
+       r->release(client, r);
+       client_put(client);
+
+       return 0;
 }
 
 static int fw_device_op_release(struct inode *inode, struct file *file)
 {
        struct client *client = file->private_data;
        struct event *e, *next_e;
-       struct client_resource *r, *next_r;
        unsigned long flags;
 
+       mutex_lock(&client->device->client_list_mutex);
+       list_del(&client->link);
+       mutex_unlock(&client->device->client_list_mutex);
+
        if (client->buffer.pages)
                fw_iso_buffer_destroy(&client->buffer, client->device->card);
 
        if (client->iso_context)
                fw_iso_context_destroy(client->iso_context);
 
-       list_for_each_entry_safe(r, next_r, &client->resource_list, link)
-               r->release(client, r);
+       /* Freeze client->resource_idr and client->event_list */
+       spin_lock_irqsave(&client->lock, flags);
+       client->in_shutdown = true;
+       spin_unlock_irqrestore(&client->lock, flags);
 
-       /*
-        * FIXME: We should wait for the async tasklets to stop
-        * running before freeing the memory.
-        */
+       idr_for_each(&client->resource_idr, shutdown_resource, client);
+       idr_remove_all(&client->resource_idr);
+       idr_destroy(&client->resource_idr);
 
        list_for_each_entry_safe(e, next_e, &client->event_list, link)
                kfree(e);
 
-       spin_lock_irqsave(&client->device->client_list_lock, flags);
-       list_del(&client->link);
-       spin_unlock_irqrestore(&client->device->client_list_lock, flags);
-
-       fw_device_put(client->device);
-       kfree(client);
+       client_put(client);
 
        return 0;
 }