Commit a3eb13c1 authored by Jason Gunthorpe's avatar Jason Gunthorpe
Browse files

mm/hmm: return the fault type from hmm_pte_need_fault()

Using two bools instead of flags return is not necessary and leads to
bugs. Returning a value is easier for the compiler to check and easier to
pass around the code flow.

Convert the two bools into flags and push the change to all callers.

Link: https://lore.kernel.org/r/20200327200021.29372-3-jgg@ziepe.ca


Reviewed-by: default avatarChristoph Hellwig <hch@lst.de>
Signed-off-by: default avatarJason Gunthorpe <jgg@mellanox.com>
parent 068354ad
Loading
Loading
Loading
Loading
+81 −102
Original line number Diff line number Diff line
@@ -32,6 +32,12 @@ struct hmm_vma_walk {
	unsigned int		flags;
};

enum {
	HMM_NEED_FAULT = 1 << 0,
	HMM_NEED_WRITE_FAULT = 1 << 1,
	HMM_NEED_ALL_BITS = HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT,
};

static int hmm_pfns_fill(unsigned long addr, unsigned long end,
		struct hmm_range *range, enum hmm_pfn_value_e value)
{
@@ -49,8 +55,7 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
 * hmm_vma_fault() - fault in a range lacking valid pmd or pte(s)
 * @addr: range virtual start address (inclusive)
 * @end: range virtual end address (exclusive)
 * @fault: should we fault or not ?
 * @write_fault: write fault ?
 * @required_fault: HMM_NEED_* flags
 * @walk: mm_walk structure
 * Return: -EBUSY after page fault, or page fault error
 *
@@ -58,8 +63,7 @@ static int hmm_pfns_fill(unsigned long addr, unsigned long end,
 * or whenever there is no page directory covering the virtual address range.
 */
static int hmm_vma_fault(unsigned long addr, unsigned long end,
			      bool fault, bool write_fault,
			      struct mm_walk *walk)
			 unsigned int required_fault, struct mm_walk *walk)
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
@@ -68,13 +72,13 @@ static int hmm_vma_fault(unsigned long addr, unsigned long end,
	unsigned long i = (addr - range->start) >> PAGE_SHIFT;
	unsigned int fault_flags = FAULT_FLAG_REMOTE;

	WARN_ON_ONCE(!fault && !write_fault);
	WARN_ON_ONCE(!required_fault);
	hmm_vma_walk->last = addr;

	if (!vma)
		goto out_error;

	if (write_fault) {
	if (required_fault & HMM_NEED_WRITE_FAULT) {
		if (!(vma->vm_flags & VM_WRITE))
			return -EPERM;
		fault_flags |= FAULT_FLAG_WRITE;
@@ -91,14 +95,13 @@ out_error:
	return -EFAULT;
}

static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
				      uint64_t pfns, uint64_t cpu_flags,
				      bool *fault, bool *write_fault)
static unsigned int hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
				       uint64_t pfns, uint64_t cpu_flags)
{
	struct hmm_range *range = hmm_vma_walk->range;

	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
		return;
		return 0;

	/*
	 * So we not only consider the individual per page request we also
@@ -114,37 +117,37 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,

	/* We aren't ask to do anything ... */
	if (!(pfns & range->flags[HMM_PFN_VALID]))
		return;
		return 0;

	/* If CPU page table is not valid then we need to fault */
	*fault = !(cpu_flags & range->flags[HMM_PFN_VALID]);
	/* Need to write fault ? */
	if ((pfns & range->flags[HMM_PFN_WRITE]) &&
	    !(cpu_flags & range->flags[HMM_PFN_WRITE])) {
		*write_fault = true;
		*fault = true;
	}
	    !(cpu_flags & range->flags[HMM_PFN_WRITE]))
		return HMM_NEED_FAULT | HMM_NEED_WRITE_FAULT;

	/* If CPU page table is not valid then we need to fault */
	if (!(cpu_flags & range->flags[HMM_PFN_VALID]))
		return HMM_NEED_FAULT;
	return 0;
}

static void hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
static unsigned int
hmm_range_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
		     const uint64_t *pfns, unsigned long npages,
				 uint64_t cpu_flags, bool *fault,
				 bool *write_fault)
		     uint64_t cpu_flags)
{
	unsigned int required_fault = 0;
	unsigned long i;

	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT) {
		*fault = *write_fault = false;
		return;
	}
	if (hmm_vma_walk->flags & HMM_FAULT_SNAPSHOT)
		return 0;

	*fault = *write_fault = false;
	for (i = 0; i < npages; ++i) {
		hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags,
				   fault, write_fault);
		if ((*write_fault))
			return;
		required_fault |=
			hmm_pte_need_fault(hmm_vma_walk, pfns[i], cpu_flags);
		if (required_fault == HMM_NEED_ALL_BITS)
			return required_fault;
	}
	return required_fault;
}

static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
@@ -152,17 +155,16 @@ static int hmm_vma_walk_hole(unsigned long addr, unsigned long end,
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	bool fault, write_fault;
	unsigned int required_fault;
	unsigned long i, npages;
	uint64_t *pfns;

	i = (addr - range->start) >> PAGE_SHIFT;
	npages = (end - addr) >> PAGE_SHIFT;
	pfns = &range->pfns[i];
	hmm_range_need_fault(hmm_vma_walk, pfns, npages,
			     0, &fault, &write_fault);
	if (fault || write_fault)
		return hmm_vma_fault(addr, end, fault, write_fault, walk);
	required_fault = hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0);
	if (required_fault)
		return hmm_vma_fault(addr, end, required_fault, walk);
	hmm_vma_walk->last = addr;
	return hmm_pfns_fill(addr, end, range, HMM_PFN_NONE);
}
@@ -183,16 +185,15 @@ static int hmm_vma_handle_pmd(struct mm_walk *walk, unsigned long addr,
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	unsigned long pfn, npages, i;
	bool fault, write_fault;
	unsigned int required_fault;
	uint64_t cpu_flags;

	npages = (end - addr) >> PAGE_SHIFT;
	cpu_flags = pmd_to_hmm_pfn_flags(range, pmd);
	hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags,
			     &fault, &write_fault);

	if (fault || write_fault)
		return hmm_vma_fault(addr, end, fault, write_fault, walk);
	required_fault =
		hmm_range_need_fault(hmm_vma_walk, pfns, npages, cpu_flags);
	if (required_fault)
		return hmm_vma_fault(addr, end, required_fault, walk);

	pfn = pmd_pfn(pmd) + ((addr & ~PMD_MASK) >> PAGE_SHIFT);
	for (i = 0; addr < end; addr += PAGE_SIZE, i++, pfn++)
@@ -229,18 +230,15 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
{
	struct hmm_vma_walk *hmm_vma_walk = walk->private;
	struct hmm_range *range = hmm_vma_walk->range;
	bool fault, write_fault;
	unsigned int required_fault;
	uint64_t cpu_flags;
	pte_t pte = *ptep;
	uint64_t orig_pfn = *pfn;

	*pfn = range->values[HMM_PFN_NONE];
	fault = write_fault = false;

	if (pte_none(pte)) {
		hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0,
				   &fault, &write_fault);
		if (fault || write_fault)
		required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
		if (required_fault)
			goto fault;
		return 0;
	}
@@ -261,9 +259,8 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
			return 0;
		}

		hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
				   &write_fault);
		if (!fault && !write_fault)
		required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0);
		if (!required_fault)
			return 0;

		if (!non_swap_entry(entry))
@@ -283,9 +280,8 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
	}

	cpu_flags = pte_to_hmm_pfn_flags(range, pte);
	hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags, &fault,
			   &write_fault);
	if (fault || write_fault)
	required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
	if (required_fault)
		goto fault;

	/*
@@ -293,9 +289,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
	 * fall through and treat it like a normal page.
	 */
	if (pte_special(pte) && !is_zero_pfn(pte_pfn(pte))) {
		hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0, &fault,
				   &write_fault);
		if (fault || write_fault) {
		if (hmm_pte_need_fault(hmm_vma_walk, orig_pfn, 0)) {
			pte_unmap(ptep);
			return -EFAULT;
		}
@@ -309,7 +303,7 @@ static int hmm_vma_handle_pte(struct mm_walk *walk, unsigned long addr,
fault:
	pte_unmap(ptep);
	/* Fault any virtual address we were asked to fault */
	return hmm_vma_fault(addr, end, fault, write_fault, walk);
	return hmm_vma_fault(addr, end, required_fault, walk);
}

static int hmm_vma_walk_pmd(pmd_t *pmdp,
@@ -322,7 +316,6 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
	uint64_t *pfns = &range->pfns[(start - range->start) >> PAGE_SHIFT];
	unsigned long npages = (end - start) >> PAGE_SHIFT;
	unsigned long addr = start;
	bool fault, write_fault;
	pte_t *ptep;
	pmd_t pmd;

@@ -332,9 +325,7 @@ again:
		return hmm_vma_walk_hole(start, end, -1, walk);

	if (thp_migration_supported() && is_pmd_migration_entry(pmd)) {
		hmm_range_need_fault(hmm_vma_walk, pfns, npages,
				     0, &fault, &write_fault);
		if (fault || write_fault) {
		if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0)) {
			hmm_vma_walk->last = addr;
			pmd_migration_entry_wait(walk->mm, pmdp);
			return -EBUSY;
@@ -343,9 +334,7 @@ again:
	}

	if (!pmd_present(pmd)) {
		hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
				     &write_fault);
		if (fault || write_fault)
		if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
			return -EFAULT;
		return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
	}
@@ -375,9 +364,7 @@ again:
	 * recover.
	 */
	if (pmd_bad(pmd)) {
		hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0, &fault,
				     &write_fault);
		if (fault || write_fault)
		if (hmm_range_need_fault(hmm_vma_walk, pfns, npages, 0))
			return -EFAULT;
		return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
	}
@@ -434,8 +421,8 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,

	if (pud_huge(pud) && pud_devmap(pud)) {
		unsigned long i, npages, pfn;
		unsigned int required_fault;
		uint64_t *pfns, cpu_flags;
		bool fault, write_fault;

		if (!pud_present(pud)) {
			spin_unlock(ptl);
@@ -447,12 +434,11 @@ static int hmm_vma_walk_pud(pud_t *pudp, unsigned long start, unsigned long end,
		pfns = &range->pfns[i];

		cpu_flags = pud_to_hmm_pfn_flags(range, pud);
		hmm_range_need_fault(hmm_vma_walk, pfns, npages,
				     cpu_flags, &fault, &write_fault);
		if (fault || write_fault) {
		required_fault = hmm_range_need_fault(hmm_vma_walk, pfns,
						      npages, cpu_flags);
		if (required_fault) {
			spin_unlock(ptl);
			return hmm_vma_fault(addr, end, fault, write_fault,
						  walk);
			return hmm_vma_fault(addr, end, required_fault, walk);
		}

		pfn = pud_pfn(pud) + ((addr & ~PUD_MASK) >> PAGE_SHIFT);
@@ -484,7 +470,7 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
	struct hmm_range *range = hmm_vma_walk->range;
	struct vm_area_struct *vma = walk->vma;
	uint64_t orig_pfn, cpu_flags;
	bool fault, write_fault;
	unsigned int required_fault;
	spinlock_t *ptl;
	pte_t entry;

@@ -495,12 +481,10 @@ static int hmm_vma_walk_hugetlb_entry(pte_t *pte, unsigned long hmask,
	orig_pfn = range->pfns[i];
	range->pfns[i] = range->values[HMM_PFN_NONE];
	cpu_flags = pte_to_hmm_pfn_flags(range, entry);
	fault = write_fault = false;
	hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags,
			   &fault, &write_fault);
	if (fault || write_fault) {
	required_fault = hmm_pte_need_fault(hmm_vma_walk, orig_pfn, cpu_flags);
	if (required_fault) {
		spin_unlock(ptl);
		return hmm_vma_fault(addr, end, fault, write_fault, walk);
		return hmm_vma_fault(addr, end, required_fault, walk);
	}

	pfn = pte_pfn(entry) + ((start & ~hmask) >> PAGE_SHIFT);
@@ -522,27 +506,25 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
	struct hmm_range *range = hmm_vma_walk->range;
	struct vm_area_struct *vma = walk->vma;

	if (!(vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) &&
	    vma->vm_flags & VM_READ)
		return 0;

	/*
	 * Skip vma ranges that don't have struct page backing them or map I/O
	 * devices directly.
	 * vma ranges that don't have struct page backing them or map I/O
	 * devices directly cannot be handled by hmm_range_fault().
	 *
	 * If the vma does not allow read access, then assume that it does not
	 * allow write access either. HMM does not support architectures that
	 * allow write without read.
	 *
	 * If a fault is requested for an unsupported range then it is a hard
	 * failure.
	 */
	if ((vma->vm_flags & (VM_IO | VM_PFNMAP | VM_MIXEDMAP)) ||
	    !(vma->vm_flags & VM_READ)) {
		bool fault, write_fault;

		/*
		 * Check to see if a fault is requested for any page in the
		 * range.
		 */
		hmm_range_need_fault(hmm_vma_walk, range->pfns +
	if (hmm_range_need_fault(hmm_vma_walk,
				 range->pfns +
					 ((start - range->start) >> PAGE_SHIFT),
					(end - start) >> PAGE_SHIFT,
					0, &fault, &write_fault);
		if (fault || write_fault)
				 (end - start) >> PAGE_SHIFT, 0))
		return -EFAULT;

	hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
@@ -552,9 +534,6 @@ static int hmm_vma_walk_test(unsigned long start, unsigned long end,
	return 1;
}

	return 0;
}

static const struct mm_walk_ops hmm_walk_ops = {
	.pud_entry	= hmm_vma_walk_pud,
	.pmd_entry	= hmm_vma_walk_pmd,