On Mon, Apr 13, 2026 at 09:08:52PM +0000, Michael Kelley wrote:
> From: Stanislav Kinsburskii <[email protected]> Sent: Monday, 
> March 30, 2026 1:04 PM
> > 
> > Allow HMM fault handling across memory regions that span multiple VMAs
> > with different protection flags. The previous implementation assumed a
> > single VMA per region, which would fail when guest memory crosses VMA
> > boundaries.
> > 
> > Iterate through VMAs within the range and handle each separately with
> > appropriate protection flags, enabling more flexible memory region
> > configurations for partitions.
> > 
> > Signed-off-by: Stanislav Kinsburskii <[email protected]>
> > ---
> >  drivers/hv/mshv_regions.c |   72 
> > +++++++++++++++++++++++++++++++++------------
> >  1 file changed, 52 insertions(+), 20 deletions(-)
> > 
> > diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
> > index ed9c55841140..1bb1bfe177e2 100644
> > --- a/drivers/hv/mshv_regions.c
> > +++ b/drivers/hv/mshv_regions.c
> > @@ -492,37 +492,72 @@ int mshv_region_get(struct mshv_mem_region *region)
> >  }
> > 
> >  /**
> > - * mshv_region_hmm_fault_and_lock - Handle HMM faults and lock the memory 
> > region
> > + * mshv_region_hmm_fault_and_lock - Handle HMM faults across VMAs and lock
> > + *                                  the memory region
> >   * @region: Pointer to the memory region structure
> > - * @range: Pointer to the HMM range structure
> > + * @start : Starting virtual address of the range to fault
> > + * @end   : Ending virtual address of the range to fault (exclusive)
> > + * @pfns  : Output array for page frame numbers with HMM flags
> >   *
> >   * This function performs the following steps:
> >   * 1. Reads the notifier sequence for the HMM range.
> >   * 2. Acquires a read lock on the memory map.
> > - * 3. Handles HMM faults for the specified range.
> > - * 4. Releases the read lock on the memory map.
> > - * 5. If successful, locks the memory region mutex.
> > - * 6. Verifies if the notifier sequence has changed during the operation.
> > - *    If it has, releases the mutex and returns -EBUSY to match with
> > - *    hmm_range_fault() return code for repeating.
> > + * 3. Iterates through VMAs in the specified range, handling each
> > + *    separately with appropriate protection flags (HMM_PFN_REQ_WRITE set
> > + *    based on VMA flags).
> > + * 4. Handles HMM faults for each VMA segment.
> > + * 5. Releases the read lock on the memory map.
> > + * 6. If successful, locks the memory region mutex.
> > + * 7. Verifies if the notifier sequence has changed during the operation.
> > + *    If it has, releases the mutex and returns -EBUSY to signal retry.
> > + *
> > + * The function expects the range [start, end] is backed by valid VMAs.
> 
> Use "[start, end)" to describe the range since end is exclusive.
> 

Will do

> > + * Returns -EFAULT if any address in the range is not covered by a VMA.
> >   *
> >   * Return: 0 on success, a negative error code otherwise.
> >   */
> >  static int mshv_region_hmm_fault_and_lock(struct mshv_mem_region *region,
> > -                                     struct hmm_range *range)
> > +                                     unsigned long start,
> > +                                     unsigned long end,
> > +                                     unsigned long *pfns)
> >  {
> > +   struct hmm_range range = {
> > +           .notifier = &region->mreg_mni,
> > +   };
> >     int ret;
> > 
> > -   range->notifier_seq = mmu_interval_read_begin(range->notifier);
> > +   range.notifier_seq = mmu_interval_read_begin(range.notifier);
> >     mmap_read_lock(region->mreg_mni.mm);
> > -   ret = hmm_range_fault(range);
> > +   while (start < end) {
> > +           struct vm_area_struct *vma;
> > +
> > +           vma = vma_lookup(current->mm, start);
> 
> The mmap_read_lock() was obtained on region->mreg_mni.mm, but the
> lookup is done against current->mm. Maybe these are the same, but
> it looks wrong.  (Pointed out by a Co-Pilot AI review.)
> 

Yes, they arethe same, but I'll update to use the same mm for clarity.

> > +           if (!vma) {
> > +                   ret = -EFAULT;
> > +                   break;
> > +           }
> > +
> > +           range.hmm_pfns = pfns;
> > +           range.start = start;
> > +           range.end = min(vma->vm_end, end);
> > +           range.default_flags = HMM_PFN_REQ_FAULT;
> > +           if (vma->vm_flags & VM_WRITE)
> > +                   range.default_flags |= HMM_PFN_REQ_WRITE;
> > +
> > +           ret = hmm_range_fault(&range);
> > +           if (ret)
> > +                   break;
> > +
> > +           start = range.end + 1;
> 
> Since range.end is exclusive, the +1 should not be done.
> 

Is it always? I'll need to check to make sure the end passed to this
function is page aligned. If it is, then I'll remove the +1.

> > +           pfns += DIV_ROUND_UP(range.end - range.start, PAGE_SIZE);
> 
> Just to confirm, range.end and range.start should always be page aligned,
> right? So the ROUND_UP should never kick in.
> 

Same as above: if the end passed to this function is page aligned, then
I'll remove the DIV_ROUND_UP and just do a simple division.

Thanks,
Stanislav

> > +   }
> >     mmap_read_unlock(region->mreg_mni.mm);
> >     if (ret)
> >             return ret;
> > 
> >     mutex_lock(&region->mreg_mutex);
> > 
> > -   if (mmu_interval_read_retry(range->notifier, range->notifier_seq)) {
> > +   if (mmu_interval_read_retry(range.notifier, range.notifier_seq)) {
> >             mutex_unlock(&region->mreg_mutex);
> >             cond_resched();
> >             return -EBUSY;
> > @@ -546,10 +581,7 @@ static int mshv_region_hmm_fault_and_lock(struct 
> > mshv_mem_region *region,
> >  static int mshv_region_range_fault(struct mshv_mem_region *region,
> >                                u64 pfn_offset, u64 pfn_count)
> >  {
> > -   struct hmm_range range = {
> > -           .notifier = &region->mreg_mni,
> > -           .default_flags = HMM_PFN_REQ_FAULT | HMM_PFN_REQ_WRITE,
> > -   };
> > +   unsigned long start, end;
> >     unsigned long *pfns;
> >     int ret;
> >     u64 i;
> > @@ -558,12 +590,12 @@ static int mshv_region_range_fault(struct 
> > mshv_mem_region *region,
> >     if (!pfns)
> >             return -ENOMEM;
> > 
> > -   range.hmm_pfns = pfns;
> > -   range.start = region->start_uaddr + pfn_offset * HV_HYP_PAGE_SIZE;
> > -   range.end = range.start + pfn_count * HV_HYP_PAGE_SIZE;
> > +   start = region->start_uaddr + pfn_offset * PAGE_SIZE;
> > +   end = start + pfn_count * PAGE_SIZE;
> > 
> >     do {
> > -           ret = mshv_region_hmm_fault_and_lock(region, &range);
> > +           ret = mshv_region_hmm_fault_and_lock(region, start, end,
> > +                                                pfns);
> >     } while (ret == -EBUSY);
> > 
> >     if (ret)
> > 
> > 
> 

Reply via email to