Allow HMM fault handling across memory regions that span multiple VMAs with different protection flags. The previous implementation assumed a single VMA per region, which would fail when guest memory crosses VMA boundaries.
Iterate through VMAs within the range and handle each separately with appropriate protection flags, enabling more flexible memory region configurations for partitions. Signed-off-by: Stanislav Kinsburskii <[email protected]> --- drivers/hv/mshv_regions.c | 72 +++++++++++++++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c index ed9c55841140..1bb1bfe177e2 100644 --- a/drivers/hv/mshv_regions.c +++ b/drivers/hv/mshv_regions.c @@ -492,37 +492,72 @@ int mshv_region_get(struct mshv_mem_region *region) } /** - * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory region + * mshv_region_hmm_fault_and_lock - Handle HMM faults across VMAs and lock + * the memory region * @region: Pointer to the memory region structure - * @range: Pointer to the HMM range structure + * @start : Starting virtual address of the range to fault + * @end : Ending virtual address of the range to fault (exclusive) + * @pfns : Output array for page frame numbers with HMM flags * * This function performs the following steps: * 1. Reads the notifier sequence for the HMM range. * 2. Acquires a read lock on the memory map. - * 3. Handles HMM faults for the specified range. - * 4. Releases the read lock on the memory map. - * 5. If successful, locks the memory region mutex. - * 6. Verifies if the notifier sequence has changed during the operation. - * If it has, releases the mutex and returns -EBUSY to match with - * hmm_range_fault() return code for repeating. + * 3. Iterates through VMAs in the specified range, handling each + * separately with appropriate protection flags (HMM_PFN_REQ_WRITE set + * based on VMA flags). + * 4. Handles HMM faults for each VMA segment. + * 5. Releases the read lock on the memory map. + * 6. If successful, locks the memory region mutex. + * 7. Verifies if the notifier sequence has changed during the operation. + * If it has, releases the mutex and returns -EBUSY to signal retry. + * + * The function expects the range [start, end] is backed by valid VMAs. + * Returns -EFAULT if any address in the range is not covered by a VMA. * * Return: 0 on success, a negative error code otherwise. */ static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region, - struct hmm_range *range) + unsigned long start, + unsigned long end, + unsigned long *pfns) { + struct hmm_range range = { + .notifier = ®ion->mreg_mni, + }; int ret; - range->notifier_seq = mmu_interval_read_begin(range->notifier); + range.notifier_seq = mmu_interval_read_begin(range.notifier); mmap_read_lock(region->mreg_mni.mm); - ret = hmm_range_fault(range); + while (start < end) { + struct vm_area_struct *vma; + + vma = vma_lookup(current->mm, start); + if (!vma) { + ret = -EFAULT; + break; + } + + range.hmm_pfns = pfns; + range.start = start; + range.end = min(vma->vm_end, end); + range.default_flags = HMM_PFN_REQ_FAULT; + if (vma->vm_flags & VM_WRITE) + range.default_flags |= HMM_PFN_REQ_WRITE; + + ret = hmm_range_fault(&range); + if (ret) + break; + + start = range.end + 1; + pfns += DIV_ROUND_UP(range.end - range.start, PAGE_SIZE); + } mmap_read_unlock(region->mreg_mni.mm); if (ret) return ret; mutex_lock(®ion->mreg_mutex); - if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) { + if (mmu_interval_read_retry(range.notifier, range.notifier_seq)) { mutex_unlock(®ion->mreg_mutex); cond_resched(); return -EBUSY; @@ -546,10 +581,7 @@ static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region, static int mshv_region_range_fault(struct mshv_mem_region *region, u64 pfn_offset, u64 pfn_count) { - struct hmm_range range = { - .notifier = ®ion->mreg_mni, - .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE, - }; + unsigned long start, end; unsigned long *pfns; int ret; u64 i; @@ -558,12 +590,12 @@ static int mshv_region_range_fault(struct mshv_mem_region *region, if (!pfns) return -ENOMEM; - range.hmm_pfns = pfns; - range.start = region->start_uaddr + pfn_offset * HV_HYP_PAGE_SIZE; - range.end = range.start + pfn_count * HV_HYP_PAGE_SIZE; + start = region->start_uaddr + pfn_offset * PAGE_SIZE; + end = start + pfn_count * PAGE_SIZE; do { - ret = mshv_region_hmm_fault_and_lock(region, &range); + ret = mshv_region_hmm_fault_and_lock(region, start, end, + pfns); } while (ret == -EBUSY); if (ret)

