mm/gup, x86/mm/pkeys: Check VMAs and PTEs for protection keys
[cascardo/linux.git] / arch / x86 / mm / fault.c
index eef44d9..319331a 100644 (file)
 #include <linux/context_tracking.h>    /* exception_enter(), ...       */
 #include <linux/uaccess.h>             /* faulthandler_disabled()      */
 
+#include <asm/cpufeature.h>            /* boot_cpu_has, ...            */
 #include <asm/traps.h>                 /* dotraplinkage, ...           */
 #include <asm/pgalloc.h>               /* pgd_*(), ...                 */
 #include <asm/kmemcheck.h>             /* kmemcheck_*(), ...           */
 #include <asm/fixmap.h>                        /* VSYSCALL_ADDR                */
 #include <asm/vsyscall.h>              /* emulate_vsyscall             */
 #include <asm/vm86.h>                  /* struct vm86                  */
+#include <asm/mmu_context.h>           /* vma_pkey()                   */
 
 #define CREATE_TRACE_POINTS
 #include <asm/trace/exceptions.h>
@@ -33,6 +35,7 @@
  *   bit 2 ==   0: kernel-mode access  1: user-mode access
  *   bit 3 ==                          1: use of reserved bit detected
  *   bit 4 ==                          1: fault was an instruction fetch
+ *   bit 5 ==                          1: protection keys block access
  */
 enum x86_pf_error_code {
 
@@ -41,6 +44,7 @@ enum x86_pf_error_code {
        PF_USER         =               1 << 2,
        PF_RSVD         =               1 << 3,
        PF_INSTR        =               1 << 4,
+       PF_PK           =               1 << 5,
 };
 
 /*
@@ -167,9 +171,60 @@ is_prefetch(struct pt_regs *regs, unsigned long error_code, unsigned long addr)
        return prefetch;
 }
 
+/*
+ * A protection key fault means that the PKRU value did not allow
+ * access to some PTE.  Userspace can figure out what PKRU was
+ * from the XSAVE state, and this function fills out a field in
+ * siginfo so userspace can discover which protection key was set
+ * on the PTE.
+ *
+ * If we get here, we know that the hardware signaled a PF_PK
+ * fault and that there was a VMA once we got in the fault
+ * handler.  It does *not* guarantee that the VMA we find here
+ * was the one that we faulted on.
+ *
+ * 1. T1   : mprotect_key(foo, PAGE_SIZE, pkey=4);
+ * 2. T1   : set PKRU to deny access to pkey=4, touches page
+ * 3. T1   : faults...
+ * 4.    T2: mprotect_key(foo, PAGE_SIZE, pkey=5);
+ * 5. T1   : enters fault handler, takes mmap_sem, etc...
+ * 6. T1   : reaches here, sees vma_pkey(vma)=5, when we really
+ *          faulted on a pte with its pkey=4.
+ */
+static void fill_sig_info_pkey(int si_code, siginfo_t *info,
+               struct vm_area_struct *vma)
+{
+       /* This is effectively an #ifdef */
+       if (!boot_cpu_has(X86_FEATURE_OSPKE))
+               return;
+
+       /* Fault not from Protection Keys: nothing to do */
+       if (si_code != SEGV_PKUERR)
+               return;
+       /*
+        * force_sig_info_fault() is called from a number of
+        * contexts, some of which have a VMA and some of which
+        * do not.  The PF_PK handing happens after we have a
+        * valid VMA, so we should never reach this without a
+        * valid VMA.
+        */
+       if (!vma) {
+               WARN_ONCE(1, "PKU fault with no VMA passed in");
+               info->si_pkey = 0;
+               return;
+       }
+       /*
+        * si_pkey should be thought of as a strong hint, but not
+        * absolutely guranteed to be 100% accurate because of
+        * the race explained above.
+        */
+       info->si_pkey = vma_pkey(vma);
+}
+
 static void
 force_sig_info_fault(int si_signo, int si_code, unsigned long address,
-                    struct task_struct *tsk, int fault)
+                    struct task_struct *tsk, struct vm_area_struct *vma,
+                    int fault)
 {
        unsigned lsb = 0;
        siginfo_t info;
@@ -184,6 +239,8 @@ force_sig_info_fault(int si_signo, int si_code, unsigned long address,
                lsb = PAGE_SHIFT;
        info.si_addr_lsb = lsb;
 
+       fill_sig_info_pkey(si_code, &info, vma);
+
        force_sig_info(si_signo, &info, tsk);
 }
 
@@ -654,6 +711,8 @@ no_context(struct pt_regs *regs, unsigned long error_code,
        struct task_struct *tsk = current;
        unsigned long flags;
        int sig;
+       /* No context means no VMA to pass down */
+       struct vm_area_struct *vma = NULL;
 
        /* Are we prepared to handle this kernel fault? */
        if (fixup_exception(regs)) {
@@ -677,7 +736,8 @@ no_context(struct pt_regs *regs, unsigned long error_code,
                        tsk->thread.cr2 = address;
 
                        /* XXX: hwpoison faults will set the wrong code. */
-                       force_sig_info_fault(signal, si_code, address, tsk, 0);
+                       force_sig_info_fault(signal, si_code, address,
+                                            tsk, vma, 0);
                }
 
                /*
@@ -754,7 +814,8 @@ show_signal_msg(struct pt_regs *regs, unsigned long error_code,
 
 static void
 __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
-                      unsigned long address, int si_code)
+                      unsigned long address, struct vm_area_struct *vma,
+                      int si_code)
 {
        struct task_struct *tsk = current;
 
@@ -797,7 +858,7 @@ __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
                tsk->thread.error_code  = error_code;
                tsk->thread.trap_nr     = X86_TRAP_PF;
 
-               force_sig_info_fault(SIGSEGV, si_code, address, tsk, 0);
+               force_sig_info_fault(SIGSEGV, si_code, address, tsk, vma, 0);
 
                return;
        }
@@ -810,14 +871,14 @@ __bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
 
 static noinline void
 bad_area_nosemaphore(struct pt_regs *regs, unsigned long error_code,
-                    unsigned long address)
+                    unsigned long address, struct vm_area_struct *vma)
 {
-       __bad_area_nosemaphore(regs, error_code, address, SEGV_MAPERR);
+       __bad_area_nosemaphore(regs, error_code, address, vma, SEGV_MAPERR);
 }
 
 static void
 __bad_area(struct pt_regs *regs, unsigned long error_code,
-          unsigned long address, int si_code)
+          unsigned long address,  struct vm_area_struct *vma, int si_code)
 {
        struct mm_struct *mm = current->mm;
 
@@ -827,25 +888,43 @@ __bad_area(struct pt_regs *regs, unsigned long error_code,
         */
        up_read(&mm->mmap_sem);
 
-       __bad_area_nosemaphore(regs, error_code, address, si_code);
+       __bad_area_nosemaphore(regs, error_code, address, vma, si_code);
 }
 
 static noinline void
 bad_area(struct pt_regs *regs, unsigned long error_code, unsigned long address)
 {
-       __bad_area(regs, error_code, address, SEGV_MAPERR);
+       __bad_area(regs, error_code, address, NULL, SEGV_MAPERR);
+}
+
+static inline bool bad_area_access_from_pkeys(unsigned long error_code,
+               struct vm_area_struct *vma)
+{
+       if (!boot_cpu_has(X86_FEATURE_OSPKE))
+               return false;
+       if (error_code & PF_PK)
+               return true;
+       return false;
 }
 
 static noinline void
 bad_area_access_error(struct pt_regs *regs, unsigned long error_code,
-                     unsigned long address)
+                     unsigned long address, struct vm_area_struct *vma)
 {
-       __bad_area(regs, error_code, address, SEGV_ACCERR);
+       /*
+        * This OSPKE check is not strictly necessary at runtime.
+        * But, doing it this way allows compiler optimizations
+        * if pkeys are compiled out.
+        */
+       if (bad_area_access_from_pkeys(error_code, vma))
+               __bad_area(regs, error_code, address, vma, SEGV_PKUERR);
+       else
+               __bad_area(regs, error_code, address, vma, SEGV_ACCERR);
 }
 
 static void
 do_sigbus(struct pt_regs *regs, unsigned long error_code, unsigned long address,
-         unsigned int fault)
+         struct vm_area_struct *vma, unsigned int fault)
 {
        struct task_struct *tsk = current;
        int code = BUS_ADRERR;
@@ -872,12 +951,13 @@ do_sigbus(struct pt_regs *regs, unsigned long error_code, unsigned long address,
                code = BUS_MCEERR_AR;
        }
 #endif
-       force_sig_info_fault(SIGBUS, code, address, tsk, fault);
+       force_sig_info_fault(SIGBUS, code, address, tsk, vma, fault);
 }
 
 static noinline void
 mm_fault_error(struct pt_regs *regs, unsigned long error_code,
-              unsigned long address, unsigned int fault)
+              unsigned long address, struct vm_area_struct *vma,
+              unsigned int fault)
 {
        if (fatal_signal_pending(current) && !(error_code & PF_USER)) {
                no_context(regs, error_code, address, 0, 0);
@@ -901,9 +981,9 @@ mm_fault_error(struct pt_regs *regs, unsigned long error_code,
        } else {
                if (fault & (VM_FAULT_SIGBUS|VM_FAULT_HWPOISON|
                             VM_FAULT_HWPOISON_LARGE))
-                       do_sigbus(regs, error_code, address, fault);
+                       do_sigbus(regs, error_code, address, vma, fault);
                else if (fault & VM_FAULT_SIGSEGV)
-                       bad_area_nosemaphore(regs, error_code, address);
+                       bad_area_nosemaphore(regs, error_code, address, vma);
                else
                        BUG();
        }
@@ -916,6 +996,12 @@ static int spurious_fault_check(unsigned long error_code, pte_t *pte)
 
        if ((error_code & PF_INSTR) && !pte_exec(*pte))
                return 0;
+       /*
+        * Note: We do not do lazy flushing on protection key
+        * changes, so no spurious fault will ever set PF_PK.
+        */
+       if ((error_code & PF_PK))
+               return 1;
 
        return 1;
 }
@@ -1005,6 +1091,15 @@ int show_unhandled_signals = 1;
 static inline int
 access_error(unsigned long error_code, struct vm_area_struct *vma)
 {
+       /*
+        * Access or read was blocked by protection keys. We do
+        * this check before any others because we do not want
+        * to, for instance, confuse a protection-key-denied
+        * write with one for which we should do a COW.
+        */
+       if (error_code & PF_PK)
+               return 1;
+
        if (error_code & PF_WRITE) {
                /* write, present and write, not present: */
                if (unlikely(!(vma->vm_flags & VM_WRITE)))
@@ -1111,7 +1206,7 @@ __do_page_fault(struct pt_regs *regs, unsigned long error_code,
                 * Don't take the mm semaphore here. If we fixup a prefetch
                 * fault we could otherwise deadlock:
                 */
-               bad_area_nosemaphore(regs, error_code, address);
+               bad_area_nosemaphore(regs, error_code, address, NULL);
 
                return;
        }
@@ -1124,7 +1219,7 @@ __do_page_fault(struct pt_regs *regs, unsigned long error_code,
                pgtable_bad(regs, error_code, address);
 
        if (unlikely(smap_violation(error_code, regs))) {
-               bad_area_nosemaphore(regs, error_code, address);
+               bad_area_nosemaphore(regs, error_code, address, NULL);
                return;
        }
 
@@ -1133,7 +1228,7 @@ __do_page_fault(struct pt_regs *regs, unsigned long error_code,
         * in a region with pagefaults disabled then we must not take the fault
         */
        if (unlikely(faulthandler_disabled() || !mm)) {
-               bad_area_nosemaphore(regs, error_code, address);
+               bad_area_nosemaphore(regs, error_code, address, NULL);
                return;
        }
 
@@ -1177,7 +1272,7 @@ __do_page_fault(struct pt_regs *regs, unsigned long error_code,
        if (unlikely(!down_read_trylock(&mm->mmap_sem))) {
                if ((error_code & PF_USER) == 0 &&
                    !search_exception_tables(regs->ip)) {
-                       bad_area_nosemaphore(regs, error_code, address);
+                       bad_area_nosemaphore(regs, error_code, address, NULL);
                        return;
                }
 retry:
@@ -1225,7 +1320,7 @@ retry:
         */
 good_area:
        if (unlikely(access_error(error_code, vma))) {
-               bad_area_access_error(regs, error_code, address);
+               bad_area_access_error(regs, error_code, address, vma);
                return;
        }
 
@@ -1263,7 +1358,7 @@ good_area:
 
        up_read(&mm->mmap_sem);
        if (unlikely(fault & VM_FAULT_ERROR)) {
-               mm_fault_error(regs, error_code, address, fault);
+               mm_fault_error(regs, error_code, address, vma, fault);
                return;
        }