]> asedeno.scripts.mit.edu Git - linux.git/blob - drivers/vfio/vfio_iommu_type1.c
sctp: fix the transport error_count check
[linux.git] / drivers / vfio / vfio_iommu_type1.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
4  *
5  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
6  *     Author: Alex Williamson <alex.williamson@redhat.com>
7  *
8  * Derived from original vfio:
9  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
10  * Author: Tom Lyon, pugs@cisco.com
11  *
12  * We arbitrarily define a Type1 IOMMU as one matching the below code.
13  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
14  * VT-d, but that makes it harder to re-use as theoretically anyone
15  * implementing a similar IOMMU could make use of this.  We expect the
16  * IOMMU to support the IOMMU API and have few to no restrictions around
17  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
18  * optimized for relatively static mappings of a userspace process with
19  * userpsace pages pinned into memory.  We also assume devices and IOMMU
20  * domains are PCI based as the IOMMU API is still centered around a
21  * device/bus interface rather than a group interface.
22  */
23
24 #include <linux/compat.h>
25 #include <linux/device.h>
26 #include <linux/fs.h>
27 #include <linux/iommu.h>
28 #include <linux/module.h>
29 #include <linux/mm.h>
30 #include <linux/rbtree.h>
31 #include <linux/sched/signal.h>
32 #include <linux/sched/mm.h>
33 #include <linux/slab.h>
34 #include <linux/uaccess.h>
35 #include <linux/vfio.h>
36 #include <linux/workqueue.h>
37 #include <linux/mdev.h>
38 #include <linux/notifier.h>
39 #include <linux/dma-iommu.h>
40 #include <linux/irqdomain.h>
41
42 #define DRIVER_VERSION  "0.2"
43 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
44 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
45
46 static bool allow_unsafe_interrupts;
47 module_param_named(allow_unsafe_interrupts,
48                    allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
49 MODULE_PARM_DESC(allow_unsafe_interrupts,
50                  "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
51
52 static bool disable_hugepages;
53 module_param_named(disable_hugepages,
54                    disable_hugepages, bool, S_IRUGO | S_IWUSR);
55 MODULE_PARM_DESC(disable_hugepages,
56                  "Disable VFIO IOMMU support for IOMMU hugepages.");
57
58 static unsigned int dma_entry_limit __read_mostly = U16_MAX;
59 module_param_named(dma_entry_limit, dma_entry_limit, uint, 0644);
60 MODULE_PARM_DESC(dma_entry_limit,
61                  "Maximum number of user DMA mappings per container (65535).");
62
63 struct vfio_iommu {
64         struct list_head        domain_list;
65         struct vfio_domain      *external_domain; /* domain for external user */
66         struct mutex            lock;
67         struct rb_root          dma_list;
68         struct blocking_notifier_head notifier;
69         unsigned int            dma_avail;
70         bool                    v2;
71         bool                    nesting;
72 };
73
74 struct vfio_domain {
75         struct iommu_domain     *domain;
76         struct list_head        next;
77         struct list_head        group_list;
78         int                     prot;           /* IOMMU_CACHE */
79         bool                    fgsp;           /* Fine-grained super pages */
80 };
81
82 struct vfio_dma {
83         struct rb_node          node;
84         dma_addr_t              iova;           /* Device address */
85         unsigned long           vaddr;          /* Process virtual addr */
86         size_t                  size;           /* Map size (bytes) */
87         int                     prot;           /* IOMMU_READ/WRITE */
88         bool                    iommu_mapped;
89         bool                    lock_cap;       /* capable(CAP_IPC_LOCK) */
90         struct task_struct      *task;
91         struct rb_root          pfn_list;       /* Ex-user pinned pfn list */
92 };
93
94 struct vfio_group {
95         struct iommu_group      *iommu_group;
96         struct list_head        next;
97         bool                    mdev_group;     /* An mdev group */
98 };
99
100 /*
101  * Guest RAM pinning working set or DMA target
102  */
103 struct vfio_pfn {
104         struct rb_node          node;
105         dma_addr_t              iova;           /* Device address */
106         unsigned long           pfn;            /* Host pfn */
107         atomic_t                ref_count;
108 };
109
110 struct vfio_regions {
111         struct list_head list;
112         dma_addr_t iova;
113         phys_addr_t phys;
114         size_t len;
115 };
116
117 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu) \
118                                         (!list_empty(&iommu->domain_list))
119
120 static int put_pfn(unsigned long pfn, int prot);
121
122 /*
123  * This code handles mapping and unmapping of user data buffers
124  * into DMA'ble space using the IOMMU
125  */
126
127 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
128                                       dma_addr_t start, size_t size)
129 {
130         struct rb_node *node = iommu->dma_list.rb_node;
131
132         while (node) {
133                 struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
134
135                 if (start + size <= dma->iova)
136                         node = node->rb_left;
137                 else if (start >= dma->iova + dma->size)
138                         node = node->rb_right;
139                 else
140                         return dma;
141         }
142
143         return NULL;
144 }
145
146 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
147 {
148         struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
149         struct vfio_dma *dma;
150
151         while (*link) {
152                 parent = *link;
153                 dma = rb_entry(parent, struct vfio_dma, node);
154
155                 if (new->iova + new->size <= dma->iova)
156                         link = &(*link)->rb_left;
157                 else
158                         link = &(*link)->rb_right;
159         }
160
161         rb_link_node(&new->node, parent, link);
162         rb_insert_color(&new->node, &iommu->dma_list);
163 }
164
165 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
166 {
167         rb_erase(&old->node, &iommu->dma_list);
168 }
169
170 /*
171  * Helper Functions for host iova-pfn list
172  */
173 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
174 {
175         struct vfio_pfn *vpfn;
176         struct rb_node *node = dma->pfn_list.rb_node;
177
178         while (node) {
179                 vpfn = rb_entry(node, struct vfio_pfn, node);
180
181                 if (iova < vpfn->iova)
182                         node = node->rb_left;
183                 else if (iova > vpfn->iova)
184                         node = node->rb_right;
185                 else
186                         return vpfn;
187         }
188         return NULL;
189 }
190
191 static void vfio_link_pfn(struct vfio_dma *dma,
192                           struct vfio_pfn *new)
193 {
194         struct rb_node **link, *parent = NULL;
195         struct vfio_pfn *vpfn;
196
197         link = &dma->pfn_list.rb_node;
198         while (*link) {
199                 parent = *link;
200                 vpfn = rb_entry(parent, struct vfio_pfn, node);
201
202                 if (new->iova < vpfn->iova)
203                         link = &(*link)->rb_left;
204                 else
205                         link = &(*link)->rb_right;
206         }
207
208         rb_link_node(&new->node, parent, link);
209         rb_insert_color(&new->node, &dma->pfn_list);
210 }
211
212 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
213 {
214         rb_erase(&old->node, &dma->pfn_list);
215 }
216
217 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
218                                 unsigned long pfn)
219 {
220         struct vfio_pfn *vpfn;
221
222         vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
223         if (!vpfn)
224                 return -ENOMEM;
225
226         vpfn->iova = iova;
227         vpfn->pfn = pfn;
228         atomic_set(&vpfn->ref_count, 1);
229         vfio_link_pfn(dma, vpfn);
230         return 0;
231 }
232
233 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
234                                       struct vfio_pfn *vpfn)
235 {
236         vfio_unlink_pfn(dma, vpfn);
237         kfree(vpfn);
238 }
239
240 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
241                                                unsigned long iova)
242 {
243         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
244
245         if (vpfn)
246                 atomic_inc(&vpfn->ref_count);
247         return vpfn;
248 }
249
250 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
251 {
252         int ret = 0;
253
254         if (atomic_dec_and_test(&vpfn->ref_count)) {
255                 ret = put_pfn(vpfn->pfn, dma->prot);
256                 vfio_remove_from_pfn_list(dma, vpfn);
257         }
258         return ret;
259 }
260
261 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
262 {
263         struct mm_struct *mm;
264         int ret;
265
266         if (!npage)
267                 return 0;
268
269         mm = async ? get_task_mm(dma->task) : dma->task->mm;
270         if (!mm)
271                 return -ESRCH; /* process exited */
272
273         ret = down_write_killable(&mm->mmap_sem);
274         if (!ret) {
275                 ret = __account_locked_vm(mm, abs(npage), npage > 0, dma->task,
276                                           dma->lock_cap);
277                 up_write(&mm->mmap_sem);
278         }
279
280         if (async)
281                 mmput(mm);
282
283         return ret;
284 }
285
286 /*
287  * Some mappings aren't backed by a struct page, for example an mmap'd
288  * MMIO range for our own or another device.  These use a different
289  * pfn conversion and shouldn't be tracked as locked pages.
290  */
291 static bool is_invalid_reserved_pfn(unsigned long pfn)
292 {
293         if (pfn_valid(pfn)) {
294                 bool reserved;
295                 struct page *tail = pfn_to_page(pfn);
296                 struct page *head = compound_head(tail);
297                 reserved = !!(PageReserved(head));
298                 if (head != tail) {
299                         /*
300                          * "head" is not a dangling pointer
301                          * (compound_head takes care of that)
302                          * but the hugepage may have been split
303                          * from under us (and we may not hold a
304                          * reference count on the head page so it can
305                          * be reused before we run PageReferenced), so
306                          * we've to check PageTail before returning
307                          * what we just read.
308                          */
309                         smp_rmb();
310                         if (PageTail(tail))
311                                 return reserved;
312                 }
313                 return PageReserved(tail);
314         }
315
316         return true;
317 }
318
319 static int put_pfn(unsigned long pfn, int prot)
320 {
321         if (!is_invalid_reserved_pfn(pfn)) {
322                 struct page *page = pfn_to_page(pfn);
323                 if (prot & IOMMU_WRITE)
324                         SetPageDirty(page);
325                 put_page(page);
326                 return 1;
327         }
328         return 0;
329 }
330
331 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
332                          int prot, unsigned long *pfn)
333 {
334         struct page *page[1];
335         struct vm_area_struct *vma;
336         struct vm_area_struct *vmas[1];
337         unsigned int flags = 0;
338         int ret;
339
340         if (prot & IOMMU_WRITE)
341                 flags |= FOLL_WRITE;
342
343         down_read(&mm->mmap_sem);
344         if (mm == current->mm) {
345                 ret = get_user_pages(vaddr, 1, flags | FOLL_LONGTERM, page,
346                                      vmas);
347         } else {
348                 ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
349                                             vmas, NULL);
350                 /*
351                  * The lifetime of a vaddr_get_pfn() page pin is
352                  * userspace-controlled. In the fs-dax case this could
353                  * lead to indefinite stalls in filesystem operations.
354                  * Disallow attempts to pin fs-dax pages via this
355                  * interface.
356                  */
357                 if (ret > 0 && vma_is_fsdax(vmas[0])) {
358                         ret = -EOPNOTSUPP;
359                         put_page(page[0]);
360                 }
361         }
362         up_read(&mm->mmap_sem);
363
364         if (ret == 1) {
365                 *pfn = page_to_pfn(page[0]);
366                 return 0;
367         }
368
369         down_read(&mm->mmap_sem);
370
371         vma = find_vma_intersection(mm, vaddr, vaddr + 1);
372
373         if (vma && vma->vm_flags & VM_PFNMAP) {
374                 *pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
375                 if (is_invalid_reserved_pfn(*pfn))
376                         ret = 0;
377         }
378
379         up_read(&mm->mmap_sem);
380         return ret;
381 }
382
383 /*
384  * Attempt to pin pages.  We really don't want to track all the pfns and
385  * the iommu can only map chunks of consecutive pfns anyway, so get the
386  * first page and all consecutive pages with the same locking.
387  */
388 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
389                                   long npage, unsigned long *pfn_base,
390                                   unsigned long limit)
391 {
392         unsigned long pfn = 0;
393         long ret, pinned = 0, lock_acct = 0;
394         bool rsvd;
395         dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
396
397         /* This code path is only user initiated */
398         if (!current->mm)
399                 return -ENODEV;
400
401         ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
402         if (ret)
403                 return ret;
404
405         pinned++;
406         rsvd = is_invalid_reserved_pfn(*pfn_base);
407
408         /*
409          * Reserved pages aren't counted against the user, externally pinned
410          * pages are already counted against the user.
411          */
412         if (!rsvd && !vfio_find_vpfn(dma, iova)) {
413                 if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
414                         put_pfn(*pfn_base, dma->prot);
415                         pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
416                                         limit << PAGE_SHIFT);
417                         return -ENOMEM;
418                 }
419                 lock_acct++;
420         }
421
422         if (unlikely(disable_hugepages))
423                 goto out;
424
425         /* Lock all the consecutive pages from pfn_base */
426         for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
427              pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
428                 ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
429                 if (ret)
430                         break;
431
432                 if (pfn != *pfn_base + pinned ||
433                     rsvd != is_invalid_reserved_pfn(pfn)) {
434                         put_pfn(pfn, dma->prot);
435                         break;
436                 }
437
438                 if (!rsvd && !vfio_find_vpfn(dma, iova)) {
439                         if (!dma->lock_cap &&
440                             current->mm->locked_vm + lock_acct + 1 > limit) {
441                                 put_pfn(pfn, dma->prot);
442                                 pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
443                                         __func__, limit << PAGE_SHIFT);
444                                 ret = -ENOMEM;
445                                 goto unpin_out;
446                         }
447                         lock_acct++;
448                 }
449         }
450
451 out:
452         ret = vfio_lock_acct(dma, lock_acct, false);
453
454 unpin_out:
455         if (ret) {
456                 if (!rsvd) {
457                         for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
458                                 put_pfn(pfn, dma->prot);
459                 }
460
461                 return ret;
462         }
463
464         return pinned;
465 }
466
467 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
468                                     unsigned long pfn, long npage,
469                                     bool do_accounting)
470 {
471         long unlocked = 0, locked = 0;
472         long i;
473
474         for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
475                 if (put_pfn(pfn++, dma->prot)) {
476                         unlocked++;
477                         if (vfio_find_vpfn(dma, iova))
478                                 locked++;
479                 }
480         }
481
482         if (do_accounting)
483                 vfio_lock_acct(dma, locked - unlocked, true);
484
485         return unlocked;
486 }
487
488 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
489                                   unsigned long *pfn_base, bool do_accounting)
490 {
491         struct mm_struct *mm;
492         int ret;
493
494         mm = get_task_mm(dma->task);
495         if (!mm)
496                 return -ENODEV;
497
498         ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
499         if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
500                 ret = vfio_lock_acct(dma, 1, true);
501                 if (ret) {
502                         put_pfn(*pfn_base, dma->prot);
503                         if (ret == -ENOMEM)
504                                 pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
505                                         "(%ld) exceeded\n", __func__,
506                                         dma->task->comm, task_pid_nr(dma->task),
507                                         task_rlimit(dma->task, RLIMIT_MEMLOCK));
508                 }
509         }
510
511         mmput(mm);
512         return ret;
513 }
514
515 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
516                                     bool do_accounting)
517 {
518         int unlocked;
519         struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
520
521         if (!vpfn)
522                 return 0;
523
524         unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
525
526         if (do_accounting)
527                 vfio_lock_acct(dma, -unlocked, true);
528
529         return unlocked;
530 }
531
532 static int vfio_iommu_type1_pin_pages(void *iommu_data,
533                                       unsigned long *user_pfn,
534                                       int npage, int prot,
535                                       unsigned long *phys_pfn)
536 {
537         struct vfio_iommu *iommu = iommu_data;
538         int i, j, ret;
539         unsigned long remote_vaddr;
540         struct vfio_dma *dma;
541         bool do_accounting;
542
543         if (!iommu || !user_pfn || !phys_pfn)
544                 return -EINVAL;
545
546         /* Supported for v2 version only */
547         if (!iommu->v2)
548                 return -EACCES;
549
550         mutex_lock(&iommu->lock);
551
552         /* Fail if notifier list is empty */
553         if (!iommu->notifier.head) {
554                 ret = -EINVAL;
555                 goto pin_done;
556         }
557
558         /*
559          * If iommu capable domain exist in the container then all pages are
560          * already pinned and accounted. Accouting should be done if there is no
561          * iommu capable domain in the container.
562          */
563         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
564
565         for (i = 0; i < npage; i++) {
566                 dma_addr_t iova;
567                 struct vfio_pfn *vpfn;
568
569                 iova = user_pfn[i] << PAGE_SHIFT;
570                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
571                 if (!dma) {
572                         ret = -EINVAL;
573                         goto pin_unwind;
574                 }
575
576                 if ((dma->prot & prot) != prot) {
577                         ret = -EPERM;
578                         goto pin_unwind;
579                 }
580
581                 vpfn = vfio_iova_get_vfio_pfn(dma, iova);
582                 if (vpfn) {
583                         phys_pfn[i] = vpfn->pfn;
584                         continue;
585                 }
586
587                 remote_vaddr = dma->vaddr + iova - dma->iova;
588                 ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
589                                              do_accounting);
590                 if (ret)
591                         goto pin_unwind;
592
593                 ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
594                 if (ret) {
595                         vfio_unpin_page_external(dma, iova, do_accounting);
596                         goto pin_unwind;
597                 }
598         }
599
600         ret = i;
601         goto pin_done;
602
603 pin_unwind:
604         phys_pfn[i] = 0;
605         for (j = 0; j < i; j++) {
606                 dma_addr_t iova;
607
608                 iova = user_pfn[j] << PAGE_SHIFT;
609                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
610                 vfio_unpin_page_external(dma, iova, do_accounting);
611                 phys_pfn[j] = 0;
612         }
613 pin_done:
614         mutex_unlock(&iommu->lock);
615         return ret;
616 }
617
618 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
619                                         unsigned long *user_pfn,
620                                         int npage)
621 {
622         struct vfio_iommu *iommu = iommu_data;
623         bool do_accounting;
624         int i;
625
626         if (!iommu || !user_pfn)
627                 return -EINVAL;
628
629         /* Supported for v2 version only */
630         if (!iommu->v2)
631                 return -EACCES;
632
633         mutex_lock(&iommu->lock);
634
635         do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
636         for (i = 0; i < npage; i++) {
637                 struct vfio_dma *dma;
638                 dma_addr_t iova;
639
640                 iova = user_pfn[i] << PAGE_SHIFT;
641                 dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
642                 if (!dma)
643                         goto unpin_exit;
644                 vfio_unpin_page_external(dma, iova, do_accounting);
645         }
646
647 unpin_exit:
648         mutex_unlock(&iommu->lock);
649         return i > npage ? npage : (i > 0 ? i : -EINVAL);
650 }
651
652 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
653                                 struct list_head *regions)
654 {
655         long unlocked = 0;
656         struct vfio_regions *entry, *next;
657
658         iommu_tlb_sync(domain->domain);
659
660         list_for_each_entry_safe(entry, next, regions, list) {
661                 unlocked += vfio_unpin_pages_remote(dma,
662                                                     entry->iova,
663                                                     entry->phys >> PAGE_SHIFT,
664                                                     entry->len >> PAGE_SHIFT,
665                                                     false);
666                 list_del(&entry->list);
667                 kfree(entry);
668         }
669
670         cond_resched();
671
672         return unlocked;
673 }
674
675 /*
676  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
677  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
678  * of these regions (currently using a list).
679  *
680  * This value specifies maximum number of regions for each IOTLB flush sync.
681  */
682 #define VFIO_IOMMU_TLB_SYNC_MAX         512
683
684 static size_t unmap_unpin_fast(struct vfio_domain *domain,
685                                struct vfio_dma *dma, dma_addr_t *iova,
686                                size_t len, phys_addr_t phys, long *unlocked,
687                                struct list_head *unmapped_list,
688                                int *unmapped_cnt)
689 {
690         size_t unmapped = 0;
691         struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
692
693         if (entry) {
694                 unmapped = iommu_unmap_fast(domain->domain, *iova, len);
695
696                 if (!unmapped) {
697                         kfree(entry);
698                 } else {
699                         iommu_tlb_range_add(domain->domain, *iova, unmapped);
700                         entry->iova = *iova;
701                         entry->phys = phys;
702                         entry->len  = unmapped;
703                         list_add_tail(&entry->list, unmapped_list);
704
705                         *iova += unmapped;
706                         (*unmapped_cnt)++;
707                 }
708         }
709
710         /*
711          * Sync if the number of fast-unmap regions hits the limit
712          * or in case of errors.
713          */
714         if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
715                 *unlocked += vfio_sync_unpin(dma, domain,
716                                              unmapped_list);
717                 *unmapped_cnt = 0;
718         }
719
720         return unmapped;
721 }
722
723 static size_t unmap_unpin_slow(struct vfio_domain *domain,
724                                struct vfio_dma *dma, dma_addr_t *iova,
725                                size_t len, phys_addr_t phys,
726                                long *unlocked)
727 {
728         size_t unmapped = iommu_unmap(domain->domain, *iova, len);
729
730         if (unmapped) {
731                 *unlocked += vfio_unpin_pages_remote(dma, *iova,
732                                                      phys >> PAGE_SHIFT,
733                                                      unmapped >> PAGE_SHIFT,
734                                                      false);
735                 *iova += unmapped;
736                 cond_resched();
737         }
738         return unmapped;
739 }
740
741 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
742                              bool do_accounting)
743 {
744         dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
745         struct vfio_domain *domain, *d;
746         LIST_HEAD(unmapped_region_list);
747         int unmapped_region_cnt = 0;
748         long unlocked = 0;
749
750         if (!dma->size)
751                 return 0;
752
753         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
754                 return 0;
755
756         /*
757          * We use the IOMMU to track the physical addresses, otherwise we'd
758          * need a much more complicated tracking system.  Unfortunately that
759          * means we need to use one of the iommu domains to figure out the
760          * pfns to unpin.  The rest need to be unmapped in advance so we have
761          * no iommu translations remaining when the pages are unpinned.
762          */
763         domain = d = list_first_entry(&iommu->domain_list,
764                                       struct vfio_domain, next);
765
766         list_for_each_entry_continue(d, &iommu->domain_list, next) {
767                 iommu_unmap(d->domain, dma->iova, dma->size);
768                 cond_resched();
769         }
770
771         while (iova < end) {
772                 size_t unmapped, len;
773                 phys_addr_t phys, next;
774
775                 phys = iommu_iova_to_phys(domain->domain, iova);
776                 if (WARN_ON(!phys)) {
777                         iova += PAGE_SIZE;
778                         continue;
779                 }
780
781                 /*
782                  * To optimize for fewer iommu_unmap() calls, each of which
783                  * may require hardware cache flushing, try to find the
784                  * largest contiguous physical memory chunk to unmap.
785                  */
786                 for (len = PAGE_SIZE;
787                      !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
788                         next = iommu_iova_to_phys(domain->domain, iova + len);
789                         if (next != phys + len)
790                                 break;
791                 }
792
793                 /*
794                  * First, try to use fast unmap/unpin. In case of failure,
795                  * switch to slow unmap/unpin path.
796                  */
797                 unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
798                                             &unlocked, &unmapped_region_list,
799                                             &unmapped_region_cnt);
800                 if (!unmapped) {
801                         unmapped = unmap_unpin_slow(domain, dma, &iova, len,
802                                                     phys, &unlocked);
803                         if (WARN_ON(!unmapped))
804                                 break;
805                 }
806         }
807
808         dma->iommu_mapped = false;
809
810         if (unmapped_region_cnt)
811                 unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list);
812
813         if (do_accounting) {
814                 vfio_lock_acct(dma, -unlocked, true);
815                 return 0;
816         }
817         return unlocked;
818 }
819
820 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
821 {
822         vfio_unmap_unpin(iommu, dma, true);
823         vfio_unlink_dma(iommu, dma);
824         put_task_struct(dma->task);
825         kfree(dma);
826         iommu->dma_avail++;
827 }
828
829 static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
830 {
831         struct vfio_domain *domain;
832         unsigned long bitmap = ULONG_MAX;
833
834         mutex_lock(&iommu->lock);
835         list_for_each_entry(domain, &iommu->domain_list, next)
836                 bitmap &= domain->domain->pgsize_bitmap;
837         mutex_unlock(&iommu->lock);
838
839         /*
840          * In case the IOMMU supports page sizes smaller than PAGE_SIZE
841          * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
842          * That way the user will be able to map/unmap buffers whose size/
843          * start address is aligned with PAGE_SIZE. Pinning code uses that
844          * granularity while iommu driver can use the sub-PAGE_SIZE size
845          * to map the buffer.
846          */
847         if (bitmap & ~PAGE_MASK) {
848                 bitmap &= PAGE_MASK;
849                 bitmap |= PAGE_SIZE;
850         }
851
852         return bitmap;
853 }
854
855 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
856                              struct vfio_iommu_type1_dma_unmap *unmap)
857 {
858         uint64_t mask;
859         struct vfio_dma *dma, *dma_last = NULL;
860         size_t unmapped = 0;
861         int ret = 0, retries = 0;
862
863         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
864
865         if (unmap->iova & mask)
866                 return -EINVAL;
867         if (!unmap->size || unmap->size & mask)
868                 return -EINVAL;
869         if (unmap->iova + unmap->size - 1 < unmap->iova ||
870             unmap->size > SIZE_MAX)
871                 return -EINVAL;
872
873         WARN_ON(mask & PAGE_MASK);
874 again:
875         mutex_lock(&iommu->lock);
876
877         /*
878          * vfio-iommu-type1 (v1) - User mappings were coalesced together to
879          * avoid tracking individual mappings.  This means that the granularity
880          * of the original mapping was lost and the user was allowed to attempt
881          * to unmap any range.  Depending on the contiguousness of physical
882          * memory and page sizes supported by the IOMMU, arbitrary unmaps may
883          * or may not have worked.  We only guaranteed unmap granularity
884          * matching the original mapping; even though it was untracked here,
885          * the original mappings are reflected in IOMMU mappings.  This
886          * resulted in a couple unusual behaviors.  First, if a range is not
887          * able to be unmapped, ex. a set of 4k pages that was mapped as a
888          * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
889          * a zero sized unmap.  Also, if an unmap request overlaps the first
890          * address of a hugepage, the IOMMU will unmap the entire hugepage.
891          * This also returns success and the returned unmap size reflects the
892          * actual size unmapped.
893          *
894          * We attempt to maintain compatibility with this "v1" interface, but
895          * we take control out of the hands of the IOMMU.  Therefore, an unmap
896          * request offset from the beginning of the original mapping will
897          * return success with zero sized unmap.  And an unmap request covering
898          * the first iova of mapping will unmap the entire range.
899          *
900          * The v2 version of this interface intends to be more deterministic.
901          * Unmap requests must fully cover previous mappings.  Multiple
902          * mappings may still be unmaped by specifying large ranges, but there
903          * must not be any previous mappings bisected by the range.  An error
904          * will be returned if these conditions are not met.  The v2 interface
905          * will only return success and a size of zero if there were no
906          * mappings within the range.
907          */
908         if (iommu->v2) {
909                 dma = vfio_find_dma(iommu, unmap->iova, 1);
910                 if (dma && dma->iova != unmap->iova) {
911                         ret = -EINVAL;
912                         goto unlock;
913                 }
914                 dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
915                 if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
916                         ret = -EINVAL;
917                         goto unlock;
918                 }
919         }
920
921         while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
922                 if (!iommu->v2 && unmap->iova > dma->iova)
923                         break;
924                 /*
925                  * Task with same address space who mapped this iova range is
926                  * allowed to unmap the iova range.
927                  */
928                 if (dma->task->mm != current->mm)
929                         break;
930
931                 if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
932                         struct vfio_iommu_type1_dma_unmap nb_unmap;
933
934                         if (dma_last == dma) {
935                                 BUG_ON(++retries > 10);
936                         } else {
937                                 dma_last = dma;
938                                 retries = 0;
939                         }
940
941                         nb_unmap.iova = dma->iova;
942                         nb_unmap.size = dma->size;
943
944                         /*
945                          * Notify anyone (mdev vendor drivers) to invalidate and
946                          * unmap iovas within the range we're about to unmap.
947                          * Vendor drivers MUST unpin pages in response to an
948                          * invalidation.
949                          */
950                         mutex_unlock(&iommu->lock);
951                         blocking_notifier_call_chain(&iommu->notifier,
952                                                     VFIO_IOMMU_NOTIFY_DMA_UNMAP,
953                                                     &nb_unmap);
954                         goto again;
955                 }
956                 unmapped += dma->size;
957                 vfio_remove_dma(iommu, dma);
958         }
959
960 unlock:
961         mutex_unlock(&iommu->lock);
962
963         /* Report how much was unmapped */
964         unmap->size = unmapped;
965
966         return ret;
967 }
968
969 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
970                           unsigned long pfn, long npage, int prot)
971 {
972         struct vfio_domain *d;
973         int ret;
974
975         list_for_each_entry(d, &iommu->domain_list, next) {
976                 ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
977                                 npage << PAGE_SHIFT, prot | d->prot);
978                 if (ret)
979                         goto unwind;
980
981                 cond_resched();
982         }
983
984         return 0;
985
986 unwind:
987         list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
988                 iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
989
990         return ret;
991 }
992
993 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
994                             size_t map_size)
995 {
996         dma_addr_t iova = dma->iova;
997         unsigned long vaddr = dma->vaddr;
998         size_t size = map_size;
999         long npage;
1000         unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1001         int ret = 0;
1002
1003         while (size) {
1004                 /* Pin a contiguous chunk of memory */
1005                 npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1006                                               size >> PAGE_SHIFT, &pfn, limit);
1007                 if (npage <= 0) {
1008                         WARN_ON(!npage);
1009                         ret = (int)npage;
1010                         break;
1011                 }
1012
1013                 /* Map it! */
1014                 ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1015                                      dma->prot);
1016                 if (ret) {
1017                         vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1018                                                 npage, true);
1019                         break;
1020                 }
1021
1022                 size -= npage << PAGE_SHIFT;
1023                 dma->size += npage << PAGE_SHIFT;
1024         }
1025
1026         dma->iommu_mapped = true;
1027
1028         if (ret)
1029                 vfio_remove_dma(iommu, dma);
1030
1031         return ret;
1032 }
1033
1034 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1035                            struct vfio_iommu_type1_dma_map *map)
1036 {
1037         dma_addr_t iova = map->iova;
1038         unsigned long vaddr = map->vaddr;
1039         size_t size = map->size;
1040         int ret = 0, prot = 0;
1041         uint64_t mask;
1042         struct vfio_dma *dma;
1043
1044         /* Verify that none of our __u64 fields overflow */
1045         if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1046                 return -EINVAL;
1047
1048         mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1049
1050         WARN_ON(mask & PAGE_MASK);
1051
1052         /* READ/WRITE from device perspective */
1053         if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1054                 prot |= IOMMU_WRITE;
1055         if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1056                 prot |= IOMMU_READ;
1057
1058         if (!prot || !size || (size | iova | vaddr) & mask)
1059                 return -EINVAL;
1060
1061         /* Don't allow IOVA or virtual address wrap */
1062         if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1063                 return -EINVAL;
1064
1065         mutex_lock(&iommu->lock);
1066
1067         if (vfio_find_dma(iommu, iova, size)) {
1068                 ret = -EEXIST;
1069                 goto out_unlock;
1070         }
1071
1072         if (!iommu->dma_avail) {
1073                 ret = -ENOSPC;
1074                 goto out_unlock;
1075         }
1076
1077         dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1078         if (!dma) {
1079                 ret = -ENOMEM;
1080                 goto out_unlock;
1081         }
1082
1083         iommu->dma_avail--;
1084         dma->iova = iova;
1085         dma->vaddr = vaddr;
1086         dma->prot = prot;
1087
1088         /*
1089          * We need to be able to both add to a task's locked memory and test
1090          * against the locked memory limit and we need to be able to do both
1091          * outside of this call path as pinning can be asynchronous via the
1092          * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1093          * task_struct and VM locked pages requires an mm_struct, however
1094          * holding an indefinite mm reference is not recommended, therefore we
1095          * only hold a reference to a task.  We could hold a reference to
1096          * current, however QEMU uses this call path through vCPU threads,
1097          * which can be killed resulting in a NULL mm and failure in the unmap
1098          * path when called via a different thread.  Avoid this problem by
1099          * using the group_leader as threads within the same group require
1100          * both CLONE_THREAD and CLONE_VM and will therefore use the same
1101          * mm_struct.
1102          *
1103          * Previously we also used the task for testing CAP_IPC_LOCK at the
1104          * time of pinning and accounting, however has_capability() makes use
1105          * of real_cred, a copy-on-write field, so we can't guarantee that it
1106          * matches group_leader, or in fact that it might not change by the
1107          * time it's evaluated.  If a process were to call MAP_DMA with
1108          * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1109          * possibly see different results for an iommu_mapped vfio_dma vs
1110          * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1111          * time of calling MAP_DMA.
1112          */
1113         get_task_struct(current->group_leader);
1114         dma->task = current->group_leader;
1115         dma->lock_cap = capable(CAP_IPC_LOCK);
1116
1117         dma->pfn_list = RB_ROOT;
1118
1119         /* Insert zero-sized and grow as we map chunks of it */
1120         vfio_link_dma(iommu, dma);
1121
1122         /* Don't pin and map if container doesn't contain IOMMU capable domain*/
1123         if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1124                 dma->size = size;
1125         else
1126                 ret = vfio_pin_map_dma(iommu, dma, size);
1127
1128 out_unlock:
1129         mutex_unlock(&iommu->lock);
1130         return ret;
1131 }
1132
1133 static int vfio_bus_type(struct device *dev, void *data)
1134 {
1135         struct bus_type **bus = data;
1136
1137         if (*bus && *bus != dev->bus)
1138                 return -EINVAL;
1139
1140         *bus = dev->bus;
1141
1142         return 0;
1143 }
1144
1145 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1146                              struct vfio_domain *domain)
1147 {
1148         struct vfio_domain *d;
1149         struct rb_node *n;
1150         unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1151         int ret;
1152
1153         /* Arbitrarily pick the first domain in the list for lookups */
1154         d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1155         n = rb_first(&iommu->dma_list);
1156
1157         for (; n; n = rb_next(n)) {
1158                 struct vfio_dma *dma;
1159                 dma_addr_t iova;
1160
1161                 dma = rb_entry(n, struct vfio_dma, node);
1162                 iova = dma->iova;
1163
1164                 while (iova < dma->iova + dma->size) {
1165                         phys_addr_t phys;
1166                         size_t size;
1167
1168                         if (dma->iommu_mapped) {
1169                                 phys_addr_t p;
1170                                 dma_addr_t i;
1171
1172                                 phys = iommu_iova_to_phys(d->domain, iova);
1173
1174                                 if (WARN_ON(!phys)) {
1175                                         iova += PAGE_SIZE;
1176                                         continue;
1177                                 }
1178
1179                                 size = PAGE_SIZE;
1180                                 p = phys + size;
1181                                 i = iova + size;
1182                                 while (i < dma->iova + dma->size &&
1183                                        p == iommu_iova_to_phys(d->domain, i)) {
1184                                         size += PAGE_SIZE;
1185                                         p += PAGE_SIZE;
1186                                         i += PAGE_SIZE;
1187                                 }
1188                         } else {
1189                                 unsigned long pfn;
1190                                 unsigned long vaddr = dma->vaddr +
1191                                                      (iova - dma->iova);
1192                                 size_t n = dma->iova + dma->size - iova;
1193                                 long npage;
1194
1195                                 npage = vfio_pin_pages_remote(dma, vaddr,
1196                                                               n >> PAGE_SHIFT,
1197                                                               &pfn, limit);
1198                                 if (npage <= 0) {
1199                                         WARN_ON(!npage);
1200                                         ret = (int)npage;
1201                                         return ret;
1202                                 }
1203
1204                                 phys = pfn << PAGE_SHIFT;
1205                                 size = npage << PAGE_SHIFT;
1206                         }
1207
1208                         ret = iommu_map(domain->domain, iova, phys,
1209                                         size, dma->prot | domain->prot);
1210                         if (ret)
1211                                 return ret;
1212
1213                         iova += size;
1214                 }
1215                 dma->iommu_mapped = true;
1216         }
1217         return 0;
1218 }
1219
1220 /*
1221  * We change our unmap behavior slightly depending on whether the IOMMU
1222  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1223  * for practically any contiguous power-of-two mapping we give it.  This means
1224  * we don't need to look for contiguous chunks ourselves to make unmapping
1225  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1226  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1227  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1228  * hugetlbfs is in use.
1229  */
1230 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1231 {
1232         struct page *pages;
1233         int ret, order = get_order(PAGE_SIZE * 2);
1234
1235         pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1236         if (!pages)
1237                 return;
1238
1239         ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1240                         IOMMU_READ | IOMMU_WRITE | domain->prot);
1241         if (!ret) {
1242                 size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1243
1244                 if (unmapped == PAGE_SIZE)
1245                         iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1246                 else
1247                         domain->fgsp = true;
1248         }
1249
1250         __free_pages(pages, order);
1251 }
1252
1253 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1254                                            struct iommu_group *iommu_group)
1255 {
1256         struct vfio_group *g;
1257
1258         list_for_each_entry(g, &domain->group_list, next) {
1259                 if (g->iommu_group == iommu_group)
1260                         return g;
1261         }
1262
1263         return NULL;
1264 }
1265
1266 static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1267 {
1268         struct list_head group_resv_regions;
1269         struct iommu_resv_region *region, *next;
1270         bool ret = false;
1271
1272         INIT_LIST_HEAD(&group_resv_regions);
1273         iommu_get_group_resv_regions(group, &group_resv_regions);
1274         list_for_each_entry(region, &group_resv_regions, list) {
1275                 /*
1276                  * The presence of any 'real' MSI regions should take
1277                  * precedence over the software-managed one if the
1278                  * IOMMU driver happens to advertise both types.
1279                  */
1280                 if (region->type == IOMMU_RESV_MSI) {
1281                         ret = false;
1282                         break;
1283                 }
1284
1285                 if (region->type == IOMMU_RESV_SW_MSI) {
1286                         *base = region->start;
1287                         ret = true;
1288                 }
1289         }
1290         list_for_each_entry_safe(region, next, &group_resv_regions, list)
1291                 kfree(region);
1292         return ret;
1293 }
1294
1295 static struct device *vfio_mdev_get_iommu_device(struct device *dev)
1296 {
1297         struct device *(*fn)(struct device *dev);
1298         struct device *iommu_device;
1299
1300         fn = symbol_get(mdev_get_iommu_device);
1301         if (fn) {
1302                 iommu_device = fn(dev);
1303                 symbol_put(mdev_get_iommu_device);
1304
1305                 return iommu_device;
1306         }
1307
1308         return NULL;
1309 }
1310
1311 static int vfio_mdev_attach_domain(struct device *dev, void *data)
1312 {
1313         struct iommu_domain *domain = data;
1314         struct device *iommu_device;
1315
1316         iommu_device = vfio_mdev_get_iommu_device(dev);
1317         if (iommu_device) {
1318                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1319                         return iommu_aux_attach_device(domain, iommu_device);
1320                 else
1321                         return iommu_attach_device(domain, iommu_device);
1322         }
1323
1324         return -EINVAL;
1325 }
1326
1327 static int vfio_mdev_detach_domain(struct device *dev, void *data)
1328 {
1329         struct iommu_domain *domain = data;
1330         struct device *iommu_device;
1331
1332         iommu_device = vfio_mdev_get_iommu_device(dev);
1333         if (iommu_device) {
1334                 if (iommu_dev_feature_enabled(iommu_device, IOMMU_DEV_FEAT_AUX))
1335                         iommu_aux_detach_device(domain, iommu_device);
1336                 else
1337                         iommu_detach_device(domain, iommu_device);
1338         }
1339
1340         return 0;
1341 }
1342
1343 static int vfio_iommu_attach_group(struct vfio_domain *domain,
1344                                    struct vfio_group *group)
1345 {
1346         if (group->mdev_group)
1347                 return iommu_group_for_each_dev(group->iommu_group,
1348                                                 domain->domain,
1349                                                 vfio_mdev_attach_domain);
1350         else
1351                 return iommu_attach_group(domain->domain, group->iommu_group);
1352 }
1353
1354 static void vfio_iommu_detach_group(struct vfio_domain *domain,
1355                                     struct vfio_group *group)
1356 {
1357         if (group->mdev_group)
1358                 iommu_group_for_each_dev(group->iommu_group, domain->domain,
1359                                          vfio_mdev_detach_domain);
1360         else
1361                 iommu_detach_group(domain->domain, group->iommu_group);
1362 }
1363
1364 static bool vfio_bus_is_mdev(struct bus_type *bus)
1365 {
1366         struct bus_type *mdev_bus;
1367         bool ret = false;
1368
1369         mdev_bus = symbol_get(mdev_bus_type);
1370         if (mdev_bus) {
1371                 ret = (bus == mdev_bus);
1372                 symbol_put(mdev_bus_type);
1373         }
1374
1375         return ret;
1376 }
1377
1378 static int vfio_mdev_iommu_device(struct device *dev, void *data)
1379 {
1380         struct device **old = data, *new;
1381
1382         new = vfio_mdev_get_iommu_device(dev);
1383         if (!new || (*old && *old != new))
1384                 return -EINVAL;
1385
1386         *old = new;
1387
1388         return 0;
1389 }
1390
1391 static int vfio_iommu_type1_attach_group(void *iommu_data,
1392                                          struct iommu_group *iommu_group)
1393 {
1394         struct vfio_iommu *iommu = iommu_data;
1395         struct vfio_group *group;
1396         struct vfio_domain *domain, *d;
1397         struct bus_type *bus = NULL;
1398         int ret;
1399         bool resv_msi, msi_remap;
1400         phys_addr_t resv_msi_base;
1401
1402         mutex_lock(&iommu->lock);
1403
1404         list_for_each_entry(d, &iommu->domain_list, next) {
1405                 if (find_iommu_group(d, iommu_group)) {
1406                         mutex_unlock(&iommu->lock);
1407                         return -EINVAL;
1408                 }
1409         }
1410
1411         if (iommu->external_domain) {
1412                 if (find_iommu_group(iommu->external_domain, iommu_group)) {
1413                         mutex_unlock(&iommu->lock);
1414                         return -EINVAL;
1415                 }
1416         }
1417
1418         group = kzalloc(sizeof(*group), GFP_KERNEL);
1419         domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1420         if (!group || !domain) {
1421                 ret = -ENOMEM;
1422                 goto out_free;
1423         }
1424
1425         group->iommu_group = iommu_group;
1426
1427         /* Determine bus_type in order to allocate a domain */
1428         ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1429         if (ret)
1430                 goto out_free;
1431
1432         if (vfio_bus_is_mdev(bus)) {
1433                 struct device *iommu_device = NULL;
1434
1435                 group->mdev_group = true;
1436
1437                 /* Determine the isolation type */
1438                 ret = iommu_group_for_each_dev(iommu_group, &iommu_device,
1439                                                vfio_mdev_iommu_device);
1440                 if (ret || !iommu_device) {
1441                         if (!iommu->external_domain) {
1442                                 INIT_LIST_HEAD(&domain->group_list);
1443                                 iommu->external_domain = domain;
1444                         } else {
1445                                 kfree(domain);
1446                         }
1447
1448                         list_add(&group->next,
1449                                  &iommu->external_domain->group_list);
1450                         mutex_unlock(&iommu->lock);
1451
1452                         return 0;
1453                 }
1454
1455                 bus = iommu_device->bus;
1456         }
1457
1458         domain->domain = iommu_domain_alloc(bus);
1459         if (!domain->domain) {
1460                 ret = -EIO;
1461                 goto out_free;
1462         }
1463
1464         if (iommu->nesting) {
1465                 int attr = 1;
1466
1467                 ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1468                                             &attr);
1469                 if (ret)
1470                         goto out_domain;
1471         }
1472
1473         ret = vfio_iommu_attach_group(domain, group);
1474         if (ret)
1475                 goto out_domain;
1476
1477         resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1478
1479         INIT_LIST_HEAD(&domain->group_list);
1480         list_add(&group->next, &domain->group_list);
1481
1482         msi_remap = irq_domain_check_msi_remap() ||
1483                     iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1484
1485         if (!allow_unsafe_interrupts && !msi_remap) {
1486                 pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1487                        __func__);
1488                 ret = -EPERM;
1489                 goto out_detach;
1490         }
1491
1492         if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1493                 domain->prot |= IOMMU_CACHE;
1494
1495         /*
1496          * Try to match an existing compatible domain.  We don't want to
1497          * preclude an IOMMU driver supporting multiple bus_types and being
1498          * able to include different bus_types in the same IOMMU domain, so
1499          * we test whether the domains use the same iommu_ops rather than
1500          * testing if they're on the same bus_type.
1501          */
1502         list_for_each_entry(d, &iommu->domain_list, next) {
1503                 if (d->domain->ops == domain->domain->ops &&
1504                     d->prot == domain->prot) {
1505                         vfio_iommu_detach_group(domain, group);
1506                         if (!vfio_iommu_attach_group(d, group)) {
1507                                 list_add(&group->next, &d->group_list);
1508                                 iommu_domain_free(domain->domain);
1509                                 kfree(domain);
1510                                 mutex_unlock(&iommu->lock);
1511                                 return 0;
1512                         }
1513
1514                         ret = vfio_iommu_attach_group(domain, group);
1515                         if (ret)
1516                                 goto out_domain;
1517                 }
1518         }
1519
1520         vfio_test_domain_fgsp(domain);
1521
1522         /* replay mappings on new domains */
1523         ret = vfio_iommu_replay(iommu, domain);
1524         if (ret)
1525                 goto out_detach;
1526
1527         if (resv_msi) {
1528                 ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1529                 if (ret)
1530                         goto out_detach;
1531         }
1532
1533         list_add(&domain->next, &iommu->domain_list);
1534
1535         mutex_unlock(&iommu->lock);
1536
1537         return 0;
1538
1539 out_detach:
1540         vfio_iommu_detach_group(domain, group);
1541 out_domain:
1542         iommu_domain_free(domain->domain);
1543 out_free:
1544         kfree(domain);
1545         kfree(group);
1546         mutex_unlock(&iommu->lock);
1547         return ret;
1548 }
1549
1550 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1551 {
1552         struct rb_node *node;
1553
1554         while ((node = rb_first(&iommu->dma_list)))
1555                 vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1556 }
1557
1558 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1559 {
1560         struct rb_node *n, *p;
1561
1562         n = rb_first(&iommu->dma_list);
1563         for (; n; n = rb_next(n)) {
1564                 struct vfio_dma *dma;
1565                 long locked = 0, unlocked = 0;
1566
1567                 dma = rb_entry(n, struct vfio_dma, node);
1568                 unlocked += vfio_unmap_unpin(iommu, dma, false);
1569                 p = rb_first(&dma->pfn_list);
1570                 for (; p; p = rb_next(p)) {
1571                         struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1572                                                          node);
1573
1574                         if (!is_invalid_reserved_pfn(vpfn->pfn))
1575                                 locked++;
1576                 }
1577                 vfio_lock_acct(dma, locked - unlocked, true);
1578         }
1579 }
1580
1581 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1582 {
1583         struct rb_node *n;
1584
1585         n = rb_first(&iommu->dma_list);
1586         for (; n; n = rb_next(n)) {
1587                 struct vfio_dma *dma;
1588
1589                 dma = rb_entry(n, struct vfio_dma, node);
1590
1591                 if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1592                         break;
1593         }
1594         /* mdev vendor driver must unregister notifier */
1595         WARN_ON(iommu->notifier.head);
1596 }
1597
1598 static void vfio_iommu_type1_detach_group(void *iommu_data,
1599                                           struct iommu_group *iommu_group)
1600 {
1601         struct vfio_iommu *iommu = iommu_data;
1602         struct vfio_domain *domain;
1603         struct vfio_group *group;
1604
1605         mutex_lock(&iommu->lock);
1606
1607         if (iommu->external_domain) {
1608                 group = find_iommu_group(iommu->external_domain, iommu_group);
1609                 if (group) {
1610                         list_del(&group->next);
1611                         kfree(group);
1612
1613                         if (list_empty(&iommu->external_domain->group_list)) {
1614                                 vfio_sanity_check_pfn_list(iommu);
1615
1616                                 if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1617                                         vfio_iommu_unmap_unpin_all(iommu);
1618
1619                                 kfree(iommu->external_domain);
1620                                 iommu->external_domain = NULL;
1621                         }
1622                         goto detach_group_done;
1623                 }
1624         }
1625
1626         list_for_each_entry(domain, &iommu->domain_list, next) {
1627                 group = find_iommu_group(domain, iommu_group);
1628                 if (!group)
1629                         continue;
1630
1631                 vfio_iommu_detach_group(domain, group);
1632                 list_del(&group->next);
1633                 kfree(group);
1634                 /*
1635                  * Group ownership provides privilege, if the group list is
1636                  * empty, the domain goes away. If it's the last domain with
1637                  * iommu and external domain doesn't exist, then all the
1638                  * mappings go away too. If it's the last domain with iommu and
1639                  * external domain exist, update accounting
1640                  */
1641                 if (list_empty(&domain->group_list)) {
1642                         if (list_is_singular(&iommu->domain_list)) {
1643                                 if (!iommu->external_domain)
1644                                         vfio_iommu_unmap_unpin_all(iommu);
1645                                 else
1646                                         vfio_iommu_unmap_unpin_reaccount(iommu);
1647                         }
1648                         iommu_domain_free(domain->domain);
1649                         list_del(&domain->next);
1650                         kfree(domain);
1651                 }
1652                 break;
1653         }
1654
1655 detach_group_done:
1656         mutex_unlock(&iommu->lock);
1657 }
1658
1659 static void *vfio_iommu_type1_open(unsigned long arg)
1660 {
1661         struct vfio_iommu *iommu;
1662
1663         iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1664         if (!iommu)
1665                 return ERR_PTR(-ENOMEM);
1666
1667         switch (arg) {
1668         case VFIO_TYPE1_IOMMU:
1669                 break;
1670         case VFIO_TYPE1_NESTING_IOMMU:
1671                 iommu->nesting = true;
1672                 /* fall through */
1673         case VFIO_TYPE1v2_IOMMU:
1674                 iommu->v2 = true;
1675                 break;
1676         default:
1677                 kfree(iommu);
1678                 return ERR_PTR(-EINVAL);
1679         }
1680
1681         INIT_LIST_HEAD(&iommu->domain_list);
1682         iommu->dma_list = RB_ROOT;
1683         iommu->dma_avail = dma_entry_limit;
1684         mutex_init(&iommu->lock);
1685         BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1686
1687         return iommu;
1688 }
1689
1690 static void vfio_release_domain(struct vfio_domain *domain, bool external)
1691 {
1692         struct vfio_group *group, *group_tmp;
1693
1694         list_for_each_entry_safe(group, group_tmp,
1695                                  &domain->group_list, next) {
1696                 if (!external)
1697                         vfio_iommu_detach_group(domain, group);
1698                 list_del(&group->next);
1699                 kfree(group);
1700         }
1701
1702         if (!external)
1703                 iommu_domain_free(domain->domain);
1704 }
1705
1706 static void vfio_iommu_type1_release(void *iommu_data)
1707 {
1708         struct vfio_iommu *iommu = iommu_data;
1709         struct vfio_domain *domain, *domain_tmp;
1710
1711         if (iommu->external_domain) {
1712                 vfio_release_domain(iommu->external_domain, true);
1713                 vfio_sanity_check_pfn_list(iommu);
1714                 kfree(iommu->external_domain);
1715         }
1716
1717         vfio_iommu_unmap_unpin_all(iommu);
1718
1719         list_for_each_entry_safe(domain, domain_tmp,
1720                                  &iommu->domain_list, next) {
1721                 vfio_release_domain(domain, false);
1722                 list_del(&domain->next);
1723                 kfree(domain);
1724         }
1725         kfree(iommu);
1726 }
1727
1728 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1729 {
1730         struct vfio_domain *domain;
1731         int ret = 1;
1732
1733         mutex_lock(&iommu->lock);
1734         list_for_each_entry(domain, &iommu->domain_list, next) {
1735                 if (!(domain->prot & IOMMU_CACHE)) {
1736                         ret = 0;
1737                         break;
1738                 }
1739         }
1740         mutex_unlock(&iommu->lock);
1741
1742         return ret;
1743 }
1744
1745 static long vfio_iommu_type1_ioctl(void *iommu_data,
1746                                    unsigned int cmd, unsigned long arg)
1747 {
1748         struct vfio_iommu *iommu = iommu_data;
1749         unsigned long minsz;
1750
1751         if (cmd == VFIO_CHECK_EXTENSION) {
1752                 switch (arg) {
1753                 case VFIO_TYPE1_IOMMU:
1754                 case VFIO_TYPE1v2_IOMMU:
1755                 case VFIO_TYPE1_NESTING_IOMMU:
1756                         return 1;
1757                 case VFIO_DMA_CC_IOMMU:
1758                         if (!iommu)
1759                                 return 0;
1760                         return vfio_domains_have_iommu_cache(iommu);
1761                 default:
1762                         return 0;
1763                 }
1764         } else if (cmd == VFIO_IOMMU_GET_INFO) {
1765                 struct vfio_iommu_type1_info info;
1766
1767                 minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1768
1769                 if (copy_from_user(&info, (void __user *)arg, minsz))
1770                         return -EFAULT;
1771
1772                 if (info.argsz < minsz)
1773                         return -EINVAL;
1774
1775                 info.flags = VFIO_IOMMU_INFO_PGSIZES;
1776
1777                 info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1778
1779                 return copy_to_user((void __user *)arg, &info, minsz) ?
1780                         -EFAULT : 0;
1781
1782         } else if (cmd == VFIO_IOMMU_MAP_DMA) {
1783                 struct vfio_iommu_type1_dma_map map;
1784                 uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1785                                 VFIO_DMA_MAP_FLAG_WRITE;
1786
1787                 minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1788
1789                 if (copy_from_user(&map, (void __user *)arg, minsz))
1790                         return -EFAULT;
1791
1792                 if (map.argsz < minsz || map.flags & ~mask)
1793                         return -EINVAL;
1794
1795                 return vfio_dma_do_map(iommu, &map);
1796
1797         } else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1798                 struct vfio_iommu_type1_dma_unmap unmap;
1799                 long ret;
1800
1801                 minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1802
1803                 if (copy_from_user(&unmap, (void __user *)arg, minsz))
1804                         return -EFAULT;
1805
1806                 if (unmap.argsz < minsz || unmap.flags)
1807                         return -EINVAL;
1808
1809                 ret = vfio_dma_do_unmap(iommu, &unmap);
1810                 if (ret)
1811                         return ret;
1812
1813                 return copy_to_user((void __user *)arg, &unmap, minsz) ?
1814                         -EFAULT : 0;
1815         }
1816
1817         return -ENOTTY;
1818 }
1819
1820 static int vfio_iommu_type1_register_notifier(void *iommu_data,
1821                                               unsigned long *events,
1822                                               struct notifier_block *nb)
1823 {
1824         struct vfio_iommu *iommu = iommu_data;
1825
1826         /* clear known events */
1827         *events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1828
1829         /* refuse to register if still events remaining */
1830         if (*events)
1831                 return -EINVAL;
1832
1833         return blocking_notifier_chain_register(&iommu->notifier, nb);
1834 }
1835
1836 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1837                                                 struct notifier_block *nb)
1838 {
1839         struct vfio_iommu *iommu = iommu_data;
1840
1841         return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1842 }
1843
1844 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1845         .name                   = "vfio-iommu-type1",
1846         .owner                  = THIS_MODULE,
1847         .open                   = vfio_iommu_type1_open,
1848         .release                = vfio_iommu_type1_release,
1849         .ioctl                  = vfio_iommu_type1_ioctl,
1850         .attach_group           = vfio_iommu_type1_attach_group,
1851         .detach_group           = vfio_iommu_type1_detach_group,
1852         .pin_pages              = vfio_iommu_type1_pin_pages,
1853         .unpin_pages            = vfio_iommu_type1_unpin_pages,
1854         .register_notifier      = vfio_iommu_type1_register_notifier,
1855         .unregister_notifier    = vfio_iommu_type1_unregister_notifier,
1856 };
1857
1858 static int __init vfio_iommu_type1_init(void)
1859 {
1860         return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1861 }
1862
1863 static void __exit vfio_iommu_type1_cleanup(void)
1864 {
1865         vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1866 }
1867
1868 module_init(vfio_iommu_type1_init);
1869 module_exit(vfio_iommu_type1_cleanup);
1870
1871 MODULE_VERSION(DRIVER_VERSION);
1872 MODULE_LICENSE("GPL v2");
1873 MODULE_AUTHOR(DRIVER_AUTHOR);
1874 MODULE_DESCRIPTION(DRIVER_DESC);