mm/gup, x86/mm/pkeys: Check VMAs and PTEs for protection keys
[cascardo/linux.git] / arch / x86 / mm / gup.c
index 6d5eb59..bab259e 100644 (file)
@@ -11,6 +11,7 @@
 #include <linux/swap.h>
 #include <linux/memremap.h>
 
+#include <asm/mmu_context.h>
 #include <asm/pgtable.h>
 
 static inline pte_t gup_get_pte(pte_t *ptep)
@@ -74,6 +75,28 @@ static void undo_dev_pagemap(int *nr, int nr_start, struct page **pages)
        }
 }
 
+/*
+ * 'pteval' can come from a pte, pmd or pud.  We only check
+ * _PAGE_PRESENT, _PAGE_USER, and _PAGE_RW in here which are the
+ * same value on all 3 types.
+ */
+static inline int pte_allows_gup(unsigned long pteval, int write)
+{
+       unsigned long need_pte_bits = _PAGE_PRESENT|_PAGE_USER;
+
+       if (write)
+               need_pte_bits |= _PAGE_RW;
+
+       if ((pteval & need_pte_bits) != need_pte_bits)
+               return 0;
+
+       /* Check memory protection keys permissions. */
+       if (!__pkru_allows_pkey(pte_flags_pkey(pteval), write))
+               return 0;
+
+       return 1;
+}
+
 /*
  * The performance critical leaf functions are made noinline otherwise gcc
  * inlines everything into a single function which results in too much
@@ -83,14 +106,9 @@ static noinline int gup_pte_range(pmd_t pmd, unsigned long addr,
                unsigned long end, int write, struct page **pages, int *nr)
 {
        struct dev_pagemap *pgmap = NULL;
-       unsigned long mask;
        int nr_start = *nr;
        pte_t *ptep;
 
-       mask = _PAGE_PRESENT|_PAGE_USER;
-       if (write)
-               mask |= _PAGE_RW;
-
        ptep = pte_offset_map(&pmd, addr);
        do {
                pte_t pte = gup_get_pte(ptep);
@@ -110,7 +128,8 @@ static noinline int gup_pte_range(pmd_t pmd, unsigned long addr,
                                pte_unmap(ptep);
                                return 0;
                        }
-               } else if ((pte_flags(pte) & (mask | _PAGE_SPECIAL)) != mask) {
+               } else if (!pte_allows_gup(pte_val(pte), write) ||
+                          pte_special(pte)) {
                        pte_unmap(ptep);
                        return 0;
                }
@@ -164,14 +183,10 @@ static int __gup_device_huge_pmd(pmd_t pmd, unsigned long addr,
 static noinline int gup_huge_pmd(pmd_t pmd, unsigned long addr,
                unsigned long end, int write, struct page **pages, int *nr)
 {
-       unsigned long mask;
        struct page *head, *page;
        int refs;
 
-       mask = _PAGE_PRESENT|_PAGE_USER;
-       if (write)
-               mask |= _PAGE_RW;
-       if ((pmd_flags(pmd) & mask) != mask)
+       if (!pte_allows_gup(pmd_val(pmd), write))
                return 0;
 
        VM_BUG_ON(!pfn_valid(pmd_pfn(pmd)));
@@ -231,14 +246,10 @@ static int gup_pmd_range(pud_t pud, unsigned long addr, unsigned long end,
 static noinline int gup_huge_pud(pud_t pud, unsigned long addr,
                unsigned long end, int write, struct page **pages, int *nr)
 {
-       unsigned long mask;
        struct page *head, *page;
        int refs;
 
-       mask = _PAGE_PRESENT|_PAGE_USER;
-       if (write)
-               mask |= _PAGE_RW;
-       if ((pud_flags(pud) & mask) != mask)
+       if (!pte_allows_gup(pud_val(pud), write))
                return 0;
        /* hugepages are never "special" */
        VM_BUG_ON(pud_flags(pud) & _PAGE_SPECIAL);
@@ -422,7 +433,7 @@ slow_irqon:
                start += nr << PAGE_SHIFT;
                pages += nr;
 
-               ret = get_user_pages_unlocked(current, mm, start,
+               ret = get_user_pages_unlocked(start,
                                              (end - start) >> PAGE_SHIFT,
                                              write, 0, pages);