Commit d45d464b authored by Christoph Hellwig's avatar Christoph Hellwig Committed by Jason Gunthorpe
Browse files

mm/hmm: merge hmm_range_snapshot into hmm_range_fault

Add a HMM_FAULT_SNAPSHOT flag so that hmm_range_snapshot can be merged
into the almost identical hmm_range_fault function.

Link: https://lore.kernel.org/r/20190726005650.2566-5-rcampbell@nvidia.com


Signed-off-by: default avatarChristoph Hellwig <hch@lst.de>
Signed-off-by: default avatarRalph Campbell <rcampbell@nvidia.com>
Reviewed-by: default avatarJason Gunthorpe <jgg@mellanox.com>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 9a4903e4
Loading
Loading
Loading
Loading
+8 −9
Original line number Diff line number Diff line
@@ -192,15 +192,14 @@ read only, or fully unmap, etc.). The device must complete the update before
the driver callback returns.

When the device driver wants to populate a range of virtual addresses, it can
use either::
use::

  long hmm_range_snapshot(struct hmm_range *range);
  long hmm_range_fault(struct hmm_range *range, bool block);
  long hmm_range_fault(struct hmm_range *range, unsigned int flags);

The first one (hmm_range_snapshot()) will only fetch present CPU page table
With the HMM_RANGE_SNAPSHOT flag, it will only fetch present CPU page table
entries and will not trigger a page fault on missing or non-present entries.
The second one does trigger a page fault on missing or read-only entries if
write access is requested (see below). Page faults use the generic mm page
Without that flag, it does trigger a page fault on missing or read-only entries
if write access is requested (see below). Page faults use the generic mm page
fault code path just like a CPU page fault.

Both functions copy CPU page table entries into their pfns array argument. Each
@@ -227,20 +226,20 @@ The usage pattern is::

      /*
       * Just wait for range to be valid, safe to ignore return value as we
       * will use the return value of hmm_range_snapshot() below under the
       * will use the return value of hmm_range_fault() below under the
       * mmap_sem to ascertain the validity of the range.
       */
      hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);

 again:
      down_read(&mm->mmap_sem);
      ret = hmm_range_snapshot(&range);
      ret = hmm_range_fault(&range, HMM_RANGE_SNAPSHOT);
      if (ret) {
          up_read(&mm->mmap_sem);
          if (ret == -EBUSY) {
            /*
             * No need to check hmm_range_wait_until_valid() return value
             * on retry we will get proper error with hmm_range_snapshot()
             * on retry we will get proper error with hmm_range_fault()
             */
            hmm_range_wait_until_valid(&range, TIMEOUT_IN_MSEC);
            goto again;
+3 −1
Original line number Diff line number Diff line
@@ -413,7 +413,9 @@ void hmm_range_unregister(struct hmm_range *range);
 */
#define HMM_FAULT_ALLOW_RETRY		(1 << 0)

long hmm_range_snapshot(struct hmm_range *range);
/* Don't fault in missing PTEs, just snapshot the current state. */
#define HMM_FAULT_SNAPSHOT		(1 << 1)

long hmm_range_fault(struct hmm_range *range, unsigned int flags);

long hmm_range_dma_map(struct hmm_range *range,
+2 −83
Original line number Diff line number Diff line
@@ -280,7 +280,6 @@ struct hmm_vma_walk {
	struct hmm_range	*range;
	struct dev_pagemap	*pgmap;
	unsigned long		last;
	bool			fault;
	unsigned int		flags;
};

@@ -373,7 +372,7 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
{
	struct hmm_range *range = hmm_vma_walk->range;

	if (!hmm_vma_walk->fault)
	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
		return;

	/*
@@ -418,7 +417,7 @@ static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
{
	unsigned long i;

	if (!hmm_vma_walk->fault) {
	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT) {
		*fault = *write_fault = false;
		return;
	}
@@ -936,85 +935,6 @@ void hmm_range_unregister(struct hmm_range *range)
}
EXPORT_SYMBOL(hmm_range_unregister);

/*
 * hmm_range_snapshot() - snapshot CPU page table for a range
 * @range: range
 * Return: -EINVAL if invalid argument, -ENOMEM out of memory, -EPERM invalid
 *          permission (for instance asking for write and range is read only),
 *          -EBUSY if you need to retry, -EFAULT invalid (ie either no valid
 *          vma or it is illegal to access that range), number of valid pages
 *          in range->pfns[] (from range start address).
 *
 * This snapshots the CPU page table for a range of virtual addresses. Snapshot
 * validity is tracked by range struct. See in include/linux/hmm.h for example
 * on how to use.
 */
long hmm_range_snapshot(struct hmm_range *range)
{
	const unsigned long device_vma = VM_IO | VM_PFNMAP | VM_MIXEDMAP;
	unsigned long start = range->start, end;
	struct hmm_vma_walk hmm_vma_walk;
	struct hmm *hmm = range->hmm;
	struct vm_area_struct *vma;
	struct mm_walk mm_walk;

	lockdep_assert_held(&hmm->mm->mmap_sem);
	do {
		/* If range is no longer valid force retry. */
		if (!range->valid)
			return -EBUSY;

		vma = find_vma(hmm->mm, start);
		if (vma == NULL || (vma->vm_flags & device_vma))
			return -EFAULT;

		if (is_vm_hugetlb_page(vma)) {
			if (huge_page_shift(hstate_vma(vma)) !=
				    range->page_shift &&
			    range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		} else {
			if (range->page_shift != PAGE_SHIFT)
				return -EINVAL;
		}

		if (!(vma->vm_flags & VM_READ)) {
			/*
			 * If vma do not allow read access, then assume that it
			 * does not allow write access, either. HMM does not
			 * support architecture that allow write without read.
			 */
			hmm_pfns_clear(range, range->pfns,
				range->start, range->end);
			return -EPERM;
		}

		range->vma = vma;
		hmm_vma_walk.pgmap = NULL;
		hmm_vma_walk.last = start;
		hmm_vma_walk.fault = false;
		hmm_vma_walk.range = range;
		mm_walk.private = &hmm_vma_walk;
		end = min(range->end, vma->vm_end);

		mm_walk.vma = vma;
		mm_walk.mm = vma->vm_mm;
		mm_walk.pte_entry = NULL;
		mm_walk.test_walk = NULL;
		mm_walk.hugetlb_entry = NULL;
		mm_walk.pud_entry = hmm_vma_walk_pud;
		mm_walk.pmd_entry = hmm_vma_walk_pmd;
		mm_walk.pte_hole = hmm_vma_walk_hole;
		mm_walk.hugetlb_entry = hmm_vma_walk_hugetlb_entry;

		walk_page_range(start, end, &mm_walk);
		start = end;
	} while (start < range->end);

	return (hmm_vma_walk.last - range->start) >> PAGE_SHIFT;
}
EXPORT_SYMBOL(hmm_range_snapshot);

/**
 * hmm_range_fault - try to fault some address in a virtual address range
 * @range:	range being faulted
@@ -1088,7 +1008,6 @@ long hmm_range_fault(struct hmm_range *range, unsigned int flags)
		range->vma = vma;
		hmm_vma_walk.pgmap = NULL;
		hmm_vma_walk.last = start;
		hmm_vma_walk.fault = true;
		hmm_vma_walk.flags = flags;
		hmm_vma_walk.range = range;
		mm_walk.private = &hmm_vma_walk;