Merge remote-tracking branch 'wireless-next/master' into iwlwifi-next
[cascardo/linux.git] / mm / mprotect.c
index 7332c17..c43d557 100644 (file)
@@ -36,6 +36,34 @@ static inline pgprot_t pgprot_modify(pgprot_t oldprot, pgprot_t newprot)
 }
 #endif
 
+/*
+ * For a prot_numa update we only hold mmap_sem for read so there is a
+ * potential race with faulting where a pmd was temporarily none. This
+ * function checks for a transhuge pmd under the appropriate lock. It
+ * returns a pte if it was successfully locked or NULL if it raced with
+ * a transhuge insertion.
+ */
+static pte_t *lock_pte_protection(struct vm_area_struct *vma, pmd_t *pmd,
+                       unsigned long addr, int prot_numa, spinlock_t **ptl)
+{
+       pte_t *pte;
+       spinlock_t *pmdl;
+
+       /* !prot_numa is protected by mmap_sem held for write */
+       if (!prot_numa)
+               return pte_offset_map_lock(vma->vm_mm, pmd, addr, ptl);
+
+       pmdl = pmd_lock(vma->vm_mm, pmd);
+       if (unlikely(pmd_trans_huge(*pmd) || pmd_none(*pmd))) {
+               spin_unlock(pmdl);
+               return NULL;
+       }
+
+       pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, ptl);
+       spin_unlock(pmdl);
+       return pte;
+}
+
 static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
                unsigned long addr, unsigned long end, pgprot_t newprot,
                int dirty_accountable, int prot_numa)
@@ -45,7 +73,10 @@ static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
        spinlock_t *ptl;
        unsigned long pages = 0;
 
-       pte = pte_offset_map_lock(mm, pmd, addr, &ptl);
+       pte = lock_pte_protection(vma, pmd, addr, prot_numa, &ptl);
+       if (!pte)
+               return 0;
+
        arch_enter_lazy_mmu_mode();
        do {
                oldpte = *pte;
@@ -58,36 +89,27 @@ static unsigned long change_pte_range(struct vm_area_struct *vma, pmd_t *pmd,
                                if (pte_numa(ptent))
                                        ptent = pte_mknonnuma(ptent);
                                ptent = pte_modify(ptent, newprot);
+                               /*
+                                * Avoid taking write faults for pages we
+                                * know to be dirty.
+                                */
+                               if (dirty_accountable && pte_dirty(ptent))
+                                       ptent = pte_mkwrite(ptent);
+                               ptep_modify_prot_commit(mm, addr, pte, ptent);
                                updated = true;
                        } else {
                                struct page *page;
 
-                               ptent = *pte;
                                page = vm_normal_page(vma, addr, oldpte);
                                if (page && !PageKsm(page)) {
                                        if (!pte_numa(oldpte)) {
-                                               ptent = pte_mknuma(ptent);
-                                               set_pte_at(mm, addr, pte, ptent);
+                                               ptep_set_numa(mm, addr, pte);
                                                updated = true;
                                        }
                                }
                        }
-
-                       /*
-                        * Avoid taking write faults for pages we know to be
-                        * dirty.
-                        */
-                       if (dirty_accountable && pte_dirty(ptent)) {
-                               ptent = pte_mkwrite(ptent);
-                               updated = true;
-                       }
-
                        if (updated)
                                pages++;
-
-                       /* Only !prot_numa always clears the pte */
-                       if (!prot_numa)
-                               ptep_modify_prot_commit(mm, addr, pte, ptent);
                } else if (IS_ENABLED(CONFIG_MIGRATION) && !pte_file(oldpte)) {
                        swp_entry_t entry = pte_to_swp_entry(oldpte);
 
@@ -118,15 +140,26 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
                pgprot_t newprot, int dirty_accountable, int prot_numa)
 {
        pmd_t *pmd;
+       struct mm_struct *mm = vma->vm_mm;
        unsigned long next;
        unsigned long pages = 0;
        unsigned long nr_huge_updates = 0;
+       unsigned long mni_start = 0;
 
        pmd = pmd_offset(pud, addr);
        do {
                unsigned long this_pages;
 
                next = pmd_addr_end(addr, end);
+               if (!pmd_trans_huge(*pmd) && pmd_none_or_clear_bad(pmd))
+                       continue;
+
+               /* invoke the mmu notifier if the pmd is populated */
+               if (!mni_start) {
+                       mni_start = addr;
+                       mmu_notifier_invalidate_range_start(mm, mni_start, end);
+               }
+
                if (pmd_trans_huge(*pmd)) {
                        if (next - addr != HPAGE_PMD_SIZE)
                                split_huge_page_pmd(vma, addr, pmd);
@@ -139,18 +172,21 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
                                                pages += HPAGE_PMD_NR;
                                                nr_huge_updates++;
                                        }
+
+                                       /* huge pmd was handled */
                                        continue;
                                }
                        }
-                       /* fall through */
+                       /* fall through, the trans huge pmd just split */
                }
-               if (pmd_none_or_clear_bad(pmd))
-                       continue;
                this_pages = change_pte_range(vma, pmd, addr, next, newprot,
                                 dirty_accountable, prot_numa);
                pages += this_pages;
        } while (pmd++, addr = next, addr != end);
 
+       if (mni_start)
+               mmu_notifier_invalidate_range_end(mm, mni_start, end);
+
        if (nr_huge_updates)
                count_vm_numa_events(NUMA_HUGE_PTE_UPDATES, nr_huge_updates);
        return pages;
@@ -210,15 +246,12 @@ unsigned long change_protection(struct vm_area_struct *vma, unsigned long start,
                       unsigned long end, pgprot_t newprot,
                       int dirty_accountable, int prot_numa)
 {
-       struct mm_struct *mm = vma->vm_mm;
        unsigned long pages;
 
-       mmu_notifier_invalidate_range_start(mm, start, end);
        if (is_vm_hugetlb_page(vma))
                pages = hugetlb_change_protection(vma, start, end, newprot);
        else
                pages = change_protection_range(vma, start, end, newprot, dirty_accountable, prot_numa);
-       mmu_notifier_invalidate_range_end(mm, start, end);
 
        return pages;
 }