Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/dtor/input
[cascardo/linux.git] / drivers / iommu / amd_iommu_v2.c
index 499b436..5f578e8 100644 (file)
@@ -47,12 +47,13 @@ struct pasid_state {
        atomic_t count;                         /* Reference count */
        unsigned mmu_notifier_count;            /* Counting nested mmu_notifier
                                                   calls */
-       struct task_struct *task;               /* Task bound to this PASID */
        struct mm_struct *mm;                   /* mm_struct for the faults */
-       struct mmu_notifier mn;                 /* mmu_otifier handle */
+       struct mmu_notifier mn;                 /* mmu_notifier handle */
        struct pri_queue pri[PRI_QUEUE_SIZE];   /* PRI tag states */
        struct device_state *device_state;      /* Link to our device_state */
        int pasid;                              /* PASID index */
+       bool invalid;                           /* Used during setup and
+                                                  teardown of the pasid */
        spinlock_t lock;                        /* Protect pri_queues and
                                                   mmu_notifer_count */
        wait_queue_head_t wq;                   /* To wait for count == 0 */
@@ -99,7 +100,6 @@ static struct workqueue_struct *iommu_wq;
 static u64 *empty_page_table;
 
 static void free_pasid_states(struct device_state *dev_state);
-static void unbind_pasid(struct device_state *dev_state, int pasid);
 
 static u16 device_id(struct pci_dev *pdev)
 {
@@ -297,37 +297,29 @@ static void put_pasid_state_wait(struct pasid_state *pasid_state)
                schedule();
 
        finish_wait(&pasid_state->wq, &wait);
-       mmput(pasid_state->mm);
        free_pasid_state(pasid_state);
 }
 
-static void __unbind_pasid(struct pasid_state *pasid_state)
+static void unbind_pasid(struct pasid_state *pasid_state)
 {
        struct iommu_domain *domain;
 
        domain = pasid_state->device_state->domain;
 
+       /*
+        * Mark pasid_state as invalid, no more faults will we added to the
+        * work queue after this is visible everywhere.
+        */
+       pasid_state->invalid = true;
+
+       /* Make sure this is visible */
+       smp_wmb();
+
+       /* After this the device/pasid can't access the mm anymore */
        amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid);
-       clear_pasid_state(pasid_state->device_state, pasid_state->pasid);
 
        /* Make sure no more pending faults are in the queue */
        flush_workqueue(iommu_wq);
-
-       mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
-
-       put_pasid_state(pasid_state); /* Reference taken in bind() function */
-}
-
-static void unbind_pasid(struct device_state *dev_state, int pasid)
-{
-       struct pasid_state *pasid_state;
-
-       pasid_state = get_pasid_state(dev_state, pasid);
-       if (pasid_state == NULL)
-               return;
-
-       __unbind_pasid(pasid_state);
-       put_pasid_state_wait(pasid_state); /* Reference taken in this function */
 }
 
 static void free_pasid_states_level1(struct pasid_state **tbl)
@@ -373,6 +365,12 @@ static void free_pasid_states(struct device_state *dev_state)
                 * unbind the PASID
                 */
                mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
+
+               put_pasid_state_wait(pasid_state); /* Reference taken in
+                                                     amd_iommu_bind_pasid */
+
+               /* Drop reference taken in amd_iommu_bind_pasid */
+               put_device_state(dev_state);
        }
 
        if (dev_state->pasid_levels == 2)
@@ -411,14 +409,6 @@ static int mn_clear_flush_young(struct mmu_notifier *mn,
        return 0;
 }
 
-static void mn_change_pte(struct mmu_notifier *mn,
-                         struct mm_struct *mm,
-                         unsigned long address,
-                         pte_t pte)
-{
-       __mn_flush_page(mn, address);
-}
-
 static void mn_invalidate_page(struct mmu_notifier *mn,
                               struct mm_struct *mm,
                               unsigned long address)
@@ -472,22 +462,23 @@ static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
        struct pasid_state *pasid_state;
        struct device_state *dev_state;
+       bool run_inv_ctx_cb;
 
        might_sleep();
 
-       pasid_state = mn_to_state(mn);
-       dev_state   = pasid_state->device_state;
+       pasid_state    = mn_to_state(mn);
+       dev_state      = pasid_state->device_state;
+       run_inv_ctx_cb = !pasid_state->invalid;
 
-       if (pasid_state->device_state->inv_ctx_cb)
+       if (run_inv_ctx_cb && pasid_state->device_state->inv_ctx_cb)
                dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid);
 
-       unbind_pasid(dev_state, pasid_state->pasid);
+       unbind_pasid(pasid_state);
 }
 
 static struct mmu_notifier_ops iommu_mn = {
        .release                = mn_release,
        .clear_flush_young      = mn_clear_flush_young,
-       .change_pte             = mn_change_pte,
        .invalidate_page        = mn_invalidate_page,
        .invalidate_range_start = mn_invalidate_range_start,
        .invalidate_range_end   = mn_invalidate_range_end,
@@ -529,7 +520,7 @@ static void do_fault(struct work_struct *work)
        write = !!(fault->flags & PPR_FAULT_WRITE);
 
        down_read(&fault->state->mm->mmap_sem);
-       npages = get_user_pages(fault->state->task, fault->state->mm,
+       npages = get_user_pages(NULL, fault->state->mm,
                                fault->address, 1, write, 0, &page, NULL);
        up_read(&fault->state->mm->mmap_sem);
 
@@ -587,7 +578,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
                goto out;
 
        pasid_state = get_pasid_state(dev_state, iommu_fault->pasid);
-       if (pasid_state == NULL) {
+       if (pasid_state == NULL || pasid_state->invalid) {
                /* We know the device but not the PASID -> send INVALID */
                amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid,
                                       PPR_INVALID, tag);
@@ -612,6 +603,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
        fault->state     = pasid_state;
        fault->tag       = tag;
        fault->finish    = finish;
+       fault->pasid     = iommu_fault->pasid;
        fault->flags     = iommu_fault->flags;
        INIT_WORK(&fault->work, do_fault);
 
@@ -620,6 +612,10 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data)
        ret = NOTIFY_OK;
 
 out_drop_state:
+
+       if (ret != NOTIFY_OK && pasid_state)
+               put_pasid_state(pasid_state);
+
        put_device_state(dev_state);
 
 out:
@@ -635,6 +631,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
 {
        struct pasid_state *pasid_state;
        struct device_state *dev_state;
+       struct mm_struct *mm;
        u16 devid;
        int ret;
 
@@ -658,20 +655,23 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
        if (pasid_state == NULL)
                goto out;
 
+
        atomic_set(&pasid_state->count, 1);
        init_waitqueue_head(&pasid_state->wq);
        spin_lock_init(&pasid_state->lock);
 
-       pasid_state->task         = task;
-       pasid_state->mm           = get_task_mm(task);
+       mm                        = get_task_mm(task);
+       pasid_state->mm           = mm;
        pasid_state->device_state = dev_state;
        pasid_state->pasid        = pasid;
+       pasid_state->invalid      = true; /* Mark as valid only if we are
+                                            done with setting up the pasid */
        pasid_state->mn.ops       = &iommu_mn;
 
        if (pasid_state->mm == NULL)
                goto out_free;
 
-       mmu_notifier_register(&pasid_state->mn, pasid_state->mm);
+       mmu_notifier_register(&pasid_state->mn, mm);
 
        ret = set_pasid_state(dev_state, pasid_state, pasid);
        if (ret)
@@ -682,15 +682,26 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid,
        if (ret)
                goto out_clear_state;
 
+       /* Now we are ready to handle faults */
+       pasid_state->invalid = false;
+
+       /*
+        * Drop the reference to the mm_struct here. We rely on the
+        * mmu_notifier release call-back to inform us when the mm
+        * is going away.
+        */
+       mmput(mm);
+
        return 0;
 
 out_clear_state:
        clear_pasid_state(dev_state, pasid);
 
 out_unregister:
-       mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
+       mmu_notifier_unregister(&pasid_state->mn, mm);
 
 out_free:
+       mmput(mm);
        free_pasid_state(pasid_state);
 
 out:
@@ -728,10 +739,22 @@ void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid)
         */
        put_pasid_state(pasid_state);
 
-       /* This will call the mn_release function and unbind the PASID */
+       /* Clear the pasid state so that the pasid can be re-used */
+       clear_pasid_state(dev_state, pasid_state->pasid);
+
+       /*
+        * Call mmu_notifier_unregister to drop our reference
+        * to pasid_state->mm
+        */
        mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm);
 
+       put_pasid_state_wait(pasid_state); /* Reference taken in
+                                             amd_iommu_bind_pasid */
 out:
+       /* Drop reference taken in this function */
+       put_device_state(dev_state);
+
+       /* Drop reference taken in amd_iommu_bind_pasid */
        put_device_state(dev_state);
 }
 EXPORT_SYMBOL(amd_iommu_unbind_pasid);