1f2ddaaf3c03bd02ad43a4d20dfac1e0fe48ed04
[cascardo/linux.git] / fs / userfaultfd.c
1 /*
2  *  fs/userfaultfd.c
3  *
4  *  Copyright (C) 2007  Davide Libenzi <davidel@xmailserver.org>
5  *  Copyright (C) 2008-2009 Red Hat, Inc.
6  *  Copyright (C) 2015  Red Hat, Inc.
7  *
8  *  This work is licensed under the terms of the GNU GPL, version 2. See
9  *  the COPYING file in the top-level directory.
10  *
11  *  Some part derived from fs/eventfd.c (anon inode setup) and
12  *  mm/ksm.c (mm hashing).
13  */
14
15 #include <linux/hashtable.h>
16 #include <linux/sched.h>
17 #include <linux/mm.h>
18 #include <linux/poll.h>
19 #include <linux/slab.h>
20 #include <linux/seq_file.h>
21 #include <linux/file.h>
22 #include <linux/bug.h>
23 #include <linux/anon_inodes.h>
24 #include <linux/syscalls.h>
25 #include <linux/userfaultfd_k.h>
26 #include <linux/mempolicy.h>
27 #include <linux/ioctl.h>
28 #include <linux/security.h>
29
30 enum userfaultfd_state {
31         UFFD_STATE_WAIT_API,
32         UFFD_STATE_RUNNING,
33 };
34
35 struct userfaultfd_ctx {
36         /* pseudo fd refcounting */
37         atomic_t refcount;
38         /* waitqueue head for the userfaultfd page faults */
39         wait_queue_head_t fault_wqh;
40         /* waitqueue head for the pseudo fd to wakeup poll/read */
41         wait_queue_head_t fd_wqh;
42         /* userfaultfd syscall flags */
43         unsigned int flags;
44         /* state machine */
45         enum userfaultfd_state state;
46         /* released */
47         bool released;
48         /* mm with one ore more vmas attached to this userfaultfd_ctx */
49         struct mm_struct *mm;
50 };
51
52 struct userfaultfd_wait_queue {
53         struct uffd_msg msg;
54         wait_queue_t wq;
55         bool pending;
56         struct userfaultfd_ctx *ctx;
57 };
58
59 struct userfaultfd_wake_range {
60         unsigned long start;
61         unsigned long len;
62 };
63
64 static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode,
65                                      int wake_flags, void *key)
66 {
67         struct userfaultfd_wake_range *range = key;
68         int ret;
69         struct userfaultfd_wait_queue *uwq;
70         unsigned long start, len;
71
72         uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
73         ret = 0;
74         /* don't wake the pending ones to avoid reads to block */
75         if (uwq->pending && !ACCESS_ONCE(uwq->ctx->released))
76                 goto out;
77         /* len == 0 means wake all */
78         start = range->start;
79         len = range->len;
80         if (len && (start > uwq->msg.arg.pagefault.address ||
81                     start + len <= uwq->msg.arg.pagefault.address))
82                 goto out;
83         ret = wake_up_state(wq->private, mode);
84         if (ret)
85                 /*
86                  * Wake only once, autoremove behavior.
87                  *
88                  * After the effect of list_del_init is visible to the
89                  * other CPUs, the waitqueue may disappear from under
90                  * us, see the !list_empty_careful() in
91                  * handle_userfault(). try_to_wake_up() has an
92                  * implicit smp_mb__before_spinlock, and the
93                  * wq->private is read before calling the extern
94                  * function "wake_up_state" (which in turns calls
95                  * try_to_wake_up). While the spin_lock;spin_unlock;
96                  * wouldn't be enough, the smp_mb__before_spinlock is
97                  * enough to avoid an explicit smp_mb() here.
98                  */
99                 list_del_init(&wq->task_list);
100 out:
101         return ret;
102 }
103
104 /**
105  * userfaultfd_ctx_get - Acquires a reference to the internal userfaultfd
106  * context.
107  * @ctx: [in] Pointer to the userfaultfd context.
108  *
109  * Returns: In case of success, returns not zero.
110  */
111 static void userfaultfd_ctx_get(struct userfaultfd_ctx *ctx)
112 {
113         if (!atomic_inc_not_zero(&ctx->refcount))
114                 BUG();
115 }
116
117 /**
118  * userfaultfd_ctx_put - Releases a reference to the internal userfaultfd
119  * context.
120  * @ctx: [in] Pointer to userfaultfd context.
121  *
122  * The userfaultfd context reference must have been previously acquired either
123  * with userfaultfd_ctx_get() or userfaultfd_ctx_fdget().
124  */
125 static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx)
126 {
127         if (atomic_dec_and_test(&ctx->refcount)) {
128                 VM_BUG_ON(spin_is_locked(&ctx->fault_pending_wqh.lock));
129                 VM_BUG_ON(waitqueue_active(&ctx->fault_pending_wqh));
130                 VM_BUG_ON(spin_is_locked(&ctx->fault_wqh.lock));
131                 VM_BUG_ON(waitqueue_active(&ctx->fault_wqh));
132                 VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock));
133                 VM_BUG_ON(waitqueue_active(&ctx->fd_wqh));
134                 mmput(ctx->mm);
135                 kfree(ctx);
136         }
137 }
138
139 static inline void msg_init(struct uffd_msg *msg)
140 {
141         BUILD_BUG_ON(sizeof(struct uffd_msg) != 32);
142         /*
143          * Must use memset to zero out the paddings or kernel data is
144          * leaked to userland.
145          */
146         memset(msg, 0, sizeof(struct uffd_msg));
147 }
148
149 static inline struct uffd_msg userfault_msg(unsigned long address,
150                                             unsigned int flags,
151                                             unsigned long reason)
152 {
153         struct uffd_msg msg;
154         msg_init(&msg);
155         msg.event = UFFD_EVENT_PAGEFAULT;
156         msg.arg.pagefault.address = address;
157         if (flags & FAULT_FLAG_WRITE)
158                 /*
159                  * If UFFD_FEATURE_PAGEFAULT_FLAG_WRITE was set in the
160                  * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WRITE
161                  * was not set in a UFFD_EVENT_PAGEFAULT, it means it
162                  * was a read fault, otherwise if set it means it's
163                  * a write fault.
164                  */
165                 msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WRITE;
166         if (reason & VM_UFFD_WP)
167                 /*
168                  * If UFFD_FEATURE_PAGEFAULT_FLAG_WP was set in the
169                  * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WP was
170                  * not set in a UFFD_EVENT_PAGEFAULT, it means it was
171                  * a missing fault, otherwise if set it means it's a
172                  * write protect fault.
173                  */
174                 msg.arg.pagefault.flags |= UFFD_PAGEFAULT_FLAG_WP;
175         return msg;
176 }
177
178 /*
179  * The locking rules involved in returning VM_FAULT_RETRY depending on
180  * FAULT_FLAG_ALLOW_RETRY, FAULT_FLAG_RETRY_NOWAIT and
181  * FAULT_FLAG_KILLABLE are not straightforward. The "Caution"
182  * recommendation in __lock_page_or_retry is not an understatement.
183  *
184  * If FAULT_FLAG_ALLOW_RETRY is set, the mmap_sem must be released
185  * before returning VM_FAULT_RETRY only if FAULT_FLAG_RETRY_NOWAIT is
186  * not set.
187  *
188  * If FAULT_FLAG_ALLOW_RETRY is set but FAULT_FLAG_KILLABLE is not
189  * set, VM_FAULT_RETRY can still be returned if and only if there are
190  * fatal_signal_pending()s, and the mmap_sem must be released before
191  * returning it.
192  */
193 int handle_userfault(struct vm_area_struct *vma, unsigned long address,
194                      unsigned int flags, unsigned long reason)
195 {
196         struct mm_struct *mm = vma->vm_mm;
197         struct userfaultfd_ctx *ctx;
198         struct userfaultfd_wait_queue uwq;
199
200         BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
201
202         ctx = vma->vm_userfaultfd_ctx.ctx;
203         if (!ctx)
204                 return VM_FAULT_SIGBUS;
205
206         BUG_ON(ctx->mm != mm);
207
208         VM_BUG_ON(reason & ~(VM_UFFD_MISSING|VM_UFFD_WP));
209         VM_BUG_ON(!(reason & VM_UFFD_MISSING) ^ !!(reason & VM_UFFD_WP));
210
211         /*
212          * If it's already released don't get it. This avoids to loop
213          * in __get_user_pages if userfaultfd_release waits on the
214          * caller of handle_userfault to release the mmap_sem.
215          */
216         if (unlikely(ACCESS_ONCE(ctx->released)))
217                 return VM_FAULT_SIGBUS;
218
219         /*
220          * Check that we can return VM_FAULT_RETRY.
221          *
222          * NOTE: it should become possible to return VM_FAULT_RETRY
223          * even if FAULT_FLAG_TRIED is set without leading to gup()
224          * -EBUSY failures, if the userfaultfd is to be extended for
225          * VM_UFFD_WP tracking and we intend to arm the userfault
226          * without first stopping userland access to the memory. For
227          * VM_UFFD_MISSING userfaults this is enough for now.
228          */
229         if (unlikely(!(flags & FAULT_FLAG_ALLOW_RETRY))) {
230                 /*
231                  * Validate the invariant that nowait must allow retry
232                  * to be sure not to return SIGBUS erroneously on
233                  * nowait invocations.
234                  */
235                 BUG_ON(flags & FAULT_FLAG_RETRY_NOWAIT);
236 #ifdef CONFIG_DEBUG_VM
237                 if (printk_ratelimit()) {
238                         printk(KERN_WARNING
239                                "FAULT_FLAG_ALLOW_RETRY missing %x\n", flags);
240                         dump_stack();
241                 }
242 #endif
243                 return VM_FAULT_SIGBUS;
244         }
245
246         /*
247          * Handle nowait, not much to do other than tell it to retry
248          * and wait.
249          */
250         if (flags & FAULT_FLAG_RETRY_NOWAIT)
251                 return VM_FAULT_RETRY;
252
253         /* take the reference before dropping the mmap_sem */
254         userfaultfd_ctx_get(ctx);
255
256         /* be gentle and immediately relinquish the mmap_sem */
257         up_read(&mm->mmap_sem);
258
259         init_waitqueue_func_entry(&uwq.wq, userfaultfd_wake_function);
260         uwq.wq.private = current;
261         uwq.msg = userfault_msg(address, flags, reason);
262         uwq.pending = true;
263         uwq.ctx = ctx;
264
265         spin_lock(&ctx->fault_wqh.lock);
266         /*
267          * After the __add_wait_queue the uwq is visible to userland
268          * through poll/read().
269          */
270         __add_wait_queue(&ctx->fault_wqh, &uwq.wq);
271         for (;;) {
272                 set_current_state(TASK_KILLABLE);
273                 if (!uwq.pending || ACCESS_ONCE(ctx->released) ||
274                     fatal_signal_pending(current))
275                         break;
276                 spin_unlock(&ctx->fault_wqh.lock);
277
278                 wake_up_poll(&ctx->fd_wqh, POLLIN);
279                 schedule();
280
281                 spin_lock(&ctx->fault_wqh.lock);
282         }
283         __remove_wait_queue(&ctx->fault_wqh, &uwq.wq);
284         __set_current_state(TASK_RUNNING);
285         spin_unlock(&ctx->fault_wqh.lock);
286
287         /*
288          * ctx may go away after this if the userfault pseudo fd is
289          * already released.
290          */
291         userfaultfd_ctx_put(ctx);
292
293         return VM_FAULT_RETRY;
294 }
295
296 static int userfaultfd_release(struct inode *inode, struct file *file)
297 {
298         struct userfaultfd_ctx *ctx = file->private_data;
299         struct mm_struct *mm = ctx->mm;
300         struct vm_area_struct *vma, *prev;
301         /* len == 0 means wake all */
302         struct userfaultfd_wake_range range = { .len = 0, };
303         unsigned long new_flags;
304
305         ACCESS_ONCE(ctx->released) = true;
306
307         /*
308          * Flush page faults out of all CPUs. NOTE: all page faults
309          * must be retried without returning VM_FAULT_SIGBUS if
310          * userfaultfd_ctx_get() succeeds but vma->vma_userfault_ctx
311          * changes while handle_userfault released the mmap_sem. So
312          * it's critical that released is set to true (above), before
313          * taking the mmap_sem for writing.
314          */
315         down_write(&mm->mmap_sem);
316         prev = NULL;
317         for (vma = mm->mmap; vma; vma = vma->vm_next) {
318                 cond_resched();
319                 BUG_ON(!!vma->vm_userfaultfd_ctx.ctx ^
320                        !!(vma->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
321                 if (vma->vm_userfaultfd_ctx.ctx != ctx) {
322                         prev = vma;
323                         continue;
324                 }
325                 new_flags = vma->vm_flags & ~(VM_UFFD_MISSING | VM_UFFD_WP);
326                 prev = vma_merge(mm, prev, vma->vm_start, vma->vm_end,
327                                  new_flags, vma->anon_vma,
328                                  vma->vm_file, vma->vm_pgoff,
329                                  vma_policy(vma),
330                                  NULL_VM_UFFD_CTX);
331                 if (prev)
332                         vma = prev;
333                 else
334                         prev = vma;
335                 vma->vm_flags = new_flags;
336                 vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
337         }
338         up_write(&mm->mmap_sem);
339
340         /*
341          * After no new page faults can wait on this fault_wqh, flush
342          * the last page faults that may have been already waiting on
343          * the fault_wqh.
344          */
345         spin_lock(&ctx->fault_wqh.lock);
346         __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, &range);
347         spin_unlock(&ctx->fault_wqh.lock);
348
349         wake_up_poll(&ctx->fd_wqh, POLLHUP);
350         userfaultfd_ctx_put(ctx);
351         return 0;
352 }
353
354 /* fault_wqh.lock must be hold by the caller */
355 static inline unsigned int find_userfault(struct userfaultfd_ctx *ctx,
356                                           struct userfaultfd_wait_queue **uwq)
357 {
358         wait_queue_t *wq;
359         struct userfaultfd_wait_queue *_uwq;
360         unsigned int ret = 0;
361
362         VM_BUG_ON(!spin_is_locked(&ctx->fault_wqh.lock));
363
364         list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) {
365                 _uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
366                 if (_uwq->pending) {
367                         ret = POLLIN;
368                         if (!uwq)
369                                 /*
370                                  * If there's at least a pending and
371                                  * we don't care which one it is,
372                                  * break immediately and leverage the
373                                  * efficiency of the LIFO walk.
374                                  */
375                                 break;
376                         /*
377                          * If we need to find which one was pending we
378                          * keep walking until we find the first not
379                          * pending one, so we read() them in FIFO order.
380                          */
381                         *uwq = _uwq;
382                 } else
383                         /*
384                          * break the loop at the first not pending
385                          * one, there cannot be pending userfaults
386                          * after the first not pending one, because
387                          * all new pending ones are inserted at the
388                          * head and we walk it in LIFO.
389                          */
390                         break;
391         }
392
393         return ret;
394 }
395
396 static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
397 {
398         struct userfaultfd_ctx *ctx = file->private_data;
399         unsigned int ret;
400
401         poll_wait(file, &ctx->fd_wqh, wait);
402
403         switch (ctx->state) {
404         case UFFD_STATE_WAIT_API:
405                 return POLLERR;
406         case UFFD_STATE_RUNNING:
407                 spin_lock(&ctx->fault_wqh.lock);
408                 ret = find_userfault(ctx, NULL);
409                 spin_unlock(&ctx->fault_wqh.lock);
410                 return ret;
411         default:
412                 BUG();
413         }
414 }
415
416 static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
417                                     struct uffd_msg *msg)
418 {
419         ssize_t ret;
420         DECLARE_WAITQUEUE(wait, current);
421         struct userfaultfd_wait_queue *uwq = NULL;
422
423         /* always take the fd_wqh lock before the fault_wqh lock */
424         spin_lock(&ctx->fd_wqh.lock);
425         __add_wait_queue(&ctx->fd_wqh, &wait);
426         for (;;) {
427                 set_current_state(TASK_INTERRUPTIBLE);
428                 spin_lock(&ctx->fault_wqh.lock);
429                 if (find_userfault(ctx, &uwq)) {
430                         /*
431                          * The fault_wqh.lock prevents the uwq to
432                          * disappear from under us.
433                          */
434                         uwq->pending = false;
435                         /* careful to always initialize msg if ret == 0 */
436                         *msg = uwq->msg;
437                         spin_unlock(&ctx->fault_wqh.lock);
438                         ret = 0;
439                         break;
440                 }
441                 spin_unlock(&ctx->fault_wqh.lock);
442                 if (signal_pending(current)) {
443                         ret = -ERESTARTSYS;
444                         break;
445                 }
446                 if (no_wait) {
447                         ret = -EAGAIN;
448                         break;
449                 }
450                 spin_unlock(&ctx->fd_wqh.lock);
451                 schedule();
452                 spin_lock(&ctx->fd_wqh.lock);
453         }
454         __remove_wait_queue(&ctx->fd_wqh, &wait);
455         __set_current_state(TASK_RUNNING);
456         spin_unlock(&ctx->fd_wqh.lock);
457
458         return ret;
459 }
460
461 static ssize_t userfaultfd_read(struct file *file, char __user *buf,
462                                 size_t count, loff_t *ppos)
463 {
464         struct userfaultfd_ctx *ctx = file->private_data;
465         ssize_t _ret, ret = 0;
466         struct uffd_msg msg;
467         int no_wait = file->f_flags & O_NONBLOCK;
468
469         if (ctx->state == UFFD_STATE_WAIT_API)
470                 return -EINVAL;
471         BUG_ON(ctx->state != UFFD_STATE_RUNNING);
472
473         for (;;) {
474                 if (count < sizeof(msg))
475                         return ret ? ret : -EINVAL;
476                 _ret = userfaultfd_ctx_read(ctx, no_wait, &msg);
477                 if (_ret < 0)
478                         return ret ? ret : _ret;
479                 if (copy_to_user((__u64 __user *) buf, &msg, sizeof(msg)))
480                         return ret ? ret : -EFAULT;
481                 ret += sizeof(msg);
482                 buf += sizeof(msg);
483                 count -= sizeof(msg);
484                 /*
485                  * Allow to read more than one fault at time but only
486                  * block if waiting for the very first one.
487                  */
488                 no_wait = O_NONBLOCK;
489         }
490 }
491
492 static void __wake_userfault(struct userfaultfd_ctx *ctx,
493                              struct userfaultfd_wake_range *range)
494 {
495         unsigned long start, end;
496
497         start = range->start;
498         end = range->start + range->len;
499
500         spin_lock(&ctx->fault_wqh.lock);
501         /* wake all in the range and autoremove */
502         __wake_up_locked_key(&ctx->fault_wqh, TASK_NORMAL, 0, range);
503         spin_unlock(&ctx->fault_wqh.lock);
504 }
505
506 static __always_inline void wake_userfault(struct userfaultfd_ctx *ctx,
507                                            struct userfaultfd_wake_range *range)
508 {
509         /*
510          * To be sure waitqueue_active() is not reordered by the CPU
511          * before the pagetable update, use an explicit SMP memory
512          * barrier here. PT lock release or up_read(mmap_sem) still
513          * have release semantics that can allow the
514          * waitqueue_active() to be reordered before the pte update.
515          */
516         smp_mb();
517
518         /*
519          * Use waitqueue_active because it's very frequent to
520          * change the address space atomically even if there are no
521          * userfaults yet. So we take the spinlock only when we're
522          * sure we've userfaults to wake.
523          */
524         if (waitqueue_active(&ctx->fault_wqh))
525                 __wake_userfault(ctx, range);
526 }
527
528 static __always_inline int validate_range(struct mm_struct *mm,
529                                           __u64 start, __u64 len)
530 {
531         __u64 task_size = mm->task_size;
532
533         if (start & ~PAGE_MASK)
534                 return -EINVAL;
535         if (len & ~PAGE_MASK)
536                 return -EINVAL;
537         if (!len)
538                 return -EINVAL;
539         if (start < mmap_min_addr)
540                 return -EINVAL;
541         if (start >= task_size)
542                 return -EINVAL;
543         if (len > task_size - start)
544                 return -EINVAL;
545         return 0;
546 }
547
548 static int userfaultfd_register(struct userfaultfd_ctx *ctx,
549                                 unsigned long arg)
550 {
551         struct mm_struct *mm = ctx->mm;
552         struct vm_area_struct *vma, *prev, *cur;
553         int ret;
554         struct uffdio_register uffdio_register;
555         struct uffdio_register __user *user_uffdio_register;
556         unsigned long vm_flags, new_flags;
557         bool found;
558         unsigned long start, end, vma_end;
559
560         user_uffdio_register = (struct uffdio_register __user *) arg;
561
562         ret = -EFAULT;
563         if (copy_from_user(&uffdio_register, user_uffdio_register,
564                            sizeof(uffdio_register)-sizeof(__u64)))
565                 goto out;
566
567         ret = -EINVAL;
568         if (!uffdio_register.mode)
569                 goto out;
570         if (uffdio_register.mode & ~(UFFDIO_REGISTER_MODE_MISSING|
571                                      UFFDIO_REGISTER_MODE_WP))
572                 goto out;
573         vm_flags = 0;
574         if (uffdio_register.mode & UFFDIO_REGISTER_MODE_MISSING)
575                 vm_flags |= VM_UFFD_MISSING;
576         if (uffdio_register.mode & UFFDIO_REGISTER_MODE_WP) {
577                 vm_flags |= VM_UFFD_WP;
578                 /*
579                  * FIXME: remove the below error constraint by
580                  * implementing the wprotect tracking mode.
581                  */
582                 ret = -EINVAL;
583                 goto out;
584         }
585
586         ret = validate_range(mm, uffdio_register.range.start,
587                              uffdio_register.range.len);
588         if (ret)
589                 goto out;
590
591         start = uffdio_register.range.start;
592         end = start + uffdio_register.range.len;
593
594         down_write(&mm->mmap_sem);
595         vma = find_vma_prev(mm, start, &prev);
596
597         ret = -ENOMEM;
598         if (!vma)
599                 goto out_unlock;
600
601         /* check that there's at least one vma in the range */
602         ret = -EINVAL;
603         if (vma->vm_start >= end)
604                 goto out_unlock;
605
606         /*
607          * Search for not compatible vmas.
608          *
609          * FIXME: this shall be relaxed later so that it doesn't fail
610          * on tmpfs backed vmas (in addition to the current allowance
611          * on anonymous vmas).
612          */
613         found = false;
614         for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
615                 cond_resched();
616
617                 BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
618                        !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
619
620                 /* check not compatible vmas */
621                 ret = -EINVAL;
622                 if (cur->vm_ops)
623                         goto out_unlock;
624
625                 /*
626                  * Check that this vma isn't already owned by a
627                  * different userfaultfd. We can't allow more than one
628                  * userfaultfd to own a single vma simultaneously or we
629                  * wouldn't know which one to deliver the userfaults to.
630                  */
631                 ret = -EBUSY;
632                 if (cur->vm_userfaultfd_ctx.ctx &&
633                     cur->vm_userfaultfd_ctx.ctx != ctx)
634                         goto out_unlock;
635
636                 found = true;
637         }
638         BUG_ON(!found);
639
640         if (vma->vm_start < start)
641                 prev = vma;
642
643         ret = 0;
644         do {
645                 cond_resched();
646
647                 BUG_ON(vma->vm_ops);
648                 BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
649                        vma->vm_userfaultfd_ctx.ctx != ctx);
650
651                 /*
652                  * Nothing to do: this vma is already registered into this
653                  * userfaultfd and with the right tracking mode too.
654                  */
655                 if (vma->vm_userfaultfd_ctx.ctx == ctx &&
656                     (vma->vm_flags & vm_flags) == vm_flags)
657                         goto skip;
658
659                 if (vma->vm_start > start)
660                         start = vma->vm_start;
661                 vma_end = min(end, vma->vm_end);
662
663                 new_flags = (vma->vm_flags & ~vm_flags) | vm_flags;
664                 prev = vma_merge(mm, prev, start, vma_end, new_flags,
665                                  vma->anon_vma, vma->vm_file, vma->vm_pgoff,
666                                  vma_policy(vma),
667                                  ((struct vm_userfaultfd_ctx){ ctx }));
668                 if (prev) {
669                         vma = prev;
670                         goto next;
671                 }
672                 if (vma->vm_start < start) {
673                         ret = split_vma(mm, vma, start, 1);
674                         if (ret)
675                                 break;
676                 }
677                 if (vma->vm_end > end) {
678                         ret = split_vma(mm, vma, end, 0);
679                         if (ret)
680                                 break;
681                 }
682         next:
683                 /*
684                  * In the vma_merge() successful mprotect-like case 8:
685                  * the next vma was merged into the current one and
686                  * the current one has not been updated yet.
687                  */
688                 vma->vm_flags = new_flags;
689                 vma->vm_userfaultfd_ctx.ctx = ctx;
690
691         skip:
692                 prev = vma;
693                 start = vma->vm_end;
694                 vma = vma->vm_next;
695         } while (vma && vma->vm_start < end);
696 out_unlock:
697         up_write(&mm->mmap_sem);
698         if (!ret) {
699                 /*
700                  * Now that we scanned all vmas we can already tell
701                  * userland which ioctls methods are guaranteed to
702                  * succeed on this range.
703                  */
704                 if (put_user(UFFD_API_RANGE_IOCTLS,
705                              &user_uffdio_register->ioctls))
706                         ret = -EFAULT;
707         }
708 out:
709         return ret;
710 }
711
712 static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
713                                   unsigned long arg)
714 {
715         struct mm_struct *mm = ctx->mm;
716         struct vm_area_struct *vma, *prev, *cur;
717         int ret;
718         struct uffdio_range uffdio_unregister;
719         unsigned long new_flags;
720         bool found;
721         unsigned long start, end, vma_end;
722         const void __user *buf = (void __user *)arg;
723
724         ret = -EFAULT;
725         if (copy_from_user(&uffdio_unregister, buf, sizeof(uffdio_unregister)))
726                 goto out;
727
728         ret = validate_range(mm, uffdio_unregister.start,
729                              uffdio_unregister.len);
730         if (ret)
731                 goto out;
732
733         start = uffdio_unregister.start;
734         end = start + uffdio_unregister.len;
735
736         down_write(&mm->mmap_sem);
737         vma = find_vma_prev(mm, start, &prev);
738
739         ret = -ENOMEM;
740         if (!vma)
741                 goto out_unlock;
742
743         /* check that there's at least one vma in the range */
744         ret = -EINVAL;
745         if (vma->vm_start >= end)
746                 goto out_unlock;
747
748         /*
749          * Search for not compatible vmas.
750          *
751          * FIXME: this shall be relaxed later so that it doesn't fail
752          * on tmpfs backed vmas (in addition to the current allowance
753          * on anonymous vmas).
754          */
755         found = false;
756         ret = -EINVAL;
757         for (cur = vma; cur && cur->vm_start < end; cur = cur->vm_next) {
758                 cond_resched();
759
760                 BUG_ON(!!cur->vm_userfaultfd_ctx.ctx ^
761                        !!(cur->vm_flags & (VM_UFFD_MISSING | VM_UFFD_WP)));
762
763                 /*
764                  * Check not compatible vmas, not strictly required
765                  * here as not compatible vmas cannot have an
766                  * userfaultfd_ctx registered on them, but this
767                  * provides for more strict behavior to notice
768                  * unregistration errors.
769                  */
770                 if (cur->vm_ops)
771                         goto out_unlock;
772
773                 found = true;
774         }
775         BUG_ON(!found);
776
777         if (vma->vm_start < start)
778                 prev = vma;
779
780         ret = 0;
781         do {
782                 cond_resched();
783
784                 BUG_ON(vma->vm_ops);
785
786                 /*
787                  * Nothing to do: this vma is already registered into this
788                  * userfaultfd and with the right tracking mode too.
789                  */
790                 if (!vma->vm_userfaultfd_ctx.ctx)
791                         goto skip;
792
793                 if (vma->vm_start > start)
794                         start = vma->vm_start;
795                 vma_end = min(end, vma->vm_end);
796
797                 new_flags = vma->vm_flags & ~(VM_UFFD_MISSING | VM_UFFD_WP);
798                 prev = vma_merge(mm, prev, start, vma_end, new_flags,
799                                  vma->anon_vma, vma->vm_file, vma->vm_pgoff,
800                                  vma_policy(vma),
801                                  NULL_VM_UFFD_CTX);
802                 if (prev) {
803                         vma = prev;
804                         goto next;
805                 }
806                 if (vma->vm_start < start) {
807                         ret = split_vma(mm, vma, start, 1);
808                         if (ret)
809                                 break;
810                 }
811                 if (vma->vm_end > end) {
812                         ret = split_vma(mm, vma, end, 0);
813                         if (ret)
814                                 break;
815                 }
816         next:
817                 /*
818                  * In the vma_merge() successful mprotect-like case 8:
819                  * the next vma was merged into the current one and
820                  * the current one has not been updated yet.
821                  */
822                 vma->vm_flags = new_flags;
823                 vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
824
825         skip:
826                 prev = vma;
827                 start = vma->vm_end;
828                 vma = vma->vm_next;
829         } while (vma && vma->vm_start < end);
830 out_unlock:
831         up_write(&mm->mmap_sem);
832 out:
833         return ret;
834 }
835
836 /*
837  * This is mostly needed to re-wakeup those userfaults that were still
838  * pending when userland wake them up the first time. We don't wake
839  * the pending one to avoid blocking reads to block, or non blocking
840  * read to return -EAGAIN, if used with POLLIN, to avoid userland
841  * doubts on why POLLIN wasn't reliable.
842  */
843 static int userfaultfd_wake(struct userfaultfd_ctx *ctx,
844                             unsigned long arg)
845 {
846         int ret;
847         struct uffdio_range uffdio_wake;
848         struct userfaultfd_wake_range range;
849         const void __user *buf = (void __user *)arg;
850
851         ret = -EFAULT;
852         if (copy_from_user(&uffdio_wake, buf, sizeof(uffdio_wake)))
853                 goto out;
854
855         ret = validate_range(ctx->mm, uffdio_wake.start, uffdio_wake.len);
856         if (ret)
857                 goto out;
858
859         range.start = uffdio_wake.start;
860         range.len = uffdio_wake.len;
861
862         /*
863          * len == 0 means wake all and we don't want to wake all here,
864          * so check it again to be sure.
865          */
866         VM_BUG_ON(!range.len);
867
868         wake_userfault(ctx, &range);
869         ret = 0;
870
871 out:
872         return ret;
873 }
874
875 /*
876  * userland asks for a certain API version and we return which bits
877  * and ioctl commands are implemented in this kernel for such API
878  * version or -EINVAL if unknown.
879  */
880 static int userfaultfd_api(struct userfaultfd_ctx *ctx,
881                            unsigned long arg)
882 {
883         struct uffdio_api uffdio_api;
884         void __user *buf = (void __user *)arg;
885         int ret;
886
887         ret = -EINVAL;
888         if (ctx->state != UFFD_STATE_WAIT_API)
889                 goto out;
890         ret = -EFAULT;
891         if (copy_from_user(&uffdio_api, buf, sizeof(uffdio_api)))
892                 goto out;
893         if (uffdio_api.api != UFFD_API || uffdio_api.features) {
894                 memset(&uffdio_api, 0, sizeof(uffdio_api));
895                 if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
896                         goto out;
897                 ret = -EINVAL;
898                 goto out;
899         }
900         uffdio_api.features = UFFD_API_FEATURES;
901         uffdio_api.ioctls = UFFD_API_IOCTLS;
902         ret = -EFAULT;
903         if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
904                 goto out;
905         ctx->state = UFFD_STATE_RUNNING;
906         ret = 0;
907 out:
908         return ret;
909 }
910
911 static long userfaultfd_ioctl(struct file *file, unsigned cmd,
912                               unsigned long arg)
913 {
914         int ret = -EINVAL;
915         struct userfaultfd_ctx *ctx = file->private_data;
916
917         switch(cmd) {
918         case UFFDIO_API:
919                 ret = userfaultfd_api(ctx, arg);
920                 break;
921         case UFFDIO_REGISTER:
922                 ret = userfaultfd_register(ctx, arg);
923                 break;
924         case UFFDIO_UNREGISTER:
925                 ret = userfaultfd_unregister(ctx, arg);
926                 break;
927         case UFFDIO_WAKE:
928                 ret = userfaultfd_wake(ctx, arg);
929                 break;
930         }
931         return ret;
932 }
933
934 #ifdef CONFIG_PROC_FS
935 static void userfaultfd_show_fdinfo(struct seq_file *m, struct file *f)
936 {
937         struct userfaultfd_ctx *ctx = f->private_data;
938         wait_queue_t *wq;
939         struct userfaultfd_wait_queue *uwq;
940         unsigned long pending = 0, total = 0;
941
942         spin_lock(&ctx->fault_wqh.lock);
943         list_for_each_entry(wq, &ctx->fault_wqh.task_list, task_list) {
944                 uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
945                 if (uwq->pending)
946                         pending++;
947                 total++;
948         }
949         spin_unlock(&ctx->fault_wqh.lock);
950
951         /*
952          * If more protocols will be added, there will be all shown
953          * separated by a space. Like this:
954          *      protocols: aa:... bb:...
955          */
956         seq_printf(m, "pending:\t%lu\ntotal:\t%lu\nAPI:\t%Lx:%x:%Lx\n",
957                    pending, total, UFFD_API, UFFD_API_FEATURES,
958                    UFFD_API_IOCTLS|UFFD_API_RANGE_IOCTLS);
959 }
960 #endif
961
962 static const struct file_operations userfaultfd_fops = {
963 #ifdef CONFIG_PROC_FS
964         .show_fdinfo    = userfaultfd_show_fdinfo,
965 #endif
966         .release        = userfaultfd_release,
967         .poll           = userfaultfd_poll,
968         .read           = userfaultfd_read,
969         .unlocked_ioctl = userfaultfd_ioctl,
970         .compat_ioctl   = userfaultfd_ioctl,
971         .llseek         = noop_llseek,
972 };
973
974 /**
975  * userfaultfd_file_create - Creates an userfaultfd file pointer.
976  * @flags: Flags for the userfaultfd file.
977  *
978  * This function creates an userfaultfd file pointer, w/out installing
979  * it into the fd table. This is useful when the userfaultfd file is
980  * used during the initialization of data structures that require
981  * extra setup after the userfaultfd creation. So the userfaultfd
982  * creation is split into the file pointer creation phase, and the
983  * file descriptor installation phase.  In this way races with
984  * userspace closing the newly installed file descriptor can be
985  * avoided.  Returns an userfaultfd file pointer, or a proper error
986  * pointer.
987  */
988 static struct file *userfaultfd_file_create(int flags)
989 {
990         struct file *file;
991         struct userfaultfd_ctx *ctx;
992
993         BUG_ON(!current->mm);
994
995         /* Check the UFFD_* constants for consistency.  */
996         BUILD_BUG_ON(UFFD_CLOEXEC != O_CLOEXEC);
997         BUILD_BUG_ON(UFFD_NONBLOCK != O_NONBLOCK);
998
999         file = ERR_PTR(-EINVAL);
1000         if (flags & ~UFFD_SHARED_FCNTL_FLAGS)
1001                 goto out;
1002
1003         file = ERR_PTR(-ENOMEM);
1004         ctx = kmalloc(sizeof(*ctx), GFP_KERNEL);
1005         if (!ctx)
1006                 goto out;
1007
1008         atomic_set(&ctx->refcount, 1);
1009         init_waitqueue_head(&ctx->fault_wqh);
1010         init_waitqueue_head(&ctx->fd_wqh);
1011         ctx->flags = flags;
1012         ctx->state = UFFD_STATE_WAIT_API;
1013         ctx->released = false;
1014         ctx->mm = current->mm;
1015         /* prevent the mm struct to be freed */
1016         atomic_inc(&ctx->mm->mm_users);
1017
1018         file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, ctx,
1019                                   O_RDWR | (flags & UFFD_SHARED_FCNTL_FLAGS));
1020         if (IS_ERR(file))
1021                 kfree(ctx);
1022 out:
1023         return file;
1024 }
1025
1026 SYSCALL_DEFINE1(userfaultfd, int, flags)
1027 {
1028         int fd, error;
1029         struct file *file;
1030
1031         error = get_unused_fd_flags(flags & UFFD_SHARED_FCNTL_FLAGS);
1032         if (error < 0)
1033                 return error;
1034         fd = error;
1035
1036         file = userfaultfd_file_create(flags);
1037         if (IS_ERR(file)) {
1038                 error = PTR_ERR(file);
1039                 goto err_put_unused_fd;
1040         }
1041         fd_install(fd, file);
1042
1043         return fd;
1044
1045 err_put_unused_fd:
1046         put_unused_fd(fd);
1047
1048         return error;
1049 }