]> asedeno.scripts.mit.edu Git - linux.git/blob - drivers/vfio/vfio_iommu_type1.c
Merge tag 'fsverity-for-linus' of git://git.kernel.org/pub/scm/fs/fscrypt/fscrypt
[linux.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/iommu.h>
28 #include <linux/module.h>
29 #include <linux/mm.h>
30 #include <linux/rbtree.h>
31 #include <linux/sched/signal.h>
32 #include <linux/sched/mm.h>
33 #include <linux/slab.h>
34 #include <linux/uaccess.h>
35 #include <linux/vfio.h>
36 #include <linux/workqueue.h>
37 #include <linux/mdev.h>
38 #include <linux/notifier.h>
39 #include <linux/dma-iommu.h>
40 #include <linux/irqdomain.h>
41
42 #define DRIVER_VERSION  "0.2"
43 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
44 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
45
46 static bool allow_unsafe_interrupts;
47 module_param_named(allow_unsafe_interrupts,
48                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
49 MODULE_PARM_DESC(allow_unsafe_interrupts,
50                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
51
52 static bool disable_hugepages;
53 module_param_named(disable_hugepages,
54                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
55 MODULE_PARM_DESC(disable_hugepages,
56                  "Disable VFIO IOMMU support for IOMMU hugepages.");
57
58 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
59 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
60 MODULE_PARM_DESC(dma_entry_limit,
61                  "Maximum number of user DMA mappings per container (65535).");
62
63 struct vfio_iommu {
64         struct list_head        domain_list;
65         struct vfio_domain      *external_domain; /* domain for external user */
66         struct mutex            lock;
67         struct rb_root          dma_list;
68         struct blocking_notifier_head notifier;
69         unsigned int            dma_avail;
70         bool                    v2;
71         bool                    nesting;
72 };
73
74 struct vfio_domain {
75         struct iommu_domain     *domain;
76         struct list_head        next;
77         struct list_head        group_list;
78         int                     prot;           /* IOMMU_CACHE */
79         bool                    fgsp;           /* Fine-grained super pages */
80 };
81
82 struct vfio_dma {
83         struct rb_node          node;
84         dma_addr_t              iova;           /* Device address */
85         unsigned long           vaddr;          /* Process virtual addr */
86         size_t                  size;           /* Map size (bytes) */
87         int                     prot;           /* IOMMU_READ/WRITE */
88         bool                    iommu_mapped;
89         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
90         struct task_struct      *task;
91         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
92 };
93
94 struct vfio_group {
95         struct iommu_group      *iommu_group;
96         struct list_head        next;
97         bool                    mdev_group;     /* An mdev group */
98 };
99
100 /*
101  * Guest RAM pinning working set or DMA target
102  */
103 struct vfio_pfn {
104         struct rb_node          node;
105         dma_addr_t              iova;           /* Device address */
106         unsigned long           pfn;            /* Host pfn */
107         atomic_t                ref_count;
108 };
109
110 struct vfio_regions {
111         struct list_head list;
112         dma_addr_t iova;
113         phys_addr_t phys;
114         size_t len;
115 };
116
117 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
118                                         (!list_empty(&iommu->domain_list))
119
120 static int put_pfn(unsigned long pfn, int prot);
121
122 /*
123  * This code handles mapping and unmapping of user data buffers
124  * into DMA'ble space using the IOMMU
125  */
126
127 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
128                                       dma_addr_t start, size_t size)
129 {
130         struct rb_node *node = iommu->dma_list.rb_node;
131
132         while (node) {
133                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
134
135                 if (start + size <= dma->iova)
136                         node = node->rb_left;
137                 else if (start >= dma->iova + dma->size)
138                         node = node->rb_right;
139                 else
140                         return dma;
141         }
142
143         return NULL;
144 }
145
146 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
147 {
148         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
149         struct vfio_dma *dma;
150
151         while (*link) {
152                 parent = *link;
153                 dma = rb_entry(parent, struct vfio_dma, node);
154
155                 if (new->iova + new->size <= dma->iova)
156                         link = &(*link)->rb_left;
157                 else
158                         link = &(*link)->rb_right;
159         }
160
161         rb_link_node(&new->node, parent, link);
162         rb_insert_color(&new->node, &iommu->dma_list);
163 }
164
165 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
166 {
167         rb_erase(&old->node, &iommu->dma_list);
168 }
169
170 /*
171  * Helper Functions for host iova-pfn list
172  */
173 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
174 {
175         struct vfio_pfn *vpfn;
176         struct rb_node *node = dma->pfn_list.rb_node;
177
178         while (node) {
179                 vpfn = rb_entry(node, struct vfio_pfn, node);
180
181                 if (iova < vpfn->iova)
182                         node = node->rb_left;
183                 else if (iova > vpfn->iova)
184                         node = node->rb_right;
185                 else
186                         return vpfn;
187         }
188         return NULL;
189 }
190
191 static void vfio_link_pfn(struct vfio_dma *dma,
192                           struct vfio_pfn *new)
193 {
194         struct rb_node **link, *parent = NULL;
195         struct vfio_pfn *vpfn;
196
197         link = &dma->pfn_list.rb_node;
198         while (*link) {
199                 parent = *link;
200                 vpfn = rb_entry(parent, struct vfio_pfn, node);
201
202                 if (new->iova < vpfn->iova)
203                         link = &(*link)->rb_left;
204                 else
205                         link = &(*link)->rb_right;
206         }
207
208         rb_link_node(&new->node, parent, link);
209         rb_insert_color(&new->node, &dma->pfn_list);
210 }
211
212 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
213 {
214         rb_erase(&old->node, &dma->pfn_list);
215 }
216
217 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
218                                 unsigned long pfn)
219 {
220         struct vfio_pfn *vpfn;
221
222         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
223         if (!vpfn)
224                 return -ENOMEM;
225
226         vpfn->iova = iova;
227         vpfn->pfn = pfn;
228         atomic_set(&vpfn->ref_count, 1);
229         vfio_link_pfn(dma, vpfn);
230         return 0;
231 }
232
233 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
234                                       struct vfio_pfn *vpfn)
235 {
236         vfio_unlink_pfn(dma, vpfn);
237         kfree(vpfn);
238 }
239
240 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
241                                                unsigned long iova)
242 {
243         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
244
245         if (vpfn)
246                 atomic_inc(&vpfn->ref_count);
247         return vpfn;
248 }
249
250 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
251 {
252         int ret = 0;
253
254         if (atomic_dec_and_test(&vpfn->ref_count)) {
255                 ret = put_pfn(vpfn->pfn, dma->prot);
256                 vfio_remove_from_pfn_list(dma, vpfn);
257         }
258         return ret;
259 }
260
261 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
262 {
263         struct mm_struct *mm;
264         int ret;
265
266         if (!npage)
267                 return 0;
268
269         mm = async ? get_task_mm(dma->task) : dma->task->mm;
270         if (!mm)
271                 return -ESRCH; /* process exited */
272
273         ret = down_write_killable(&mm->mmap_sem);
274         if (!ret) {
275                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
276                                           dma->lock_cap);
277                 up_write(&mm->mmap_sem);
278         }
279
280         if (async)
281                 mmput(mm);
282
283         return ret;
284 }
285
286 /*
287  * Some mappings aren't backed by a struct page, for example an mmap'd
288  * MMIO range for our own or another device.  These use a different
289  * pfn conversion and shouldn't be tracked as locked pages.
290  */
291 static bool is_invalid_reserved_pfn(unsigned long pfn)
292 {
293         if (pfn_valid(pfn)) {
294                 bool reserved;
295                 struct page *tail = pfn_to_page(pfn);
296                 struct page *head = compound_head(tail);
297                 reserved = !!(PageReserved(head));
298                 if (head != tail) {
299                         /*
300                          * "head" is not a dangling pointer
301                          * (compound_head takes care of that)
302                          * but the hugepage may have been split
303                          * from under us (and we may not hold a
304                          * reference count on the head page so it can
305                          * be reused before we run PageReferenced), so
306                          * we've to check PageTail before returning
307                          * what we just read.
308                          */
309                         smp_rmb();
310                         if (PageTail(tail))
311                                 return reserved;
312                 }
313                 return PageReserved(tail);
314         }
315
316         return true;
317 }
318
319 static int put_pfn(unsigned long pfn, int prot)
320 {
321         if (!is_invalid_reserved_pfn(pfn)) {
322                 struct page *page = pfn_to_page(pfn);
323                 if (prot & IOMMU_WRITE)
324                         SetPageDirty(page);
325                 put_page(page);
326                 return 1;
327         }
328         return 0;
329 }
330
331 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
332                          int prot, unsigned long *pfn)
333 {
334         struct page *page[1];
335         struct vm_area_struct *vma;
336         struct vm_area_struct *vmas[1];
337         unsigned int flags = 0;
338         int ret;
339
340         if (prot & IOMMU_WRITE)
341                 flags |= FOLL_WRITE;
342
343         down_read(&mm->mmap_sem);
344         if (mm == current->mm) {
345                 ret = get_user_pages(vaddr, 1, flags | FOLL_LONGTERM, page,
346                                      vmas);
347         } else {
348                 ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
349                                             vmas, NULL);
350                 /*
351                  * The lifetime of a vaddr_get_pfn() page pin is
352                  * userspace-controlled. In the fs-dax case this could
353                  * lead to indefinite stalls in filesystem operations.
354                  * Disallow attempts to pin fs-dax pages via this
355                  * interface.
356                  */
357                 if (ret > 0 && vma_is_fsdax(vmas[0])) {
358                         ret = -EOPNOTSUPP;
359                         put_page(page[0]);
360                 }
361         }
362         up_read(&mm->mmap_sem);
363
364         if (ret == 1) {
365                 *pfn = page_to_pfn(page[0]);
366                 return 0;
367         }
368
369         down_read(&mm->mmap_sem);
370
371         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
372
373         if (vma && vma->vm_flags & VM_PFNMAP) {
374                 *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
375                 if (is_invalid_reserved_pfn(*pfn))
376                         ret = 0;
377         }
378
379         up_read(&mm->mmap_sem);
380         return ret;
381 }
382
383 /*
384  * Attempt to pin pages.  We really don't want to track all the pfns and
385  * the iommu can only map chunks of consecutive pfns anyway, so get the
386  * first page and all consecutive pages with the same locking.
387  */
388 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
389                                   long npage, unsigned long *pfn_base,
390                                   unsigned long limit)
391 {
392         unsigned long pfn = 0;
393         long ret, pinned = 0, lock_acct = 0;
394         bool rsvd;
395         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
396
397         /* This code path is only user initiated */
398         if (!current->mm)
399                 return -ENODEV;
400
401         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
402         if (ret)
403                 return ret;
404
405         pinned++;
406         rsvd = is_invalid_reserved_pfn(*pfn_base);
407
408         /*
409          * Reserved pages aren't counted against the user, externally pinned
410          * pages are already counted against the user.
411          */
412         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
413                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
414                         put_pfn(*pfn_base, dma->prot);
415                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
416                                         limit << PAGE_SHIFT);
417                         return -ENOMEM;
418                 }
419                 lock_acct++;
420         }
421
422         if (unlikely(disable_hugepages))
423                 goto out;
424
425         /* Lock all the consecutive pages from pfn_base */
426         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
427              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
428                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
429                 if (ret)
430                         break;
431
432                 if (pfn != *pfn_base + pinned ||
433                     rsvd != is_invalid_reserved_pfn(pfn)) {
434                         put_pfn(pfn, dma->prot);
435                         break;
436                 }
437
438                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
439                         if (!dma->lock_cap &&
440                             current->mm->locked_vm + lock_acct + 1 > limit) {
441                                 put_pfn(pfn, dma->prot);
442                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
443                                         __func__, limit << PAGE_SHIFT);
444                                 ret = -ENOMEM;
445                                 goto unpin_out;
446                         }
447                         lock_acct++;
448                 }
449         }
450
451 out:
452         ret = vfio_lock_acct(dma, lock_acct, false);
453
454 unpin_out:
455         if (ret) {
456                 if (!rsvd) {
457                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
458                                 put_pfn(pfn, dma->prot);
459                 }
460
461                 return ret;
462         }
463
464         return pinned;
465 }
466
467 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
468                                     unsigned long pfn, long npage,
469                                     bool do_accounting)
470 {
471         long unlocked = 0, locked = 0;
472         long i;
473
474         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
475                 if (put_pfn(pfn++, dma->prot)) {
476                         unlocked++;
477                         if (vfio_find_vpfn(dma, iova))
478                                 locked++;
479                 }
480         }
481
482         if (do_accounting)
483                 vfio_lock_acct(dma, locked - unlocked, true);
484
485         return unlocked;
486 }
487
488 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
489                                   unsigned long *pfn_base, bool do_accounting)
490 {
491         struct mm_struct *mm;
492         int ret;
493
494         mm = get_task_mm(dma->task);
495         if (!mm)
496                 return -ENODEV;
497
498         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
499         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
500                 ret = vfio_lock_acct(dma, 1, true);
501                 if (ret) {
502                         put_pfn(*pfn_base, dma->prot);
503                         if (ret == -ENOMEM)
504                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
505                                         "(%ld) exceeded\n", __func__,
506                                         dma->task->comm, task_pid_nr(dma->task),
507                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
508                 }
509         }
510
511         mmput(mm);
512         return ret;
513 }
514
515 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
516                                     bool do_accounting)
517 {
518         int unlocked;
519         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
520
521         if (!vpfn)
522                 return 0;
523
524         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
525
526         if (do_accounting)
527                 vfio_lock_acct(dma, -unlocked, true);
528
529         return unlocked;
530 }
531
532 static int vfio_iommu_type1_pin_pages(void *iommu_data,
533                                       unsigned long *user_pfn,
534                                       int npage, int prot,
535                                       unsigned long *phys_pfn)
536 {
537         struct vfio_iommu *iommu = iommu_data;
538         int i, j, ret;
539         unsigned long remote_vaddr;
540         struct vfio_dma *dma;
541         bool do_accounting;
542
543         if (!iommu || !user_pfn || !phys_pfn)
544                 return -EINVAL;
545
546         /* Supported for v2 version only */
547         if (!iommu->v2)
548                 return -EACCES;
549
550         mutex_lock(&iommu->lock);
551
552         /* Fail if notifier list is empty */
553         if (!iommu->notifier.head) {
554                 ret = -EINVAL;
555                 goto pin_done;
556         }
557
558         /*
559          * If iommu capable domain exist in the container then all pages are
560          * already pinned and accounted. Accouting should be done if there is no
561          * iommu capable domain in the container.
562          */
563         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
564
565         for (i = 0; i < npage; i++) {
566                 dma_addr_t iova;
567                 struct vfio_pfn *vpfn;
568
569                 iova = user_pfn[i] << PAGE_SHIFT;
570                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
571                 if (!dma) {
572                         ret = -EINVAL;
573                         goto pin_unwind;
574                 }
575
576                 if ((dma->prot & prot) != prot) {
577                         ret = -EPERM;
578                         goto pin_unwind;
579                 }
580
581                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
582                 if (vpfn) {
583                         phys_pfn[i] = vpfn->pfn;
584                         continue;
585                 }
586
587                 remote_vaddr = dma->vaddr + iova - dma->iova;
588                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
589                                              do_accounting);
590                 if (ret)
591                         goto pin_unwind;
592
593                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
594                 if (ret) {
595                         vfio_unpin_page_external(dma, iova, do_accounting);
596                         goto pin_unwind;
597                 }
598         }
599
600         ret = i;
601         goto pin_done;
602
603 pin_unwind:
604         phys_pfn[i] = 0;
605         for (j = 0; j < i; j++) {
606                 dma_addr_t iova;
607
608                 iova = user_pfn[j] << PAGE_SHIFT;
609                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
610                 vfio_unpin_page_external(dma, iova, do_accounting);
611                 phys_pfn[j] = 0;
612         }
613 pin_done:
614         mutex_unlock(&iommu->lock);
615         return ret;
616 }
617
618 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
619                                         unsigned long *user_pfn,
620                                         int npage)
621 {
622         struct vfio_iommu *iommu = iommu_data;
623         bool do_accounting;
624         int i;
625
626         if (!iommu || !user_pfn)
627                 return -EINVAL;
628
629         /* Supported for v2 version only */
630         if (!iommu->v2)
631                 return -EACCES;
632
633         mutex_lock(&iommu->lock);
634
635         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
636         for (i = 0; i < npage; i++) {
637                 struct vfio_dma *dma;
638                 dma_addr_t iova;
639
640                 iova = user_pfn[i] << PAGE_SHIFT;
641                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
642                 if (!dma)
643                         goto unpin_exit;
644                 vfio_unpin_page_external(dma, iova, do_accounting);
645         }
646
647 unpin_exit:
648         mutex_unlock(&iommu->lock);
649         return i > npage ? npage : (i > 0 ? i : -EINVAL);
650 }
651
652 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
653                             struct list_head *regions,
654                             struct iommu_iotlb_gather *iotlb_gather)
655 {
656         long unlocked = 0;
657         struct vfio_regions *entry, *next;
658
659         iommu_tlb_sync(domain->domain, iotlb_gather);
660
661         list_for_each_entry_safe(entry, next, regions, list) {
662                 unlocked += vfio_unpin_pages_remote(dma,
663                                                     entry->iova,
664                                                     entry->phys >> PAGE_SHIFT,
665                                                     entry->len >> PAGE_SHIFT,
666                                                     false);
667                 list_del(&entry->list);
668                 kfree(entry);
669         }
670
671         cond_resched();
672
673         return unlocked;
674 }
675
676 /*
677  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
678  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
679  * of these regions (currently using a list).
680  *
681  * This value specifies maximum number of regions for each IOTLB flush sync.
682  */
683 #define VFIO_IOMMU_TLB_SYNC_MAX         512
684
685 static size_t unmap_unpin_fast(struct vfio_domain *domain,
686                                struct vfio_dma *dma, dma_addr_t *iova,
687                                size_t len, phys_addr_t phys, long *unlocked,
688                                struct list_head *unmapped_list,
689                                int *unmapped_cnt,
690                                struct iommu_iotlb_gather *iotlb_gather)
691 {
692         size_t unmapped = 0;
693         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
694
695         if (entry) {
696                 unmapped = iommu_unmap_fast(domain->domain, *iova, len,
697                                             iotlb_gather);
698
699                 if (!unmapped) {
700                         kfree(entry);
701                 } else {
702                         entry->iova = *iova;
703                         entry->phys = phys;
704                         entry->len  = unmapped;
705                         list_add_tail(&entry->list, unmapped_list);
706
707                         *iova += unmapped;
708                         (*unmapped_cnt)++;
709                 }
710         }
711
712         /*
713          * Sync if the number of fast-unmap regions hits the limit
714          * or in case of errors.
715          */
716         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
717                 *unlocked += vfio_sync_unpin(dma, domain, unmapped_list,
718                                              iotlb_gather);
719                 *unmapped_cnt = 0;
720         }
721
722         return unmapped;
723 }
724
725 static size_t unmap_unpin_slow(struct vfio_domain *domain,
726                                struct vfio_dma *dma, dma_addr_t *iova,
727                                size_t len, phys_addr_t phys,
728                                long *unlocked)
729 {
730         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
731
732         if (unmapped) {
733                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
734                                                      phys >> PAGE_SHIFT,
735                                                      unmapped >> PAGE_SHIFT,
736                                                      false);
737                 *iova += unmapped;
738                 cond_resched();
739         }
740         return unmapped;
741 }
742
743 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
744                              bool do_accounting)
745 {
746         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
747         struct vfio_domain *domain, *d;
748         LIST_HEAD(unmapped_region_list);
749         struct iommu_iotlb_gather iotlb_gather;
750         int unmapped_region_cnt = 0;
751         long unlocked = 0;
752
753         if (!dma->size)
754                 return 0;
755
756         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
757                 return 0;
758
759         /*
760          * We use the IOMMU to track the physical addresses, otherwise we'd
761          * need a much more complicated tracking system.  Unfortunately that
762          * means we need to use one of the iommu domains to figure out the
763          * pfns to unpin.  The rest need to be unmapped in advance so we have
764          * no iommu translations remaining when the pages are unpinned.
765          */
766         domain = d = list_first_entry(&iommu->domain_list,
767                                       struct vfio_domain, next);
768
769         list_for_each_entry_continue(d, &iommu->domain_list, next) {
770                 iommu_unmap(d->domain, dma->iova, dma->size);
771                 cond_resched();
772         }
773
774         iommu_iotlb_gather_init(&iotlb_gather);
775         while (iova < end) {
776                 size_t unmapped, len;
777                 phys_addr_t phys, next;
778
779                 phys = iommu_iova_to_phys(domain->domain, iova);
780                 if (WARN_ON(!phys)) {
781                         iova += PAGE_SIZE;
782                         continue;
783                 }
784
785                 /*
786                  * To optimize for fewer iommu_unmap() calls, each of which
787                  * may require hardware cache flushing, try to find the
788                  * largest contiguous physical memory chunk to unmap.
789                  */
790                 for (len = PAGE_SIZE;
791                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
792                         next = iommu_iova_to_phys(domain->domain, iova + len);
793                         if (next != phys + len)
794                                 break;
795                 }
796
797                 /*
798                  * First, try to use fast unmap/unpin. In case of failure,
799                  * switch to slow unmap/unpin path.
800                  */
801                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
802                                             &unlocked, &unmapped_region_list,
803                                             &unmapped_region_cnt,
804                                             &iotlb_gather);
805                 if (!unmapped) {
806                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
807                                                     phys, &unlocked);
808                         if (WARN_ON(!unmapped))
809                                 break;
810                 }
811         }
812
813         dma->iommu_mapped = false;
814
815         if (unmapped_region_cnt) {
816                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list,
817                                             &iotlb_gather);
818         }
819
820         if (do_accounting) {
821                 vfio_lock_acct(dma, -unlocked, true);
822                 return 0;
823         }
824         return unlocked;
825 }
826
827 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
828 {
829         vfio_unmap_unpin(iommu, dma, true);
830         vfio_unlink_dma(iommu, dma);
831         put_task_struct(dma->task);
832         kfree(dma);
833         iommu->dma_avail++;
834 }
835
836 static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
837 {
838         struct vfio_domain *domain;
839         unsigned long bitmap = ULONG_MAX;
840
841         mutex_lock(&iommu->lock);
842         list_for_each_entry(domain, &iommu->domain_list, next)
843                 bitmap &= domain->domain->pgsize_bitmap;
844         mutex_unlock(&iommu->lock);
845
846         /*
847          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
848          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
849          * That way the user will be able to map/unmap buffers whose size/
850          * start address is aligned with PAGE_SIZE. Pinning code uses that
851          * granularity while iommu driver can use the sub-PAGE_SIZE size
852          * to map the buffer.
853          */
854         if (bitmap & ~PAGE_MASK) {
855                 bitmap &= PAGE_MASK;
856                 bitmap |= PAGE_SIZE;
857         }
858
859         return bitmap;
860 }
861
862 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
863                              struct vfio_iommu_type1_dma_unmap *unmap)
864 {
865         uint64_t mask;
866         struct vfio_dma *dma, *dma_last = NULL;
867         size_t unmapped = 0;
868         int ret = 0, retries = 0;
869
870         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
871
872         if (unmap->iova & mask)
873                 return -EINVAL;
874         if (!unmap->size || unmap->size & mask)
875                 return -EINVAL;
876         if (unmap->iova + unmap->size - 1 < unmap->iova ||
877             unmap->size > SIZE_MAX)
878                 return -EINVAL;
879
880         WARN_ON(mask & PAGE_MASK);
881 again:
882         mutex_lock(&iommu->lock);
883
884         /*
885          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
886          * avoid tracking individual mappings.  This means that the granularity
887          * of the original mapping was lost and the user was allowed to attempt
888          * to unmap any range.  Depending on the contiguousness of physical
889          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
890          * or may not have worked.  We only guaranteed unmap granularity
891          * matching the original mapping; even though it was untracked here,
892          * the original mappings are reflected in IOMMU mappings.  This
893          * resulted in a couple unusual behaviors.  First, if a range is not
894          * able to be unmapped, ex. a set of 4k pages that was mapped as a
895          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
896          * a zero sized unmap.  Also, if an unmap request overlaps the first
897          * address of a hugepage, the IOMMU will unmap the entire hugepage.
898          * This also returns success and the returned unmap size reflects the
899          * actual size unmapped.
900          *
901          * We attempt to maintain compatibility with this "v1" interface, but
902          * we take control out of the hands of the IOMMU.  Therefore, an unmap
903          * request offset from the beginning of the original mapping will
904          * return success with zero sized unmap.  And an unmap request covering
905          * the first iova of mapping will unmap the entire range.
906          *
907          * The v2 version of this interface intends to be more deterministic.
908          * Unmap requests must fully cover previous mappings.  Multiple
909          * mappings may still be unmaped by specifying large ranges, but there
910          * must not be any previous mappings bisected by the range.  An error
911          * will be returned if these conditions are not met.  The v2 interface
912          * will only return success and a size of zero if there were no
913          * mappings within the range.
914          */
915         if (iommu->v2) {
916                 dma = vfio_find_dma(iommu, unmap->iova, 1);
917                 if (dma && dma->iova != unmap->iova) {
918                         ret = -EINVAL;
919                         goto unlock;
920                 }
921                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
922                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
923                         ret = -EINVAL;
924                         goto unlock;
925                 }
926         }
927
928         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
929                 if (!iommu->v2 && unmap->iova > dma->iova)
930                         break;
931                 /*
932                  * Task with same address space who mapped this iova range is
933                  * allowed to unmap the iova range.
934                  */
935                 if (dma->task->mm != current->mm)
936                         break;
937
938                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
939                         struct vfio_iommu_type1_dma_unmap nb_unmap;
940
941                         if (dma_last == dma) {
942                                 BUG_ON(++retries > 10);
943                         } else {
944                                 dma_last = dma;
945                                 retries = 0;
946                         }
947
948                         nb_unmap.iova = dma->iova;
949                         nb_unmap.size = dma->size;
950
951                         /*
952                          * Notify anyone (mdev vendor drivers) to invalidate and
953                          * unmap iovas within the range we're about to unmap.
954                          * Vendor drivers MUST unpin pages in response to an
955                          * invalidation.
956                          */
957                         mutex_unlock(&iommu->lock);
958                         blocking_notifier_call_chain(&iommu->notifier,
959                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
960                                                     &nb_unmap);
961                         goto again;
962                 }
963                 unmapped += dma->size;
964                 vfio_remove_dma(iommu, dma);
965         }
966
967 unlock:
968         mutex_unlock(&iommu->lock);
969
970         /* Report how much was unmapped */
971         unmap->size = unmapped;
972
973         return ret;
974 }
975
976 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
977                           unsigned long pfn, long npage, int prot)
978 {
979         struct vfio_domain *d;
980         int ret;
981
982         list_for_each_entry(d, &iommu->domain_list, next) {
983                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
984                                 npage << PAGE_SHIFT, prot | d->prot);
985                 if (ret)
986                         goto unwind;
987
988                 cond_resched();
989         }
990
991         return 0;
992
993 unwind:
994         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
995                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
996
997         return ret;
998 }
999
1000 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1001                             size_t map_size)
1002 {
1003         dma_addr_t iova = dma->iova;
1004         unsigned long vaddr = dma->vaddr;
1005         size_t size = map_size;
1006         long npage;
1007         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1008         int ret = 0;
1009
1010         while (size) {
1011                 /* Pin a contiguous chunk of memory */
1012                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1013                                               size >> PAGE_SHIFT, &pfn, limit);
1014                 if (npage <= 0) {
1015                         WARN_ON(!npage);
1016                         ret = (int)npage;
1017                         break;
1018                 }
1019
1020                 /* Map it! */
1021                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1022                                      dma->prot);
1023                 if (ret) {
1024                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1025                                                 npage, true);
1026                         break;
1027                 }
1028
1029                 size -= npage << PAGE_SHIFT;
1030                 dma->size += npage << PAGE_SHIFT;
1031         }
1032
1033         dma->iommu_mapped = true;
1034
1035         if (ret)
1036                 vfio_remove_dma(iommu, dma);
1037
1038         return ret;
1039 }
1040
1041 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1042                            struct vfio_iommu_type1_dma_map *map)
1043 {
1044         dma_addr_t iova = map->iova;
1045         unsigned long vaddr = map->vaddr;
1046         size_t size = map->size;
1047         int ret = 0, prot = 0;
1048         uint64_t mask;
1049         struct vfio_dma *dma;
1050
1051         /* Verify that none of our __u64 fields overflow */
1052         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1053                 return -EINVAL;
1054
1055         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1056
1057         WARN_ON(mask & PAGE_MASK);
1058
1059         /* READ/WRITE from device perspective */
1060         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1061                 prot |= IOMMU_WRITE;
1062         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1063                 prot |= IOMMU_READ;
1064
1065         if (!prot || !size || (size | iova | vaddr) & mask)
1066                 return -EINVAL;
1067
1068         /* Don't allow IOVA or virtual address wrap */
1069         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1070                 return -EINVAL;
1071
1072         mutex_lock(&iommu->lock);
1073
1074         if (vfio_find_dma(iommu, iova, size)) {
1075                 ret = -EEXIST;
1076                 goto out_unlock;
1077         }
1078
1079         if (!iommu->dma_avail) {
1080                 ret = -ENOSPC;
1081                 goto out_unlock;
1082         }
1083
1084         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1085         if (!dma) {
1086                 ret = -ENOMEM;
1087                 goto out_unlock;
1088         }
1089
1090         iommu->dma_avail--;
1091         dma->iova = iova;
1092         dma->vaddr = vaddr;
1093         dma->prot = prot;
1094
1095         /*
1096          * We need to be able to both add to a task's locked memory and test
1097          * against the locked memory limit and we need to be able to do both
1098          * outside of this call path as pinning can be asynchronous via the
1099          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1100          * task_struct and VM locked pages requires an mm_struct, however
1101          * holding an indefinite mm reference is not recommended, therefore we
1102          * only hold a reference to a task.  We could hold a reference to
1103          * current, however QEMU uses this call path through vCPU threads,
1104          * which can be killed resulting in a NULL mm and failure in the unmap
1105          * path when called via a different thread.  Avoid this problem by
1106          * using the group_leader as threads within the same group require
1107          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1108          * mm_struct.
1109          *
1110          * Previously we also used the task for testing CAP_IPC_LOCK at the
1111          * time of pinning and accounting, however has_capability() makes use
1112          * of real_cred, a copy-on-write field, so we can't guarantee that it
1113          * matches group_leader, or in fact that it might not change by the
1114          * time it's evaluated.  If a process were to call MAP_DMA with
1115          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1116          * possibly see different results for an iommu_mapped vfio_dma vs
1117          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1118          * time of calling MAP_DMA.
1119          */
1120         get_task_struct(current->group_leader);
1121         dma->task = current->group_leader;
1122         dma->lock_cap = capable(CAP_IPC_LOCK);
1123
1124         dma->pfn_list = RB_ROOT;
1125
1126         /* Insert zero-sized and grow as we map chunks of it */
1127         vfio_link_dma(iommu, dma);
1128
1129         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1130         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1131                 dma->size = size;
1132         else
1133                 ret = vfio_pin_map_dma(iommu, dma, size);
1134
1135 out_unlock:
1136         mutex_unlock(&iommu->lock);
1137         return ret;
1138 }
1139
1140 static int vfio_bus_type(struct device *dev, void *data)
1141 {
1142         struct bus_type **bus = data;
1143
1144         if (*bus && *bus != dev->bus)
1145                 return -EINVAL;
1146
1147         *bus = dev->bus;
1148
1149         return 0;
1150 }
1151
1152 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1153                              struct vfio_domain *domain)
1154 {
1155         struct vfio_domain *d;
1156         struct rb_node *n;
1157         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1158         int ret;
1159
1160         /* Arbitrarily pick the first domain in the list for lookups */
1161         d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1162         n = rb_first(&iommu->dma_list);
1163
1164         for (; n; n = rb_next(n)) {
1165                 struct vfio_dma *dma;
1166                 dma_addr_t iova;
1167
1168                 dma = rb_entry(n, struct vfio_dma, node);
1169                 iova = dma->iova;
1170
1171                 while (iova < dma->iova + dma->size) {
1172                         phys_addr_t phys;
1173                         size_t size;
1174
1175                         if (dma->iommu_mapped) {
1176                                 phys_addr_t p;
1177                                 dma_addr_t i;
1178
1179                                 phys = iommu_iova_to_phys(d->domain, iova);
1180
1181                                 if (WARN_ON(!phys)) {
1182                                         iova += PAGE_SIZE;
1183                                         continue;
1184                                 }
1185
1186                                 size = PAGE_SIZE;
1187                                 p = phys + size;
1188                                 i = iova + size;
1189                                 while (i < dma->iova + dma->size &&
1190                                        p == iommu_iova_to_phys(d->domain, i)) {
1191                                         size += PAGE_SIZE;
1192                                         p += PAGE_SIZE;
1193                                         i += PAGE_SIZE;
1194                                 }
1195                         } else {
1196                                 unsigned long pfn;
1197                                 unsigned long vaddr = dma->vaddr +
1198                                                      (iova - dma->iova);
1199                                 size_t n = dma->iova + dma->size - iova;
1200                                 long npage;
1201
1202                                 npage = vfio_pin_pages_remote(dma, vaddr,
1203                                                               n >> PAGE_SHIFT,
1204                                                               &pfn, limit);
1205                                 if (npage <= 0) {
1206                                         WARN_ON(!npage);
1207                                         ret = (int)npage;
1208                                         return ret;
1209                                 }
1210
1211                                 phys = pfn << PAGE_SHIFT;
1212                                 size = npage << PAGE_SHIFT;
1213                         }
1214
1215                         ret = iommu_map(domain->domain, iova, phys,
1216                                         size, dma->prot | domain->prot);
1217                         if (ret)
1218                                 return ret;
1219
1220                         iova += size;
1221                 }
1222                 dma->iommu_mapped = true;
1223         }
1224         return 0;
1225 }
1226
1227 /*
1228  * We change our unmap behavior slightly depending on whether the IOMMU
1229  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1230  * for practically any contiguous power-of-two mapping we give it.  This means
1231  * we don't need to look for contiguous chunks ourselves to make unmapping
1232  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1233  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1234  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1235  * hugetlbfs is in use.
1236  */
1237 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1238 {
1239         struct page *pages;
1240         int ret, order = get_order(PAGE_SIZE * 2);
1241
1242         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1243         if (!pages)
1244                 return;
1245
1246         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1247                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1248         if (!ret) {
1249                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1250
1251                 if (unmapped == PAGE_SIZE)
1252                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1253                 else
1254                         domain->fgsp = true;
1255         }
1256
1257         __free_pages(pages, order);
1258 }
1259
1260 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1261                                            struct iommu_group *iommu_group)
1262 {
1263         struct vfio_group *g;
1264
1265         list_for_each_entry(g, &domain->group_list, next) {
1266                 if (g->iommu_group == iommu_group)
1267                         return g;
1268         }
1269
1270         return NULL;
1271 }
1272
1273 static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1274 {
1275         struct list_head group_resv_regions;
1276         struct iommu_resv_region *region, *next;
1277         bool ret = false;
1278
1279         INIT_LIST_HEAD(&group_resv_regions);
1280         iommu_get_group_resv_regions(group, &group_resv_regions);
1281         list_for_each_entry(region, &group_resv_regions, list) {
1282                 /*
1283                  * The presence of any 'real' MSI regions should take
1284                  * precedence over the software-managed one if the
1285                  * IOMMU driver happens to advertise both types.
1286                  */
1287                 if (region->type == IOMMU_RESV_MSI) {
1288                         ret = false;
1289                         break;
1290                 }
1291
1292                 if (region->type == IOMMU_RESV_SW_MSI) {
1293                         *base = region->start;
1294                         ret = true;
1295                 }
1296         }
1297         list_for_each_entry_safe(region, next, &group_resv_regions, list)
1298                 kfree(region);
1299         return ret;
1300 }
1301
1302 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1303 {
1304         struct device *(*fn)(struct device *dev);
1305         struct device *iommu_device;
1306
1307         fn = symbol_get(mdev_get_iommu_device);
1308         if (fn) {
1309                 iommu_device = fn(dev);
1310                 symbol_put(mdev_get_iommu_device);
1311
1312                 return iommu_device;
1313         }
1314
1315         return NULL;
1316 }
1317
1318 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1319 {
1320         struct iommu_domain *domain = data;
1321         struct device *iommu_device;
1322
1323         iommu_device = vfio_mdev_get_iommu_device(dev);
1324         if (iommu_device) {
1325                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1326                         return iommu_aux_attach_device(domain, iommu_device);
1327                 else
1328                         return iommu_attach_device(domain, iommu_device);
1329         }
1330
1331         return -EINVAL;
1332 }
1333
1334 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1335 {
1336         struct iommu_domain *domain = data;
1337         struct device *iommu_device;
1338
1339         iommu_device = vfio_mdev_get_iommu_device(dev);
1340         if (iommu_device) {
1341                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1342                         iommu_aux_detach_device(domain, iommu_device);
1343                 else
1344                         iommu_detach_device(domain, iommu_device);
1345         }
1346
1347         return 0;
1348 }
1349
1350 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1351                                    struct vfio_group *group)
1352 {
1353         if (group->mdev_group)
1354                 return iommu_group_for_each_dev(group->iommu_group,
1355                                                 domain->domain,
1356                                                 vfio_mdev_attach_domain);
1357         else
1358                 return iommu_attach_group(domain->domain, group->iommu_group);
1359 }
1360
1361 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1362                                     struct vfio_group *group)
1363 {
1364         if (group->mdev_group)
1365                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1366                                          vfio_mdev_detach_domain);
1367         else
1368                 iommu_detach_group(domain->domain, group->iommu_group);
1369 }
1370
1371 static bool vfio_bus_is_mdev(struct bus_type *bus)
1372 {
1373         struct bus_type *mdev_bus;
1374         bool ret = false;
1375
1376         mdev_bus = symbol_get(mdev_bus_type);
1377         if (mdev_bus) {
1378                 ret = (bus == mdev_bus);
1379                 symbol_put(mdev_bus_type);
1380         }
1381
1382         return ret;
1383 }
1384
1385 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1386 {
1387         struct device **old = data, *new;
1388
1389         new = vfio_mdev_get_iommu_device(dev);
1390         if (!new || (*old && *old != new))
1391                 return -EINVAL;
1392
1393         *old = new;
1394
1395         return 0;
1396 }
1397
1398 static int vfio_iommu_type1_attach_group(void *iommu_data,
1399                                          struct iommu_group *iommu_group)
1400 {
1401         struct vfio_iommu *iommu = iommu_data;
1402         struct vfio_group *group;
1403         struct vfio_domain *domain, *d;
1404         struct bus_type *bus = NULL;
1405         int ret;
1406         bool resv_msi, msi_remap;
1407         phys_addr_t resv_msi_base;
1408
1409         mutex_lock(&iommu->lock);
1410
1411         list_for_each_entry(d, &iommu->domain_list, next) {
1412                 if (find_iommu_group(d, iommu_group)) {
1413                         mutex_unlock(&iommu->lock);
1414                         return -EINVAL;
1415                 }
1416         }
1417
1418         if (iommu->external_domain) {
1419                 if (find_iommu_group(iommu->external_domain, iommu_group)) {
1420                         mutex_unlock(&iommu->lock);
1421                         return -EINVAL;
1422                 }
1423         }
1424
1425         group = kzalloc(sizeof(*group), GFP_KERNEL);
1426         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1427         if (!group || !domain) {
1428                 ret = -ENOMEM;
1429                 goto out_free;
1430         }
1431
1432         group->iommu_group = iommu_group;
1433
1434         /* Determine bus_type in order to allocate a domain */
1435         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1436         if (ret)
1437                 goto out_free;
1438
1439         if (vfio_bus_is_mdev(bus)) {
1440                 struct device *iommu_device = NULL;
1441
1442                 group->mdev_group = true;
1443
1444                 /* Determine the isolation type */
1445                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
1446                                                vfio_mdev_iommu_device);
1447                 if (ret || !iommu_device) {
1448                         if (!iommu->external_domain) {
1449                                 INIT_LIST_HEAD(&domain->group_list);
1450                                 iommu->external_domain = domain;
1451                         } else {
1452                                 kfree(domain);
1453                         }
1454
1455                         list_add(&group->next,
1456                                  &iommu->external_domain->group_list);
1457                         mutex_unlock(&iommu->lock);
1458
1459                         return 0;
1460                 }
1461
1462                 bus = iommu_device->bus;
1463         }
1464
1465         domain->domain = iommu_domain_alloc(bus);
1466         if (!domain->domain) {
1467                 ret = -EIO;
1468                 goto out_free;
1469         }
1470
1471         if (iommu->nesting) {
1472                 int attr = 1;
1473
1474                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1475                                             &attr);
1476                 if (ret)
1477                         goto out_domain;
1478         }
1479
1480         ret = vfio_iommu_attach_group(domain, group);
1481         if (ret)
1482                 goto out_domain;
1483
1484         resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1485
1486         INIT_LIST_HEAD(&domain->group_list);
1487         list_add(&group->next, &domain->group_list);
1488
1489         msi_remap = irq_domain_check_msi_remap() ||
1490                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1491
1492         if (!allow_unsafe_interrupts && !msi_remap) {
1493                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1494                        __func__);
1495                 ret = -EPERM;
1496                 goto out_detach;
1497         }
1498
1499         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1500                 domain->prot |= IOMMU_CACHE;
1501
1502         /*
1503          * Try to match an existing compatible domain.  We don't want to
1504          * preclude an IOMMU driver supporting multiple bus_types and being
1505          * able to include different bus_types in the same IOMMU domain, so
1506          * we test whether the domains use the same iommu_ops rather than
1507          * testing if they're on the same bus_type.
1508          */
1509         list_for_each_entry(d, &iommu->domain_list, next) {
1510                 if (d->domain->ops == domain->domain->ops &&
1511                     d->prot == domain->prot) {
1512                         vfio_iommu_detach_group(domain, group);
1513                         if (!vfio_iommu_attach_group(d, group)) {
1514                                 list_add(&group->next, &d->group_list);
1515                                 iommu_domain_free(domain->domain);
1516                                 kfree(domain);
1517                                 mutex_unlock(&iommu->lock);
1518                                 return 0;
1519                         }
1520
1521                         ret = vfio_iommu_attach_group(domain, group);
1522                         if (ret)
1523                                 goto out_domain;
1524                 }
1525         }
1526
1527         vfio_test_domain_fgsp(domain);
1528
1529         /* replay mappings on new domains */
1530         ret = vfio_iommu_replay(iommu, domain);
1531         if (ret)
1532                 goto out_detach;
1533
1534         if (resv_msi) {
1535                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1536                 if (ret)
1537                         goto out_detach;
1538         }
1539
1540         list_add(&domain->next, &iommu->domain_list);
1541
1542         mutex_unlock(&iommu->lock);
1543
1544         return 0;
1545
1546 out_detach:
1547         vfio_iommu_detach_group(domain, group);
1548 out_domain:
1549         iommu_domain_free(domain->domain);
1550 out_free:
1551         kfree(domain);
1552         kfree(group);
1553         mutex_unlock(&iommu->lock);
1554         return ret;
1555 }
1556
1557 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1558 {
1559         struct rb_node *node;
1560
1561         while ((node = rb_first(&iommu->dma_list)))
1562                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1563 }
1564
1565 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1566 {
1567         struct rb_node *n, *p;
1568
1569         n = rb_first(&iommu->dma_list);
1570         for (; n; n = rb_next(n)) {
1571                 struct vfio_dma *dma;
1572                 long locked = 0, unlocked = 0;
1573
1574                 dma = rb_entry(n, struct vfio_dma, node);
1575                 unlocked += vfio_unmap_unpin(iommu, dma, false);
1576                 p = rb_first(&dma->pfn_list);
1577                 for (; p; p = rb_next(p)) {
1578                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1579                                                          node);
1580
1581                         if (!is_invalid_reserved_pfn(vpfn->pfn))
1582                                 locked++;
1583                 }
1584                 vfio_lock_acct(dma, locked - unlocked, true);
1585         }
1586 }
1587
1588 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1589 {
1590         struct rb_node *n;
1591
1592         n = rb_first(&iommu->dma_list);
1593         for (; n; n = rb_next(n)) {
1594                 struct vfio_dma *dma;
1595
1596                 dma = rb_entry(n, struct vfio_dma, node);
1597
1598                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1599                         break;
1600         }
1601         /* mdev vendor driver must unregister notifier */
1602         WARN_ON(iommu->notifier.head);
1603 }
1604
1605 static void vfio_iommu_type1_detach_group(void *iommu_data,
1606                                           struct iommu_group *iommu_group)
1607 {
1608         struct vfio_iommu *iommu = iommu_data;
1609         struct vfio_domain *domain;
1610         struct vfio_group *group;
1611
1612         mutex_lock(&iommu->lock);
1613
1614         if (iommu->external_domain) {
1615                 group = find_iommu_group(iommu->external_domain, iommu_group);
1616                 if (group) {
1617                         list_del(&group->next);
1618                         kfree(group);
1619
1620                         if (list_empty(&iommu->external_domain->group_list)) {
1621                                 vfio_sanity_check_pfn_list(iommu);
1622
1623                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1624                                         vfio_iommu_unmap_unpin_all(iommu);
1625
1626                                 kfree(iommu->external_domain);
1627                                 iommu->external_domain = NULL;
1628                         }
1629                         goto detach_group_done;
1630                 }
1631         }
1632
1633         list_for_each_entry(domain, &iommu->domain_list, next) {
1634                 group = find_iommu_group(domain, iommu_group);
1635                 if (!group)
1636                         continue;
1637
1638                 vfio_iommu_detach_group(domain, group);
1639                 list_del(&group->next);
1640                 kfree(group);
1641                 /*
1642                  * Group ownership provides privilege, if the group list is
1643                  * empty, the domain goes away. If it's the last domain with
1644                  * iommu and external domain doesn't exist, then all the
1645                  * mappings go away too. If it's the last domain with iommu and
1646                  * external domain exist, update accounting
1647                  */
1648                 if (list_empty(&domain->group_list)) {
1649                         if (list_is_singular(&iommu->domain_list)) {
1650                                 if (!iommu->external_domain)
1651                                         vfio_iommu_unmap_unpin_all(iommu);
1652                                 else
1653                                         vfio_iommu_unmap_unpin_reaccount(iommu);
1654                         }
1655                         iommu_domain_free(domain->domain);
1656                         list_del(&domain->next);
1657                         kfree(domain);
1658                 }
1659                 break;
1660         }
1661
1662 detach_group_done:
1663         mutex_unlock(&iommu->lock);
1664 }
1665
1666 static void *vfio_iommu_type1_open(unsigned long arg)
1667 {
1668         struct vfio_iommu *iommu;
1669
1670         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1671         if (!iommu)
1672                 return ERR_PTR(-ENOMEM);
1673
1674         switch (arg) {
1675         case VFIO_TYPE1_IOMMU:
1676                 break;
1677         case VFIO_TYPE1_NESTING_IOMMU:
1678                 iommu->nesting = true;
1679                 /* fall through */
1680         case VFIO_TYPE1v2_IOMMU:
1681                 iommu->v2 = true;
1682                 break;
1683         default:
1684                 kfree(iommu);
1685                 return ERR_PTR(-EINVAL);
1686         }
1687
1688         INIT_LIST_HEAD(&iommu->domain_list);
1689         iommu->dma_list = RB_ROOT;
1690         iommu->dma_avail = dma_entry_limit;
1691         mutex_init(&iommu->lock);
1692         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1693
1694         return iommu;
1695 }
1696
1697 static void vfio_release_domain(struct vfio_domain *domain, bool external)
1698 {
1699         struct vfio_group *group, *group_tmp;
1700
1701         list_for_each_entry_safe(group, group_tmp,
1702                                  &domain->group_list, next) {
1703                 if (!external)
1704                         vfio_iommu_detach_group(domain, group);
1705                 list_del(&group->next);
1706                 kfree(group);
1707         }
1708
1709         if (!external)
1710                 iommu_domain_free(domain->domain);
1711 }
1712
1713 static void vfio_iommu_type1_release(void *iommu_data)
1714 {
1715         struct vfio_iommu *iommu = iommu_data;
1716         struct vfio_domain *domain, *domain_tmp;
1717
1718         if (iommu->external_domain) {
1719                 vfio_release_domain(iommu->external_domain, true);
1720                 vfio_sanity_check_pfn_list(iommu);
1721                 kfree(iommu->external_domain);
1722         }
1723
1724         vfio_iommu_unmap_unpin_all(iommu);
1725
1726         list_for_each_entry_safe(domain, domain_tmp,
1727                                  &iommu->domain_list, next) {
1728                 vfio_release_domain(domain, false);
1729                 list_del(&domain->next);
1730                 kfree(domain);
1731         }
1732         kfree(iommu);
1733 }
1734
1735 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1736 {
1737         struct vfio_domain *domain;
1738         int ret = 1;
1739
1740         mutex_lock(&iommu->lock);
1741         list_for_each_entry(domain, &iommu->domain_list, next) {
1742                 if (!(domain->prot & IOMMU_CACHE)) {
1743                         ret = 0;
1744                         break;
1745                 }
1746         }
1747         mutex_unlock(&iommu->lock);
1748
1749         return ret;
1750 }
1751
1752 static long vfio_iommu_type1_ioctl(void *iommu_data,
1753                                    unsigned int cmd, unsigned long arg)
1754 {
1755         struct vfio_iommu *iommu = iommu_data;
1756         unsigned long minsz;
1757
1758         if (cmd == VFIO_CHECK_EXTENSION) {
1759                 switch (arg) {
1760                 case VFIO_TYPE1_IOMMU:
1761                 case VFIO_TYPE1v2_IOMMU:
1762                 case VFIO_TYPE1_NESTING_IOMMU:
1763                         return 1;
1764                 case VFIO_DMA_CC_IOMMU:
1765                         if (!iommu)
1766                                 return 0;
1767                         return vfio_domains_have_iommu_cache(iommu);
1768                 default:
1769                         return 0;
1770                 }
1771         } else if (cmd == VFIO_IOMMU_GET_INFO) {
1772                 struct vfio_iommu_type1_info info;
1773
1774                 minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1775
1776                 if (copy_from_user(&info, (void __user *)arg, minsz))
1777                         return -EFAULT;
1778
1779                 if (info.argsz < minsz)
1780                         return -EINVAL;
1781
1782                 info.flags = VFIO_IOMMU_INFO_PGSIZES;
1783
1784                 info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1785
1786                 return copy_to_user((void __user *)arg, &info, minsz) ?
1787                         -EFAULT : 0;
1788
1789         } else if (cmd == VFIO_IOMMU_MAP_DMA) {
1790                 struct vfio_iommu_type1_dma_map map;
1791                 uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1792                                 VFIO_DMA_MAP_FLAG_WRITE;
1793
1794                 minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1795
1796                 if (copy_from_user(&map, (void __user *)arg, minsz))
1797                         return -EFAULT;
1798
1799                 if (map.argsz < minsz || map.flags & ~mask)
1800                         return -EINVAL;
1801
1802                 return vfio_dma_do_map(iommu, &map);
1803
1804         } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1805                 struct vfio_iommu_type1_dma_unmap unmap;
1806                 long ret;
1807
1808                 minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1809
1810                 if (copy_from_user(&unmap, (void __user *)arg, minsz))
1811                         return -EFAULT;
1812
1813                 if (unmap.argsz < minsz || unmap.flags)
1814                         return -EINVAL;
1815
1816                 ret = vfio_dma_do_unmap(iommu, &unmap);
1817                 if (ret)
1818                         return ret;
1819
1820                 return copy_to_user((void __user *)arg, &unmap, minsz) ?
1821                         -EFAULT : 0;
1822         }
1823
1824         return -ENOTTY;
1825 }
1826
1827 static int vfio_iommu_type1_register_notifier(void *iommu_data,
1828                                               unsigned long *events,
1829                                               struct notifier_block *nb)
1830 {
1831         struct vfio_iommu *iommu = iommu_data;
1832
1833         /* clear known events */
1834         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1835
1836         /* refuse to register if still events remaining */
1837         if (*events)
1838                 return -EINVAL;
1839
1840         return blocking_notifier_chain_register(&iommu->notifier, nb);
1841 }
1842
1843 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1844                                                 struct notifier_block *nb)
1845 {
1846         struct vfio_iommu *iommu = iommu_data;
1847
1848         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1849 }
1850
1851 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1852         .name                   = "vfio-iommu-type1",
1853         .owner                  = THIS_MODULE,
1854         .open                   = vfio_iommu_type1_open,
1855         .release                = vfio_iommu_type1_release,
1856         .ioctl                  = vfio_iommu_type1_ioctl,
1857         .attach_group           = vfio_iommu_type1_attach_group,
1858         .detach_group           = vfio_iommu_type1_detach_group,
1859         .pin_pages              = vfio_iommu_type1_pin_pages,
1860         .unpin_pages            = vfio_iommu_type1_unpin_pages,
1861         .register_notifier      = vfio_iommu_type1_register_notifier,
1862         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
1863 };
1864
1865 static int __init vfio_iommu_type1_init(void)
1866 {
1867         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1868 }
1869
1870 static void __exit vfio_iommu_type1_cleanup(void)
1871 {
1872         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1873 }
1874
1875 module_init(vfio_iommu_type1_init);
1876 module_exit(vfio_iommu_type1_cleanup);
1877
1878 MODULE_VERSION(DRIVER_VERSION);
1879 MODULE_LICENSE("GPL v2");
1880 MODULE_AUTHOR(DRIVER_AUTHOR);
1881 MODULE_DESCRIPTION(DRIVER_DESC);