[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <9c3c87d5-e64e-f13f-ef36-b438e4de1e66@nextfour.com>
Date: Sun, 18 Jul 2021 09:19:46 +0300
From: Mika Penttilä <mika.penttila@...tfour.com>
To: Qi Zheng <zhengqi.arch@...edance.com>, akpm@...ux-foundation.org,
tglx@...utronix.de, hannes@...xchg.org, mhocko@...nel.org,
vdavydov.dev@...il.com
Cc: linux-doc@...r.kernel.org, linux-kernel@...r.kernel.org,
linux-mm@...ck.org, songmuchun@...edance.com
Subject: Re: [PATCH 5/7] mm: free user PTE page table pages
Hi,
On 18.7.2021 7.30, Qi Zheng wrote:
> Some malloc libraries(e.g. jemalloc or tcmalloc) usually
> allocate the amount of VAs by mmap() and do not unmap
> those VAs. They will use madvise(MADV_DONTNEED) to free
> physical memory if they want. But the page tables do not
> be freed by madvise(), so it can produce many page tables
> when the process touches an enormous virtual address space.
>
> The following figures are a memory usage snapshot of one
> process which actually happened on our server:
>
> VIRT: 55t
> RES: 590g
> VmPTE: 110g
>
> As we can see, the PTE page tables size is 110g, while the
> RES is 590g. In theory, the process only need 1.2g PTE page
> tables to map those physical memory. The reason why PTE page
> tables occupy a lot of memory is that madvise(MADV_DONTNEED)
> only empty the PTE and free physical memory but doesn't free
> the PTE page table pages. So we can free those empty PTE page
> tables to save memory. In the above cases, we can save memory
> about 108g(best case). And the larger the difference between
> the size of VIRT and RES, the more memory we save.
>
> In this patch series, we add a pte_refcount field to the
> struct page of page table to track how many users of PTE page
> table. Similar to the mechanism of page refcount, the user of
> PTE page table should hold a refcount to it before accessing.
> The PTE page table page will be freed when the last refcount
> is dropped.
>
> Signed-off-by: Qi Zheng <zhengqi.arch@...edance.com>
> ---
> Documentation/vm/split_page_table_lock.rst | 2 +-
> arch/arm/mm/pgd.c | 2 +-
> arch/arm64/mm/hugetlbpage.c | 4 +-
> arch/ia64/mm/hugetlbpage.c | 2 +-
> arch/parisc/mm/hugetlbpage.c | 2 +-
> arch/powerpc/mm/hugetlbpage.c | 2 +-
> arch/s390/mm/gmap.c | 8 +-
> arch/s390/mm/pgtable.c | 6 +-
> arch/sh/mm/hugetlbpage.c | 2 +-
> arch/sparc/mm/hugetlbpage.c | 2 +-
> arch/x86/kernel/tboot.c | 2 +-
> fs/proc/task_mmu.c | 23 ++-
> fs/userfaultfd.c | 2 +
> include/linux/mm.h | 13 +-
> include/linux/mm_types.h | 8 +-
> include/linux/pgtable.h | 3 +-
> include/linux/pte_ref.h | 217 +++++++++++++++++++++++
> include/linux/rmap.h | 3 +
> kernel/events/uprobes.c | 3 +
> mm/Kconfig | 4 +
> mm/Makefile | 3 +-
> mm/debug_vm_pgtable.c | 3 +-
> mm/filemap.c | 45 ++---
> mm/gup.c | 10 +-
> mm/hmm.c | 4 +
> mm/internal.h | 2 +
> mm/khugepaged.c | 10 ++
> mm/ksm.c | 4 +
> mm/madvise.c | 20 ++-
> mm/memcontrol.c | 11 +-
> mm/memory.c | 271 +++++++++++++++++++----------
> mm/mempolicy.c | 5 +-
> mm/migrate.c | 21 ++-
> mm/mincore.c | 6 +-
> mm/mlock.c | 1 +
> mm/mprotect.c | 10 +-
> mm/mremap.c | 12 +-
> mm/page_vma_mapped.c | 4 +
> mm/pagewalk.c | 19 +-
> mm/pgtable-generic.c | 2 +
> mm/pte_ref.c | 132 ++++++++++++++
> mm/rmap.c | 13 +-
> mm/swapfile.c | 6 +-
> mm/userfaultfd.c | 15 +-
> 44 files changed, 758 insertions(+), 181 deletions(-)
> create mode 100644 include/linux/pte_ref.h
> create mode 100644 mm/pte_ref.c
>
> diff --git a/Documentation/vm/split_page_table_lock.rst b/Documentation/vm/split_page_table_lock.rst
> index c08919662704..98eb7ba0d2ab 100644
> --- a/Documentation/vm/split_page_table_lock.rst
> +++ b/Documentation/vm/split_page_table_lock.rst
> @@ -20,7 +20,7 @@ There are helpers to lock/unlock a table and other accessor functions:
> lock;
> - pte_unmap_unlock()
> unlocks and unmaps PTE table;
> - - pte_alloc_map_lock()
> + - pte_alloc_get_map_lock()
> allocates PTE table if needed and take the lock, returns pointer
> to taken lock or NULL if allocation failed;
> - pte_lockptr()
> diff --git a/arch/arm/mm/pgd.c b/arch/arm/mm/pgd.c
> index f8e9bc58a84f..b2408ad9dcf5 100644
> --- a/arch/arm/mm/pgd.c
> +++ b/arch/arm/mm/pgd.c
> @@ -100,7 +100,7 @@ pgd_t *pgd_alloc(struct mm_struct *mm)
> if (!new_pmd)
> goto no_pmd;
>
> - new_pte = pte_alloc_map(mm, new_pmd, 0);
> + new_pte = pte_alloc_get_map(mm, new_pmd, 0);
> if (!new_pte)
> goto no_pte;
>
> diff --git a/arch/arm64/mm/hugetlbpage.c b/arch/arm64/mm/hugetlbpage.c
> index 23505fc35324..54f6beb3eb6b 100644
> --- a/arch/arm64/mm/hugetlbpage.c
> +++ b/arch/arm64/mm/hugetlbpage.c
> @@ -280,9 +280,9 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> * 32-bit arm platform then it will cause trouble in
> * the case where CONFIG_HIGHPTE is set, since there
> * will be no pte_unmap() to correspond with this
> - * pte_alloc_map().
> + * pte_alloc_get_map().
> */
> - ptep = pte_alloc_map(mm, pmdp, addr);
> + ptep = pte_alloc_get_map(mm, pmdp, addr);
> } else if (sz == PMD_SIZE) {
> if (want_pmd_share(vma, addr) && pud_none(READ_ONCE(*pudp)))
> ptep = huge_pmd_share(mm, vma, addr, pudp);
> diff --git a/arch/ia64/mm/hugetlbpage.c b/arch/ia64/mm/hugetlbpage.c
> index f993cb36c062..cb230005e7dd 100644
> --- a/arch/ia64/mm/hugetlbpage.c
> +++ b/arch/ia64/mm/hugetlbpage.c
> @@ -41,7 +41,7 @@ huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> if (pud) {
> pmd = pmd_alloc(mm, pud, taddr);
> if (pmd)
> - pte = pte_alloc_map(mm, pmd, taddr);
> + pte = pte_alloc_get_map(mm, pmd, taddr);
> }
> return pte;
> }
> diff --git a/arch/parisc/mm/hugetlbpage.c b/arch/parisc/mm/hugetlbpage.c
> index d1d3990b83f6..ff16db9a44a5 100644
> --- a/arch/parisc/mm/hugetlbpage.c
> +++ b/arch/parisc/mm/hugetlbpage.c
> @@ -66,7 +66,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> if (pud) {
> pmd = pmd_alloc(mm, pud, addr);
> if (pmd)
> - pte = pte_alloc_map(mm, pmd, addr);
> + pte = pte_alloc_get_map(mm, pmd, addr);
> }
> return pte;
> }
> diff --git a/arch/powerpc/mm/hugetlbpage.c b/arch/powerpc/mm/hugetlbpage.c
> index 9a75ba078e1b..20af2db18d08 100644
> --- a/arch/powerpc/mm/hugetlbpage.c
> +++ b/arch/powerpc/mm/hugetlbpage.c
> @@ -182,7 +182,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> return NULL;
>
> if (IS_ENABLED(CONFIG_PPC_8xx) && pshift < PMD_SHIFT)
> - return pte_alloc_map(mm, (pmd_t *)hpdp, addr);
> + return pte_alloc_get_map(mm, (pmd_t *)hpdp, addr);
>
> BUG_ON(!hugepd_none(*hpdp) && !hugepd_ok(*hpdp));
>
> diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
> index 9bb2c7512cd5..b243b276d9b6 100644
> --- a/arch/s390/mm/gmap.c
> +++ b/arch/s390/mm/gmap.c
> @@ -856,7 +856,7 @@ static pte_t *gmap_pte_op_walk(struct gmap *gmap, unsigned long gaddr,
> table = gmap_table_walk(gmap, gaddr, 1); /* get segment pointer */
> if (!table || *table & _SEGMENT_ENTRY_INVALID)
> return NULL;
> - return pte_alloc_map_lock(gmap->mm, (pmd_t *) table, gaddr, ptl);
> + return pte_alloc_get_map_lock(gmap->mm, (pmd_t *) table, gaddr, ptl);
> }
>
> /**
> @@ -925,7 +925,7 @@ static inline pmd_t *gmap_pmd_op_walk(struct gmap *gmap, unsigned long gaddr)
> return NULL;
> }
>
> - /* 4k page table entries are locked via the pte (pte_alloc_map_lock). */
> + /* 4k page table entries are locked via the pte (pte_alloc_get_map_lock). */
> if (!pmd_large(*pmdp))
> spin_unlock(&gmap->guest_table_lock);
> return pmdp;
> @@ -1012,7 +1012,7 @@ static int gmap_protect_pte(struct gmap *gmap, unsigned long gaddr,
> if (pmd_val(*pmdp) & _SEGMENT_ENTRY_INVALID)
> return -EAGAIN;
>
> - ptep = pte_alloc_map_lock(gmap->mm, pmdp, gaddr, &ptl);
> + ptep = pte_alloc_get_map_lock(gmap->mm, pmdp, gaddr, &ptl);
> if (!ptep)
> return -ENOMEM;
>
> @@ -2473,7 +2473,7 @@ void gmap_sync_dirty_log_pmd(struct gmap *gmap, unsigned long bitmap[4],
> bitmap_fill(bitmap, _PAGE_ENTRIES);
> } else {
> for (i = 0; i < _PAGE_ENTRIES; i++, vmaddr += PAGE_SIZE) {
> - ptep = pte_alloc_map_lock(gmap->mm, pmdp, vmaddr, &ptl);
> + ptep = pte_alloc_get_map_lock(gmap->mm, pmdp, vmaddr, &ptl);
> if (!ptep)
> continue;
> if (ptep_test_and_clear_uc(gmap->mm, vmaddr, ptep))
> diff --git a/arch/s390/mm/pgtable.c b/arch/s390/mm/pgtable.c
> index eec3a9d7176e..82217a753751 100644
> --- a/arch/s390/mm/pgtable.c
> +++ b/arch/s390/mm/pgtable.c
> @@ -801,7 +801,7 @@ int set_guest_storage_key(struct mm_struct *mm, unsigned long addr,
> }
> spin_unlock(ptl);
>
> - ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> + ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
> if (unlikely(!ptep))
> return -EFAULT;
>
> @@ -900,7 +900,7 @@ int reset_guest_reference_bit(struct mm_struct *mm, unsigned long addr)
> }
> spin_unlock(ptl);
>
> - ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> + ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
> if (unlikely(!ptep))
> return -EFAULT;
>
> @@ -956,7 +956,7 @@ int get_guest_storage_key(struct mm_struct *mm, unsigned long addr,
> }
> spin_unlock(ptl);
>
> - ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> + ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
> if (unlikely(!ptep))
> return -EFAULT;
>
> diff --git a/arch/sh/mm/hugetlbpage.c b/arch/sh/mm/hugetlbpage.c
> index 999ab5916e69..ea7fa277952b 100644
> --- a/arch/sh/mm/hugetlbpage.c
> +++ b/arch/sh/mm/hugetlbpage.c
> @@ -38,7 +38,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> if (pud) {
> pmd = pmd_alloc(mm, pud, addr);
> if (pmd)
> - pte = pte_alloc_map(mm, pmd, addr);
> + pte = pte_alloc_get_map(mm, pmd, addr);
> }
> }
> }
> diff --git a/arch/sparc/mm/hugetlbpage.c b/arch/sparc/mm/hugetlbpage.c
> index 0f49fada2093..599c04b54205 100644
> --- a/arch/sparc/mm/hugetlbpage.c
> +++ b/arch/sparc/mm/hugetlbpage.c
> @@ -297,7 +297,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
> return NULL;
> if (sz >= PMD_SIZE)
> return (pte_t *)pmd;
> - return pte_alloc_map(mm, pmd, addr);
> + return pte_alloc_get_map(mm, pmd, addr);
> }
>
> pte_t *huge_pte_offset(struct mm_struct *mm,
> diff --git a/arch/x86/kernel/tboot.c b/arch/x86/kernel/tboot.c
> index f9af561c3cd4..f2210bf3d357 100644
> --- a/arch/x86/kernel/tboot.c
> +++ b/arch/x86/kernel/tboot.c
> @@ -131,7 +131,7 @@ static int map_tboot_page(unsigned long vaddr, unsigned long pfn,
> pmd = pmd_alloc(&tboot_mm, pud, vaddr);
> if (!pmd)
> return -1;
> - pte = pte_alloc_map(&tboot_mm, pmd, vaddr);
> + pte = pte_alloc_get_map(&tboot_mm, pmd, vaddr);
> if (!pte)
> return -1;
> set_pte_at(&tboot_mm, vaddr, pte, pfn_pte(pfn, prot));
> diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
> index eb97468dfe4c..b3cf4b8a91d6 100644
> --- a/fs/proc/task_mmu.c
> +++ b/fs/proc/task_mmu.c
> @@ -574,6 +574,7 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> struct vm_area_struct *vma = walk->vma;
> pte_t *pte;
> spinlock_t *ptl;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -582,7 +583,8 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> goto out;
> }
>
> - if (pmd_trans_unstable(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> + !pte_try_get(vma->vm_mm, pmd))
> goto out;
> /*
> * The mmap_lock held all the way back in m_start() is what
> @@ -593,6 +595,7 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> for (; addr != end; pte++, addr += PAGE_SIZE)
> smaps_pte_entry(pte, addr, walk);
> pte_unmap_unlock(pte - 1, ptl);
> + pte_put(vma->vm_mm, pmd, start);
> out:
> cond_resched();
> return 0;
> @@ -1121,6 +1124,7 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
> pte_t *pte, ptent;
> spinlock_t *ptl;
> struct page *page;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -1143,7 +1147,8 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
> return 0;
> }
>
> - if (pmd_trans_unstable(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> + !pte_try_get(vma->vm_mm, pmd))
> return 0;
>
> pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
> @@ -1168,6 +1173,7 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
> ClearPageReferenced(page);
> }
> pte_unmap_unlock(pte - 1, ptl);
> + pte_put(vma->vm_mm, pmd, start);
> cond_resched();
> return 0;
> }
> @@ -1407,6 +1413,7 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
> spinlock_t *ptl;
> pte_t *pte, *orig_pte;
> int err = 0;
> + unsigned long start = addr;
>
> #ifdef CONFIG_TRANSPARENT_HUGEPAGE
> ptl = pmd_trans_huge_lock(pmdp, vma);
> @@ -1471,10 +1478,13 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
> return err;
> }
>
> - if (pmd_trans_unstable(pmdp))
> + if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmdp))
> return 0;
> #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
>
> + if (!pte_try_get(walk->mm, pmdp))
> + return 0;
> +
> /*
> * We can assume that @vma always points to a valid one and @end never
> * goes beyond vma->vm_end.
> @@ -1489,6 +1499,7 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
> break;
> }
> pte_unmap_unlock(orig_pte, ptl);
> + pte_put(walk->mm, pmdp, start);
>
> cond_resched();
>
> @@ -1795,6 +1806,7 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
> spinlock_t *ptl;
> pte_t *orig_pte;
> pte_t *pte;
> + unsigned long start = addr;
>
> #ifdef CONFIG_TRANSPARENT_HUGEPAGE
> ptl = pmd_trans_huge_lock(pmd, vma);
> @@ -1809,9 +1821,11 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
> return 0;
> }
>
> - if (pmd_trans_unstable(pmd))
> + if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
> return 0;
> #endif
> + if (!pte_try_get(walk->mm, pmd))
> + return 0;
> orig_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
> do {
> struct page *page = can_gather_numa_stats(*pte, vma, addr);
> @@ -1821,6 +1835,7 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
>
> } while (pte++, addr += PAGE_SIZE, addr != end);
> pte_unmap_unlock(orig_pte, ptl);
> + pte_put(walk->mm, pmd, start);
> cond_resched();
> return 0;
> }
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index f6e0f0c0d0e5..4fc6c3bafd70 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -28,6 +28,7 @@
> #include <linux/ioctl.h>
> #include <linux/security.h>
> #include <linux/hugetlb.h>
> +#include <linux/pte_ref.h>
>
> int sysctl_unprivileged_userfaultfd __read_mostly;
>
> @@ -508,6 +509,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
> must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
> vmf->address,
> vmf->flags, reason);
> + pte_put_vmf(vmf);
> mmap_read_unlock(mm);
>
> if (likely(must_wait && !READ_ONCE(ctx->released))) {
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index 5efd63a20d72..befe823b4918 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -447,6 +447,7 @@ extern pgprot_t protection_map[16];
> * @FAULT_FLAG_REMOTE: The fault is not for current task/mm.
> * @FAULT_FLAG_INSTRUCTION: The fault was during an instruction fetch.
> * @FAULT_FLAG_INTERRUPTIBLE: The fault can be interrupted by non-fatal signals.
> + * @FAULT_FLAG_PTE_GET: Indicates that pte has been get.
> *
> * About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
> * whether we would allow page faults to retry by specifying these two
> @@ -478,6 +479,7 @@ enum fault_flag {
> FAULT_FLAG_REMOTE = 1 << 7,
> FAULT_FLAG_INSTRUCTION = 1 << 8,
> FAULT_FLAG_INTERRUPTIBLE = 1 << 9,
> + FAULT_FLAG_PTE_GET = 1 << 10,
> };
>
> /*
> @@ -2148,7 +2150,6 @@ static inline void mm_inc_nr_ptes(struct mm_struct *mm) {}
> static inline void mm_dec_nr_ptes(struct mm_struct *mm) {}
> #endif
>
> -int __pte_alloc(struct mm_struct *mm, pmd_t *pmd);
> int __pte_alloc_kernel(pmd_t *pmd);
>
> #if defined(CONFIG_MMU)
> @@ -2274,15 +2275,6 @@ static inline void pgtable_pte_page_dtor(struct page *page)
> pte_unmap(pte); \
> } while (0)
>
> -#define pte_alloc(mm, pmd) (unlikely(pmd_none(*(pmd))) && __pte_alloc(mm, pmd))
> -
> -#define pte_alloc_map(mm, pmd, address) \
> - (pte_alloc(mm, pmd) ? NULL : pte_offset_map(pmd, address))
> -
> -#define pte_alloc_map_lock(mm, pmd, address, ptlp) \
> - (pte_alloc(mm, pmd) ? \
> - NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
> -
> #define pte_alloc_kernel(pmd, address) \
> ((unlikely(pmd_none(*(pmd))) && __pte_alloc_kernel(pmd))? \
> NULL: pte_offset_kernel(pmd, address))
> @@ -2374,7 +2366,6 @@ static inline spinlock_t *pud_lock(struct mm_struct *mm, pud_t *pud)
> return ptl;
> }
>
> -extern void pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte);
> extern void __init pagecache_init(void);
> extern void __init free_area_init_memoryless_node(int nid);
> extern void free_initmem(void);
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index f37abb2d222e..eed4a5db59ea 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -153,11 +153,17 @@ struct page {
> };
> struct { /* Page table pages */
> unsigned long _pt_pad_1; /* compound_head */
> - pgtable_t pmd_huge_pte; /* protected by page->ptl */
> + union {
> + pgtable_t pmd_huge_pte; /* protected by page->ptl */
> + pmd_t *pmd; /* PTE page only */
> + };
> unsigned long _pt_pad_2; /* mapping */
> union {
> struct mm_struct *pt_mm; /* x86 pgds only */
> atomic_t pt_frag_refcount; /* powerpc */
> +#ifdef CONFIG_FREE_USER_PTE
> + atomic_t pte_refcount; /* PTE page only */
> +#endif
> };
> #if USE_SPLIT_PTE_PTLOCKS
> #if ALLOC_SPLIT_PTLOCKS
> diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
> index d147480cdefc..172bb63b7ed9 100644
> --- a/include/linux/pgtable.h
> +++ b/include/linux/pgtable.h
> @@ -331,7 +331,6 @@ static inline pte_t ptep_get_lockless(pte_t *ptep)
> }
> #endif /* CONFIG_GUP_GET_PTE_LOW_HIGH */
>
> -#ifdef CONFIG_TRANSPARENT_HUGEPAGE
> #ifndef __HAVE_ARCH_PMDP_HUGE_GET_AND_CLEAR
> static inline pmd_t pmdp_huge_get_and_clear(struct mm_struct *mm,
> unsigned long address,
> @@ -342,6 +341,8 @@ static inline pmd_t pmdp_huge_get_and_clear(struct mm_struct *mm,
> return pmd;
> }
> #endif /* __HAVE_ARCH_PMDP_HUGE_GET_AND_CLEAR */
> +
> +#ifdef CONFIG_TRANSPARENT_HUGEPAGE
> #ifndef __HAVE_ARCH_PUDP_HUGE_GET_AND_CLEAR
> static inline pud_t pudp_huge_get_and_clear(struct mm_struct *mm,
> unsigned long address,
> diff --git a/include/linux/pte_ref.h b/include/linux/pte_ref.h
> new file mode 100644
> index 000000000000..695fbe8b991b
> --- /dev/null
> +++ b/include/linux/pte_ref.h
> @@ -0,0 +1,217 @@
> +// SPDX-License-Identifier: GPL-2.0
> +/*
> + * Free user PTE page table pages
> + *
> + * Copyright (c) 2021, ByteDance. All rights reserved.
> + *
> + * Author: Qi Zheng <zhengqi.arch@...edance.com>
> + */
> +#ifndef _LINUX_PTE_REF_H
> +#define _LINUX_PTE_REF_H
> +
> +#include <linux/mm.h>
> +#include <linux/pgtable.h>
> +#include <asm/pgalloc.h>
> +
> +bool pte_install_try_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte);
> +int __pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmd);
> +int __pte_alloc_get(struct mm_struct *mm, pmd_t *pmd);
> +
> +#ifdef CONFIG_FREE_USER_PTE
> +void free_pte_table(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr);
> +
> +static inline void pte_ref_init(pgtable_t pte, pmd_t *pmd, int count)
> +{
> + pte->pmd = pmd;
> + atomic_set(&pte->pte_refcount, count);
> +}
> +
> +static inline pmd_t *pte_to_pmd(pte_t *pte)
> +{
> + return virt_to_page(pte)->pmd;
> +}
> +
> +static inline void pte_migrate_pmd(pmd_t old_pmd, pmd_t *new_pmd)
> +{
> + pmd_pgtable(old_pmd)->pmd = new_pmd;
> +}
> +
> +/*
> + * Get the reference to the PTE page table to prevent it from being
> + * release.
> + *
> + * The caller should already hold a reference to PTE page table by
> + * calling pte_try_get(), and then this function is safe to use under
> + * mmap_lock or anon_lock or i_mmap_lock or when parallel threads are
> + * excluded by other means which can make @pmdp entry stable.
> + */
> +static inline void pte_get_many(pmd_t *pmdp, unsigned int nr)
> +{
> + pgtable_t pte = pmd_pgtable(*pmdp);
> +
> + VM_BUG_ON(pte->pmd != pmdp);
> + atomic_add(nr, &pte->pte_refcount);
> +}
> +
> +static inline void pte_get(pmd_t *pmdp)
> +{
> + pte_get_many(pmdp, 1);
> +}
> +
> +static inline bool pte_get_unless_zero(pmd_t *pmdp)
> +{
> + pgtable_t pte = pmd_pgtable(*pmdp);
> +
> + VM_BUG_ON(!PageTable(pte));
> + return atomic_inc_not_zero(&pte->pte_refcount);
> +}
> +
> +/*
> + * Try to get a reference to the PTE page table to prevent it from
> + * being release.
> + *
> + * This function is safe to use under mmap_lock or anon_lock or
> + * i_mmap_lock or when parallel threads are excluded by other means
> + * which can make @pmdp entry stable.
> + */
> +static inline bool pte_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + bool retval = true;
> + spinlock_t *ptl;
> +
> + ptl = pmd_lock(mm, pmdp);
> + if (pmd_leaf(*pmdp) || !pmd_present(*pmdp) ||
> + !pte_get_unless_zero(pmdp))
> + retval = false;
> + spin_unlock(ptl);
> +
> + return retval;
> +}
> +
> +/*
> + * Put the reference to the PTE page table, and then the PTE page
> + * will be released when the reference is decreased to 0.
> + *
> + * This function is safe to use under mmap_lock or anon_lock or
> + * i_mmap_lock or when parallel threads are excluded by other means
> + * which can make @pmdp entry stable.
> + */
> +static inline void pte_put_many(struct mm_struct *mm, pmd_t *pmdp,
> + unsigned long addr, unsigned int nr)
> +{
> + pgtable_t pte = pmd_pgtable(*pmdp);
> +
> + VM_BUG_ON(mm == &init_mm);
> + VM_BUG_ON(pmd_devmap_trans_unstable(pmdp));
> + VM_BUG_ON(pte->pmd != pmdp);
> + if (atomic_sub_and_test(nr, &pte->pte_refcount))
> + free_pte_table(mm, pmdp, addr & PMD_MASK);
> +}
> +
> +static inline void pte_put(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> + pte_put_many(mm, pmdp, addr, 1);
> +}
> +
> +/*
> + * The mmap_lock maybe unlocked in advance in some cases in
> + * handle_pte_fault(), so we should ensure the pte_put() is performed
> + * in the critical section of the mmap_lock.
> + */
> +static inline void pte_put_vmf(struct vm_fault *vmf)
> +{
> + if (!(vmf->flags & FAULT_FLAG_PTE_GET))
> + return;
> + vmf->flags &= ~FAULT_FLAG_PTE_GET;
> +
> + pte_put(vmf->vma->vm_mm, vmf->pmd, vmf->address);
> +}
> +
> +static inline int pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + if (!pte_try_get(mm, pmdp))
> + return __pte_alloc_try_get(mm, pmdp);
> + return 1;
> +}
> +
> +static inline int pte_alloc_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + spinlock_t *ptl;
> +
> + ptl = pmd_lock(mm, pmdp);
> + if (pmd_none(*pmdp) || !pte_get_unless_zero(pmdp)) {
> + spin_unlock(ptl);
> + return __pte_alloc_get(mm, pmdp);
> + }
> + spin_unlock(ptl);
> + return 0;
> +}
> +#else
> +static inline void pte_ref_init(pgtable_t pte, pmd_t *pmd, int count)
> +{
> +}
> +
> +static inline pmd_t *pte_to_pmd(pte_t *pte)
> +{
> + return NULL;
> +}
> +
> +static inline void pte_migrate_pmd(pmd_t old_pmd, pmd_t *new_pmd)
> +{
> +}
> +
> +static inline void pte_get_many(pmd_t *pmdp, unsigned int nr)
> +{
> +}
> +
> +static inline void pte_get(pmd_t *pmdp)
> +{
> +}
> +
> +static inline bool pte_get_unless_zero(pmd_t *pmdp)
> +{
> + return true;
> +}
> +
> +static inline bool pte_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + return true;
> +}
> +
> +static inline void pte_put_many(struct mm_struct *mm, pmd_t *pmdp,
> + unsigned long addr, unsigned int value)
> +{
> +}
> +
> +static inline void pte_put(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> +}
> +
> +static inline void pte_put_vmf(struct vm_fault *vmf)
> +{
> +}
> +
> +static inline int pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + if (unlikely(pmd_none(*pmdp)))
> + return __pte_alloc_try_get(mm, pmdp);
> + if (unlikely(pmd_devmap_trans_unstable(pmdp)))
> + return 0;
> + return 1;
> +}
> +
> +static inline int pte_alloc_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> + if (unlikely(pmd_none(*pmdp)))
> + return __pte_alloc_get(mm, pmdp);
> + return 0;
> +}
> +#endif /* CONFIG_FREE_USER_PTE */
> +
> +#define pte_alloc_get_map(mm, pmd, address) \
> + (pte_alloc_get(mm, pmd) ? NULL : pte_offset_map(pmd, address))
> +
> +#define pte_alloc_get_map_lock(mm, pmd, address, ptlp) \
> + (pte_alloc_get(mm, pmd) ? \
> + NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
> +#endif
> diff --git a/include/linux/rmap.h b/include/linux/rmap.h
> index 83fb86133fe1..886411eccb55 100644
> --- a/include/linux/rmap.h
> +++ b/include/linux/rmap.h
> @@ -11,6 +11,7 @@
> #include <linux/rwsem.h>
> #include <linux/memcontrol.h>
> #include <linux/highmem.h>
> +#include <linux/pte_ref.h>
>
> /*
> * The anon_vma heads a list of private "related" vmas, to scan if
> @@ -220,6 +221,8 @@ static inline void page_vma_mapped_walk_done(struct page_vma_mapped_walk *pvmw)
> pte_unmap(pvmw->pte);
> if (pvmw->ptl)
> spin_unlock(pvmw->ptl);
> + if (pvmw->pte && !PageHuge(pvmw->page))
> + pte_put(pvmw->vma->vm_mm, pvmw->pmd, pvmw->address);
> }
>
> bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw);
> diff --git a/kernel/events/uprobes.c b/kernel/events/uprobes.c
> index af24dc3febbe..2791190e1a01 100644
> --- a/kernel/events/uprobes.c
> +++ b/kernel/events/uprobes.c
> @@ -205,6 +205,9 @@ static int __replace_page(struct vm_area_struct *vma, unsigned long addr,
> try_to_free_swap(old_page);
> page_vma_mapped_walk_done(&pvmw);
>
> + if (!new_page)
> + pte_put(mm, pte_to_pmd(pvmw.pte), addr);
> +
> if ((vma->vm_flags & VM_LOCKED) && !PageCompound(old_page))
> munlock_vma_page(old_page);
> put_page(old_page);
> diff --git a/mm/Kconfig b/mm/Kconfig
> index 5dc28e9205e0..745f6cdc5e9b 100644
> --- a/mm/Kconfig
> +++ b/mm/Kconfig
> @@ -889,4 +889,8 @@ config IO_MAPPING
> config SECRETMEM
> def_bool ARCH_HAS_SET_DIRECT_MAP && !EMBEDDED
>
> +config FREE_USER_PTE
> + def_bool y
> + depends on X86_64
> +
> endmenu
> diff --git a/mm/Makefile b/mm/Makefile
> index e3436741d539..1ab513342d54 100644
> --- a/mm/Makefile
> +++ b/mm/Makefile
> @@ -38,7 +38,8 @@ mmu-y := nommu.o
> mmu-$(CONFIG_MMU) := highmem.o memory.o mincore.o \
> mlock.o mmap.o mmu_gather.o mprotect.o mremap.o \
> msync.o page_vma_mapped.o pagewalk.o \
> - pgtable-generic.o rmap.o vmalloc.o ioremap.o
> + pgtable-generic.o rmap.o vmalloc.o ioremap.o \
> + pte_ref.o
>
>
> ifdef CONFIG_CROSS_MEMORY_ATTACH
> diff --git a/mm/debug_vm_pgtable.c b/mm/debug_vm_pgtable.c
> index 1c922691aa61..8cae3b3329dc 100644
> --- a/mm/debug_vm_pgtable.c
> +++ b/mm/debug_vm_pgtable.c
> @@ -31,6 +31,7 @@
> #include <linux/io.h>
> #include <asm/pgalloc.h>
> #include <asm/tlbflush.h>
> +#include <linux/pte_ref.h>
>
> /*
> * Please refer Documentation/vm/arch_pgtable_helpers.rst for the semantics
> @@ -1018,7 +1019,7 @@ static int __init debug_vm_pgtable(void)
> /*
> * Allocate pgtable_t
> */
> - if (pte_alloc(mm, pmdp)) {
> + if (pte_alloc_try_get(mm, pmdp) < 0) {
> pr_err("pgtable allocation failed\n");
> return 1;
> }
> diff --git a/mm/filemap.c b/mm/filemap.c
> index db0184884890..024ca645c3a2 100644
> --- a/mm/filemap.c
> +++ b/mm/filemap.c
> @@ -1699,6 +1699,7 @@ int __lock_page_or_retry(struct page *page, struct vm_fault *vmf)
> if (flags & FAULT_FLAG_RETRY_NOWAIT)
> return 0;
>
> + pte_put_vmf(vmf);
> mmap_read_unlock(mm);
> if (flags & FAULT_FLAG_KILLABLE)
> wait_on_page_locked_killable(page);
> @@ -1711,6 +1712,7 @@ int __lock_page_or_retry(struct page *page, struct vm_fault *vmf)
>
> ret = __lock_page_killable(page);
> if (ret) {
> + pte_put_vmf(vmf);
> mmap_read_unlock(mm);
> return 0;
> }
> @@ -3160,32 +3162,30 @@ static bool filemap_map_pmd(struct vm_fault *vmf, struct page *page)
> struct mm_struct *mm = vmf->vma->vm_mm;
>
> /* Huge page is mapped? No need to proceed. */
> - if (pmd_trans_huge(*vmf->pmd)) {
> - unlock_page(page);
> - put_page(page);
> - return true;
> - }
> + if (pmd_trans_huge(*vmf->pmd))
> + goto out;
>
> if (pmd_none(*vmf->pmd) && PageTransHuge(page)) {
> - vm_fault_t ret = do_set_pmd(vmf, page);
> - if (!ret) {
> - /* The page is mapped successfully, reference consumed. */
> - unlock_page(page);
> - return true;
> - }
> + vm_fault_t ret = do_set_pmd(vmf, page);
> + if (!ret) {
> + /* The page is mapped successfully, reference consumed. */
> + unlock_page(page);
> + return true;
> + }
> }
>
> - if (pmd_none(*vmf->pmd))
> - pte_install(mm, vmf->pmd, &vmf->prealloc_pte);
> -
> - /* See comment in handle_pte_fault() */
> - if (pmd_devmap_trans_unstable(vmf->pmd)) {
> - unlock_page(page);
> - put_page(page);
> - return true;
> + if (IS_ENABLED(CONFIG_FREE_USER_PTE) || pmd_none(*vmf->pmd)) {
> + if (!pte_install_try_get(mm, vmf->pmd, &vmf->prealloc_pte))
> + goto out;
> + } else if (pmd_devmap_trans_unstable(vmf->pmd)) { /* See comment in handle_pte_fault() */
> + goto out;
> }
>
> return false;
> +out:
> + unlock_page(page);
> + put_page(page);
> + return true;
> }
>
> static struct page *next_uptodate_page(struct page *page,
> @@ -3259,6 +3259,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
> struct page *head, *page;
> unsigned int mmap_miss = READ_ONCE(file->f_ra.mmap_miss);
> vm_fault_t ret = 0;
> + unsigned int nr_get = 0;
>
> rcu_read_lock();
> head = first_map_page(mapping, &xas, end_pgoff);
> @@ -3267,7 +3268,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
>
> if (filemap_map_pmd(vmf, head)) {
> ret = VM_FAULT_NOPAGE;
> - goto out;
> + goto put;
> }
>
> addr = vma->vm_start + ((start_pgoff - vma->vm_pgoff) << PAGE_SHIFT);
> @@ -3292,6 +3293,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
> ret = VM_FAULT_NOPAGE;
>
> do_set_pte(vmf, page, addr);
> + nr_get++;
> /* no need to invalidate: a not-present page won't be cached */
> update_mmu_cache(vma, addr, vmf->pte);
> unlock_page(head);
> @@ -3301,6 +3303,9 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
> put_page(head);
> } while ((head = next_map_page(mapping, &xas, end_pgoff)) != NULL);
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> + pte_get_many(vmf->pmd, nr_get);
> +put:
> + pte_put(vma->vm_mm, vmf->pmd, addr);
> out:
> rcu_read_unlock();
> WRITE_ONCE(file->f_ra.mmap_miss, mmap_miss);
> diff --git a/mm/gup.c b/mm/gup.c
> index 42b8b1fa6521..3e2a153cb18e 100644
> --- a/mm/gup.c
> +++ b/mm/gup.c
> @@ -498,10 +498,14 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
> if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
> (FOLL_PIN | FOLL_GET)))
> return ERR_PTR(-EINVAL);
> +
> retry:
> if (unlikely(pmd_bad(*pmd)))
> return no_page_table(vma, flags);
>
> + if (!pte_try_get(mm, pmd))
> + return no_page_table(vma, flags);
> +
> ptep = pte_offset_map_lock(mm, pmd, address, &ptl);
> pte = *ptep;
> if (!pte_present(pte)) {
> @@ -519,6 +523,7 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
> if (!is_migration_entry(entry))
> goto no_page;
> pte_unmap_unlock(ptep, ptl);
> + pte_put(mm, pmd, address);
> migration_entry_wait(mm, pmd, address);
> goto retry;
> }
> @@ -526,6 +531,7 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
> goto no_page;
> if ((flags & FOLL_WRITE) && !can_follow_write_pte(pte, flags)) {
> pte_unmap_unlock(ptep, ptl);
> + pte_put(mm, pmd, address);
> return NULL;
> }
>
> @@ -614,9 +620,11 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
> }
> out:
> pte_unmap_unlock(ptep, ptl);
> + pte_put(mm, pmd, address);
> return page;
> no_page:
> pte_unmap_unlock(ptep, ptl);
> + pte_put(mm, pmd, address);
> if (!pte_none(pte))
> return NULL;
> return no_page_table(vma, flags);
> @@ -713,7 +721,7 @@ static struct page *follow_pmd_mask(struct vm_area_struct *vma,
> } else {
> spin_unlock(ptl);
> split_huge_pmd(vma, pmd, address);
> - ret = pte_alloc(mm, pmd) ? -ENOMEM : 0;
> + ret = pte_alloc_get(mm, pmd) ? -ENOMEM : 0;
> }
>
> return ret ? ERR_PTR(ret) :
> diff --git a/mm/hmm.c b/mm/hmm.c
> index fad6be2bf072..29bb379510cc 100644
> --- a/mm/hmm.c
> +++ b/mm/hmm.c
> @@ -380,6 +380,9 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
> return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
> }
>
> + if (!pte_try_get(walk->mm, pmdp))
> + goto again;
> +
> ptep = pte_offset_map(pmdp, addr);
> for (; addr < end; addr += PAGE_SIZE, ptep++, hmm_pfns++) {
> int r;
> @@ -391,6 +394,7 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
> }
> }
> pte_unmap(ptep - 1);
> + pte_put(walk->mm, pmdp, start);
> return 0;
> }
>
> diff --git a/mm/internal.h b/mm/internal.h
> index 31ff935b2547..642a7e0af740 100644
> --- a/mm/internal.h
> +++ b/mm/internal.h
> @@ -11,6 +11,7 @@
> #include <linux/mm.h>
> #include <linux/pagemap.h>
> #include <linux/tracepoint-defs.h>
> +#include <linux/pte_ref.h>
>
> /*
> * The set of flags that only affect watermark checking and reclaim
> @@ -441,6 +442,7 @@ static inline struct file *maybe_unlock_mmap_for_io(struct vm_fault *vmf,
> if (fault_flag_allow_retry_first(flags) &&
> !(flags & FAULT_FLAG_RETRY_NOWAIT)) {
> fpin = get_file(vmf->vma->vm_file);
> + pte_put_vmf(vmf);
> mmap_read_unlock(vmf->vma->vm_mm);
> }
> return fpin;
> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> index b0412be08fa2..e6c4d1b7a12a 100644
> --- a/mm/khugepaged.c
> +++ b/mm/khugepaged.c
> @@ -741,6 +741,7 @@ static void __collapse_huge_page_copy(pte_t *pte, struct page *page,
> {
> struct page *src_page, *tmp;
> pte_t *_pte;
> +
> for (_pte = pte; _pte < pte + HPAGE_PMD_NR;
> _pte++, page++, address += PAGE_SIZE) {
> pte_t pteval = *_pte;
> @@ -1239,6 +1240,10 @@ static int khugepaged_scan_pmd(struct mm_struct *mm,
> goto out;
> }
>
> + if (!pte_try_get(mm, pmd)) {
> + result = SCAN_PMD_NULL;
> + goto out;
> + }
> memset(khugepaged_node_load, 0, sizeof(khugepaged_node_load));
> pte = pte_offset_map_lock(mm, pmd, address, &ptl);
> for (_address = address, _pte = pte; _pte < pte+HPAGE_PMD_NR;
> @@ -1361,6 +1366,7 @@ static int khugepaged_scan_pmd(struct mm_struct *mm,
> }
> out_unmap:
> pte_unmap_unlock(pte, ptl);
> + pte_put(mm, pmd, address);
> if (ret) {
> node = khugepaged_find_target_node();
> /* collapse_huge_page will return with the mmap_lock released */
> @@ -1463,6 +1469,8 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
> if (!pmd)
> goto drop_hpage;
>
> + if (!pte_try_get(mm, pmd))
> + goto drop_hpage;
> start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
>
> /* step 1: check all mapped PTEs are to the right huge page */
> @@ -1501,6 +1509,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
> }
>
> pte_unmap_unlock(start_pte, ptl);
> + pte_put(mm, pmd, haddr);
>
> /* step 3: set proper refcount and mm_counters. */
> if (count) {
> @@ -1522,6 +1531,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
>
> abort:
> pte_unmap_unlock(start_pte, ptl);
> + pte_put(mm, pmd, haddr);
> goto drop_hpage;
> }
>
> diff --git a/mm/ksm.c b/mm/ksm.c
> index 3fa9bc8a67cf..2e106f58dad0 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -1133,6 +1133,9 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
> if (!pmd)
> goto out;
>
> + if (!pte_try_get(mm, pmd))
> + goto out;
> +
> mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
> addr + PAGE_SIZE);
> mmu_notifier_invalidate_range_start(&range);
> @@ -1182,6 +1185,7 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
> err = 0;
> out_mn:
> mmu_notifier_invalidate_range_end(&range);
> + pte_put(mm, pmd, addr);
> out:
> return err;
> }
> diff --git a/mm/madvise.c b/mm/madvise.c
> index 012129fbfaf8..4c4b35292212 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -191,7 +191,9 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
> struct vm_area_struct *vma = walk->private;
> unsigned long index;
>
> - if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> + !pte_try_get(vma->vm_mm, pmd))
> return 0;
>
> for (index = start; index != end; index += PAGE_SIZE) {
> @@ -215,6 +217,7 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
> if (page)
> put_page(page);
> }
> + pte_put(vma->vm_mm, pmd, start);
>
> return 0;
> }
> @@ -318,6 +321,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
> spinlock_t *ptl;
> struct page *page = NULL;
> LIST_HEAD(page_list);
> + unsigned long start = addr;
>
> if (fatal_signal_pending(current))
> return -EINTR;
> @@ -389,9 +393,11 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
> }
>
> regular_page:
> - if (pmd_trans_unstable(pmd))
> + if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
> return 0;
> #endif
> + if (!pte_try_get(vma->vm_mm, pmd))
> + return 0;
> tlb_change_page_size(tlb, PAGE_SIZE);
> orig_pte = pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
> flush_tlb_batched_pending(mm);
> @@ -471,6 +477,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
>
> arch_leave_lazy_mmu_mode();
> pte_unmap_unlock(orig_pte, ptl);
> + pte_put(vma->vm_mm, pmd, start);
> if (pageout)
> reclaim_pages(&page_list);
> cond_resched();
> @@ -580,14 +587,18 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
> struct page *page;
> int nr_swap = 0;
> unsigned long next;
> + unsigned int nr_put = 0;
> + unsigned long start = addr;
>
> next = pmd_addr_end(addr, end);
> if (pmd_trans_huge(*pmd))
> if (madvise_free_huge_pmd(tlb, vma, pmd, addr, next))
> goto next;
>
> - if (pmd_trans_unstable(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + pmd_trans_unstable(pmd)) || !pte_try_get(mm, pmd))
> return 0;
> + nr_put++;
>
> tlb_change_page_size(tlb, PAGE_SIZE);
> orig_pte = pte = pte_offset_map_lock(mm, pmd, addr, &ptl);
> @@ -612,6 +623,7 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
> nr_swap--;
> free_swap_and_cache(entry);
> pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> + nr_put++;
> continue;
> }
>
> @@ -696,6 +708,8 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
> }
> arch_leave_lazy_mmu_mode();
> pte_unmap_unlock(orig_pte, ptl);
> + if (nr_put)
> + pte_put_many(mm, pmd, start, nr_put);
> cond_resched();
> next:
> return 0;
> diff --git a/mm/memcontrol.c b/mm/memcontrol.c
> index ae1f5d0cb581..4f19e5f2cd18 100644
> --- a/mm/memcontrol.c
> +++ b/mm/memcontrol.c
> @@ -5819,6 +5819,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
> struct vm_area_struct *vma = walk->vma;
> pte_t *pte;
> spinlock_t *ptl;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -5833,13 +5834,15 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
> return 0;
> }
>
> - if (pmd_trans_unstable(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> + !pte_try_get(vma->vm_mm, pmd))
> return 0;
> pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
> for (; addr != end; pte++, addr += PAGE_SIZE)
> if (get_mctgt_type(vma, addr, *pte, NULL))
> mc.precharge++; /* increment precharge temporarily */
> pte_unmap_unlock(pte - 1, ptl);
> + pte_put(vma->vm_mm, pmd, start);
> cond_resched();
>
> return 0;
> @@ -6019,6 +6022,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
> enum mc_target_type target_type;
> union mc_target target;
> struct page *page;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -6051,9 +6055,11 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
> return 0;
> }
>
> - if (pmd_trans_unstable(pmd))
> + if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
> return 0;
> retry:
> + if (!pte_try_get(vma->vm_mm, pmd))
> + return 0;
> pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
> for (; addr != end; addr += PAGE_SIZE) {
> pte_t ptent = *(pte++);
> @@ -6104,6 +6110,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
> }
> }
> pte_unmap_unlock(pte - 1, ptl);
> + pte_put(vma->vm_mm, pmd, start);
> cond_resched();
>
> if (addr != end) {
> diff --git a/mm/memory.c b/mm/memory.c
> index 3bf2636413ee..242ed135bde4 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -219,6 +219,17 @@ static void check_sync_rss_stat(struct task_struct *task)
>
> #endif /* SPLIT_RSS_COUNTING */
>
> +#ifdef CONFIG_FREE_USER_PTE
> +static inline void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
> + unsigned long addr)
> +{
> + /*
> + * We should never reach here since the PTE page tables are
> + * dynamically freed.
> + */
> + BUG();
> +}
> +#else
> /*
> * Note: this doesn't free the actual pages themselves. That
> * has been handled earlier when unmapping all the memory regions.
> @@ -231,6 +242,7 @@ static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
> pte_free_tlb(tlb, token, addr);
> mm_dec_nr_ptes(tlb->mm);
> }
> +#endif /* CONFIG_FREE_USER_PTE */
>
> static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
> unsigned long addr, unsigned long end,
> @@ -433,44 +445,6 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
> }
> }
>
> -void pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> -{
> - spinlock_t *ptl = pmd_lock(mm, pmd);
> -
> - if (likely(pmd_none(*pmd))) { /* Has another populated it ? */
> - mm_inc_nr_ptes(mm);
> - /*
> - * Ensure all pte setup (eg. pte page lock and page clearing) are
> - * visible before the pte is made visible to other CPUs by being
> - * put into page tables.
> - *
> - * The other side of the story is the pointer chasing in the page
> - * table walking code (when walking the page table without locking;
> - * ie. most of the time). Fortunately, these data accesses consist
> - * of a chain of data-dependent loads, meaning most CPUs (alpha
> - * being the notable exception) will already guarantee loads are
> - * seen in-order. See the alpha page table accessors for the
> - * smp_rmb() barriers in page table walking code.
> - */
> - smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
> - pmd_populate(mm, pmd, *pte);
> - *pte = NULL;
> - }
> - spin_unlock(ptl);
> -}
> -
> -int __pte_alloc(struct mm_struct *mm, pmd_t *pmd)
> -{
> - pgtable_t new = pte_alloc_one(mm);
> - if (!new)
> - return -ENOMEM;
> -
> - pte_install(mm, pmd, &new);
> - if (new)
> - pte_free(mm, new);
> - return 0;
> -}
> -
> int __pte_alloc_kernel(pmd_t *pmd)
> {
> pte_t *new = pte_alloc_one_kernel(&init_mm);
> @@ -479,7 +453,7 @@ int __pte_alloc_kernel(pmd_t *pmd)
>
> spin_lock(&init_mm.page_table_lock);
> if (likely(pmd_none(*pmd))) { /* Has another populated it ? */
> - smp_wmb(); /* See comment in pte_install() */
> + smp_wmb(); /* See comment in __pte_install() */
> pmd_populate_kernel(&init_mm, pmd, new);
> new = NULL;
> }
> @@ -860,6 +834,7 @@ copy_nonpresent_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,
> if (!userfaultfd_wp(dst_vma))
> pte = pte_swp_clear_uffd_wp(pte);
> set_pte_at(dst_mm, addr, dst_pte, pte);
> + pte_get(pte_to_pmd(dst_pte));
> return 0;
> }
>
> @@ -928,6 +903,7 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma
> /* Uffd-wp needs to be delivered to dest pte as well */
> pte = pte_wrprotect(pte_mkuffd_wp(pte));
> set_pte_at(dst_vma->vm_mm, addr, dst_pte, pte);
> + pte_get(pte_to_pmd(dst_pte));
> return 0;
> }
>
> @@ -980,6 +956,7 @@ copy_present_pte(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> pte = pte_clear_uffd_wp(pte);
>
> set_pte_at(dst_vma->vm_mm, addr, dst_pte, pte);
> + pte_get(pte_to_pmd(dst_pte));
> return 0;
> }
>
> @@ -1021,7 +998,7 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> progress = 0;
> init_rss_vec(rss);
>
> - dst_pte = pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
> + dst_pte = pte_alloc_get_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
> if (!dst_pte) {
> ret = -ENOMEM;
> goto out;
> @@ -1109,8 +1086,10 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> goto out;
> } else if (ret == -EAGAIN) {
> prealloc = page_copy_prealloc(src_mm, src_vma, addr);
> - if (!prealloc)
> - return -ENOMEM;
> + if (!prealloc) {
> + ret = -ENOMEM;
> + goto out;
> + }
> } else if (ret) {
> VM_WARN_ON_ONCE(1);
> }
> @@ -1118,11 +1097,14 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> /* We've captured and resolved the error. Reset, try again. */
> ret = 0;
>
> - if (addr != end)
> + if (addr != end) {
> + pte_put(dst_mm, dst_pmd, addr);
> goto again;
> + }
> out:
> if (unlikely(prealloc))
> put_page(prealloc);
> + pte_put(dst_mm, dst_pmd, addr);
> return ret;
> }
>
> @@ -1141,9 +1123,13 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> return -ENOMEM;
> src_pmd = pmd_offset(src_pud, addr);
> do {
> + pmd_t pmdval;
> +
> next = pmd_addr_end(addr, end);
> - if (is_swap_pmd(*src_pmd) || pmd_trans_huge(*src_pmd)
> - || pmd_devmap(*src_pmd)) {
> +retry:
> + pmdval = READ_ONCE(*src_pmd);
> + if (is_swap_pmd(pmdval) || pmd_trans_huge(pmdval)
> + || pmd_devmap(pmdval)) {
> int err;
> VM_BUG_ON_VMA(next-addr != HPAGE_PMD_SIZE, src_vma);
> err = copy_huge_pmd(dst_mm, src_mm, dst_pmd, src_pmd,
> @@ -1156,9 +1142,15 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
> }
> if (pmd_none_or_clear_bad(src_pmd))
> continue;
> +
> + if (!pte_try_get(src_mm, src_pmd))
> + goto retry;
> if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> - addr, next))
> + addr, next)) {
> + pte_put(src_mm, src_pmd, addr);
> return -ENOMEM;
> + }
> + pte_put(src_mm, src_pmd, addr);
> } while (dst_pmd++, src_pmd++, addr = next, addr != end);
> return 0;
> }
> @@ -1316,6 +1308,8 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> pte_t *start_pte;
> pte_t *pte;
> swp_entry_t entry;
> + unsigned int nr_put = 0;
> + unsigned long start = addr;
>
> tlb_change_page_size(tlb, PAGE_SIZE);
> again:
> @@ -1348,6 +1342,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> }
> ptent = ptep_get_and_clear_full(mm, addr, pte,
> tlb->fullmm);
> + nr_put++;
> tlb_remove_tlb_entry(tlb, pte, addr);
> if (unlikely(!page))
> continue;
> @@ -1390,6 +1385,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> }
>
> pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> + nr_put++;
> rss[mm_counter(page)]--;
>
> if (is_device_private_entry(entry))
> @@ -1414,6 +1410,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> if (unlikely(!free_swap_and_cache(entry)))
> print_bad_pte(vma, addr, ptent, NULL);
> pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> + nr_put++;
> } while (pte++, addr += PAGE_SIZE, addr != end);
>
> add_mm_rss_vec(mm, rss);
> @@ -1440,6 +1437,9 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
> goto again;
> }
>
> + if (nr_put)
> + pte_put_many(mm, pmd, start, nr_put);
> +
> return addr;
> }
>
> @@ -1479,9 +1479,13 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
> * because MADV_DONTNEED holds the mmap_lock in read
> * mode.
> */
> - if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> + !pte_try_get(tlb->mm, pmd))
> goto next;
> +
> next = zap_pte_range(tlb, vma, pmd, addr, next, details);
> + pte_put(tlb->mm, pmd, addr);
> next:
> cond_resched();
> } while (pmd++, addr = next, addr != end);
> @@ -1736,7 +1740,7 @@ pte_t *__get_locked_pte(struct mm_struct *mm, unsigned long addr,
>
> if (!pmd)
> return NULL;
> - return pte_alloc_map_lock(mm, pmd, addr, ptl);
> + return pte_alloc_get_map_lock(mm, pmd, addr, ptl);
> }
>
> static int validate_page_before_insert(struct page *page)
> @@ -1757,6 +1761,7 @@ static int insert_page_into_pte_locked(struct mm_struct *mm, pte_t *pte,
> inc_mm_counter_fast(mm, mm_counter_file(page));
> page_add_file_rmap(page, false);
> set_pte_at(mm, addr, pte, mk_pte(page, prot));
> + pte_get(pte_to_pmd(pte));
> return 0;
> }
>
> @@ -1784,6 +1789,7 @@ static int insert_page(struct vm_area_struct *vma, unsigned long addr,
> goto out;
> retval = insert_page_into_pte_locked(mm, pte, addr, page, prot);
> pte_unmap_unlock(pte, ptl);
> + pte_put(mm, pte_to_pmd(pte), addr);
> out:
> return retval;
> }
> @@ -1827,7 +1833,7 @@ static int insert_pages(struct vm_area_struct *vma, unsigned long addr,
>
> /* Allocate the PTE if necessary; takes PMD lock once only. */
> ret = -ENOMEM;
> - if (pte_alloc(mm, pmd))
> + if (pte_alloc_try_get(mm, pmd) < 0)
> goto out;
>
> while (pages_to_write_in_pmd) {
> @@ -1854,6 +1860,7 @@ static int insert_pages(struct vm_area_struct *vma, unsigned long addr,
> if (remaining_pages_total)
> goto more;
> ret = 0;
> + pte_put(mm, pmd, addr);
> out:
> *num = remaining_pages_total;
> return ret;
> @@ -2077,10 +2084,12 @@ static vm_fault_t insert_pfn(struct vm_area_struct *vma, unsigned long addr,
> }
>
> set_pte_at(mm, addr, pte, entry);
> + pte_get(pte_to_pmd(pte));
> update_mmu_cache(vma, addr, pte); /* XXX: why not for insert_page? */
>
> out_unlock:
> pte_unmap_unlock(pte, ptl);
> + pte_put(mm, pte_to_pmd(pte), addr);
> return VM_FAULT_NOPAGE;
> }
>
> @@ -2284,8 +2293,10 @@ static int remap_pte_range(struct mm_struct *mm, pmd_t *pmd,
> pte_t *pte, *mapped_pte;
> spinlock_t *ptl;
> int err = 0;
> + unsigned int nr_get = 0;
> + unsigned long start_addr = addr;
>
> - mapped_pte = pte = pte_alloc_map_lock(mm, pmd, addr, &ptl);
> + mapped_pte = pte = pte_alloc_get_map_lock(mm, pmd, addr, &ptl);
> if (!pte)
> return -ENOMEM;
> arch_enter_lazy_mmu_mode();
> @@ -2296,10 +2307,13 @@ static int remap_pte_range(struct mm_struct *mm, pmd_t *pmd,
> break;
> }
> set_pte_at(mm, addr, pte, pte_mkspecial(pfn_pte(pfn, prot)));
> + nr_get++;
> pfn++;
> } while (pte++, addr += PAGE_SIZE, addr != end);
> + pte_get_many(pmd, nr_get);
> arch_leave_lazy_mmu_mode();
> pte_unmap_unlock(mapped_pte, ptl);
> + pte_put(mm, pmd, start_addr);
> return err;
> }
>
> @@ -2512,13 +2526,17 @@ static int apply_to_pte_range(struct mm_struct *mm, pmd_t *pmd,
> pte_t *pte, *mapped_pte;
> int err = 0;
> spinlock_t *ptl;
> + unsigned int nr_put = 0;
> + unsigned int nr_get = 0;
> + unsigned long start = addr;
>
> if (create) {
> mapped_pte = pte = (mm == &init_mm) ?
> pte_alloc_kernel_track(pmd, addr, mask) :
> - pte_alloc_map_lock(mm, pmd, addr, &ptl);
> + pte_alloc_get_map_lock(mm, pmd, addr, &ptl);
> if (!pte)
> return -ENOMEM;
> + nr_put++;
> } else {
> mapped_pte = pte = (mm == &init_mm) ?
> pte_offset_kernel(pmd, addr) :
> @@ -2531,19 +2549,32 @@ static int apply_to_pte_range(struct mm_struct *mm, pmd_t *pmd,
>
> if (fn) {
> do {
> - if (create || !pte_none(*pte)) {
> + if (create) {
> err = fn(pte++, addr, data);
> - if (err)
> - break;
> + if (IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + mm != &init_mm && !pte_none(*(pte-1)))
> + nr_get++;
> + } else if (!pte_none(*pte)) {
> + err = fn(pte++, addr, data);
> + if (IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + mm != &init_mm && pte_none(*(pte-1)))
> + nr_put++;
> }
> + if (err)
> + break;
> } while (addr += PAGE_SIZE, addr != end);
> }
> *mask |= PGTBL_PTE_MODIFIED;
>
> arch_leave_lazy_mmu_mode();
>
> - if (mm != &init_mm)
> + if (mm != &init_mm) {
> pte_unmap_unlock(mapped_pte, ptl);
> + pte_get_many(pmd, nr_get);
> + if (nr_put)
> + pte_put_many(mm, pmd, start, nr_put);
> + }
> +
> return err;
> }
>
> @@ -2567,6 +2598,7 @@ static int apply_to_pmd_range(struct mm_struct *mm, pud_t *pud,
> }
> do {
> next = pmd_addr_end(addr, end);
> +retry:
> if (pmd_none(*pmd) && !create)
> continue;
> if (WARN_ON_ONCE(pmd_leaf(*pmd)))
> @@ -2576,8 +2608,12 @@ static int apply_to_pmd_range(struct mm_struct *mm, pud_t *pud,
> continue;
> pmd_clear_bad(pmd);
> }
> + if (!create && !pte_try_get(mm, pmd))
> + goto retry;
> err = apply_to_pte_range(mm, pmd, addr, next,
> fn, data, create, mask);
> + if (!create)
> + pte_put(mm, pmd, addr);
> if (err)
> break;
> } while (pmd++, addr = next, addr != end);
> @@ -3726,21 +3762,19 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
> return VM_FAULT_SIGBUS;
>
> /*
> - * Use pte_alloc() instead of pte_alloc_map(). We can't run
> + * Use pte_alloc_try_get() instead of pte_alloc_get_map(). We can't run
> * pte_offset_map() on pmds where a huge pmd might be created
> * from a different thread.
> *
> - * pte_alloc_map() is safe to use under mmap_write_lock(mm) or when
> + * pte_alloc_get_map() is safe to use under mmap_write_lock(mm) or when
> * parallel threads are excluded by other means.
> *
> * Here we only have mmap_read_lock(mm).
> */
> - if (pte_alloc(vma->vm_mm, vmf->pmd))
> - return VM_FAULT_OOM;
> -
> - /* See comment in handle_pte_fault() */
> - if (unlikely(pmd_trans_unstable(vmf->pmd)))
> - return 0;
> + ret = pte_alloc_try_get(vma->vm_mm, vmf->pmd);
> + if (ret <= 0)
> + return ret < 0 ? VM_FAULT_OOM : 0;
> + ret = 0;
>
> /* Use the zero-page for reads */
> if (!(vmf->flags & FAULT_FLAG_WRITE) &&
> @@ -3759,7 +3793,8 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
> /* Deliver the page fault to userland, check inside PT lock */
> if (userfaultfd_missing(vma)) {
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> - return handle_userfault(vmf, VM_UFFD_MISSING);
> + ret = handle_userfault(vmf, VM_UFFD_MISSING);
> + goto put;
> }
> goto setpte;
> }
> @@ -3802,7 +3837,8 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
> if (userfaultfd_missing(vma)) {
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> put_page(page);
> - return handle_userfault(vmf, VM_UFFD_MISSING);
> + ret = handle_userfault(vmf, VM_UFFD_MISSING);
> + goto put;
> }
>
> inc_mm_counter_fast(vma->vm_mm, MM_ANONPAGES);
> @@ -3810,19 +3846,23 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
> lru_cache_add_inactive_or_unevictable(page, vma);
> setpte:
> set_pte_at(vma->vm_mm, vmf->address, vmf->pte, entry);
> + pte_get(vmf->pmd);
>
> /* No need to invalidate - it was non-present before */
> update_mmu_cache(vma, vmf->address, vmf->pte);
> unlock:
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> - return ret;
> + goto put;
> release:
> put_page(page);
> goto unlock;
> oom_free_page:
> put_page(page);
> oom:
> - return VM_FAULT_OOM;
> + ret = VM_FAULT_OOM;
> +put:
> + pte_put(vma->vm_mm, vmf->pmd, vmf->address);
> + return ret;
> }
>
> /*
> @@ -3850,7 +3890,7 @@ static vm_fault_t __do_fault(struct vm_fault *vmf)
> * unlock_page(B)
> * # flush A, B to clear the writeback
> */
> - if (pmd_none(*vmf->pmd) && !vmf->prealloc_pte) {
> + if (!vmf->prealloc_pte) {
> vmf->prealloc_pte = pte_alloc_one(vma->vm_mm);
> if (!vmf->prealloc_pte)
> return VM_FAULT_OOM;
> @@ -4020,6 +4060,7 @@ vm_fault_t finish_fault(struct vm_fault *vmf)
> return ret;
> }
>
> +retry:
> if (pmd_none(*vmf->pmd)) {
> if (PageTransCompound(page)) {
> ret = do_set_pmd(vmf, page);
> @@ -4027,27 +4068,33 @@ vm_fault_t finish_fault(struct vm_fault *vmf)
> return ret;
> }
>
> - if (vmf->prealloc_pte)
> - pte_install(vma->vm_mm, vmf->pmd, &vmf->prealloc_pte);
> - else if (unlikely(pte_alloc(vma->vm_mm, vmf->pmd)))
> - return VM_FAULT_OOM;
> - }
> -
> - /* See comment in handle_pte_fault() */
> - if (pmd_devmap_trans_unstable(vmf->pmd))
> + if (vmf->prealloc_pte) {
> + if (!pte_install_try_get(vma->vm_mm, vmf->pmd, &vmf->prealloc_pte))
> + return 0;
> + } else {
> + ret = pte_alloc_try_get(vma->vm_mm, vmf->pmd);
> + if (ret <= 0)
> + return ret < 0 ? VM_FAULT_OOM : 0;
> + }
> + } else if (pmd_devmap_trans_unstable(vmf->pmd)) { /* See comment in handle_pte_fault() */
> return 0;
> + } else if (!pte_try_get(vma->vm_mm, vmf->pmd)) {
> + goto retry;
> + }
>
> vmf->pte = pte_offset_map_lock(vma->vm_mm, vmf->pmd,
> vmf->address, &vmf->ptl);
> ret = 0;
> /* Re-check under ptl */
> - if (likely(pte_none(*vmf->pte)))
> + if (likely(pte_none(*vmf->pte))) {
> do_set_pte(vmf, page, vmf->address);
> - else
> + pte_get(vmf->pmd);
> + } else
> ret = VM_FAULT_NOPAGE;
>
> update_mmu_tlb(vma, vmf->address, vmf->pte);
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> + pte_put(vma->vm_mm, vmf->pmd, vmf->address);
> return ret;
> }
>
> @@ -4268,9 +4315,14 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
> * If we find a migration pmd entry or a none pmd entry, which
> * should never happen, return SIGBUS
> */
> - if (unlikely(!pmd_present(*vmf->pmd)))
> + if (unlikely(!pmd_present(*vmf->pmd))) {
> ret = VM_FAULT_SIGBUS;
> - else {
> + goto out;
> + } else {
> + if (!pte_try_get(vma->vm_mm, vmf->pmd)) {
> + ret = VM_FAULT_SIGBUS;
> + goto out;
> + }
> vmf->pte = pte_offset_map_lock(vmf->vma->vm_mm,
> vmf->pmd,
> vmf->address,
> @@ -4288,6 +4340,7 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
> ret = VM_FAULT_NOPAGE;
>
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> + pte_put(vma->vm_mm, vmf->pmd, vmf->address);
> }
> } else if (!(vmf->flags & FAULT_FLAG_WRITE))
> ret = do_read_fault(vmf);
> @@ -4301,6 +4354,7 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
> pte_free(vm_mm, vmf->prealloc_pte);
> vmf->prealloc_pte = NULL;
> }
> +out:
> return ret;
> }
>
> @@ -4496,11 +4550,13 @@ static vm_fault_t wp_huge_pud(struct vm_fault *vmf, pud_t orig_pud)
> static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> {
> pte_t entry;
> + vm_fault_t ret;
>
> - if (unlikely(pmd_none(*vmf->pmd))) {
> +retry:
> + if (unlikely(pmd_none(READ_ONCE(*vmf->pmd)))) {
> /*
> - * Leave __pte_alloc() until later: because vm_ops->fault may
> - * want to allocate huge page, and if we expose page table
> + * Leave __pte_alloc_try_get() until later: because vm_ops->fault
> + * may want to allocate huge page, and if we expose page table
> * for an instant, it will be difficult to retract from
> * concurrent faults and from rmap lookups.
> */
> @@ -4517,9 +4573,18 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> * that it is a regular pmd that we can walk with
> * pte_offset_map() and we can do that through an atomic read
> * in C, which is what pmd_trans_unstable() provides.
> + *
> + * Note: we do this in pte_try_get() when CONFIG_FREE_USER_PTE
> */
> if (pmd_devmap_trans_unstable(vmf->pmd))
> return 0;
> +
> + if (!pte_try_get(vmf->vma->vm_mm, vmf->pmd))
> + goto retry;
> +
> + if (IS_ENABLED(CONFIG_FREE_USER_PTE))
> + vmf->flags |= FAULT_FLAG_PTE_GET;
> +
> /*
> * A regular pmd is established and it can't morph into a huge
> * pmd from under us anymore at this point because we hold the
> @@ -4541,6 +4606,7 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> if (pte_none(vmf->orig_pte)) {
> pte_unmap(vmf->pte);
> vmf->pte = NULL;
> + pte_put_vmf(vmf);
> }
> }
>
> @@ -4551,11 +4617,15 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> return do_fault(vmf);
> }
>
> - if (!pte_present(vmf->orig_pte))
> - return do_swap_page(vmf);
> + if (!pte_present(vmf->orig_pte)) {
> + ret = do_swap_page(vmf);
> + goto put;
> + }
>
> - if (pte_protnone(vmf->orig_pte) && vma_is_accessible(vmf->vma))
> - return do_numa_page(vmf);
> + if (pte_protnone(vmf->orig_pte) && vma_is_accessible(vmf->vma)) {
> + ret = do_numa_page(vmf);
> + goto put;
> + }
>
> vmf->ptl = pte_lockptr(vmf->vma->vm_mm, vmf->pmd);
> spin_lock(vmf->ptl);
> @@ -4565,8 +4635,10 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> goto unlock;
> }
> if (vmf->flags & FAULT_FLAG_WRITE) {
> - if (!pte_write(entry))
> - return do_wp_page(vmf);
> + if (!pte_write(entry)) {
> + ret = do_wp_page(vmf);
> + goto put;
> + }
> entry = pte_mkdirty(entry);
> }
> entry = pte_mkyoung(entry);
> @@ -4588,7 +4660,10 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
> }
> unlock:
> pte_unmap_unlock(vmf->pte, vmf->ptl);
> - return 0;
> + ret = 0;
> +put:
> + pte_put_vmf(vmf);
> + return ret;
> }
>
> /*
> @@ -4816,7 +4891,7 @@ int __p4d_alloc(struct mm_struct *mm, pgd_t *pgd, unsigned long address)
> if (pgd_present(*pgd)) /* Another has populated it */
> p4d_free(mm, new);
> else {
> - smp_wmb(); /* See comment in pte_install() */
> + smp_wmb(); /* See comment in __pte_install() */
> pgd_populate(mm, pgd, new);
> }
> spin_unlock(&mm->page_table_lock);
> @@ -4838,7 +4913,7 @@ int __pud_alloc(struct mm_struct *mm, p4d_t *p4d, unsigned long address)
> spin_lock(&mm->page_table_lock);
> if (!p4d_present(*p4d)) {
> mm_inc_nr_puds(mm);
> - smp_wmb(); /* See comment in pte_install() */
> + smp_wmb(); /* See comment in __pte_install() */
> p4d_populate(mm, p4d, new);
> } else /* Another has populated it */
> pud_free(mm, new);
> @@ -4862,7 +4937,7 @@ int __pmd_alloc(struct mm_struct *mm, pud_t *pud, unsigned long address)
> ptl = pud_lock(mm, pud);
> if (!pud_present(*pud)) {
> mm_inc_nr_pmds(mm);
> - smp_wmb(); /* See comment in pte_install() */
> + smp_wmb(); /* See comment in __pte_install() */
> pud_populate(mm, pud, new);
> } else /* Another has populated it */
> pmd_free(mm, new);
> @@ -4925,13 +5000,22 @@ int follow_invalidate_pte(struct mm_struct *mm, unsigned long address,
> (address & PAGE_MASK) + PAGE_SIZE);
> mmu_notifier_invalidate_range_start(range);
> }
> + if (!pte_try_get(mm, pmd))
> + goto out;
> ptep = pte_offset_map_lock(mm, pmd, address, ptlp);
> if (!pte_present(*ptep))
> goto unlock;
> + /*
> + * when we reach here, it means that the ->pte_refcount is at least
> + * one and the contents of the PTE page table are stable until @ptlp is
> + * released, so we can put pte safely.
> + */
> + pte_put(mm, pmd, address);
> *ptepp = ptep;
> return 0;
> unlock:
> pte_unmap_unlock(ptep, *ptlp);
> + pte_put(mm, pmd, address);
> if (range)
> mmu_notifier_invalidate_range_end(range);
> out:
> @@ -5058,6 +5142,7 @@ int generic_access_phys(struct vm_area_struct *vma, unsigned long addr,
> return -EINVAL;
> pte = *ptep;
> pte_unmap_unlock(ptep, ptl);
> + pte_put(vma->vm_mm, pte_to_pmd(ptep), addr);
>
> prot = pgprot_val(pte_pgprot(pte));
> phys_addr = (resource_size_t)pte_pfn(pte) << PAGE_SHIFT;
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index e32360e90274..cbb3640717ff 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -509,6 +509,7 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
> bool has_unmovable = false;
> pte_t *pte, *mapped_pte;
> spinlock_t *ptl;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -518,7 +519,8 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
> }
> /* THP was split, fall through to pte walk */
>
> - if (pmd_trans_unstable(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> + !pte_try_get(walk->mm, pmd))
> return 0;
>
> mapped_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
> @@ -554,6 +556,7 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
> break;
> }
> pte_unmap_unlock(mapped_pte, ptl);
> + pte_put(walk->mm, pmd, start);
> cond_resched();
>
> if (has_unmovable)
> diff --git a/mm/migrate.c b/mm/migrate.c
> index 23cbd9de030b..6a94e8558b2c 100644
> --- a/mm/migrate.c
> +++ b/mm/migrate.c
> @@ -2265,6 +2265,8 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp,
> if (unlikely(pmd_bad(*pmdp)))
> return migrate_vma_collect_skip(start, end, walk);
>
> + if (!pte_try_get(mm, pmdp))
> + goto again;
> ptep = pte_offset_map_lock(mm, pmdp, addr, &ptl);
> arch_enter_lazy_mmu_mode();
>
> @@ -2386,6 +2388,7 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp,
> }
> arch_leave_lazy_mmu_mode();
> pte_unmap_unlock(ptep - 1, ptl);
> + pte_put(mm, pmdp, start);
>
> /* Only flush the TLB if we actually modified any entries */
> if (unmapped)
> @@ -2793,26 +2796,22 @@ static void migrate_vma_insert_page(struct migrate_vma *migrate,
> goto abort;
>
> /*
> - * Use pte_alloc() instead of pte_alloc_map(). We can't run
> + * Use pte_alloc_try_get() instead of pte_alloc_get_map(). We can't run
> * pte_offset_map() on pmds where a huge pmd might be created
> * from a different thread.
> *
> - * pte_alloc_map() is safe to use under mmap_write_lock(mm) or when
> + * pte_alloc_get_map() is safe to use under mmap_write_lock(mm) or when
> * parallel threads are excluded by other means.
> *
> * Here we only have mmap_read_lock(mm).
> */
> - if (pte_alloc(mm, pmdp))
> - goto abort;
> -
> - /* See the comment in pte_alloc_one_map() */
> - if (unlikely(pmd_trans_unstable(pmdp)))
> + if (pte_alloc_try_get(mm, pmdp) <= 0)
> goto abort;
>
> if (unlikely(anon_vma_prepare(vma)))
> - goto abort;
> + goto put;
> if (mem_cgroup_charge(page, vma->vm_mm, GFP_KERNEL))
> - goto abort;
> + goto put;
>
> /*
> * The memory barrier inside __SetPageUptodate makes sure that
> @@ -2881,15 +2880,19 @@ static void migrate_vma_insert_page(struct migrate_vma *migrate,
> } else {
> /* No need to invalidate - it was non-present before */
> set_pte_at(mm, addr, ptep, entry);
> + pte_get(pmdp);
> update_mmu_cache(vma, addr, ptep);
> }
>
> pte_unmap_unlock(ptep, ptl);
> + pte_put(mm, pmdp, addr);
> *src = MIGRATE_PFN_MIGRATE;
> return;
>
> unlock_abort:
> pte_unmap_unlock(ptep, ptl);
> +put:
> + pte_put(mm, pmdp, addr);
> abort:
> *src &= ~MIGRATE_PFN_MIGRATE;
> }
> diff --git a/mm/mincore.c b/mm/mincore.c
> index 9122676b54d6..e21e271a7657 100644
> --- a/mm/mincore.c
> +++ b/mm/mincore.c
> @@ -18,6 +18,7 @@
> #include <linux/shmem_fs.h>
> #include <linux/hugetlb.h>
> #include <linux/pgtable.h>
> +#include <linux/pte_ref.h>
>
> #include <linux/uaccess.h>
>
> @@ -104,6 +105,7 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> pte_t *ptep;
> unsigned char *vec = walk->private;
> int nr = (end - addr) >> PAGE_SHIFT;
> + unsigned long start = addr;
>
> ptl = pmd_trans_huge_lock(pmd, vma);
> if (ptl) {
> @@ -112,7 +114,8 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> goto out;
> }
>
> - if (pmd_trans_unstable(pmd)) {
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> + !pte_try_get(walk->mm, pmd)) {
> __mincore_unmapped_range(addr, end, vma, vec);
> goto out;
> }
> @@ -148,6 +151,7 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
> vec++;
> }
> pte_unmap_unlock(ptep - 1, ptl);
> + pte_put(walk->mm, pmd, start);
> out:
> walk->private += nr;
> cond_resched();
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 16d2ee160d43..ead14abb016a 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -397,6 +397,7 @@ static unsigned long __munlock_pagevec_fill(struct pagevec *pvec,
> break;
> }
> pte_unmap_unlock(pte, ptl);
> + pte_put(vma->vm_mm, pte_to_pmd(pte), start);
> return start;
> }
>
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index 4cb240fd9936..9cbd0848c5c5 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -274,9 +274,12 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
> pmd = pmd_offset(pud, addr);
> do {
> unsigned long this_pages;
> + pmd_t pmdval;
>
> next = pmd_addr_end(addr, end);
>
> +retry:
> + pmdval = READ_ONCE(*pmd);
> /*
> * Automatic NUMA balancing walks the tables with mmap_lock
> * held for read. It's possible a parallel update to occur
> @@ -285,7 +288,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
> * Hence, it's necessary to atomically read the PMD value
> * for all the checks.
> */
> - if (!is_swap_pmd(*pmd) && !pmd_devmap(*pmd) &&
> + if (!is_swap_pmd(pmdval) && !pmd_devmap(pmdval) &&
> pmd_none_or_clear_bad_unless_trans_huge(pmd))
> goto next;
>
> @@ -297,7 +300,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
> mmu_notifier_invalidate_range_start(&range);
> }
>
> - if (is_swap_pmd(*pmd) || pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
> + if (is_swap_pmd(pmdval) || pmd_trans_huge(pmdval) || pmd_devmap(pmdval)) {
> if (next - addr != HPAGE_PMD_SIZE) {
> __split_huge_pmd(vma, pmd, addr, false, NULL);
> } else {
> @@ -316,8 +319,11 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
> }
> /* fall through, the trans huge pmd just split */
> }
> + if (!pte_try_get(vma->vm_mm, pmd))
> + goto retry;
> this_pages = change_pte_range(vma, pmd, addr, next, newprot,
> cp_flags);
> + pte_put(vma->vm_mm, pmd, addr);
> pages += this_pages;
> next:
> cond_resched();
> diff --git a/mm/mremap.c b/mm/mremap.c
> index 5989d3990020..776c6ea7bd06 100644
> --- a/mm/mremap.c
> +++ b/mm/mremap.c
> @@ -141,6 +141,9 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
> spinlock_t *old_ptl, *new_ptl;
> bool force_flush = false;
> unsigned long len = old_end - old_addr;
> + unsigned long old_start = old_addr;
> + unsigned int nr_put = 0;
> + unsigned int nr_get = 0;
>
> /*
> * When need_rmap_locks is true, we take the i_mmap_rwsem and anon_vma
> @@ -181,6 +184,7 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
> continue;
>
> pte = ptep_get_and_clear(mm, old_addr, old_pte);
> + nr_put++;
> /*
> * If we are remapping a valid PTE, make sure
> * to flush TLB before we drop the PTL for the
> @@ -197,7 +201,9 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
> pte = move_pte(pte, new_vma->vm_page_prot, old_addr, new_addr);
> pte = move_soft_dirty_pte(pte);
> set_pte_at(mm, new_addr, new_pte, pte);
> + nr_get++;
> }
> + pte_get_many(new_pmd, nr_get);
>
> arch_leave_lazy_mmu_mode();
> if (force_flush)
> @@ -206,6 +212,8 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
> spin_unlock(new_ptl);
> pte_unmap(new_pte - 1);
> pte_unmap_unlock(old_pte - 1, old_ptl);
> + if (nr_put)
> + pte_put_many(mm, old_pmd, old_start, nr_put);
> if (need_rmap_locks)
> drop_rmap_locks(vma);
> }
> @@ -271,6 +279,7 @@ static bool move_normal_pmd(struct vm_area_struct *vma, unsigned long old_addr,
> VM_BUG_ON(!pmd_none(*new_pmd));
>
> pmd_populate(mm, new_pmd, pmd_pgtable(pmd));
> + pte_migrate_pmd(pmd, new_pmd);
> flush_tlb_range(vma, old_addr, old_addr + PMD_SIZE);
> if (new_ptl != old_ptl)
> spin_unlock(new_ptl);
> @@ -548,10 +557,11 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
> continue;
> }
>
> - if (pte_alloc(new_vma->vm_mm, new_pmd))
> + if (pte_alloc_get(new_vma->vm_mm, new_pmd))
> break;
> move_ptes(vma, old_pmd, old_addr, old_addr + extent, new_vma,
> new_pmd, new_addr, need_rmap_locks);
> + pte_put(new_vma->vm_mm, new_pmd, new_addr);
> }
>
> mmu_notifier_invalidate_range_end(&range);
> diff --git a/mm/page_vma_mapped.c b/mm/page_vma_mapped.c
> index f7b331081791..eb84fa5825c0 100644
> --- a/mm/page_vma_mapped.c
> +++ b/mm/page_vma_mapped.c
> @@ -211,6 +211,7 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
> }
>
> pvmw->pmd = pmd_offset(pud, pvmw->address);
> +retry:
> /*
> * Make sure the pmd value isn't cached in a register by the
> * compiler and used as a stale value after we've observed a
> @@ -258,6 +259,8 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
> step_forward(pvmw, PMD_SIZE);
> continue;
> }
> + if (!pte_try_get(pvmw->vma->vm_mm, pvmw->pmd))
> + goto retry;
> if (!map_pte(pvmw))
> goto next_pte;
> this_pte:
> @@ -275,6 +278,7 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
> pvmw->ptl = NULL;
> }
> pte_unmap(pvmw->pte);
> + pte_put(pvmw->vma->vm_mm, pvmw->pmd, pvmw->address);
> pvmw->pte = NULL;
> goto restart;
> }
> diff --git a/mm/pagewalk.c b/mm/pagewalk.c
> index 9b3db11a4d1d..4080a88d7852 100644
> --- a/mm/pagewalk.c
> +++ b/mm/pagewalk.c
> @@ -3,6 +3,7 @@
> #include <linux/highmem.h>
> #include <linux/sched.h>
> #include <linux/hugetlb.h>
> +#include <linux/pte_ref.h>
>
> /*
> * We want to know the real level where a entry is located ignoring any
> @@ -108,9 +109,9 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
>
> pmd = pmd_offset(pud, addr);
> do {
> -again:
> next = pmd_addr_end(addr, end);
> - if (pmd_none(*pmd) || (!walk->vma && !walk->no_vma)) {
> +again:
> + if (pmd_none(READ_ONCE(*pmd)) || (!walk->vma && !walk->no_vma)) {
> if (ops->pte_hole)
> err = ops->pte_hole(addr, next, depth, walk);
> if (err)
> @@ -147,10 +148,18 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
> goto again;
> }
>
> - if (is_hugepd(__hugepd(pmd_val(*pmd))))
> + if (is_hugepd(__hugepd(pmd_val(*pmd)))) {
> err = walk_hugepd_range((hugepd_t *)pmd, addr, next, walk, PMD_SHIFT);
> - else
> - err = walk_pte_range(pmd, addr, next, walk);
> + } else {
> + if (!walk->no_vma) {
> + if (!pte_try_get(walk->mm, pmd))
> + goto again;
> + err = walk_pte_range(pmd, addr, next, walk);
> + pte_put(walk->mm, pmd, addr);
> + } else {
> + err = walk_pte_range(pmd, addr, next, walk);
> + }
> + }
> if (err)
> break;
> } while (pmd++, addr = next, addr != end);
> diff --git a/mm/pgtable-generic.c b/mm/pgtable-generic.c
> index 4e640baf9794..f935779a0967 100644
> --- a/mm/pgtable-generic.c
> +++ b/mm/pgtable-generic.c
> @@ -11,6 +11,7 @@
> #include <linux/hugetlb.h>
> #include <linux/pgtable.h>
> #include <asm/tlb.h>
> +#include <linux/pte_ref.h>
>
> /*
> * If a p?d_bad entry is found while walking page tables, report
> @@ -186,6 +187,7 @@ pgtable_t pgtable_trans_huge_withdraw(struct mm_struct *mm, pmd_t *pmdp)
> struct page, lru);
> if (pmd_huge_pte(mm, pmdp))
> list_del(&pgtable->lru);
> + pte_ref_init(pgtable, pmdp, HPAGE_PMD_NR);
> return pgtable;
> }
> #endif
> diff --git a/mm/pte_ref.c b/mm/pte_ref.c
> new file mode 100644
> index 000000000000..1b8d9828d513
> --- /dev/null
> +++ b/mm/pte_ref.c
> @@ -0,0 +1,132 @@
> +// SPDX-License-Identifier: GPL-2.0
> +/*
> + * Free user PTE page table pages
> + *
> + * Copyright (c) 2021, ByteDance. All rights reserved.
> + *
> + * Author: Qi Zheng <zhengqi.arch@...edance.com>
> + */
> +
> +#include <linux/pte_ref.h>
> +#include <linux/hugetlb.h>
> +#include <asm/tlbflush.h>
> +
> +#ifdef CONFIG_DEBUG_VM
> +static void pte_free_debug(pmd_t pmd)
> +{
> + pte_t *ptep = (pte_t *)pmd_page_vaddr(pmd);
> + int i = 0;
> +
> + for (i = 0; i < PTRS_PER_PTE; i++)
> + BUG_ON(!pte_none(*ptep++));
> +}
> +#else
> +static inline void pte_free_debug(pmd_t pmd)
> +{
> +}
> +#endif
> +
> +void free_pte_table(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> + struct vm_area_struct vma = TLB_FLUSH_VMA(mm, 0);
> + spinlock_t *ptl;
> + pmd_t pmd;
> +
> + ptl = pmd_lock(mm, pmdp);
> + pmd = pmdp_huge_get_and_clear(mm, addr, pmdp);
> + spin_unlock(ptl);
> +
> + pte_free_debug(pmd);
> + flush_tlb_range(&vma, addr, addr + PMD_SIZE);
> + mm_dec_nr_ptes(mm);
> + pte_free(mm, pmd_pgtable(pmd));
> +}
> +
> +static inline void __pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> + mm_inc_nr_ptes(mm);
> + /*
> + * Ensure all pte setup (eg. pte page lock and page clearing) are
> + * visible before the pte is made visible to other CPUs by being
> + * put into page tables.
> + *
> + * The other side of the story is the pointer chasing in the page
> + * table walking code (when walking the page table without locking;
> + * ie. most of the time). Fortunately, these data accesses consist
> + * of a chain of data-dependent loads, meaning most CPUs (alpha
> + * being the notable exception) will already guarantee loads are
> + * seen in-order. See the alpha page table accessors for the
> + * smp_rmb() barriers in page table walking code.
> + */
> + smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
> + pmd_populate(mm, pmd, *pte);
> + pte_ref_init(*pte, pmd, 1);
> + *pte = NULL;
> +}
> +
> +/*
> + * returns true if the pmd has been populated with PTE page table,
> + * or false for all other cases.
> + */
> +bool pte_install_try_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> + spinlock_t *ptl;
> + bool retval = true;
> +
> +retry:
> + ptl = pmd_lock(mm, pmd);
> + if (likely(pmd_none(*pmd))) {
> + __pte_install(mm, pmd, pte);
> + } else if (pmd_leaf(*pmd) || !pmd_present(*pmd)) {
> + retval = false;
> + } else if (!pte_get_unless_zero(pmd)) {
> + spin_unlock(ptl);
> + goto retry;
> + }
> + spin_unlock(ptl);
> + return retval;
> +}
> +
Can pte_get_unless_zero() return true above? Can the pmd have been by populated by others? In that case the ref count is wrongly incremented.
> +static void pte_install_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> + spinlock_t *ptl;
> +
> +retry:
> + ptl = pmd_lock(mm, pmd);
> + if (likely(pmd_none(*pmd))) {
> + __pte_install(mm, pmd, pte);
> + } else if (!pte_get_unless_zero(pmd)) {
> + spin_unlock(ptl);
> + goto retry;
> + }
> + spin_unlock(ptl);
> +}
> +
> +/*
> + * returns -ENOMEM if memory allocation failed, or 1 if the pmd
> + * has been populated with PTE page table, or 0 for all other cases.
> + */
> +int __pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmd)
> +{
> + int retval;
> + pgtable_t new = pte_alloc_one(mm);
> + if (!new)
> + return -ENOMEM;
> +
> + retval = pte_install_try_get(mm, pmd, &new);
> + if (new)
> + pte_free(mm, new);
> + return retval;
> +}
> +
> +int __pte_alloc_get(struct mm_struct *mm, pmd_t *pmd)
> +{
> + pgtable_t new = pte_alloc_one(mm);
> + if (!new)
> + return -ENOMEM;
> +
> + pte_install_get(mm, pmd, &new);
> + if (new)
> + pte_free(mm, new);
> + return 0;
> +}
> diff --git a/mm/rmap.c b/mm/rmap.c
> index fed7c4df25f2..8c10dbca02d4 100644
> --- a/mm/rmap.c
> +++ b/mm/rmap.c
> @@ -1402,6 +1402,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> bool ret = true;
> struct mmu_notifier_range range;
> enum ttu_flags flags = (enum ttu_flags)(long)arg;
> + unsigned int nr_put = 0;
>
> /*
> * When racing against e.g. zap_pte_range() on another cpu,
> @@ -1551,6 +1552,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> /* We have to invalidate as we cleared the pte */
> mmu_notifier_invalidate_range(mm, address,
> address + PAGE_SIZE);
> + nr_put++;
> } else if (PageAnon(page)) {
> swp_entry_t entry = { .val = page_private(subpage) };
> pte_t swp_pte;
> @@ -1564,6 +1566,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> /* We have to invalidate as we cleared the pte */
> mmu_notifier_invalidate_range(mm, address,
> address + PAGE_SIZE);
> + nr_put++;
> page_vma_mapped_walk_done(&pvmw);
> break;
> }
> @@ -1575,6 +1578,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> mmu_notifier_invalidate_range(mm,
> address, address + PAGE_SIZE);
> dec_mm_counter(mm, MM_ANONPAGES);
> + nr_put++;
> goto discard;
> }
>
> @@ -1630,6 +1634,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> * See Documentation/vm/mmu_notifier.rst
> */
> dec_mm_counter(mm, mm_counter_file(page));
> + nr_put++;
> }
> discard:
> /*
> @@ -1644,7 +1649,8 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
> }
>
> mmu_notifier_invalidate_range_end(&range);
> -
> + if (nr_put)
> + pte_put_many(mm, pvmw.pmd, address, nr_put);
> return ret;
> }
>
> @@ -1705,6 +1711,7 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
> bool ret = true;
> struct mmu_notifier_range range;
> enum ttu_flags flags = (enum ttu_flags)(long)arg;
> + unsigned int nr_put = 0;
>
> if (is_zone_device_page(page) && !is_device_private_page(page))
> return true;
> @@ -1871,6 +1878,7 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
> /* We have to invalidate as we cleared the pte */
> mmu_notifier_invalidate_range(mm, address,
> address + PAGE_SIZE);
> + nr_put++;
> } else {
> swp_entry_t entry;
> pte_t swp_pte;
> @@ -1919,6 +1927,9 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
>
> mmu_notifier_invalidate_range_end(&range);
>
> + if (nr_put)
> + pte_put_many(mm, pvmw.pmd, address, nr_put);
> +
> return ret;
> }
>
> diff --git a/mm/swapfile.c b/mm/swapfile.c
> index 1e07d1c776f2..6153283be500 100644
> --- a/mm/swapfile.c
> +++ b/mm/swapfile.c
> @@ -40,6 +40,7 @@
> #include <linux/swap_slots.h>
> #include <linux/sort.h>
> #include <linux/completion.h>
> +#include <linux/pte_ref.h>
>
> #include <asm/tlbflush.h>
> #include <linux/swapops.h>
> @@ -2021,10 +2022,13 @@ static inline int unuse_pmd_range(struct vm_area_struct *vma, pud_t *pud,
> do {
> cond_resched();
> next = pmd_addr_end(addr, end);
> - if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> + if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> + !pte_try_get(vma->vm_mm, pmd))
> continue;
> ret = unuse_pte_range(vma, pmd, addr, next, type,
> frontswap, fs_pages_to_unuse);
> + pte_put(vma->vm_mm, pmd, addr);
> if (ret)
> return ret;
> } while (pmd++, addr = next, addr != end);
> diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> index 0e2132834bc7..7ebf4fb09a85 100644
> --- a/mm/userfaultfd.c
> +++ b/mm/userfaultfd.c
> @@ -111,6 +111,7 @@ int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
> lru_cache_add_inactive_or_unevictable(page, dst_vma);
>
> set_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> + pte_get(dst_pmd);
>
> /* No need to invalidate - it was non-present before */
> update_mmu_cache(dst_vma, dst_addr, dst_pte);
> @@ -205,6 +206,7 @@ static int mfill_zeropage_pte(struct mm_struct *dst_mm,
> if (!pte_none(*dst_pte))
> goto out_unlock;
> set_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> + pte_get(dst_pmd);
> /* No need to invalidate - it was non-present before */
> update_mmu_cache(dst_vma, dst_addr, dst_pte);
> ret = 0;
> @@ -570,6 +572,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>
> while (src_addr < src_start + len) {
> pmd_t dst_pmdval;
> + int ret = 1;
>
> BUG_ON(dst_addr >= dst_start + len);
>
> @@ -588,13 +591,14 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
> err = -EEXIST;
> break;
> }
> - if (unlikely(pmd_none(dst_pmdval)) &&
> - unlikely(__pte_alloc(dst_mm, dst_pmd))) {
> +
> + if ((IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> + unlikely((ret = pte_alloc_try_get(dst_mm, dst_pmd)) < 0)) ||
> + (unlikely(pmd_none(dst_pmdval)) &&
> + unlikely((ret = __pte_alloc_try_get(dst_mm, dst_pmd)) < 0))) {
> err = -ENOMEM;
> break;
> - }
> - /* If an huge pmd materialized from under us fail */
> - if (unlikely(pmd_trans_huge(*dst_pmd))) {
> + } else if (!ret || unlikely(pmd_trans_huge(*dst_pmd))) {
> err = -EFAULT;
> break;
> }
> @@ -604,6 +608,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>
> err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
> src_addr, &page, mcopy_mode, wp_copy);
> + pte_put(dst_mm, dst_pmd, dst_addr);
> cond_resched();
>
> if (unlikely(err == -ENOENT)) {
Powered by blists - more mailing lists