mm: gup: add get_user_pages_locked and get_user_pages_unlocked
[cascardo/linux.git] / mm / gup.c
index 1a8ab05..71a3773 100644 (file)
--- a/mm/gup.c
+++ b/mm/gup.c
@@ -575,6 +575,165 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
        return 0;
 }
 
+static __always_inline long __get_user_pages_locked(struct task_struct *tsk,
+                                               struct mm_struct *mm,
+                                               unsigned long start,
+                                               unsigned long nr_pages,
+                                               int write, int force,
+                                               struct page **pages,
+                                               struct vm_area_struct **vmas,
+                                               int *locked, bool notify_drop)
+{
+       int flags = FOLL_TOUCH;
+       long ret, pages_done;
+       bool lock_dropped;
+
+       if (locked) {
+               /* if VM_FAULT_RETRY can be returned, vmas become invalid */
+               BUG_ON(vmas);
+               /* check caller initialized locked */
+               BUG_ON(*locked != 1);
+       }
+
+       if (pages)
+               flags |= FOLL_GET;
+       if (write)
+               flags |= FOLL_WRITE;
+       if (force)
+               flags |= FOLL_FORCE;
+
+       pages_done = 0;
+       lock_dropped = false;
+       for (;;) {
+               ret = __get_user_pages(tsk, mm, start, nr_pages, flags, pages,
+                                      vmas, locked);
+               if (!locked)
+                       /* VM_FAULT_RETRY couldn't trigger, bypass */
+                       return ret;
+
+               /* VM_FAULT_RETRY cannot return errors */
+               if (!*locked) {
+                       BUG_ON(ret < 0);
+                       BUG_ON(ret >= nr_pages);
+               }
+
+               if (!pages)
+                       /* If it's a prefault don't insist harder */
+                       return ret;
+
+               if (ret > 0) {
+                       nr_pages -= ret;
+                       pages_done += ret;
+                       if (!nr_pages)
+                               break;
+               }
+               if (*locked) {
+                       /* VM_FAULT_RETRY didn't trigger */
+                       if (!pages_done)
+                               pages_done = ret;
+                       break;
+               }
+               /* VM_FAULT_RETRY triggered, so seek to the faulting offset */
+               pages += ret;
+               start += ret << PAGE_SHIFT;
+
+               /*
+                * Repeat on the address that fired VM_FAULT_RETRY
+                * without FAULT_FLAG_ALLOW_RETRY but with
+                * FAULT_FLAG_TRIED.
+                */
+               *locked = 1;
+               lock_dropped = true;
+               down_read(&mm->mmap_sem);
+               ret = __get_user_pages(tsk, mm, start, 1, flags | FOLL_TRIED,
+                                      pages, NULL, NULL);
+               if (ret != 1) {
+                       BUG_ON(ret > 1);
+                       if (!pages_done)
+                               pages_done = ret;
+                       break;
+               }
+               nr_pages--;
+               pages_done++;
+               if (!nr_pages)
+                       break;
+               pages++;
+               start += PAGE_SIZE;
+       }
+       if (notify_drop && lock_dropped && *locked) {
+               /*
+                * We must let the caller know we temporarily dropped the lock
+                * and so the critical section protected by it was lost.
+                */
+               up_read(&mm->mmap_sem);
+               *locked = 0;
+       }
+       return pages_done;
+}
+
+/*
+ * We can leverage the VM_FAULT_RETRY functionality in the page fault
+ * paths better by using either get_user_pages_locked() or
+ * get_user_pages_unlocked().
+ *
+ * get_user_pages_locked() is suitable to replace the form:
+ *
+ *      down_read(&mm->mmap_sem);
+ *      do_something()
+ *      get_user_pages(tsk, mm, ..., pages, NULL);
+ *      up_read(&mm->mmap_sem);
+ *
+ *  to:
+ *
+ *      int locked = 1;
+ *      down_read(&mm->mmap_sem);
+ *      do_something()
+ *      get_user_pages_locked(tsk, mm, ..., pages, &locked);
+ *      if (locked)
+ *          up_read(&mm->mmap_sem);
+ */
+long get_user_pages_locked(struct task_struct *tsk, struct mm_struct *mm,
+                          unsigned long start, unsigned long nr_pages,
+                          int write, int force, struct page **pages,
+                          int *locked)
+{
+       return __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+                                      pages, NULL, locked, true);
+}
+EXPORT_SYMBOL(get_user_pages_locked);
+
+/*
+ * get_user_pages_unlocked() is suitable to replace the form:
+ *
+ *      down_read(&mm->mmap_sem);
+ *      get_user_pages(tsk, mm, ..., pages, NULL);
+ *      up_read(&mm->mmap_sem);
+ *
+ *  with:
+ *
+ *      get_user_pages_unlocked(tsk, mm, ..., pages);
+ *
+ * It is functionally equivalent to get_user_pages_fast so
+ * get_user_pages_fast should be used instead, if the two parameters
+ * "tsk" and "mm" are respectively equal to current and current->mm,
+ * or if "force" shall be set to 1 (get_user_pages_fast misses the
+ * "force" parameter).
+ */
+long get_user_pages_unlocked(struct task_struct *tsk, struct mm_struct *mm,
+                            unsigned long start, unsigned long nr_pages,
+                            int write, int force, struct page **pages)
+{
+       long ret;
+       int locked = 1;
+       down_read(&mm->mmap_sem);
+       ret = __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+                                     pages, NULL, &locked, false);
+       if (locked)
+               up_read(&mm->mmap_sem);
+       return ret;
+}
+EXPORT_SYMBOL(get_user_pages_unlocked);
+
 /*
  * get_user_pages() - pin user pages in memory
  * @tsk:       the task_struct to use for page fault accounting, or
@@ -624,22 +783,18 @@ int fixup_user_fault(struct task_struct *tsk, struct mm_struct *mm,
  * use the correct cache flushing APIs.
  *
  * See also get_user_pages_fast, for performance critical applications.
+ *
+ * get_user_pages should be phased out in favor of
+ * get_user_pages_locked|unlocked or get_user_pages_fast. Nothing
+ * should use get_user_pages because it cannot pass
+ * FAULT_FLAG_ALLOW_RETRY to handle_mm_fault.
  */
 long get_user_pages(struct task_struct *tsk, struct mm_struct *mm,
                unsigned long start, unsigned long nr_pages, int write,
                int force, struct page **pages, struct vm_area_struct **vmas)
 {
-       int flags = FOLL_TOUCH;
-
-       if (pages)
-               flags |= FOLL_GET;
-       if (write)
-               flags |= FOLL_WRITE;
-       if (force)
-               flags |= FOLL_FORCE;
-
-       return __get_user_pages(tsk, mm, start, nr_pages, flags, pages, vmas,
-                               NULL);
+       return __get_user_pages_locked(tsk, mm, start, nr_pages, write, force,
+                                      pages, vmas, NULL, false);
 }
 EXPORT_SYMBOL(get_user_pages);