blkcg: remove unnecessary NULL checks from __cfqg_set_weight_device()
[cascardo/linux.git] / drivers / vhost / vhost.c
index 9e8e004..a9fe859 100644 (file)
 #include <linux/file.h>
 #include <linux/highmem.h>
 #include <linux/slab.h>
+#include <linux/vmalloc.h>
 #include <linux/kthread.h>
 #include <linux/cgroup.h>
 #include <linux/module.h>
+#include <linux/sort.h>
 
 #include "vhost.h"
 
+static ushort max_mem_regions = 64;
+module_param(max_mem_regions, ushort, 0444);
+MODULE_PARM_DESC(max_mem_regions,
+       "Maximum number of memory regions in memory map. (default: 64)");
+
 enum {
-       VHOST_MEMORY_MAX_NREGIONS = 64,
        VHOST_MEMORY_F_LOG = 0x1,
 };
 
@@ -543,7 +549,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev, bool locked)
                fput(dev->log_file);
        dev->log_file = NULL;
        /* No one will access memory at this point */
-       kfree(dev->memory);
+       kvfree(dev->memory);
        dev->memory = NULL;
        WARN_ON(!list_empty(&dev->work_list));
        if (dev->worker) {
@@ -663,6 +669,28 @@ int vhost_vq_access_ok(struct vhost_virtqueue *vq)
 }
 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
 
+static int vhost_memory_reg_sort_cmp(const void *p1, const void *p2)
+{
+       const struct vhost_memory_region *r1 = p1, *r2 = p2;
+       if (r1->guest_phys_addr < r2->guest_phys_addr)
+               return 1;
+       if (r1->guest_phys_addr > r2->guest_phys_addr)
+               return -1;
+       return 0;
+}
+
+static void *vhost_kvzalloc(unsigned long size)
+{
+       void *n = kzalloc(size, GFP_KERNEL | __GFP_NOWARN | __GFP_REPEAT);
+
+       if (!n) {
+               n = vzalloc(size);
+               if (!n)
+                       return ERR_PTR(-ENOMEM);
+       }
+       return n;
+}
+
 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 {
        struct vhost_memory mem, *newmem, *oldmem;
@@ -673,21 +701,23 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
                return -EFAULT;
        if (mem.padding)
                return -EOPNOTSUPP;
-       if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS)
+       if (mem.nregions > max_mem_regions)
                return -E2BIG;
-       newmem = kmalloc(size + mem.nregions * sizeof *m->regions, GFP_KERNEL);
+       newmem = vhost_kvzalloc(size + mem.nregions * sizeof(*m->regions));
        if (!newmem)
                return -ENOMEM;
 
        memcpy(newmem, &mem, size);
        if (copy_from_user(newmem->regions, m->regions,
                           mem.nregions * sizeof *m->regions)) {
-               kfree(newmem);
+               kvfree(newmem);
                return -EFAULT;
        }
+       sort(newmem->regions, newmem->nregions, sizeof(*newmem->regions),
+               vhost_memory_reg_sort_cmp, NULL);
 
        if (!memory_access_ok(d, newmem, 0)) {
-               kfree(newmem);
+               kvfree(newmem);
                return -EFAULT;
        }
        oldmem = d->memory;
@@ -699,7 +729,7 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
                d->vqs[i]->memory = newmem;
                mutex_unlock(&d->vqs[i]->mutex);
        }
-       kfree(oldmem);
+       kvfree(oldmem);
        return 0;
 }
 
@@ -992,17 +1022,22 @@ EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
 static const struct vhost_memory_region *find_region(struct vhost_memory *mem,
                                                     __u64 addr, __u32 len)
 {
-       struct vhost_memory_region *reg;
-       int i;
+       const struct vhost_memory_region *reg;
+       int start = 0, end = mem->nregions;
 
-       /* linear search is not brilliant, but we really have on the order of 6
-        * regions in practice */
-       for (i = 0; i < mem->nregions; ++i) {
-               reg = mem->regions + i;
-               if (reg->guest_phys_addr <= addr &&
-                   reg->guest_phys_addr + reg->memory_size - 1 >= addr)
-                       return reg;
+       while (start < end) {
+               int slot = start + (end - start) / 2;
+               reg = mem->regions + slot;
+               if (addr >= reg->guest_phys_addr)
+                       end = slot;
+               else
+                       start = slot + 1;
        }
+
+       reg = mem->regions + start;
+       if (addr >= reg->guest_phys_addr &&
+               reg->guest_phys_addr + reg->memory_size > addr)
+               return reg;
        return NULL;
 }