Merge tag 'mac80211-for-davem-2015-09-22' of git://git.kernel.org/pub/scm/linux/kerne...
[cascardo/linux.git] / fs / userfaultfd.c
index 8977a4e..634e676 100644 (file)
@@ -45,6 +45,8 @@ struct userfaultfd_ctx {
        wait_queue_head_t fault_wqh;
        /* waitqueue head for the pseudo fd to wakeup poll/read */
        wait_queue_head_t fd_wqh;
+       /* a refile sequence protected by fault_pending_wqh lock */
+       struct seqcount refile_seq;
        /* pseudo fd refcounting */
        atomic_t refcount;
        /* userfaultfd syscall flags */
@@ -179,6 +181,67 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
        return msg;
 }
 
+/*
+ * Verify the pagetables are still not ok after having reigstered into
+ * the fault_pending_wqh to avoid userland having to UFFDIO_WAKE any
+ * userfault that has already been resolved, if userfaultfd_read and
+ * UFFDIO_COPY|ZEROPAGE are being run simultaneously on two different
+ * threads.
+ */
+static inline bool userfaultfd_must_wait(struct userfaultfd_ctx *ctx,
+                                        unsigned long address,
+                                        unsigned long flags,
+                                        unsigned long reason)
+{
+       struct mm_struct *mm = ctx->mm;
+       pgd_t *pgd;
+       pud_t *pud;
+       pmd_t *pmd, _pmd;
+       pte_t *pte;
+       bool ret = true;
+
+       VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
+
+       pgd = pgd_offset(mm, address);
+       if (!pgd_present(*pgd))
+               goto out;
+       pud = pud_offset(pgd, address);
+       if (!pud_present(*pud))
+               goto out;
+       pmd = pmd_offset(pud, address);
+       /*
+        * READ_ONCE must function as a barrier with narrower scope
+        * and it must be equivalent to:
+        *      _pmd = *pmd; barrier();
+        *
+        * This is to deal with the instability (as in
+        * pmd_trans_unstable) of the pmd.
+        */
+       _pmd = READ_ONCE(*pmd);
+       if (!pmd_present(_pmd))
+               goto out;
+
+       ret = false;
+       if (pmd_trans_huge(_pmd))
+               goto out;
+
+       /*
+        * the pmd is stable (as in !pmd_trans_unstable) so we can re-read it
+        * and use the standard pte_offset_map() instead of parsing _pmd.
+        */
+       pte = pte_offset_map(pmd, address);
+       /*
+        * Lockless access: we're in a wait_event so it's ok if it
+        * changes under us.
+        */
+       if (pte_none(*pte))
+               ret = true;
+       pte_unmap(pte);
+
+out:
+       return ret;
+}
+
 /*
  * The locking rules involved in returning VM_FAULT_RETRY depending on
  * FAULT_FLAG_ALLOW_RETRY, FAULT_FLAG_RETRY_NOWAIT and
@@ -201,6 +264,7 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
        struct userfaultfd_ctx *ctx;
        struct userfaultfd_wait_queue uwq;
        int ret;
+       bool must_wait, return_to_userland;
 
        BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
 
@@ -260,14 +324,14 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
        /* take the reference before dropping the mmap_sem */
        userfaultfd_ctx_get(ctx);
 
-       /* be gentle and immediately relinquish the mmap_sem */
-       up_read(&mm->mmap_sem);
-
        init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function);
        uwq.wq.private = current;
        uwq.msg = userfault_msg(address, flags, reason);
        uwq.ctx = ctx;
 
+       return_to_userland = (flags & (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE)) ==
+               (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE);
+
        spin_lock(&ctx->fault_pending_wqh.lock);
        /*
         * After the __add_wait_queue the uwq is visible to userland
@@ -279,11 +343,16 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
         * following the spin_unlock to happen before the list_add in
         * __add_wait_queue.
         */
-       set_current_state(TASK_KILLABLE);
+       set_current_state(return_to_userland ? TASK_INTERRUPTIBLE :
+                         TASK_KILLABLE);
        spin_unlock(&ctx->fault_pending_wqh.lock);
 
-       if (likely(!ACCESS_ONCE(ctx->released) &&
-                  !fatal_signal_pending(current))) {
+       must_wait = userfaultfd_must_wait(ctx, address, flags, reason);
+       up_read(&mm->mmap_sem);
+
+       if (likely(must_wait && !ACCESS_ONCE(ctx->released) &&
+                  (return_to_userland ? !signal_pending(current) :
+                   !fatal_signal_pending(current)))) {
                wake_up_poll(&ctx->fd_wqh, POLLIN);
                schedule();
                ret |= VM_FAULT_MAJOR;
@@ -291,6 +360,30 @@ int handle_userfault(struct vm_area_struct *vma, unsigned long address,
 
        __set_current_state(TASK_RUNNING);
 
+       if (return_to_userland) {
+               if (signal_pending(current) &&
+                   !fatal_signal_pending(current)) {
+                       /*
+                        * If we got a SIGSTOP or SIGCONT and this is
+                        * a normal userland page fault, just let
+                        * userland return so the signal will be
+                        * handled and gdb debugging works.  The page
+                        * fault code immediately after we return from
+                        * this function is going to release the
+                        * mmap_sem and it's not depending on it
+                        * (unlike gup would if we were not to return
+                        * VM_FAULT_RETRY).
+                        *
+                        * If a fatal signal is pending we still take
+                        * the streamlined VM_FAULT_RETRY failure path
+                        * and there's no need to retake the mmap_sem
+                        * in such case.
+                        */
+                       down_read(&mm->mmap_sem);
+                       ret = 0;
+               }
+       }
+
        /*
         * Here we race with the list_del; list_add in
         * userfaultfd_ctx_read(), however because we don't ever run
@@ -455,6 +548,15 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                spin_lock(&ctx->fault_pending_wqh.lock);
                uwq = find_userfault(ctx);
                if (uwq) {
+                       /*
+                        * Use a seqcount to repeat the lockless check
+                        * in wake_userfault() to avoid missing
+                        * wakeups because during the refile both
+                        * waitqueue could become empty if this is the
+                        * only userfault.
+                        */
+                       write_seqcount_begin(&ctx->refile_seq);
+
                        /*
                         * The fault_pending_wqh.lock prevents the uwq
                         * to disappear from under us.
@@ -479,6 +581,8 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                        list_del(&uwq->wq.task_list);
                        __add_wait_queue(&ctx->fault_wqh, &uwq->wq);
 
+                       write_seqcount_end(&ctx->refile_seq);
+
                        /* careful to always initialize msg if ret == 0 */
                        *msg = uwq->msg;
                        spin_unlock(&ctx->fault_pending_wqh.lock);
@@ -515,7 +619,6 @@ static ssize_t userfaultfd_read(struct file *file, char __user *buf,
 
        if (ctx->state == UFFD_STATE_WAIT_API)
                return -EINVAL;
-       BUG_ON(ctx->state != UFFD_STATE_RUNNING);
 
        for (;;) {
                if (count < sizeof(msg))
@@ -557,6 +660,9 @@ static void __wake_userfault(struct userfaultfd_ctx *ctx,
 static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
                                           struct userfaultfd_wake_range *range)
 {
+       unsigned seq;
+       bool need_wakeup;
+
        /*
         * To be sure waitqueue_active() is not reordered by the CPU
         * before the pagetable update, use an explicit SMP memory
@@ -572,8 +678,13 @@ static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
         * userfaults yet. So we take the spinlock only when we're
         * sure we've userfaults to wake.
         */
-       if (waitqueue_active(&ctx->fault_pending_wqh) ||
-           waitqueue_active(&ctx->fault_wqh))
+       do {
+               seq = read_seqcount_begin(&ctx->refile_seq);
+               need_wakeup = waitqueue_active(&ctx->fault_pending_wqh) ||
+                       waitqueue_active(&ctx->fault_wqh);
+               cond_resched();
+       } while (read_seqcount_retry(&ctx->refile_seq, seq));
+       if (need_wakeup)
                __wake_userfault(ctx, range);
 }
 
@@ -886,17 +997,6 @@ out:
 }
 
 /*
- * userfaultfd_wake is needed in case an userfault is in flight by the
- * time a UFFDIO_COPY (or other ioctl variants) completes. The page
- * may be well get mapped and the page fault if repeated wouldn't lead
- * to a userfault anymore, but before scheduling in TASK_KILLABLE mode
- * handle_userfault() doesn't recheck the pagetables and it doesn't
- * serialize against UFFDO_COPY (or other ioctl variants). Ultimately
- * the knowledge of which pages are mapped is left to userland who is
- * responsible for handling the race between read() userfaults and
- * background UFFDIO_COPY (or other ioctl variants), if done by
- * separate concurrent threads.
- *
  * userfaultfd_wake may be used in combination with the
  * UFFDIO_*_MODE_DONTWAKE to wakeup userfaults in batches.
  */
@@ -932,6 +1032,96 @@ out:
        return ret;
 }
 
+static int userfaultfd_copy(struct userfaultfd_ctx *ctx,
+                           unsigned long arg)
+{
+       __s64 ret;
+       struct uffdio_copy uffdio_copy;
+       struct uffdio_copy __user *user_uffdio_copy;
+       struct userfaultfd_wake_range range;
+
+       user_uffdio_copy = (struct uffdio_copy __user *) arg;
+
+       ret = -EFAULT;
+       if (copy_from_user(&uffdio_copy, user_uffdio_copy,
+                          /* don't copy "copy" last field */
+                          sizeof(uffdio_copy)-sizeof(__s64)))
+               goto out;
+
+       ret = validate_range(ctx->mm, uffdio_copy.dst, uffdio_copy.len);
+       if (ret)
+               goto out;
+       /*
+        * double check for wraparound just in case. copy_from_user()
+        * will later check uffdio_copy.src + uffdio_copy.len to fit
+        * in the userland range.
+        */
+       ret = -EINVAL;
+       if (uffdio_copy.src + uffdio_copy.len <= uffdio_copy.src)
+               goto out;
+       if (uffdio_copy.mode & ~UFFDIO_COPY_MODE_DONTWAKE)
+               goto out;
+
+       ret = mcopy_atomic(ctx->mm, uffdio_copy.dst, uffdio_copy.src,
+                          uffdio_copy.len);
+       if (unlikely(put_user(ret, &user_uffdio_copy->copy)))
+               return -EFAULT;
+       if (ret < 0)
+               goto out;
+       BUG_ON(!ret);
+       /* len == 0 would wake all */
+       range.len = ret;
+       if (!(uffdio_copy.mode & UFFDIO_COPY_MODE_DONTWAKE)) {
+               range.start = uffdio_copy.dst;
+               wake_userfault(ctx, &range);
+       }
+       ret = range.len == uffdio_copy.len ? 0 : -EAGAIN;
+out:
+       return ret;
+}
+
+static int userfaultfd_zeropage(struct userfaultfd_ctx *ctx,
+                               unsigned long arg)
+{
+       __s64 ret;
+       struct uffdio_zeropage uffdio_zeropage;
+       struct uffdio_zeropage __user *user_uffdio_zeropage;
+       struct userfaultfd_wake_range range;
+
+       user_uffdio_zeropage = (struct uffdio_zeropage __user *) arg;
+
+       ret = -EFAULT;
+       if (copy_from_user(&uffdio_zeropage, user_uffdio_zeropage,
+                          /* don't copy "zeropage" last field */
+                          sizeof(uffdio_zeropage)-sizeof(__s64)))
+               goto out;
+
+       ret = validate_range(ctx->mm, uffdio_zeropage.range.start,
+                            uffdio_zeropage.range.len);
+       if (ret)
+               goto out;
+       ret = -EINVAL;
+       if (uffdio_zeropage.mode & ~UFFDIO_ZEROPAGE_MODE_DONTWAKE)
+               goto out;
+
+       ret = mfill_zeropage(ctx->mm, uffdio_zeropage.range.start,
+                            uffdio_zeropage.range.len);
+       if (unlikely(put_user(ret, &user_uffdio_zeropage->zeropage)))
+               return -EFAULT;
+       if (ret < 0)
+               goto out;
+       /* len == 0 would wake all */
+       BUG_ON(!ret);
+       range.len = ret;
+       if (!(uffdio_zeropage.mode & UFFDIO_ZEROPAGE_MODE_DONTWAKE)) {
+               range.start = uffdio_zeropage.range.start;
+               wake_userfault(ctx, &range);
+       }
+       ret = range.len == uffdio_zeropage.range.len ? 0 : -EAGAIN;
+out:
+       return ret;
+}
+
 /*
  * userland asks for a certain API version and we return which bits
  * and ioctl commands are implemented in this kernel for such API
@@ -974,6 +1164,9 @@ static long userfaultfd_ioctl(struct file *file, unsigned cmd,
        int ret = -EINVAL;
        struct userfaultfd_ctx *ctx = file->private_data;
 
+       if (cmd != UFFDIO_API && ctx->state == UFFD_STATE_WAIT_API)
+               return -EINVAL;
+
        switch(cmd) {
        case UFFDIO_API:
                ret = userfaultfd_api(ctx, arg);
@@ -987,6 +1180,12 @@ static long userfaultfd_ioctl(struct file *file, unsigned cmd,
        case UFFDIO_WAKE:
                ret = userfaultfd_wake(ctx, arg);
                break;
+       case UFFDIO_COPY:
+               ret = userfaultfd_copy(ctx, arg);
+               break;
+       case UFFDIO_ZEROPAGE:
+               ret = userfaultfd_zeropage(ctx, arg);
+               break;
        }
        return ret;
 }
@@ -1041,6 +1240,7 @@ static void init_once_userfaultfd_ctx(void *mem)
        init_waitqueue_head(&ctx->fault_pending_wqh);
        init_waitqueue_head(&ctx->fault_wqh);
        init_waitqueue_head(&ctx->fd_wqh);
+       seqcount_init(&ctx->refile_seq);
 }
 
 /**