]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/vfio/vfio_iommu_type1.c
Merge branches 'pm-core', 'pm-qos', 'pm-domains' and 'pm-opp'
[linux.git] / drivers / vfio / vfio_iommu_type1.c
index f3726ba12aa6ceccff8a0c96523aa683537be8f4..bd6f293c4ebd59b5f283ebc1193767c982e1f1ef 100644 (file)
 #include <linux/uaccess.h>
 #include <linux/vfio.h>
 #include <linux/workqueue.h>
-#include <linux/pid_namespace.h>
 #include <linux/mdev.h>
 #include <linux/notifier.h>
+#include <linux/dma-iommu.h>
+#include <linux/irqdomain.h>
 
 #define DRIVER_VERSION  "0.2"
 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
@@ -268,28 +269,38 @@ static void vfio_lock_acct(struct task_struct *task, long npage)
 {
        struct vwork *vwork;
        struct mm_struct *mm;
+       bool is_current;
 
        if (!npage)
                return;
 
-       mm = get_task_mm(task);
+       is_current = (task->mm == current->mm);
+
+       mm = is_current ? task->mm : get_task_mm(task);
        if (!mm)
-               return; /* process exited or nothing to do */
+               return; /* process exited */
 
        if (down_write_trylock(&mm->mmap_sem)) {
                mm->locked_vm += npage;
                up_write(&mm->mmap_sem);
-               mmput(mm);
+               if (!is_current)
+                       mmput(mm);
                return;
        }
 
+       if (is_current) {
+               mm = get_task_mm(task);
+               if (!mm)
+                       return;
+       }
+
        /*
         * Couldn't get mmap_sem lock, so must setup to update
         * mm->locked_vm later. If locked_vm were atomic, we
         * wouldn't need this silliness
         */
        vwork = kmalloc(sizeof(struct vwork), GFP_KERNEL);
-       if (!vwork) {
+       if (WARN_ON(!vwork)) {
                mmput(mm);
                return;
        }
@@ -393,77 +404,71 @@ static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
                                  long npage, unsigned long *pfn_base)
 {
-       unsigned long limit;
-       bool lock_cap = ns_capable(task_active_pid_ns(dma->task)->user_ns,
-                                  CAP_IPC_LOCK);
-       struct mm_struct *mm;
-       long ret, i = 0, lock_acct = 0;
+       unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+       bool lock_cap = capable(CAP_IPC_LOCK);
+       long ret, pinned = 0, lock_acct = 0;
        bool rsvd;
        dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
 
-       mm = get_task_mm(dma->task);
-       if (!mm)
+       /* This code path is only user initiated */
+       if (!current->mm)
                return -ENODEV;
 
-       ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
+       ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
        if (ret)
-               goto pin_pg_remote_exit;
+               return ret;
 
+       pinned++;
        rsvd = is_invalid_reserved_pfn(*pfn_base);
-       limit = task_rlimit(dma->task, RLIMIT_MEMLOCK) >> PAGE_SHIFT;
 
        /*
         * Reserved pages aren't counted against the user, externally pinned
         * pages are already counted against the user.
         */
        if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-               if (!lock_cap && mm->locked_vm + 1 > limit) {
+               if (!lock_cap && current->mm->locked_vm + 1 > limit) {
                        put_pfn(*pfn_base, dma->prot);
                        pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
                                        limit << PAGE_SHIFT);
-                       ret = -ENOMEM;
-                       goto pin_pg_remote_exit;
+                       return -ENOMEM;
                }
                lock_acct++;
        }
 
-       i++;
-       if (likely(!disable_hugepages)) {
-               /* Lock all the consecutive pages from pfn_base */
-               for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; i < npage;
-                    i++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
-                       unsigned long pfn = 0;
+       if (unlikely(disable_hugepages))
+               goto out;
 
-                       ret = vaddr_get_pfn(mm, vaddr, dma->prot, &pfn);
-                       if (ret)
-                               break;
+       /* Lock all the consecutive pages from pfn_base */
+       for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
+            pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
+               unsigned long pfn = 0;
+
+               ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
+               if (ret)
+                       break;
+
+               if (pfn != *pfn_base + pinned ||
+                   rsvd != is_invalid_reserved_pfn(pfn)) {
+                       put_pfn(pfn, dma->prot);
+                       break;
+               }
 
-                       if (pfn != *pfn_base + i ||
-                           rsvd != is_invalid_reserved_pfn(pfn)) {
+               if (!rsvd && !vfio_find_vpfn(dma, iova)) {
+                       if (!lock_cap &&
+                           current->mm->locked_vm + lock_acct + 1 > limit) {
                                put_pfn(pfn, dma->prot);
+                               pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
+                                       __func__, limit << PAGE_SHIFT);
                                break;
                        }
-
-                       if (!rsvd && !vfio_find_vpfn(dma, iova)) {
-                               if (!lock_cap &&
-                                   mm->locked_vm + lock_acct + 1 > limit) {
-                                       put_pfn(pfn, dma->prot);
-                                       pr_warn("%s: RLIMIT_MEMLOCK (%ld) "
-                                               "exceeded\n", __func__,
-                                               limit << PAGE_SHIFT);
-                                       break;
-                               }
-                               lock_acct++;
-                       }
+                       lock_acct++;
                }
        }
 
-       vfio_lock_acct(dma->task, lock_acct);
-       ret = i;
+out:
+       vfio_lock_acct(current, lock_acct);
 
-pin_pg_remote_exit:
-       mmput(mm);
-       return ret;
+       return pinned;
 }
 
 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
@@ -473,10 +478,10 @@ static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
        long unlocked = 0, locked = 0;
        long i;
 
-       for (i = 0; i < npage; i++) {
+       for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
                if (put_pfn(pfn++, dma->prot)) {
                        unlocked++;
-                       if (vfio_find_vpfn(dma, iova + (i << PAGE_SHIFT)))
+                       if (vfio_find_vpfn(dma, iova))
                                locked++;
                }
        }
@@ -491,8 +496,7 @@ static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
                                  unsigned long *pfn_base, bool do_accounting)
 {
        unsigned long limit;
-       bool lock_cap = ns_capable(task_active_pid_ns(dma->task)->user_ns,
-                                  CAP_IPC_LOCK);
+       bool lock_cap = has_capability(dma->task, CAP_IPC_LOCK);
        struct mm_struct *mm;
        int ret;
        bool rsvd;
@@ -1177,6 +1181,28 @@ static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
        return NULL;
 }
 
+static bool vfio_iommu_has_resv_msi(struct iommu_group *group,
+                                   phys_addr_t *base)
+{
+       struct list_head group_resv_regions;
+       struct iommu_resv_region *region, *next;
+       bool ret = false;
+
+       INIT_LIST_HEAD(&group_resv_regions);
+       iommu_get_group_resv_regions(group, &group_resv_regions);
+       list_for_each_entry(region, &group_resv_regions, list) {
+               if (region->type & IOMMU_RESV_MSI) {
+                       *base = region->start;
+                       ret = true;
+                       goto out;
+               }
+       }
+out:
+       list_for_each_entry_safe(region, next, &group_resv_regions, list)
+               kfree(region);
+       return ret;
+}
+
 static int vfio_iommu_type1_attach_group(void *iommu_data,
                                         struct iommu_group *iommu_group)
 {
@@ -1185,6 +1211,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        struct vfio_domain *domain, *d;
        struct bus_type *bus = NULL, *mdev_bus;
        int ret;
+       bool resv_msi, msi_remap;
+       phys_addr_t resv_msi_base;
 
        mutex_lock(&iommu->lock);
 
@@ -1254,11 +1282,15 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        if (ret)
                goto out_domain;
 
+       resv_msi = vfio_iommu_has_resv_msi(iommu_group, &resv_msi_base);
+
        INIT_LIST_HEAD(&domain->group_list);
        list_add(&group->next, &domain->group_list);
 
-       if (!allow_unsafe_interrupts &&
-           !iommu_capable(bus, IOMMU_CAP_INTR_REMAP)) {
+       msi_remap = resv_msi ? irq_domain_check_msi_remap() :
+                               iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
+
+       if (!allow_unsafe_interrupts && !msi_remap) {
                pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
                       __func__);
                ret = -EPERM;
@@ -1300,6 +1332,12 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
        if (ret)
                goto out_detach;
 
+       if (resv_msi) {
+               ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
+               if (ret)
+                       goto out_detach;
+       }
+
        list_add(&domain->next, &iommu->domain_list);
 
        mutex_unlock(&iommu->lock);