diff options
Diffstat (limited to 'drivers/iommu/amd_iommu_v2.c')
-rw-r--r-- | drivers/iommu/amd_iommu_v2.c | 107 |
1 files changed, 65 insertions, 42 deletions
diff --git a/drivers/iommu/amd_iommu_v2.c b/drivers/iommu/amd_iommu_v2.c index 499b4366a98d..5f578e850fc5 100644 --- a/drivers/iommu/amd_iommu_v2.c +++ b/drivers/iommu/amd_iommu_v2.c @@ -47,12 +47,13 @@ struct pasid_state { atomic_t count; /* Reference count */ unsigned mmu_notifier_count; /* Counting nested mmu_notifier calls */ - struct task_struct *task; /* Task bound to this PASID */ struct mm_struct *mm; /* mm_struct for the faults */ - struct mmu_notifier mn; /* mmu_otifier handle */ + struct mmu_notifier mn; /* mmu_notifier handle */ struct pri_queue pri[PRI_QUEUE_SIZE]; /* PRI tag states */ struct device_state *device_state; /* Link to our device_state */ int pasid; /* PASID index */ + bool invalid; /* Used during setup and + teardown of the pasid */ spinlock_t lock; /* Protect pri_queues and mmu_notifer_count */ wait_queue_head_t wq; /* To wait for count == 0 */ @@ -99,7 +100,6 @@ static struct workqueue_struct *iommu_wq; static u64 *empty_page_table; static void free_pasid_states(struct device_state *dev_state); -static void unbind_pasid(struct device_state *dev_state, int pasid); static u16 device_id(struct pci_dev *pdev) { @@ -297,37 +297,29 @@ static void put_pasid_state_wait(struct pasid_state *pasid_state) schedule(); finish_wait(&pasid_state->wq, &wait); - mmput(pasid_state->mm); free_pasid_state(pasid_state); } -static void __unbind_pasid(struct pasid_state *pasid_state) +static void unbind_pasid(struct pasid_state *pasid_state) { struct iommu_domain *domain; domain = pasid_state->device_state->domain; + /* + * Mark pasid_state as invalid, no more faults will we added to the + * work queue after this is visible everywhere. + */ + pasid_state->invalid = true; + + /* Make sure this is visible */ + smp_wmb(); + + /* After this the device/pasid can't access the mm anymore */ amd_iommu_domain_clear_gcr3(domain, pasid_state->pasid); - clear_pasid_state(pasid_state->device_state, pasid_state->pasid); /* Make sure no more pending faults are in the queue */ flush_workqueue(iommu_wq); - - mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); - - put_pasid_state(pasid_state); /* Reference taken in bind() function */ -} - -static void unbind_pasid(struct device_state *dev_state, int pasid) -{ - struct pasid_state *pasid_state; - - pasid_state = get_pasid_state(dev_state, pasid); - if (pasid_state == NULL) - return; - - __unbind_pasid(pasid_state); - put_pasid_state_wait(pasid_state); /* Reference taken in this function */ } static void free_pasid_states_level1(struct pasid_state **tbl) @@ -373,6 +365,12 @@ static void free_pasid_states(struct device_state *dev_state) * unbind the PASID */ mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); + + put_pasid_state_wait(pasid_state); /* Reference taken in + amd_iommu_bind_pasid */ + + /* Drop reference taken in amd_iommu_bind_pasid */ + put_device_state(dev_state); } if (dev_state->pasid_levels == 2) @@ -411,14 +409,6 @@ static int mn_clear_flush_young(struct mmu_notifier *mn, return 0; } -static void mn_change_pte(struct mmu_notifier *mn, - struct mm_struct *mm, - unsigned long address, - pte_t pte) -{ - __mn_flush_page(mn, address); -} - static void mn_invalidate_page(struct mmu_notifier *mn, struct mm_struct *mm, unsigned long address) @@ -472,22 +462,23 @@ static void mn_release(struct mmu_notifier *mn, struct mm_struct *mm) { struct pasid_state *pasid_state; struct device_state *dev_state; + bool run_inv_ctx_cb; might_sleep(); - pasid_state = mn_to_state(mn); - dev_state = pasid_state->device_state; + pasid_state = mn_to_state(mn); + dev_state = pasid_state->device_state; + run_inv_ctx_cb = !pasid_state->invalid; - if (pasid_state->device_state->inv_ctx_cb) + if (run_inv_ctx_cb && pasid_state->device_state->inv_ctx_cb) dev_state->inv_ctx_cb(dev_state->pdev, pasid_state->pasid); - unbind_pasid(dev_state, pasid_state->pasid); + unbind_pasid(pasid_state); } static struct mmu_notifier_ops iommu_mn = { .release = mn_release, .clear_flush_young = mn_clear_flush_young, - .change_pte = mn_change_pte, .invalidate_page = mn_invalidate_page, .invalidate_range_start = mn_invalidate_range_start, .invalidate_range_end = mn_invalidate_range_end, @@ -529,7 +520,7 @@ static void do_fault(struct work_struct *work) write = !!(fault->flags & PPR_FAULT_WRITE); down_read(&fault->state->mm->mmap_sem); - npages = get_user_pages(fault->state->task, fault->state->mm, + npages = get_user_pages(NULL, fault->state->mm, fault->address, 1, write, 0, &page, NULL); up_read(&fault->state->mm->mmap_sem); @@ -587,7 +578,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) goto out; pasid_state = get_pasid_state(dev_state, iommu_fault->pasid); - if (pasid_state == NULL) { + if (pasid_state == NULL || pasid_state->invalid) { /* We know the device but not the PASID -> send INVALID */ amd_iommu_complete_ppr(dev_state->pdev, iommu_fault->pasid, PPR_INVALID, tag); @@ -612,6 +603,7 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) fault->state = pasid_state; fault->tag = tag; fault->finish = finish; + fault->pasid = iommu_fault->pasid; fault->flags = iommu_fault->flags; INIT_WORK(&fault->work, do_fault); @@ -620,6 +612,10 @@ static int ppr_notifier(struct notifier_block *nb, unsigned long e, void *data) ret = NOTIFY_OK; out_drop_state: + + if (ret != NOTIFY_OK && pasid_state) + put_pasid_state(pasid_state); + put_device_state(dev_state); out: @@ -635,6 +631,7 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid, { struct pasid_state *pasid_state; struct device_state *dev_state; + struct mm_struct *mm; u16 devid; int ret; @@ -658,20 +655,23 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid, if (pasid_state == NULL) goto out; + atomic_set(&pasid_state->count, 1); init_waitqueue_head(&pasid_state->wq); spin_lock_init(&pasid_state->lock); - pasid_state->task = task; - pasid_state->mm = get_task_mm(task); + mm = get_task_mm(task); + pasid_state->mm = mm; pasid_state->device_state = dev_state; pasid_state->pasid = pasid; + pasid_state->invalid = true; /* Mark as valid only if we are + done with setting up the pasid */ pasid_state->mn.ops = &iommu_mn; if (pasid_state->mm == NULL) goto out_free; - mmu_notifier_register(&pasid_state->mn, pasid_state->mm); + mmu_notifier_register(&pasid_state->mn, mm); ret = set_pasid_state(dev_state, pasid_state, pasid); if (ret) @@ -682,15 +682,26 @@ int amd_iommu_bind_pasid(struct pci_dev *pdev, int pasid, if (ret) goto out_clear_state; + /* Now we are ready to handle faults */ + pasid_state->invalid = false; + + /* + * Drop the reference to the mm_struct here. We rely on the + * mmu_notifier release call-back to inform us when the mm + * is going away. + */ + mmput(mm); + return 0; out_clear_state: clear_pasid_state(dev_state, pasid); out_unregister: - mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); + mmu_notifier_unregister(&pasid_state->mn, mm); out_free: + mmput(mm); free_pasid_state(pasid_state); out: @@ -728,10 +739,22 @@ void amd_iommu_unbind_pasid(struct pci_dev *pdev, int pasid) */ put_pasid_state(pasid_state); - /* This will call the mn_release function and unbind the PASID */ + /* Clear the pasid state so that the pasid can be re-used */ + clear_pasid_state(dev_state, pasid_state->pasid); + + /* + * Call mmu_notifier_unregister to drop our reference + * to pasid_state->mm + */ mmu_notifier_unregister(&pasid_state->mn, pasid_state->mm); + put_pasid_state_wait(pasid_state); /* Reference taken in + amd_iommu_bind_pasid */ out: + /* Drop reference taken in this function */ + put_device_state(dev_state); + + /* Drop reference taken in amd_iommu_bind_pasid */ put_device_state(dev_state); } EXPORT_SYMBOL(amd_iommu_unbind_pasid); |