]> asedeno.scripts.mit.edu Git - linux.git/commitdiff
RMDA/odp: Consolidate umem_odp initialization
authorJason Gunthorpe <jgg@mellanox.com>
Mon, 19 Aug 2019 11:17:02 +0000 (14:17 +0300)
committerJason Gunthorpe <jgg@mellanox.com>
Wed, 21 Aug 2019 17:08:42 +0000 (14:08 -0300)
This is done in two different places, consolidate all the post-allocation
initialization into a single function.

Link: https://lore.kernel.org/r/20190819111710.18440-5-leon@kernel.org
Signed-off-by: Leon Romanovsky <leonro@mellanox.com>
Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
drivers/infiniband/core/umem_odp.c

index 7300d0a10d1ed34c8422ed396bc19660c2b61e02..1643ff374b586a27b75039fc9b90b00a60ed9d93 100644 (file)
@@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
        .invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
 };
 
-static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
-{
-       struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
-
-       down_write(&per_mm->umem_rwsem);
-       /*
-        * Note that the representation of the intervals in the interval tree
-        * considers the ending point as contained in the interval, while the
-        * function ib_umem_end returns the first address which is not
-        * contained in the umem.
-        */
-       umem_odp->interval_tree.start = ib_umem_start(umem_odp);
-       umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
-       interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree);
-       up_write(&per_mm->umem_rwsem);
-}
-
 static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
 {
        struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
@@ -237,33 +220,23 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
        return ERR_PTR(ret);
 }
 
-static int get_per_mm(struct ib_umem_odp *umem_odp)
+static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
 {
        struct ib_ucontext *ctx = umem_odp->umem.context;
        struct ib_ucontext_per_mm *per_mm;
 
+       lockdep_assert_held(&ctx->per_mm_list_lock);
+
        /*
         * Generally speaking we expect only one or two per_mm in this list,
         * so no reason to optimize this search today.
         */
-       mutex_lock(&ctx->per_mm_list_lock);
        list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
                if (per_mm->mm == umem_odp->umem.owning_mm)
-                       goto found;
-       }
-
-       per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
-       if (IS_ERR(per_mm)) {
-               mutex_unlock(&ctx->per_mm_list_lock);
-               return PTR_ERR(per_mm);
+                       return per_mm;
        }
 
-found:
-       umem_odp->per_mm = per_mm;
-       per_mm->odp_mrs_count++;
-       mutex_unlock(&ctx->per_mm_list_lock);
-
-       return 0;
+       return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
 }
 
 static void free_per_mm(struct rcu_head *rcu)
@@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
        mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
 }
 
+static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
+                                  struct ib_ucontext_per_mm *per_mm)
+{
+       struct ib_ucontext *ctx = umem_odp->umem.context;
+       int ret;
+
+       umem_odp->umem.is_odp = 1;
+       if (!umem_odp->is_implicit_odp) {
+               size_t pages = ib_umem_odp_num_pages(umem_odp);
+
+               if (!pages)
+                       return -EINVAL;
+
+               /*
+                * Note that the representation of the intervals in the
+                * interval tree considers the ending point as contained in
+                * the interval, while the function ib_umem_end returns the
+                * first address which is not contained in the umem.
+                */
+               umem_odp->interval_tree.start = ib_umem_start(umem_odp);
+               umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
+
+               umem_odp->page_list = vzalloc(
+                       array_size(sizeof(*umem_odp->page_list), pages));
+               if (!umem_odp->page_list)
+                       return -ENOMEM;
+
+               umem_odp->dma_list =
+                       vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
+               if (!umem_odp->dma_list) {
+                       ret = -ENOMEM;
+                       goto out_page_list;
+               }
+       }
+
+       mutex_lock(&ctx->per_mm_list_lock);
+       if (!per_mm) {
+               per_mm = get_per_mm(umem_odp);
+               if (IS_ERR(per_mm)) {
+                       ret = PTR_ERR(per_mm);
+                       goto out_unlock;
+               }
+       }
+       umem_odp->per_mm = per_mm;
+       per_mm->odp_mrs_count++;
+       mutex_unlock(&ctx->per_mm_list_lock);
+
+       mutex_init(&umem_odp->umem_mutex);
+       init_completion(&umem_odp->notifier_completion);
+
+       if (!umem_odp->is_implicit_odp) {
+               down_write(&per_mm->umem_rwsem);
+               interval_tree_insert(&umem_odp->interval_tree,
+                                    &per_mm->umem_tree);
+               up_write(&per_mm->umem_rwsem);
+       }
+
+       return 0;
+
+out_unlock:
+       mutex_unlock(&ctx->per_mm_list_lock);
+       vfree(umem_odp->dma_list);
+out_page_list:
+       vfree(umem_odp->page_list);
+       return ret;
+}
+
 struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
                                      unsigned long addr, size_t size)
 {
-       struct ib_ucontext_per_mm *per_mm = root->per_mm;
-       struct ib_ucontext *ctx = per_mm->context;
+       /*
+        * Caller must ensure that root cannot be freed during the call to
+        * ib_alloc_odp_umem.
+        */
        struct ib_umem_odp *odp_data;
        struct ib_umem *umem;
-       int pages = size >> PAGE_SHIFT;
        int ret;
 
-       if (!size)
-               return ERR_PTR(-EINVAL);
-
        odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
        if (!odp_data)
                return ERR_PTR(-ENOMEM);
        umem = &odp_data->umem;
-       umem->context    = ctx;
+       umem->context    = root->umem.context;
        umem->length     = size;
        umem->address    = addr;
-       odp_data->page_shift = PAGE_SHIFT;
        umem->writable   = root->umem.writable;
-       umem->is_odp = 1;
-       odp_data->per_mm = per_mm;
-       umem->owning_mm  = per_mm->mm;
-       mmgrab(umem->owning_mm);
-
-       mutex_init(&odp_data->umem_mutex);
-       init_completion(&odp_data->notifier_completion);
-
-       odp_data->page_list =
-               vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
-       if (!odp_data->page_list) {
-               ret = -ENOMEM;
-               goto out_odp_data;
-       }
+       umem->owning_mm  = root->umem.owning_mm;
+       odp_data->page_shift = PAGE_SHIFT;
 
-       odp_data->dma_list =
-               vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
-       if (!odp_data->dma_list) {
-               ret = -ENOMEM;
-               goto out_page_list;
+       ret = ib_init_umem_odp(odp_data, root->per_mm);
+       if (ret) {
+               kfree(odp_data);
+               return ERR_PTR(ret);
        }
 
-       /*
-        * Caller must ensure that the umem_odp that the per_mm came from
-        * cannot be freed during the call to ib_alloc_odp_umem.
-        */
-       mutex_lock(&ctx->per_mm_list_lock);
-       per_mm->odp_mrs_count++;
-       mutex_unlock(&ctx->per_mm_list_lock);
-       add_umem_to_per_mm(odp_data);
+       mmgrab(umem->owning_mm);
 
        return odp_data;
-
-out_page_list:
-       vfree(odp_data->page_list);
-out_odp_data:
-       mmdrop(umem->owning_mm);
-       kfree(odp_data);
-       return ERR_PTR(ret);
 }
 EXPORT_SYMBOL(ib_alloc_odp_umem);
 
 int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
 {
-       struct ib_umem *umem = &umem_odp->umem;
        /*
         * NOTE: This must called in a process context where umem->owning_mm
         * == current->mm
         */
-       struct mm_struct *mm = umem->owning_mm;
-       int ret_val;
+       struct mm_struct *mm = umem_odp->umem.owning_mm;
 
        if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
                umem_odp->is_implicit_odp = 1;
@@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
                up_read(&mm->mmap_sem);
        }
 
-       mutex_init(&umem_odp->umem_mutex);
-
-       init_completion(&umem_odp->notifier_completion);
-
-       if (!umem_odp->is_implicit_odp) {
-               if (!ib_umem_odp_num_pages(umem_odp))
-                       return -EINVAL;
-
-               umem_odp->page_list =
-                       vzalloc(array_size(sizeof(*umem_odp->page_list),
-                                          ib_umem_odp_num_pages(umem_odp)));
-               if (!umem_odp->page_list)
-                       return -ENOMEM;
-
-               umem_odp->dma_list =
-                       vzalloc(array_size(sizeof(*umem_odp->dma_list),
-                                          ib_umem_odp_num_pages(umem_odp)));
-               if (!umem_odp->dma_list) {
-                       ret_val = -ENOMEM;
-                       goto out_page_list;
-               }
-       }
-
-       ret_val = get_per_mm(umem_odp);
-       if (ret_val)
-               goto out_dma_list;
-
-       if (!umem_odp->is_implicit_odp)
-               add_umem_to_per_mm(umem_odp);
-
-       return 0;
-
-out_dma_list:
-       vfree(umem_odp->dma_list);
-out_page_list:
-       vfree(umem_odp->page_list);
-       return ret_val;
+       return ib_init_umem_odp(umem_odp, NULL);
 }
 
 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)