Commit c571feca authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

RDMA/odp: use mmu_notifier_get/put for 'struct ib_ucontext_per_mm'

This is a significant simplification, no extra list is kept per FD, and
the interval tree is now shared between all the ucontexts, reducing
overhead if there are multiple ucontexts active.

Link: https://lore.kernel.org/r/20190806231548.25242-7-jgg@ziepe.ca


Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent daa138a5
Loading
Loading
Loading
Loading
+53 −115
Original line number Diff line number Diff line
@@ -82,7 +82,7 @@ static void ib_umem_notifier_release(struct mmu_notifier *mn,
	struct rb_node *node;

	down_read(&per_mm->umem_rwsem);
	if (!per_mm->active)
	if (!per_mm->mn.users)
		goto out;

	for (node = rb_first_cached(&per_mm->umem_tree); node;
@@ -125,10 +125,10 @@ static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
	else if (!down_read_trylock(&per_mm->umem_rwsem))
		return -EAGAIN;

	if (!per_mm->active) {
	if (!per_mm->mn.users) {
		up_read(&per_mm->umem_rwsem);
		/*
		 * At this point active is permanently set and visible to this
		 * At this point users is permanently zero and visible to this
		 * CPU without a lock, that fact is relied on to skip the unlock
		 * in range_end.
		 */
@@ -158,7 +158,7 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
	struct ib_ucontext_per_mm *per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);

	if (unlikely(!per_mm->active))
	if (unlikely(!per_mm->mn.users))
		return;

	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
@@ -167,122 +167,47 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
	up_read(&per_mm->umem_rwsem);
}

static const struct mmu_notifier_ops ib_umem_notifiers = {
	.release                    = ib_umem_notifier_release,
	.invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
};

static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

	down_write(&per_mm->umem_rwsem);
	interval_tree_remove(&umem_odp->interval_tree, &per_mm->umem_tree);
	complete_all(&umem_odp->notifier_completion);
	up_write(&per_mm->umem_rwsem);
}

static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
					       struct mm_struct *mm)
static struct mmu_notifier *ib_umem_alloc_notifier(struct mm_struct *mm)
{
	struct ib_ucontext_per_mm *per_mm;
	int ret;

	per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
	if (!per_mm)
		return ERR_PTR(-ENOMEM);

	per_mm->context = ctx;
	per_mm->mm = mm;
	per_mm->umem_tree = RB_ROOT_CACHED;
	init_rwsem(&per_mm->umem_rwsem);
	per_mm->active = true;

	WARN_ON(mm != current->mm);
	rcu_read_lock();
	per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
	rcu_read_unlock();

	WARN_ON(mm != current->mm);

	per_mm->mn.ops = &ib_umem_notifiers;
	ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
	if (ret) {
		dev_err(&ctx->device->dev,
			"Failed to register mmu_notifier %d\n", ret);
		goto out_pid;
	}

	list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
	return per_mm;

out_pid:
	put_pid(per_mm->tgid);
	kfree(per_mm);
	return ERR_PTR(ret);
}

static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext *ctx = umem_odp->umem.context;
	struct ib_ucontext_per_mm *per_mm;

	lockdep_assert_held(&ctx->per_mm_list_lock);

	/*
	 * Generally speaking we expect only one or two per_mm in this list,
	 * so no reason to optimize this search today.
	 */
	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
		if (per_mm->mm == umem_odp->umem.owning_mm)
			return per_mm;
	}

	return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
}

static void free_per_mm(struct rcu_head *rcu)
{
	kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
	return &per_mm->mn;
}

static void put_per_mm(struct ib_umem_odp *umem_odp)
static void ib_umem_free_notifier(struct mmu_notifier *mn)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
	struct ib_ucontext *ctx = umem_odp->umem.context;
	bool need_free;

	mutex_lock(&ctx->per_mm_list_lock);
	umem_odp->per_mm = NULL;
	per_mm->odp_mrs_count--;
	need_free = per_mm->odp_mrs_count == 0;
	if (need_free)
		list_del(&per_mm->ucontext_list);
	mutex_unlock(&ctx->per_mm_list_lock);

	if (!need_free)
		return;

	/*
	 * NOTE! mmu_notifier_unregister() can happen between a start/end
	 * callback, resulting in an start/end, and thus an unbalanced
	 * lock. This doesn't really matter to us since we are about to kfree
	 * the memory that holds the lock, however LOCKDEP doesn't like this.
	 */
	down_write(&per_mm->umem_rwsem);
	per_mm->active = false;
	up_write(&per_mm->umem_rwsem);
	struct ib_ucontext_per_mm *per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);

	WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
	mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);

	put_pid(per_mm->tgid);
	mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
	kfree(per_mm);
}

static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
				   struct ib_ucontext_per_mm *per_mm)
static const struct mmu_notifier_ops ib_umem_notifiers = {
	.release                    = ib_umem_notifier_release,
	.invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
	.alloc_notifier		    = ib_umem_alloc_notifier,
	.free_notifier		    = ib_umem_free_notifier,
};

static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext *ctx = umem_odp->umem.context;
	struct ib_ucontext_per_mm *per_mm;
	struct mmu_notifier *mn;
	int ret;

	umem_odp->umem.is_odp = 1;
@@ -327,17 +252,13 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
		}
	}

	mutex_lock(&ctx->per_mm_list_lock);
	if (!per_mm) {
		per_mm = get_per_mm(umem_odp);
		if (IS_ERR(per_mm)) {
			ret = PTR_ERR(per_mm);
			goto out_unlock;
		}
	mn = mmu_notifier_get(&ib_umem_notifiers, umem_odp->umem.owning_mm);
	if (IS_ERR(mn)) {
		ret = PTR_ERR(mn);
		goto out_dma_list;
	}
	umem_odp->per_mm = per_mm;
	per_mm->odp_mrs_count++;
	mutex_unlock(&ctx->per_mm_list_lock);
	umem_odp->per_mm = per_mm =
		container_of(mn, struct ib_ucontext_per_mm, mn);

	mutex_init(&umem_odp->umem_mutex);
	init_completion(&umem_odp->notifier_completion);
@@ -352,8 +273,7 @@ static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,

	return 0;

out_unlock:
	mutex_unlock(&ctx->per_mm_list_lock);
out_dma_list:
	kvfree(umem_odp->dma_list);
out_page_list:
	kvfree(umem_odp->page_list);
@@ -398,7 +318,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
	umem_odp->is_implicit_odp = 1;
	umem_odp->page_shift = PAGE_SHIFT;

	ret = ib_init_umem_odp(umem_odp, NULL);
	ret = ib_init_umem_odp(umem_odp);
	if (ret) {
		kfree(umem_odp);
		return ERR_PTR(ret);
@@ -441,7 +361,7 @@ struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
	umem->owning_mm  = root->umem.owning_mm;
	odp_data->page_shift = PAGE_SHIFT;

	ret = ib_init_umem_odp(odp_data, root->per_mm);
	ret = ib_init_umem_odp(odp_data);
	if (ret) {
		kfree(odp_data);
		return ERR_PTR(ret);
@@ -509,7 +429,7 @@ struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
		up_read(&mm->mmap_sem);
	}

	ret = ib_init_umem_odp(umem_odp, NULL);
	ret = ib_init_umem_odp(umem_odp);
	if (ret)
		goto err_free;
	return umem_odp;
@@ -522,6 +442,8 @@ EXPORT_SYMBOL(ib_umem_odp_get);

void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
{
	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

	/*
	 * Ensure that no more pages are mapped in the umem.
	 *
@@ -531,11 +453,27 @@ void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
	if (!umem_odp->is_implicit_odp) {
		ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
					    ib_umem_end(umem_odp));
		remove_umem_from_per_mm(umem_odp);
		kvfree(umem_odp->dma_list);
		kvfree(umem_odp->page_list);
	}
	put_per_mm(umem_odp);

	down_write(&per_mm->umem_rwsem);
	if (!umem_odp->is_implicit_odp) {
		interval_tree_remove(&umem_odp->interval_tree,
				     &per_mm->umem_tree);
		complete_all(&umem_odp->notifier_completion);
	}
	/*
	 * NOTE! mmu_notifier_unregister() can happen between a start/end
	 * callback, resulting in a missing end, and thus an unbalanced
	 * lock. This doesn't really matter to us since we are about to kfree
	 * the memory that holds the lock, however LOCKDEP doesn't like this.
	 * Thus we call the mmu_notifier_put under the rwsem and test the
	 * internal users count to reliably see if we are past this point.
	 */
	mmu_notifier_put(&per_mm->mn);
	up_write(&per_mm->umem_rwsem);

	mmdrop(umem_odp->umem.owning_mm);
	kfree(umem_odp);
}
+0 −3
Original line number Diff line number Diff line
@@ -252,9 +252,6 @@ static int ib_uverbs_get_context(struct uverbs_attr_bundle *attrs)
	ucontext->closing = false;
	ucontext->cleanup_retryable = false;

	mutex_init(&ucontext->per_mm_list_lock);
	INIT_LIST_HEAD(&ucontext->per_mm_list);

	ret = get_unused_fd_flags(O_CLOEXEC);
	if (ret < 0)
		goto err_free;
+1 −0
Original line number Diff line number Diff line
@@ -1487,6 +1487,7 @@ static void __exit ib_uverbs_cleanup(void)
				 IB_UVERBS_NUM_FIXED_MINOR);
	unregister_chrdev_region(dynamic_uverbs_dev,
				 IB_UVERBS_NUM_DYNAMIC_MINOR);
	mmu_notifier_synchronize();
}

module_init(ib_uverbs_init);
+0 −5
Original line number Diff line number Diff line
@@ -1995,11 +1995,6 @@ static void mlx5_ib_dealloc_ucontext(struct ib_ucontext *ibcontext)
	struct mlx5_ib_dev *dev = to_mdev(ibcontext->device);
	struct mlx5_bfreg_info *bfregi;

	/* All umem's must be destroyed before destroying the ucontext. */
	mutex_lock(&ibcontext->per_mm_list_lock);
	WARN_ON(!list_empty(&ibcontext->per_mm_list));
	mutex_unlock(&ibcontext->per_mm_list_lock);

	bfregi = &context->bfregi;
	mlx5_ib_dealloc_transport_domain(dev, context->tdn, context->devx_uid);

+1 −9
Original line number Diff line number Diff line
@@ -122,20 +122,12 @@ static inline size_t ib_umem_odp_num_pages(struct ib_umem_odp *umem_odp)
#ifdef CONFIG_INFINIBAND_ON_DEMAND_PAGING

struct ib_ucontext_per_mm {
	struct ib_ucontext *context;
	struct mm_struct *mm;
	struct mmu_notifier mn;
	struct pid *tgid;
	bool active;

	struct rb_root_cached umem_tree;
	/* Protects umem_tree */
	struct rw_semaphore umem_rwsem;

	struct mmu_notifier mn;
	unsigned int odp_mrs_count;

	struct list_head ucontext_list;
	struct rcu_head rcu;
};

struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
Loading