]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - mm/hmm.c
mm/hmm: improve and rename hmm_vma_get_pfns() to hmm_range_snapshot()
[linux.git] / mm / hmm.c
index fe1cd87e49acc94641eaf7178dc07e5c4306e408..bd957a9f10d1391268b4f3be88df572645cb23cc 100644 (file)
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -50,6 +50,7 @@ static const struct mmu_notifier_ops hmm_mmu_notifier_ops;
  */
 struct hmm {
        struct mm_struct        *mm;
+       struct kref             kref;
        spinlock_t              lock;
        struct list_head        ranges;
        struct list_head        mirrors;
@@ -57,24 +58,33 @@ struct hmm {
        struct rw_semaphore     mirrors_sem;
 };
 
-/*
- * hmm_register - register HMM against an mm (HMM internal)
+static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
+{
+       struct hmm *hmm = READ_ONCE(mm->hmm);
+
+       if (hmm && kref_get_unless_zero(&hmm->kref))
+               return hmm;
+
+       return NULL;
+}
+
+/**
+ * hmm_get_or_create - register HMM against an mm (HMM internal)
  *
  * @mm: mm struct to attach to
+ * Returns: returns an HMM object, either by referencing the existing
+ *          (per-process) object, or by creating a new one.
  *
- * This is not intended to be used directly by device drivers. It allocates an
- * HMM struct if mm does not have one, and initializes it.
+ * This is not intended to be used directly by device drivers. If mm already
+ * has an HMM struct then it get a reference on it and returns it. Otherwise
+ * it allocates an HMM struct, initializes it, associate it with the mm and
+ * returns it.
  */
-static struct hmm *hmm_register(struct mm_struct *mm)
+static struct hmm *hmm_get_or_create(struct mm_struct *mm)
 {
-       struct hmm *hmm = READ_ONCE(mm->hmm);
+       struct hmm *hmm = mm_get_hmm(mm);
        bool cleanup = false;
 
-       /*
-        * The hmm struct can only be freed once the mm_struct goes away,
-        * hence we should always have pre-allocated an new hmm struct
-        * above.
-        */
        if (hmm)
                return hmm;
 
@@ -86,6 +96,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
        hmm->mmu_notifier.ops = NULL;
        INIT_LIST_HEAD(&hmm->ranges);
        spin_lock_init(&hmm->lock);
+       kref_init(&hmm->kref);
        hmm->mm = mm;
 
        spin_lock(&mm->page_table_lock);
@@ -106,7 +117,7 @@ static struct hmm *hmm_register(struct mm_struct *mm)
        if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
                goto error_mm;
 
-       return mm->hmm;
+       return hmm;
 
 error_mm:
        spin_lock(&mm->page_table_lock);
@@ -118,9 +129,41 @@ static struct hmm *hmm_register(struct mm_struct *mm)
        return NULL;
 }
 
+static void hmm_free(struct kref *kref)
+{
+       struct hmm *hmm = container_of(kref, struct hmm, kref);
+       struct mm_struct *mm = hmm->mm;
+
+       mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
+
+       spin_lock(&mm->page_table_lock);
+       if (mm->hmm == hmm)
+               mm->hmm = NULL;
+       spin_unlock(&mm->page_table_lock);
+
+       kfree(hmm);
+}
+
+static inline void hmm_put(struct hmm *hmm)
+{
+       kref_put(&hmm->kref, hmm_free);
+}
+
 void hmm_mm_destroy(struct mm_struct *mm)
 {
-       kfree(mm->hmm);
+       struct hmm *hmm;
+
+       spin_lock(&mm->page_table_lock);
+       hmm = mm_get_hmm(mm);
+       mm->hmm = NULL;
+       if (hmm) {
+               hmm->mm = NULL;
+               spin_unlock(&mm->page_table_lock);
+               hmm_put(hmm);
+               return;
+       }
+
+       spin_unlock(&mm->page_table_lock);
 }
 
 static int hmm_invalidate_range(struct hmm *hmm, bool device,
@@ -131,16 +174,10 @@ static int hmm_invalidate_range(struct hmm *hmm, bool device,
 
        spin_lock(&hmm->lock);
        list_for_each_entry(range, &hmm->ranges, list) {
-               unsigned long addr, idx, npages;
-
                if (update->end < range->start || update->start >= range->end)
                        continue;
 
                range->valid = false;
-               addr = max(update->start, range->start);
-               idx = (addr - range->start) >> PAGE_SHIFT;
-               npages = (min(range->end, update->end) - addr) >> PAGE_SHIFT;
-               memset(&range->pfns[idx], 0, sizeof(*range->pfns) * npages);
        }
        spin_unlock(&hmm->lock);
 
@@ -165,7 +202,7 @@ static int hmm_invalidate_range(struct hmm *hmm, bool device,
 static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 {
        struct hmm_mirror *mirror;
-       struct hmm *hmm = mm->hmm;
+       struct hmm *hmm = mm_get_hmm(mm);
 
        down_write(&hmm->mirrors_sem);
        mirror = list_first_entry_or_null(&hmm->mirrors, struct hmm_mirror,
@@ -186,13 +223,16 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
                                                  struct hmm_mirror, list);
        }
        up_write(&hmm->mirrors_sem);
+
+       hmm_put(hmm);
 }
 
 static int hmm_invalidate_range_start(struct mmu_notifier *mn,
                        const struct mmu_notifier_range *range)
 {
+       struct hmm *hmm = mm_get_hmm(range->mm);
        struct hmm_update update;
-       struct hmm *hmm = range->mm->hmm;
+       int ret;
 
        VM_BUG_ON(!hmm);
 
@@ -200,14 +240,16 @@ static int hmm_invalidate_range_start(struct mmu_notifier *mn,
        update.end = range->end;
        update.event = HMM_UPDATE_INVALIDATE;
        update.blockable = range->blockable;
-       return hmm_invalidate_range(hmm, true, &update);
+       ret = hmm_invalidate_range(hmm, true, &update);
+       hmm_put(hmm);
+       return ret;
 }
 
 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
                        const struct mmu_notifier_range *range)
 {
+       struct hmm *hmm = mm_get_hmm(range->mm);
        struct hmm_update update;
-       struct hmm *hmm = range->mm->hmm;
 
        VM_BUG_ON(!hmm);
 
@@ -216,6 +258,7 @@ static void hmm_invalidate_range_end(struct mmu_notifier *mn,
        update.event = HMM_UPDATE_INVALIDATE;
        update.blockable = true;
        hmm_invalidate_range(hmm, false, &update);
+       hmm_put(hmm);
 }
 
 static const struct mmu_notifier_ops hmm_mmu_notifier_ops = {
@@ -241,24 +284,13 @@ int hmm_mirror_register(struct hmm_mirror *mirror, struct mm_struct *mm)
        if (!mm || !mirror || !mirror->ops)
                return -EINVAL;
 
-again:
-       mirror->hmm = hmm_register(mm);
+       mirror->hmm = hmm_get_or_create(mm);
        if (!mirror->hmm)
                return -ENOMEM;
 
        down_write(&mirror->hmm->mirrors_sem);
-       if (mirror->hmm->mm == NULL) {
-               /*
-                * A racing hmm_mirror_unregister() is about to destroy the hmm
-                * struct. Try again to allocate a new one.
-                */
-               up_write(&mirror->hmm->mirrors_sem);
-               mirror->hmm = NULL;
-               goto again;
-       } else {
-               list_add(&mirror->list, &mirror->hmm->mirrors);
-               up_write(&mirror->hmm->mirrors_sem);
-       }
+       list_add(&mirror->list, &mirror->hmm->mirrors);
+       up_write(&mirror->hmm->mirrors_sem);
 
        return 0;
 }
@@ -273,33 +305,18 @@ EXPORT_SYMBOL(hmm_mirror_register);
  */
 void hmm_mirror_unregister(struct hmm_mirror *mirror)
 {
-       bool should_unregister = false;
-       struct mm_struct *mm;
-       struct hmm *hmm;
+       struct hmm *hmm = READ_ONCE(mirror->hmm);
 
-       if (mirror->hmm == NULL)
+       if (hmm == NULL)
                return;
 
-       hmm = mirror->hmm;
        down_write(&hmm->mirrors_sem);
        list_del_init(&mirror->list);
-       should_unregister = list_empty(&hmm->mirrors);
+       /* To protect us against double unregister ... */
        mirror->hmm = NULL;
-       mm = hmm->mm;
-       hmm->mm = NULL;
        up_write(&hmm->mirrors_sem);
 
-       if (!should_unregister || mm == NULL)
-               return;
-
-       mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
-
-       spin_lock(&mm->page_table_lock);
-       if (mm->hmm == hmm)
-               mm->hmm = NULL;
-       spin_unlock(&mm->page_table_lock);
-
-       kfree(hmm);
+       hmm_put(hmm);
 }
 EXPORT_SYMBOL(hmm_mirror_unregister);
 
@@ -685,46 +702,54 @@ static void hmm_pfns_special(struct hmm_range *range)
 }
 
 /*
- * hmm_vma_get_pfns() - snapshot CPU page table for a range of virtual addresses
- * @range: range being snapshotted
- * Returns: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
- *          vma permission, 0 success
+ * hmm_range_snapshot() - snapshot CPU page table for a range
+ * @range: range
+ * Returns: number of valid pages in range->pfns[] (from range start
+ *          address). This may be zero. If the return value is negative,
+ *          then one of the following values may be returned:
+ *
+ *           -EINVAL  invalid arguments or mm or virtual address are in an
+ *                    invalid vma (ie either hugetlbfs or device file vma).
+ *           -EPERM   For example, asking for write, when the range is
+ *                    read-only
+ *           -EAGAIN  Caller needs to retry
+ *           -EFAULT  Either no valid vma exists for this range, or it is
+ *                    illegal to access the range
  *
  * This snapshots the CPU page table for a range of virtual addresses. Snapshot
  * validity is tracked by range struct. See hmm_vma_range_done() for further
  * information.
- *
- * The range struct is initialized here. It tracks the CPU page table, but only
- * if the function returns success (0), in which case the caller must then call
- * hmm_vma_range_done() to stop CPU page table update tracking on this range.
- *
- * NOT CALLING hmm_vma_range_done() IF FUNCTION RETURNS 0 WILL LEAD TO SERIOUS
- * MEMORY CORRUPTION ! YOU HAVE BEEN WARNED !
  */
-int hmm_vma_get_pfns(struct hmm_range *range)
+long hmm_range_snapshot(struct hmm_range *range)
 {
        struct vm_area_struct *vma = range->vma;
        struct hmm_vma_walk hmm_vma_walk;
        struct mm_walk mm_walk;
        struct hmm *hmm;
 
+       range->hmm = NULL;
+
        /* Sanity check, this really should not happen ! */
        if (range->start < vma->vm_start || range->start >= vma->vm_end)
                return -EINVAL;
        if (range->end < vma->vm_start || range->end > vma->vm_end)
                return -EINVAL;
 
-       hmm = hmm_register(vma->vm_mm);
+       hmm = hmm_get_or_create(vma->vm_mm);
        if (!hmm)
                return -ENOMEM;
-       /* Caller must have registered a mirror, via hmm_mirror_register() ! */
-       if (!hmm->mmu_notifier.ops)
+
+       /* Check if hmm_mm_destroy() was call. */
+       if (hmm->mm == NULL) {
+               hmm_put(hmm);
                return -EINVAL;
+       }
 
        /* FIXME support hugetlb fs */
        if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
                        vma_is_dax(vma)) {
                hmm_pfns_special(range);
+               hmm_put(hmm);
                return -EINVAL;
        }
 
@@ -736,6 +761,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
                 * operations such has atomic access would not work.
                 */
                hmm_pfns_clear(range, range->pfns, range->start, range->end);
+               hmm_put(hmm);
                return -EPERM;
        }
 
@@ -748,6 +774,7 @@ int hmm_vma_get_pfns(struct hmm_range *range)
        hmm_vma_walk.fault = false;
        hmm_vma_walk.range = range;
        mm_walk.private = &hmm_vma_walk;
+       hmm_vma_walk.last = range->start;
 
        mm_walk.vma = vma;
        mm_walk.mm = vma->vm_mm;
@@ -758,9 +785,15 @@ int hmm_vma_get_pfns(struct hmm_range *range)
        mm_walk.pte_hole = hmm_vma_walk_hole;
 
        walk_page_range(range->start, range->end, &mm_walk);
-       return 0;
+       /*
+        * Transfer hmm reference to the range struct it will be drop inside
+        * the hmm_vma_range_done() function (which _must_ be call if this
+        * function return 0).
+        */
+       range->hmm = hmm;
+       return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
 }
-EXPORT_SYMBOL(hmm_vma_get_pfns);
+EXPORT_SYMBOL(hmm_range_snapshot);
 
 /*
  * hmm_vma_range_done() - stop tracking change to CPU page table over a range
@@ -802,25 +835,27 @@ EXPORT_SYMBOL(hmm_vma_get_pfns);
  */
 bool hmm_vma_range_done(struct hmm_range *range)
 {
-       unsigned long npages = (range->end - range->start) >> PAGE_SHIFT;
-       struct hmm *hmm;
+       bool ret = false;
 
-       if (range->end <= range->start) {
+       /* Sanity check this really should not happen. */
+       if (range->hmm == NULL || range->end <= range->start) {
                BUG();
                return false;
        }
 
-       hmm = hmm_register(range->vma->vm_mm);
-       if (!hmm) {
-               memset(range->pfns, 0, sizeof(*range->pfns) * npages);
-               return false;
-       }
-
-       spin_lock(&hmm->lock);
+       spin_lock(&range->hmm->lock);
        list_del_rcu(&range->list);
-       spin_unlock(&hmm->lock);
+       ret = range->valid;
+       spin_unlock(&range->hmm->lock);
+
+       /* Is the mm still alive ? */
+       if (range->hmm->mm == NULL)
+               ret = false;
 
-       return range->valid;
+       /* Drop reference taken by hmm_vma_fault() or hmm_vma_get_pfns() */
+       hmm_put(range->hmm);
+       range->hmm = NULL;
+       return ret;
 }
 EXPORT_SYMBOL(hmm_vma_range_done);
 
@@ -880,25 +915,31 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
        struct hmm *hmm;
        int ret;
 
+       range->hmm = NULL;
+
        /* Sanity check, this really should not happen ! */
        if (range->start < vma->vm_start || range->start >= vma->vm_end)
                return -EINVAL;
        if (range->end < vma->vm_start || range->end > vma->vm_end)
                return -EINVAL;
 
-       hmm = hmm_register(vma->vm_mm);
+       hmm = hmm_get_or_create(vma->vm_mm);
        if (!hmm) {
                hmm_pfns_clear(range, range->pfns, range->start, range->end);
                return -ENOMEM;
        }
-       /* Caller must have registered a mirror using hmm_mirror_register() */
-       if (!hmm->mmu_notifier.ops)
+
+       /* Check if hmm_mm_destroy() was call. */
+       if (hmm->mm == NULL) {
+               hmm_put(hmm);
                return -EINVAL;
+       }
 
        /* FIXME support hugetlb fs */
        if (is_vm_hugetlb_page(vma) || (vma->vm_flags & VM_SPECIAL) ||
                        vma_is_dax(vma)) {
                hmm_pfns_special(range);
+               hmm_put(hmm);
                return -EINVAL;
        }
 
@@ -910,6 +951,7 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
                 * operations such has atomic access would not work.
                 */
                hmm_pfns_clear(range, range->pfns, range->start, range->end);
+               hmm_put(hmm);
                return -EPERM;
        }
 
@@ -945,7 +987,16 @@ int hmm_vma_fault(struct hmm_range *range, bool block)
                hmm_pfns_clear(range, &range->pfns[i], hmm_vma_walk.last,
                               range->end);
                hmm_vma_range_done(range);
+               hmm_put(hmm);
+       } else {
+               /*
+                * Transfer hmm reference to the range struct it will be drop
+                * inside the hmm_vma_range_done() function (which _must_ be
+                * call if this function return 0).
+                */
+               range->hmm = hmm;
        }
+
        return ret;
 }
 EXPORT_SYMBOL(hmm_vma_fault);