Merge branch 'drm-next' of git://people.freedesktop.org/~airlied/linux
[cascardo/linux.git] / drivers / iommu / amd_iommu_v2.c
index 90d734b..90f70d0 100644 (file)
@@ -92,13 +92,6 @@ static spinlock_t state_lock;
 
 static struct workqueue_struct *iommu_wq;
 
-/*
- * Empty page table - Used between
- * mmu_notifier_invalidate_range_start and
- * mmu_notifier_invalidate_range_end
- */
-static u64 *empty_page_table;
-
 static void free_pasid_states(struct device_state *dev_state);
 
 static u16 device_id(struct pci_dev *pdev)
@@ -279,10 +272,8 @@ static void free_pasid_state(struct pasid_state *pasid_state)
 
 static void put_pasid_state(struct pasid_state *pasid_state)
 {
-       if (atomic_dec_and_test(&pasid_state->count)) {
-               put_device_state(pasid_state->device_state);
+       if (atomic_dec_and_test(&pasid_state->count))
                wake_up(&pasid_state->wq);
-       }
 }
 
 static void put_pasid_state_wait(struct pasid_state *pasid_state)
@@ -291,9 +282,7 @@ static void put_pasid_state_wait(struct pasid_state *pasid_state)
 
        prepare_to_wait(&pasid_state->wq, &wait, TASK_UNINTERRUPTIBLE);
 
-       if (atomic_dec_and_test(&pasid_state->count))
-               put_device_state(pasid_state->device_state);
-       else
+       if (!atomic_dec_and_test(&pasid_state->count))
                schedule();
 
        finish_wait(&pasid_state->wq, &wait);
@@ -418,46 +407,21 @@ static void mn_invalidate_page(struct mmu_notifier *mn,
        __mn_flush_page(mn, address);
 }
 
-static void mn_invalidate_range_start(struct mmu_notifier *mn,
-                                     struct mm_struct *mm,
-                                     unsigned long start, unsigned long end)
-{
-       struct pasid_state *pasid_state;
-       struct device_state *dev_state;
-       unsigned long flags;
-
-       pasid_state = mn_to_state(mn);
-       dev_state   = pasid_state->device_state;
-
-       spin_lock_irqsave(&pasid_state->lock, flags);
-       if (pasid_state->mmu_notifier_count == 0) {
-               amd_iommu_domain_set_gcr3(dev_state->domain,
-                                         pasid_state->pasid,
-                                         __pa(empty_page_table));
-       }
-       pasid_state->mmu_notifier_count += 1;
-       spin_unlock_irqrestore(&pasid_state->lock, flags);
-}
-
-static void mn_invalidate_range_end(struct mmu_notifier *mn,
-                                   struct mm_struct *mm,
-                                   unsigned long start, unsigned long end)
+static void mn_invalidate_range(struct mmu_notifier *mn,
+                               struct mm_struct *mm,
+                               unsigned long start, unsigned long end)
 {
        struct pasid_state *pasid_state;
        struct device_state *dev_state;
-       unsigned long flags;
 
        pasid_state = mn_to_state(mn);
        dev_state   = pasid_state->device_state;
 
-       spin_lock_irqsave(&pasid_state->lock, flags);
-       pasid_state->mmu_notifier_count -= 1;
-       if (pasid_state->mmu_notifier_count == 0) {
-               amd_iommu_domain_set_gcr3(dev_state->domain,
-                                         pasid_state->pasid,
-                                         __pa(pasid_state->mm->pgd));
-       }
-       spin_unlock_irqrestore(&pasid_state->lock, flags);
+       if ((start ^ (end - 1)) < PAGE_SIZE)
+               amd_iommu_flush_page(dev_state->domain, pasid_state->pasid,
+                                    start);
+       else
+               amd_iommu_flush_tlb(dev_state->domain, pasid_state->pasid);
 }
 
 static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
@@ -482,8 +446,7 @@ static struct mmu_notifier_ops iommu_mn = {
        .release                = mn_release,
        .clear_flush_young      = mn_clear_flush_young,
        .invalidate_page        = mn_invalidate_page,
-       .invalidate_range_start = mn_invalidate_range_start,
-       .invalidate_range_end   = mn_invalidate_range_end,
+       .invalidate_range       = mn_invalidate_range,
 };
 
 static void set_pri_tag_status(struct pasid_state *pasid_state,
@@ -513,45 +476,67 @@ static void finish_pri_tag(struct device_state *dev_state,
        spin_unlock_irqrestore(&pasid_state->lock, flags);
 }
 
+static void handle_fault_error(struct fault *fault)
+{
+       int status;
+
+       if (!fault->dev_state->inv_ppr_cb) {
+               set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
+               return;
+       }
+
+       status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
+                                             fault->pasid,
+                                             fault->address,
+                                             fault->flags);
+       switch (status) {
+       case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
+               set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
+               break;
+       case AMD_IOMMU_INV_PRI_RSP_INVALID:
+               set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
+               break;
+       case AMD_IOMMU_INV_PRI_RSP_FAIL:
+               set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
+               break;
+       default:
+               BUG();
+       }
+}
+
 static void do_fault(struct work_struct *work)
 {
        struct fault *fault = container_of(work, struct fault, work);
-       int npages, write;
-       struct page *page;
+       struct mm_struct *mm;
+       struct vm_area_struct *vma;
+       u64 address;
+       int ret, write;
 
        write = !!(fault->flags & PPR_FAULT_WRITE);
 
-       down_read(&fault->state->mm->mmap_sem);
-       npages = get_user_pages(NULL, fault->state->mm,
-                               fault->address, 1, write, 0, &page, NULL);
-       up_read(&fault->state->mm->mmap_sem);
-
-       if (npages == 1) {
-               put_page(page);
-       } else if (fault->dev_state->inv_ppr_cb) {
-               int status;
-
-               status = fault->dev_state->inv_ppr_cb(fault->dev_state->pdev,
-                                                     fault->pasid,
-                                                     fault->address,
-                                                     fault->flags);
-               switch (status) {
-               case AMD_IOMMU_INV_PRI_RSP_SUCCESS:
-                       set_pri_tag_status(fault->state, fault->tag, PPR_SUCCESS);
-                       break;
-               case AMD_IOMMU_INV_PRI_RSP_INVALID:
-                       set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
-                       break;
-               case AMD_IOMMU_INV_PRI_RSP_FAIL:
-                       set_pri_tag_status(fault->state, fault->tag, PPR_FAILURE);
-                       break;
-               default:
-                       BUG();
-               }
-       } else {
-               set_pri_tag_status(fault->state, fault->tag, PPR_INVALID);
+       mm = fault->state->mm;
+       address = fault->address;
+
+       down_read(&mm->mmap_sem);
+       vma = find_extend_vma(mm, address);
+       if (!vma || address < vma->vm_start) {
+               /* failed to get a vma in the right range */
+               up_read(&mm->mmap_sem);
+               handle_fault_error(fault);
+               goto out;
+       }
+
+       ret = handle_mm_fault(mm, vma, address, write);
+       if (ret & VM_FAULT_ERROR) {
+               /* failed to service fault */
+               up_read(&mm->mmap_sem);
+               handle_fault_error(fault);
+               goto out;
        }
 
+       up_read(&mm->mmap_sem);
+
+out:
        finish_pri_tag(fault->dev_state, fault->state, fault->tag);
 
        put_pasid_state(fault->state);
@@ -954,18 +939,10 @@ static int __init amd_iommu_v2_init(void)
        if (iommu_wq == NULL)
                goto out;
 
-       ret = -ENOMEM;
-       empty_page_table = (u64 *)get_zeroed_page(GFP_KERNEL);
-       if (empty_page_table == NULL)
-               goto out_destroy_wq;
-
        amd_iommu_register_ppr_notifier(&ppr_nb);
 
        return 0;
 
-out_destroy_wq:
-       destroy_workqueue(iommu_wq);
-
 out:
        return ret;
 }
@@ -999,8 +976,6 @@ static void __exit amd_iommu_v2_exit(void)
        }
 
        destroy_workqueue(iommu_wq);
-
-       free_page((unsigned long)empty_page_table);
 }
 
 module_init(amd_iommu_v2_init);