Commit 8a9320b7 authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

mm/hmm: Simplify hmm_get_or_create and make it reliable



As coded this function can false-fail in various racy situations. Make it
reliable and simpler by running under the write side of the mmap_sem and
avoiding the false-failing compare/exchange pattern. Due to the mmap_sem
this no longer has to avoid racing with a 2nd parallel
hmm_get_or_create().

Unfortunately this still has to use the page_table_lock as the
non-sleeping lock protecting mm->hmm, since the contexts where we free the
hmm are incompatible with mmap_sem.

Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Reviewed-by: default avatarJohn Hubbard <jhubbard@nvidia.com>
Reviewed-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Reviewed-by: default avatarIra Weiny <ira.weiny@intel.com>
Reviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Tested-by: default avatarPhilip Yang <Philip.Yang@amd.com>
parent c8a53b2d
Loading
Loading
Loading
Loading
+30 −47
Original line number Diff line number Diff line
@@ -31,16 +31,6 @@
#if IS_ENABLED(CONFIG_HMM_MIRROR)
static const struct mmu_notifier_ops hmm_mmu_notifier_ops;

static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
{
	struct hmm *hmm = READ_ONCE(mm->hmm);

	if (hmm && kref_get_unless_zero(&hmm->kref))
		return hmm;

	return NULL;
}

/**
 * hmm_get_or_create - register HMM against an mm (HMM internal)
 *
@@ -55,11 +45,16 @@ static inline struct hmm *mm_get_hmm(struct mm_struct *mm)
 */
static struct hmm *hmm_get_or_create(struct mm_struct *mm)
{
	struct hmm *hmm = mm_get_hmm(mm);
	bool cleanup = false;
	struct hmm *hmm;

	if (hmm)
		return hmm;
	lockdep_assert_held_exclusive(&mm->mmap_sem);

	/* Abuse the page_table_lock to also protect mm->hmm. */
	spin_lock(&mm->page_table_lock);
	hmm = mm->hmm;
	if (mm->hmm && kref_get_unless_zero(&mm->hmm->kref))
		goto out_unlock;
	spin_unlock(&mm->page_table_lock);

	hmm = kmalloc(sizeof(*hmm), GFP_KERNEL);
	if (!hmm)
@@ -74,57 +69,45 @@ static struct hmm *hmm_get_or_create(struct mm_struct *mm)
	hmm->notifiers = 0;
	hmm->dead = false;
	hmm->mm = mm;
	mmgrab(hmm->mm);

	spin_lock(&mm->page_table_lock);
	if (!mm->hmm)
		mm->hmm = hmm;
	else
		cleanup = true;
	spin_unlock(&mm->page_table_lock);
	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
	if (__mmu_notifier_register(&hmm->mmu_notifier, mm)) {
		kfree(hmm);
		return NULL;
	}

	if (cleanup)
		goto error;
	mmgrab(hmm->mm);

	/*
	 * We should only get here if hold the mmap_sem in write mode ie on
	 * registration of first mirror through hmm_mirror_register()
	 * We hold the exclusive mmap_sem here so we know that mm->hmm is
	 * still NULL or 0 kref, and is safe to update.
	 */
	hmm->mmu_notifier.ops = &hmm_mmu_notifier_ops;
	if (__mmu_notifier_register(&hmm->mmu_notifier, mm))
		goto error_mm;

	return hmm;

error_mm:
	spin_lock(&mm->page_table_lock);
	if (mm->hmm == hmm)
		mm->hmm = NULL;
	mm->hmm = hmm;

out_unlock:
	spin_unlock(&mm->page_table_lock);
error:
	mmdrop(hmm->mm);
	kfree(hmm);
	return NULL;
	return hmm;
}

static void hmm_free_rcu(struct rcu_head *rcu)
{
	kfree(container_of(rcu, struct hmm, rcu));
	struct hmm *hmm = container_of(rcu, struct hmm, rcu);

	mmdrop(hmm->mm);
	kfree(hmm);
}

static void hmm_free(struct kref *kref)
{
	struct hmm *hmm = container_of(kref, struct hmm, kref);
	struct mm_struct *mm = hmm->mm;

	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, mm);
	spin_lock(&hmm->mm->page_table_lock);
	if (hmm->mm->hmm == hmm)
		hmm->mm->hmm = NULL;
	spin_unlock(&hmm->mm->page_table_lock);

	spin_lock(&mm->page_table_lock);
	if (mm->hmm == hmm)
		mm->hmm = NULL;
	spin_unlock(&mm->page_table_lock);

	mmdrop(hmm->mm);
	mmu_notifier_unregister_no_release(&hmm->mmu_notifier, hmm->mm);
	mmu_notifier_call_srcu(&hmm->rcu, hmm_free_rcu);
}