nbd: fix race in ioctl
[cascardo/linux.git] / drivers / block / nbd.c
index 6f55b26..a9e3980 100644 (file)
@@ -451,14 +451,9 @@ static int nbd_thread_recv(struct nbd_device *nbd, struct block_device *bdev)
 
        sk_set_memalloc(nbd->sock->sk);
 
-       nbd->task_recv = current;
-
        ret = device_create_file(disk_to_dev(nbd->disk), &pid_attr);
        if (ret) {
                dev_err(disk_to_dev(nbd->disk), "device_create_file failed!\n");
-
-               nbd->task_recv = NULL;
-
                return ret;
        }
 
@@ -477,9 +472,6 @@ static int nbd_thread_recv(struct nbd_device *nbd, struct block_device *bdev)
        nbd_size_clear(nbd, bdev);
 
        device_remove_file(disk_to_dev(nbd->disk), &pid_attr);
-
-       nbd->task_recv = NULL;
-
        return ret;
 }
 
@@ -788,6 +780,8 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd,
                if (!nbd->sock)
                        return -EINVAL;
 
+               /* We have to claim the device under the lock */
+               nbd->task_recv = current;
                mutex_unlock(&nbd->tx_lock);
 
                nbd_parse_flags(nbd, bdev);
@@ -796,6 +790,7 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd,
                                     nbd_name(nbd));
                if (IS_ERR(thread)) {
                        mutex_lock(&nbd->tx_lock);
+                       nbd->task_recv = NULL;
                        return PTR_ERR(thread);
                }
 
@@ -805,6 +800,7 @@ static int __nbd_ioctl(struct block_device *bdev, struct nbd_device *nbd,
                kthread_stop(thread);
 
                mutex_lock(&nbd->tx_lock);
+               nbd->task_recv = NULL;
 
                sock_shutdown(nbd);
                nbd_clear_que(nbd);