lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <9c3c87d5-e64e-f13f-ef36-b438e4de1e66@nextfour.com>
Date:   Sun, 18 Jul 2021 09:19:46 +0300
From:   Mika Penttilä <mika.penttila@...tfour.com>
To:     Qi Zheng <zhengqi.arch@...edance.com>, akpm@...ux-foundation.org,
        tglx@...utronix.de, hannes@...xchg.org, mhocko@...nel.org,
        vdavydov.dev@...il.com
Cc:     linux-doc@...r.kernel.org, linux-kernel@...r.kernel.org,
        linux-mm@...ck.org, songmuchun@...edance.com
Subject: Re: [PATCH 5/7] mm: free user PTE page table pages

Hi,


On 18.7.2021 7.30, Qi Zheng wrote:
> Some malloc libraries(e.g. jemalloc or tcmalloc) usually
> allocate the amount of VAs by mmap() and do not unmap
> those VAs. They will use madvise(MADV_DONTNEED) to free
> physical memory if they want. But the page tables do not
> be freed by madvise(), so it can produce many page tables
> when the process touches an enormous virtual address space.
>
> The following figures are a memory usage snapshot of one
> process which actually happened on our server:
>
>          VIRT:  55t
>          RES:   590g
>          VmPTE: 110g
>
> As we can see, the PTE page tables size is 110g, while the
> RES is 590g. In theory, the process only need 1.2g PTE page
> tables to map those physical memory. The reason why PTE page
> tables occupy a lot of memory is that madvise(MADV_DONTNEED)
> only empty the PTE and free physical memory but doesn't free
> the PTE page table pages. So we can free those empty PTE page
> tables to save memory. In the above cases, we can save memory
> about 108g(best case). And the larger the difference between
> the size of VIRT and RES, the more memory we save.
>
> In this patch series, we add a pte_refcount field to the
> struct page of page table to track how many users of PTE page
> table. Similar to the mechanism of page refcount, the user of
> PTE page table should hold a refcount to it before accessing.
> The PTE page table page will be freed when the last refcount
> is dropped.
>
> Signed-off-by: Qi Zheng <zhengqi.arch@...edance.com>
> ---
>   Documentation/vm/split_page_table_lock.rst |   2 +-
>   arch/arm/mm/pgd.c                          |   2 +-
>   arch/arm64/mm/hugetlbpage.c                |   4 +-
>   arch/ia64/mm/hugetlbpage.c                 |   2 +-
>   arch/parisc/mm/hugetlbpage.c               |   2 +-
>   arch/powerpc/mm/hugetlbpage.c              |   2 +-
>   arch/s390/mm/gmap.c                        |   8 +-
>   arch/s390/mm/pgtable.c                     |   6 +-
>   arch/sh/mm/hugetlbpage.c                   |   2 +-
>   arch/sparc/mm/hugetlbpage.c                |   2 +-
>   arch/x86/kernel/tboot.c                    |   2 +-
>   fs/proc/task_mmu.c                         |  23 ++-
>   fs/userfaultfd.c                           |   2 +
>   include/linux/mm.h                         |  13 +-
>   include/linux/mm_types.h                   |   8 +-
>   include/linux/pgtable.h                    |   3 +-
>   include/linux/pte_ref.h                    | 217 +++++++++++++++++++++++
>   include/linux/rmap.h                       |   3 +
>   kernel/events/uprobes.c                    |   3 +
>   mm/Kconfig                                 |   4 +
>   mm/Makefile                                |   3 +-
>   mm/debug_vm_pgtable.c                      |   3 +-
>   mm/filemap.c                               |  45 ++---
>   mm/gup.c                                   |  10 +-
>   mm/hmm.c                                   |   4 +
>   mm/internal.h                              |   2 +
>   mm/khugepaged.c                            |  10 ++
>   mm/ksm.c                                   |   4 +
>   mm/madvise.c                               |  20 ++-
>   mm/memcontrol.c                            |  11 +-
>   mm/memory.c                                | 271 +++++++++++++++++++----------
>   mm/mempolicy.c                             |   5 +-
>   mm/migrate.c                               |  21 ++-
>   mm/mincore.c                               |   6 +-
>   mm/mlock.c                                 |   1 +
>   mm/mprotect.c                              |  10 +-
>   mm/mremap.c                                |  12 +-
>   mm/page_vma_mapped.c                       |   4 +
>   mm/pagewalk.c                              |  19 +-
>   mm/pgtable-generic.c                       |   2 +
>   mm/pte_ref.c                               | 132 ++++++++++++++
>   mm/rmap.c                                  |  13 +-
>   mm/swapfile.c                              |   6 +-
>   mm/userfaultfd.c                           |  15 +-
>   44 files changed, 758 insertions(+), 181 deletions(-)
>   create mode 100644 include/linux/pte_ref.h
>   create mode 100644 mm/pte_ref.c
>
> diff --git a/Documentation/vm/split_page_table_lock.rst b/Documentation/vm/split_page_table_lock.rst
> index c08919662704..98eb7ba0d2ab 100644
> --- a/Documentation/vm/split_page_table_lock.rst
> +++ b/Documentation/vm/split_page_table_lock.rst
> @@ -20,7 +20,7 @@ There are helpers to lock/unlock a table and other accessor functions:
>   	lock;
>    - pte_unmap_unlock()
>   	unlocks and unmaps PTE table;
> - - pte_alloc_map_lock()
> + - pte_alloc_get_map_lock()
>   	allocates PTE table if needed and take the lock, returns pointer
>   	to taken lock or NULL if allocation failed;
>    - pte_lockptr()
> diff --git a/arch/arm/mm/pgd.c b/arch/arm/mm/pgd.c
> index f8e9bc58a84f..b2408ad9dcf5 100644
> --- a/arch/arm/mm/pgd.c
> +++ b/arch/arm/mm/pgd.c
> @@ -100,7 +100,7 @@ pgd_t *pgd_alloc(struct mm_struct *mm)
>   		if (!new_pmd)
>   			goto no_pmd;
>   
> -		new_pte = pte_alloc_map(mm, new_pmd, 0);
> +		new_pte = pte_alloc_get_map(mm, new_pmd, 0);
>   		if (!new_pte)
>   			goto no_pte;
>   
> diff --git a/arch/arm64/mm/hugetlbpage.c b/arch/arm64/mm/hugetlbpage.c
> index 23505fc35324..54f6beb3eb6b 100644
> --- a/arch/arm64/mm/hugetlbpage.c
> +++ b/arch/arm64/mm/hugetlbpage.c
> @@ -280,9 +280,9 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   		 * 32-bit arm platform then it will cause trouble in
>   		 * the case where CONFIG_HIGHPTE is set, since there
>   		 * will be no pte_unmap() to correspond with this
> -		 * pte_alloc_map().
> +		 * pte_alloc_get_map().
>   		 */
> -		ptep = pte_alloc_map(mm, pmdp, addr);
> +		ptep = pte_alloc_get_map(mm, pmdp, addr);
>   	} else if (sz == PMD_SIZE) {
>   		if (want_pmd_share(vma, addr) && pud_none(READ_ONCE(*pudp)))
>   			ptep = huge_pmd_share(mm, vma, addr, pudp);
> diff --git a/arch/ia64/mm/hugetlbpage.c b/arch/ia64/mm/hugetlbpage.c
> index f993cb36c062..cb230005e7dd 100644
> --- a/arch/ia64/mm/hugetlbpage.c
> +++ b/arch/ia64/mm/hugetlbpage.c
> @@ -41,7 +41,7 @@ huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   	if (pud) {
>   		pmd = pmd_alloc(mm, pud, taddr);
>   		if (pmd)
> -			pte = pte_alloc_map(mm, pmd, taddr);
> +			pte = pte_alloc_get_map(mm, pmd, taddr);
>   	}
>   	return pte;
>   }
> diff --git a/arch/parisc/mm/hugetlbpage.c b/arch/parisc/mm/hugetlbpage.c
> index d1d3990b83f6..ff16db9a44a5 100644
> --- a/arch/parisc/mm/hugetlbpage.c
> +++ b/arch/parisc/mm/hugetlbpage.c
> @@ -66,7 +66,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   	if (pud) {
>   		pmd = pmd_alloc(mm, pud, addr);
>   		if (pmd)
> -			pte = pte_alloc_map(mm, pmd, addr);
> +			pte = pte_alloc_get_map(mm, pmd, addr);
>   	}
>   	return pte;
>   }
> diff --git a/arch/powerpc/mm/hugetlbpage.c b/arch/powerpc/mm/hugetlbpage.c
> index 9a75ba078e1b..20af2db18d08 100644
> --- a/arch/powerpc/mm/hugetlbpage.c
> +++ b/arch/powerpc/mm/hugetlbpage.c
> @@ -182,7 +182,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   		return NULL;
>   
>   	if (IS_ENABLED(CONFIG_PPC_8xx) && pshift < PMD_SHIFT)
> -		return pte_alloc_map(mm, (pmd_t *)hpdp, addr);
> +		return pte_alloc_get_map(mm, (pmd_t *)hpdp, addr);
>   
>   	BUG_ON(!hugepd_none(*hpdp) && !hugepd_ok(*hpdp));
>   
> diff --git a/arch/s390/mm/gmap.c b/arch/s390/mm/gmap.c
> index 9bb2c7512cd5..b243b276d9b6 100644
> --- a/arch/s390/mm/gmap.c
> +++ b/arch/s390/mm/gmap.c
> @@ -856,7 +856,7 @@ static pte_t *gmap_pte_op_walk(struct gmap *gmap, unsigned long gaddr,
>   	table = gmap_table_walk(gmap, gaddr, 1); /* get segment pointer */
>   	if (!table || *table & _SEGMENT_ENTRY_INVALID)
>   		return NULL;
> -	return pte_alloc_map_lock(gmap->mm, (pmd_t *) table, gaddr, ptl);
> +	return pte_alloc_get_map_lock(gmap->mm, (pmd_t *) table, gaddr, ptl);
>   }
>   
>   /**
> @@ -925,7 +925,7 @@ static inline pmd_t *gmap_pmd_op_walk(struct gmap *gmap, unsigned long gaddr)
>   		return NULL;
>   	}
>   
> -	/* 4k page table entries are locked via the pte (pte_alloc_map_lock). */
> +	/* 4k page table entries are locked via the pte (pte_alloc_get_map_lock). */
>   	if (!pmd_large(*pmdp))
>   		spin_unlock(&gmap->guest_table_lock);
>   	return pmdp;
> @@ -1012,7 +1012,7 @@ static int gmap_protect_pte(struct gmap *gmap, unsigned long gaddr,
>   	if (pmd_val(*pmdp) & _SEGMENT_ENTRY_INVALID)
>   		return -EAGAIN;
>   
> -	ptep = pte_alloc_map_lock(gmap->mm, pmdp, gaddr, &ptl);
> +	ptep = pte_alloc_get_map_lock(gmap->mm, pmdp, gaddr, &ptl);
>   	if (!ptep)
>   		return -ENOMEM;
>   
> @@ -2473,7 +2473,7 @@ void gmap_sync_dirty_log_pmd(struct gmap *gmap, unsigned long bitmap[4],
>   			bitmap_fill(bitmap, _PAGE_ENTRIES);
>   	} else {
>   		for (i = 0; i < _PAGE_ENTRIES; i++, vmaddr += PAGE_SIZE) {
> -			ptep = pte_alloc_map_lock(gmap->mm, pmdp, vmaddr, &ptl);
> +			ptep = pte_alloc_get_map_lock(gmap->mm, pmdp, vmaddr, &ptl);
>   			if (!ptep)
>   				continue;
>   			if (ptep_test_and_clear_uc(gmap->mm, vmaddr, ptep))
> diff --git a/arch/s390/mm/pgtable.c b/arch/s390/mm/pgtable.c
> index eec3a9d7176e..82217a753751 100644
> --- a/arch/s390/mm/pgtable.c
> +++ b/arch/s390/mm/pgtable.c
> @@ -801,7 +801,7 @@ int set_guest_storage_key(struct mm_struct *mm, unsigned long addr,
>   	}
>   	spin_unlock(ptl);
>   
> -	ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> +	ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
>   	if (unlikely(!ptep))
>   		return -EFAULT;
>   
> @@ -900,7 +900,7 @@ int reset_guest_reference_bit(struct mm_struct *mm, unsigned long addr)
>   	}
>   	spin_unlock(ptl);
>   
> -	ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> +	ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
>   	if (unlikely(!ptep))
>   		return -EFAULT;
>   
> @@ -956,7 +956,7 @@ int get_guest_storage_key(struct mm_struct *mm, unsigned long addr,
>   	}
>   	spin_unlock(ptl);
>   
> -	ptep = pte_alloc_map_lock(mm, pmdp, addr, &ptl);
> +	ptep = pte_alloc_get_map_lock(mm, pmdp, addr, &ptl);
>   	if (unlikely(!ptep))
>   		return -EFAULT;
>   
> diff --git a/arch/sh/mm/hugetlbpage.c b/arch/sh/mm/hugetlbpage.c
> index 999ab5916e69..ea7fa277952b 100644
> --- a/arch/sh/mm/hugetlbpage.c
> +++ b/arch/sh/mm/hugetlbpage.c
> @@ -38,7 +38,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   			if (pud) {
>   				pmd = pmd_alloc(mm, pud, addr);
>   				if (pmd)
> -					pte = pte_alloc_map(mm, pmd, addr);
> +					pte = pte_alloc_get_map(mm, pmd, addr);
>   			}
>   		}
>   	}
> diff --git a/arch/sparc/mm/hugetlbpage.c b/arch/sparc/mm/hugetlbpage.c
> index 0f49fada2093..599c04b54205 100644
> --- a/arch/sparc/mm/hugetlbpage.c
> +++ b/arch/sparc/mm/hugetlbpage.c
> @@ -297,7 +297,7 @@ pte_t *huge_pte_alloc(struct mm_struct *mm, struct vm_area_struct *vma,
>   		return NULL;
>   	if (sz >= PMD_SIZE)
>   		return (pte_t *)pmd;
> -	return pte_alloc_map(mm, pmd, addr);
> +	return pte_alloc_get_map(mm, pmd, addr);
>   }
>   
>   pte_t *huge_pte_offset(struct mm_struct *mm,
> diff --git a/arch/x86/kernel/tboot.c b/arch/x86/kernel/tboot.c
> index f9af561c3cd4..f2210bf3d357 100644
> --- a/arch/x86/kernel/tboot.c
> +++ b/arch/x86/kernel/tboot.c
> @@ -131,7 +131,7 @@ static int map_tboot_page(unsigned long vaddr, unsigned long pfn,
>   	pmd = pmd_alloc(&tboot_mm, pud, vaddr);
>   	if (!pmd)
>   		return -1;
> -	pte = pte_alloc_map(&tboot_mm, pmd, vaddr);
> +	pte = pte_alloc_get_map(&tboot_mm, pmd, vaddr);
>   	if (!pte)
>   		return -1;
>   	set_pte_at(&tboot_mm, vaddr, pte, pfn_pte(pfn, prot));
> diff --git a/fs/proc/task_mmu.c b/fs/proc/task_mmu.c
> index eb97468dfe4c..b3cf4b8a91d6 100644
> --- a/fs/proc/task_mmu.c
> +++ b/fs/proc/task_mmu.c
> @@ -574,6 +574,7 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   	struct vm_area_struct *vma = walk->vma;
>   	pte_t *pte;
>   	spinlock_t *ptl;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -582,7 +583,8 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   		goto out;
>   	}
>   
> -	if (pmd_trans_unstable(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> +	    !pte_try_get(vma->vm_mm, pmd))
>   		goto out;
>   	/*
>   	 * The mmap_lock held all the way back in m_start() is what
> @@ -593,6 +595,7 @@ static int smaps_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   	for (; addr != end; pte++, addr += PAGE_SIZE)
>   		smaps_pte_entry(pte, addr, walk);
>   	pte_unmap_unlock(pte - 1, ptl);
> +	pte_put(vma->vm_mm, pmd, start);
>   out:
>   	cond_resched();
>   	return 0;
> @@ -1121,6 +1124,7 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
>   	pte_t *pte, ptent;
>   	spinlock_t *ptl;
>   	struct page *page;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -1143,7 +1147,8 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
>   		return 0;
>   	}
>   
> -	if (pmd_trans_unstable(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> +	    !pte_try_get(vma->vm_mm, pmd))
>   		return 0;
>   
>   	pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
> @@ -1168,6 +1173,7 @@ static int clear_refs_pte_range(pmd_t *pmd, unsigned long addr,
>   		ClearPageReferenced(page);
>   	}
>   	pte_unmap_unlock(pte - 1, ptl);
> +	pte_put(vma->vm_mm, pmd, start);
>   	cond_resched();
>   	return 0;
>   }
> @@ -1407,6 +1413,7 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
>   	spinlock_t *ptl;
>   	pte_t *pte, *orig_pte;
>   	int err = 0;
> +	unsigned long start = addr;
>   
>   #ifdef CONFIG_TRANSPARENT_HUGEPAGE
>   	ptl = pmd_trans_huge_lock(pmdp, vma);
> @@ -1471,10 +1478,13 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
>   		return err;
>   	}
>   
> -	if (pmd_trans_unstable(pmdp))
> +	if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmdp))
>   		return 0;
>   #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
>   
> +	if (!pte_try_get(walk->mm, pmdp))
> +		return 0;
> +
>   	/*
>   	 * We can assume that @vma always points to a valid one and @end never
>   	 * goes beyond vma->vm_end.
> @@ -1489,6 +1499,7 @@ static int pagemap_pmd_range(pmd_t *pmdp, unsigned long addr, unsigned long end,
>   			break;
>   	}
>   	pte_unmap_unlock(orig_pte, ptl);
> +	pte_put(walk->mm, pmdp, start);
>   
>   	cond_resched();
>   
> @@ -1795,6 +1806,7 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
>   	spinlock_t *ptl;
>   	pte_t *orig_pte;
>   	pte_t *pte;
> +	unsigned long start = addr;
>   
>   #ifdef CONFIG_TRANSPARENT_HUGEPAGE
>   	ptl = pmd_trans_huge_lock(pmd, vma);
> @@ -1809,9 +1821,11 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
>   		return 0;
>   	}
>   
> -	if (pmd_trans_unstable(pmd))
> +	if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
>   		return 0;
>   #endif
> +	if (!pte_try_get(walk->mm, pmd))
> +		return 0;
>   	orig_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
>   	do {
>   		struct page *page = can_gather_numa_stats(*pte, vma, addr);
> @@ -1821,6 +1835,7 @@ static int gather_pte_stats(pmd_t *pmd, unsigned long addr,
>   
>   	} while (pte++, addr += PAGE_SIZE, addr != end);
>   	pte_unmap_unlock(orig_pte, ptl);
> +	pte_put(walk->mm, pmd, start);
>   	cond_resched();
>   	return 0;
>   }
> diff --git a/fs/userfaultfd.c b/fs/userfaultfd.c
> index f6e0f0c0d0e5..4fc6c3bafd70 100644
> --- a/fs/userfaultfd.c
> +++ b/fs/userfaultfd.c
> @@ -28,6 +28,7 @@
>   #include <linux/ioctl.h>
>   #include <linux/security.h>
>   #include <linux/hugetlb.h>
> +#include <linux/pte_ref.h>
>   
>   int sysctl_unprivileged_userfaultfd __read_mostly;
>   
> @@ -508,6 +509,7 @@ vm_fault_t handle_userfault(struct vm_fault *vmf, unsigned long reason)
>   		must_wait = userfaultfd_huge_must_wait(ctx, vmf->vma,
>   						       vmf->address,
>   						       vmf->flags, reason);
> +	pte_put_vmf(vmf);
>   	mmap_read_unlock(mm);
>   
>   	if (likely(must_wait && !READ_ONCE(ctx->released))) {
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index 5efd63a20d72..befe823b4918 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -447,6 +447,7 @@ extern pgprot_t protection_map[16];
>    * @FAULT_FLAG_REMOTE: The fault is not for current task/mm.
>    * @FAULT_FLAG_INSTRUCTION: The fault was during an instruction fetch.
>    * @FAULT_FLAG_INTERRUPTIBLE: The fault can be interrupted by non-fatal signals.
> + * @FAULT_FLAG_PTE_GET: Indicates that pte has been get.
>    *
>    * About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
>    * whether we would allow page faults to retry by specifying these two
> @@ -478,6 +479,7 @@ enum fault_flag {
>   	FAULT_FLAG_REMOTE =		1 << 7,
>   	FAULT_FLAG_INSTRUCTION =	1 << 8,
>   	FAULT_FLAG_INTERRUPTIBLE =	1 << 9,
> +	FAULT_FLAG_PTE_GET =		1 << 10,
>   };
>   
>   /*
> @@ -2148,7 +2150,6 @@ static inline void mm_inc_nr_ptes(struct mm_struct *mm) {}
>   static inline void mm_dec_nr_ptes(struct mm_struct *mm) {}
>   #endif
>   
> -int __pte_alloc(struct mm_struct *mm, pmd_t *pmd);
>   int __pte_alloc_kernel(pmd_t *pmd);
>   
>   #if defined(CONFIG_MMU)
> @@ -2274,15 +2275,6 @@ static inline void pgtable_pte_page_dtor(struct page *page)
>   	pte_unmap(pte);					\
>   } while (0)
>   
> -#define pte_alloc(mm, pmd) (unlikely(pmd_none(*(pmd))) && __pte_alloc(mm, pmd))
> -
> -#define pte_alloc_map(mm, pmd, address)			\
> -	(pte_alloc(mm, pmd) ? NULL : pte_offset_map(pmd, address))
> -
> -#define pte_alloc_map_lock(mm, pmd, address, ptlp)	\
> -	(pte_alloc(mm, pmd) ?			\
> -		 NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
> -
>   #define pte_alloc_kernel(pmd, address)			\
>   	((unlikely(pmd_none(*(pmd))) && __pte_alloc_kernel(pmd))? \
>   		NULL: pte_offset_kernel(pmd, address))
> @@ -2374,7 +2366,6 @@ static inline spinlock_t *pud_lock(struct mm_struct *mm, pud_t *pud)
>   	return ptl;
>   }
>   
> -extern void pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte);
>   extern void __init pagecache_init(void);
>   extern void __init free_area_init_memoryless_node(int nid);
>   extern void free_initmem(void);
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index f37abb2d222e..eed4a5db59ea 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -153,11 +153,17 @@ struct page {
>   		};
>   		struct {	/* Page table pages */
>   			unsigned long _pt_pad_1;	/* compound_head */
> -			pgtable_t pmd_huge_pte; /* protected by page->ptl */
> +			union {
> +				pgtable_t pmd_huge_pte; /* protected by page->ptl */
> +				pmd_t *pmd;             /* PTE page only */
> +			};
>   			unsigned long _pt_pad_2;	/* mapping */
>   			union {
>   				struct mm_struct *pt_mm; /* x86 pgds only */
>   				atomic_t pt_frag_refcount; /* powerpc */
> +#ifdef CONFIG_FREE_USER_PTE
> +				atomic_t pte_refcount;  /* PTE page only */
> +#endif
>   			};
>   #if USE_SPLIT_PTE_PTLOCKS
>   #if ALLOC_SPLIT_PTLOCKS
> diff --git a/include/linux/pgtable.h b/include/linux/pgtable.h
> index d147480cdefc..172bb63b7ed9 100644
> --- a/include/linux/pgtable.h
> +++ b/include/linux/pgtable.h
> @@ -331,7 +331,6 @@ static inline pte_t ptep_get_lockless(pte_t *ptep)
>   }
>   #endif /* CONFIG_GUP_GET_PTE_LOW_HIGH */
>   
> -#ifdef CONFIG_TRANSPARENT_HUGEPAGE
>   #ifndef __HAVE_ARCH_PMDP_HUGE_GET_AND_CLEAR
>   static inline pmd_t pmdp_huge_get_and_clear(struct mm_struct *mm,
>   					    unsigned long address,
> @@ -342,6 +341,8 @@ static inline pmd_t pmdp_huge_get_and_clear(struct mm_struct *mm,
>   	return pmd;
>   }
>   #endif /* __HAVE_ARCH_PMDP_HUGE_GET_AND_CLEAR */
> +
> +#ifdef CONFIG_TRANSPARENT_HUGEPAGE
>   #ifndef __HAVE_ARCH_PUDP_HUGE_GET_AND_CLEAR
>   static inline pud_t pudp_huge_get_and_clear(struct mm_struct *mm,
>   					    unsigned long address,
> diff --git a/include/linux/pte_ref.h b/include/linux/pte_ref.h
> new file mode 100644
> index 000000000000..695fbe8b991b
> --- /dev/null
> +++ b/include/linux/pte_ref.h
> @@ -0,0 +1,217 @@
> +// SPDX-License-Identifier: GPL-2.0
> +/*
> + * Free user PTE page table pages
> + *
> + * Copyright (c) 2021, ByteDance. All rights reserved.
> + *
> + * 	Author: Qi Zheng <zhengqi.arch@...edance.com>
> + */
> +#ifndef _LINUX_PTE_REF_H
> +#define _LINUX_PTE_REF_H
> +
> +#include <linux/mm.h>
> +#include <linux/pgtable.h>
> +#include <asm/pgalloc.h>
> +
> +bool pte_install_try_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte);
> +int __pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmd);
> +int __pte_alloc_get(struct mm_struct *mm, pmd_t *pmd);
> +
> +#ifdef CONFIG_FREE_USER_PTE
> +void free_pte_table(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr);
> +
> +static inline void pte_ref_init(pgtable_t pte, pmd_t *pmd, int count)
> +{
> +	pte->pmd = pmd;
> +	atomic_set(&pte->pte_refcount, count);
> +}
> +
> +static inline pmd_t *pte_to_pmd(pte_t *pte)
> +{
> +	return virt_to_page(pte)->pmd;
> +}
> +
> +static inline void pte_migrate_pmd(pmd_t old_pmd, pmd_t *new_pmd)
> +{
> +	pmd_pgtable(old_pmd)->pmd = new_pmd;
> +}
> +
> +/*
> + * Get the reference to the PTE page table to prevent it from being
> + * release.
> + *
> + * The caller should already hold a reference to PTE page table by
> + * calling pte_try_get(), and then this function is safe to use under
> + * mmap_lock or anon_lock or i_mmap_lock or when parallel threads are
> + * excluded by other means which can make @pmdp entry stable.
> + */
> +static inline void pte_get_many(pmd_t *pmdp, unsigned int nr)
> +{
> +	pgtable_t pte = pmd_pgtable(*pmdp);
> +
> +	VM_BUG_ON(pte->pmd != pmdp);
> +	atomic_add(nr, &pte->pte_refcount);
> +}
> +
> +static inline void pte_get(pmd_t *pmdp)
> +{
> +	pte_get_many(pmdp, 1);
> +}
> +
> +static inline bool pte_get_unless_zero(pmd_t *pmdp)
> +{
> +	pgtable_t pte = pmd_pgtable(*pmdp);
> +
> +	VM_BUG_ON(!PageTable(pte));
> +	return atomic_inc_not_zero(&pte->pte_refcount);
> +}
> +
> +/*
> + * Try to get a reference to the PTE page table to prevent it from
> + * being release.
> + *
> + * This function is safe to use under mmap_lock or anon_lock or
> + * i_mmap_lock or when parallel threads are excluded by other means
> + * which can make @pmdp entry stable.
> + */
> +static inline bool pte_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	bool retval = true;
> +	spinlock_t *ptl;
> +
> +	ptl = pmd_lock(mm, pmdp);
> +	if (pmd_leaf(*pmdp) || !pmd_present(*pmdp) ||
> +	    !pte_get_unless_zero(pmdp))
> +		retval = false;
> +	spin_unlock(ptl);
> +
> +	return retval;
> +}
> +
> +/*
> + * Put the reference to the PTE page table, and then the PTE page
> + * will be released when the reference is decreased to 0.
> + *
> + * This function is safe to use under mmap_lock or anon_lock or
> + * i_mmap_lock or when parallel threads are excluded by other means
> + * which can make @pmdp entry stable.
> + */
> +static inline void pte_put_many(struct mm_struct *mm, pmd_t *pmdp,
> +				unsigned long addr, unsigned int nr)
> +{
> +	pgtable_t pte = pmd_pgtable(*pmdp);
> +
> +	VM_BUG_ON(mm == &init_mm);
> +	VM_BUG_ON(pmd_devmap_trans_unstable(pmdp));
> +	VM_BUG_ON(pte->pmd != pmdp);
> +	if (atomic_sub_and_test(nr, &pte->pte_refcount))
> +		free_pte_table(mm, pmdp, addr & PMD_MASK);
> +}
> +
> +static inline void pte_put(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> +	pte_put_many(mm, pmdp, addr, 1);
> +}
> +
> +/*
> + * The mmap_lock maybe unlocked in advance in some cases in
> + * handle_pte_fault(), so we should ensure the pte_put() is performed
> + * in the critical section of the mmap_lock.
> + */
> +static inline void pte_put_vmf(struct vm_fault *vmf)
> +{
> +	if (!(vmf->flags & FAULT_FLAG_PTE_GET))
> +		return;
> +	vmf->flags &= ~FAULT_FLAG_PTE_GET;
> +
> +	pte_put(vmf->vma->vm_mm, vmf->pmd, vmf->address);
> +}
> +
> +static inline int pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	if (!pte_try_get(mm, pmdp))
> +		return __pte_alloc_try_get(mm, pmdp);
> +	return 1;
> +}
> +
> +static inline int pte_alloc_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	spinlock_t *ptl;
> +
> +	ptl = pmd_lock(mm, pmdp);
> +	if (pmd_none(*pmdp) || !pte_get_unless_zero(pmdp)) {
> +		spin_unlock(ptl);
> +		return __pte_alloc_get(mm, pmdp);
> +	}
> +	spin_unlock(ptl);
> +	return 0;
> +}
> +#else
> +static inline void pte_ref_init(pgtable_t pte, pmd_t *pmd, int count)
> +{
> +}
> +
> +static inline pmd_t *pte_to_pmd(pte_t *pte)
> +{
> +	return NULL;
> +}
> +
> +static inline void pte_migrate_pmd(pmd_t old_pmd, pmd_t *new_pmd)
> +{
> +}
> +
> +static inline void pte_get_many(pmd_t *pmdp, unsigned int nr)
> +{
> +}
> +
> +static inline void pte_get(pmd_t *pmdp)
> +{
> +}
> +
> +static inline bool pte_get_unless_zero(pmd_t *pmdp)
> +{
> +	return true;
> +}
> +
> +static inline bool pte_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	return true;
> +}
> +
> +static inline void pte_put_many(struct mm_struct *mm, pmd_t *pmdp,
> +				unsigned long addr, unsigned int value)
> +{
> +}
> +
> +static inline void pte_put(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> +}
> +
> +static inline void pte_put_vmf(struct vm_fault *vmf)
> +{
> +}
> +
> +static inline int pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	if (unlikely(pmd_none(*pmdp)))
> +		return __pte_alloc_try_get(mm, pmdp);
> +	if (unlikely(pmd_devmap_trans_unstable(pmdp)))
> +		return 0;
> +	return 1;
> +}
> +
> +static inline int pte_alloc_get(struct mm_struct *mm, pmd_t *pmdp)
> +{
> +	if (unlikely(pmd_none(*pmdp)))
> +		return __pte_alloc_get(mm, pmdp);
> +	return 0;
> +}
> +#endif /* CONFIG_FREE_USER_PTE */
> +
> +#define pte_alloc_get_map(mm, pmd, address)		\
> +	(pte_alloc_get(mm, pmd) ? NULL : pte_offset_map(pmd, address))
> +
> +#define pte_alloc_get_map_lock(mm, pmd, address, ptlp)	\
> +	(pte_alloc_get(mm, pmd) ?			\
> +		 NULL : pte_offset_map_lock(mm, pmd, address, ptlp))
> +#endif
> diff --git a/include/linux/rmap.h b/include/linux/rmap.h
> index 83fb86133fe1..886411eccb55 100644
> --- a/include/linux/rmap.h
> +++ b/include/linux/rmap.h
> @@ -11,6 +11,7 @@
>   #include <linux/rwsem.h>
>   #include <linux/memcontrol.h>
>   #include <linux/highmem.h>
> +#include <linux/pte_ref.h>
>   
>   /*
>    * The anon_vma heads a list of private "related" vmas, to scan if
> @@ -220,6 +221,8 @@ static inline void page_vma_mapped_walk_done(struct page_vma_mapped_walk *pvmw)
>   		pte_unmap(pvmw->pte);
>   	if (pvmw->ptl)
>   		spin_unlock(pvmw->ptl);
> +	if (pvmw->pte && !PageHuge(pvmw->page))
> +		pte_put(pvmw->vma->vm_mm, pvmw->pmd, pvmw->address);
>   }
>   
>   bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw);
> diff --git a/kernel/events/uprobes.c b/kernel/events/uprobes.c
> index af24dc3febbe..2791190e1a01 100644
> --- a/kernel/events/uprobes.c
> +++ b/kernel/events/uprobes.c
> @@ -205,6 +205,9 @@ static int __replace_page(struct vm_area_struct *vma, unsigned long addr,
>   		try_to_free_swap(old_page);
>   	page_vma_mapped_walk_done(&pvmw);
>   
> +	if (!new_page)
> +		pte_put(mm, pte_to_pmd(pvmw.pte), addr);
> +
>   	if ((vma->vm_flags & VM_LOCKED) && !PageCompound(old_page))
>   		munlock_vma_page(old_page);
>   	put_page(old_page);
> diff --git a/mm/Kconfig b/mm/Kconfig
> index 5dc28e9205e0..745f6cdc5e9b 100644
> --- a/mm/Kconfig
> +++ b/mm/Kconfig
> @@ -889,4 +889,8 @@ config IO_MAPPING
>   config SECRETMEM
>   	def_bool ARCH_HAS_SET_DIRECT_MAP && !EMBEDDED
>   
> +config FREE_USER_PTE
> +	def_bool y
> +	depends on X86_64
> +
>   endmenu
> diff --git a/mm/Makefile b/mm/Makefile
> index e3436741d539..1ab513342d54 100644
> --- a/mm/Makefile
> +++ b/mm/Makefile
> @@ -38,7 +38,8 @@ mmu-y			:= nommu.o
>   mmu-$(CONFIG_MMU)	:= highmem.o memory.o mincore.o \
>   			   mlock.o mmap.o mmu_gather.o mprotect.o mremap.o \
>   			   msync.o page_vma_mapped.o pagewalk.o \
> -			   pgtable-generic.o rmap.o vmalloc.o ioremap.o
> +			   pgtable-generic.o rmap.o vmalloc.o ioremap.o \
> +			   pte_ref.o
>   
>   
>   ifdef CONFIG_CROSS_MEMORY_ATTACH
> diff --git a/mm/debug_vm_pgtable.c b/mm/debug_vm_pgtable.c
> index 1c922691aa61..8cae3b3329dc 100644
> --- a/mm/debug_vm_pgtable.c
> +++ b/mm/debug_vm_pgtable.c
> @@ -31,6 +31,7 @@
>   #include <linux/io.h>
>   #include <asm/pgalloc.h>
>   #include <asm/tlbflush.h>
> +#include <linux/pte_ref.h>
>   
>   /*
>    * Please refer Documentation/vm/arch_pgtable_helpers.rst for the semantics
> @@ -1018,7 +1019,7 @@ static int __init debug_vm_pgtable(void)
>   	/*
>   	 * Allocate pgtable_t
>   	 */
> -	if (pte_alloc(mm, pmdp)) {
> +	if (pte_alloc_try_get(mm, pmdp) < 0) {
>   		pr_err("pgtable allocation failed\n");
>   		return 1;
>   	}
> diff --git a/mm/filemap.c b/mm/filemap.c
> index db0184884890..024ca645c3a2 100644
> --- a/mm/filemap.c
> +++ b/mm/filemap.c
> @@ -1699,6 +1699,7 @@ int __lock_page_or_retry(struct page *page, struct vm_fault *vmf)
>   		if (flags & FAULT_FLAG_RETRY_NOWAIT)
>   			return 0;
>   
> +		pte_put_vmf(vmf);
>   		mmap_read_unlock(mm);
>   		if (flags & FAULT_FLAG_KILLABLE)
>   			wait_on_page_locked_killable(page);
> @@ -1711,6 +1712,7 @@ int __lock_page_or_retry(struct page *page, struct vm_fault *vmf)
>   
>   		ret = __lock_page_killable(page);
>   		if (ret) {
> +			pte_put_vmf(vmf);
>   			mmap_read_unlock(mm);
>   			return 0;
>   		}
> @@ -3160,32 +3162,30 @@ static bool filemap_map_pmd(struct vm_fault *vmf, struct page *page)
>   	struct mm_struct *mm = vmf->vma->vm_mm;
>   
>   	/* Huge page is mapped? No need to proceed. */
> -	if (pmd_trans_huge(*vmf->pmd)) {
> -		unlock_page(page);
> -		put_page(page);
> -		return true;
> -	}
> +	if (pmd_trans_huge(*vmf->pmd))
> +		goto out;
>   
>   	if (pmd_none(*vmf->pmd) && PageTransHuge(page)) {
> -	    vm_fault_t ret = do_set_pmd(vmf, page);
> -	    if (!ret) {
> -		    /* The page is mapped successfully, reference consumed. */
> -		    unlock_page(page);
> -		    return true;
> -	    }
> +		vm_fault_t ret = do_set_pmd(vmf, page);
> +		if (!ret) {
> +			/* The page is mapped successfully, reference consumed. */
> +			unlock_page(page);
> +			return true;
> +		}
>   	}
>   
> -	if (pmd_none(*vmf->pmd))
> -		pte_install(mm, vmf->pmd, &vmf->prealloc_pte);
> -
> -	/* See comment in handle_pte_fault() */
> -	if (pmd_devmap_trans_unstable(vmf->pmd)) {
> -		unlock_page(page);
> -		put_page(page);
> -		return true;
> +	if (IS_ENABLED(CONFIG_FREE_USER_PTE) || pmd_none(*vmf->pmd)) {
> +		if (!pte_install_try_get(mm, vmf->pmd, &vmf->prealloc_pte))
> +			goto out;
> +	} else if (pmd_devmap_trans_unstable(vmf->pmd)) { /* See comment in handle_pte_fault() */
> +		goto out;
>   	}
>   
>   	return false;
> +out:
> +	unlock_page(page);
> +	put_page(page);
> +	return true;
>   }
>   
>   static struct page *next_uptodate_page(struct page *page,
> @@ -3259,6 +3259,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
>   	struct page *head, *page;
>   	unsigned int mmap_miss = READ_ONCE(file->f_ra.mmap_miss);
>   	vm_fault_t ret = 0;
> +	unsigned int nr_get = 0;
>   
>   	rcu_read_lock();
>   	head = first_map_page(mapping, &xas, end_pgoff);
> @@ -3267,7 +3268,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
>   
>   	if (filemap_map_pmd(vmf, head)) {
>   		ret = VM_FAULT_NOPAGE;
> -		goto out;
> +		goto put;
>   	}
>   
>   	addr = vma->vm_start + ((start_pgoff - vma->vm_pgoff) << PAGE_SHIFT);
> @@ -3292,6 +3293,7 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
>   			ret = VM_FAULT_NOPAGE;
>   
>   		do_set_pte(vmf, page, addr);
> +		nr_get++;
>   		/* no need to invalidate: a not-present page won't be cached */
>   		update_mmu_cache(vma, addr, vmf->pte);
>   		unlock_page(head);
> @@ -3301,6 +3303,9 @@ vm_fault_t filemap_map_pages(struct vm_fault *vmf,
>   		put_page(head);
>   	} while ((head = next_map_page(mapping, &xas, end_pgoff)) != NULL);
>   	pte_unmap_unlock(vmf->pte, vmf->ptl);
> +	pte_get_many(vmf->pmd, nr_get);
> +put:
> +	pte_put(vma->vm_mm, vmf->pmd, addr);
>   out:
>   	rcu_read_unlock();
>   	WRITE_ONCE(file->f_ra.mmap_miss, mmap_miss);
> diff --git a/mm/gup.c b/mm/gup.c
> index 42b8b1fa6521..3e2a153cb18e 100644
> --- a/mm/gup.c
> +++ b/mm/gup.c
> @@ -498,10 +498,14 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>   	if (WARN_ON_ONCE((flags & (FOLL_PIN | FOLL_GET)) ==
>   			 (FOLL_PIN | FOLL_GET)))
>   		return ERR_PTR(-EINVAL);
> +
>   retry:
>   	if (unlikely(pmd_bad(*pmd)))
>   		return no_page_table(vma, flags);
>   
> +	if (!pte_try_get(mm, pmd))
> +		return no_page_table(vma, flags);
> +
>   	ptep = pte_offset_map_lock(mm, pmd, address, &ptl);
>   	pte = *ptep;
>   	if (!pte_present(pte)) {
> @@ -519,6 +523,7 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>   		if (!is_migration_entry(entry))
>   			goto no_page;
>   		pte_unmap_unlock(ptep, ptl);
> +		pte_put(mm, pmd, address);
>   		migration_entry_wait(mm, pmd, address);
>   		goto retry;
>   	}
> @@ -526,6 +531,7 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>   		goto no_page;
>   	if ((flags & FOLL_WRITE) && !can_follow_write_pte(pte, flags)) {
>   		pte_unmap_unlock(ptep, ptl);
> +		pte_put(mm, pmd, address);
>   		return NULL;
>   	}
>   
> @@ -614,9 +620,11 @@ static struct page *follow_page_pte(struct vm_area_struct *vma,
>   	}
>   out:
>   	pte_unmap_unlock(ptep, ptl);
> +	pte_put(mm, pmd, address);
>   	return page;
>   no_page:
>   	pte_unmap_unlock(ptep, ptl);
> +	pte_put(mm, pmd, address);
>   	if (!pte_none(pte))
>   		return NULL;
>   	return no_page_table(vma, flags);
> @@ -713,7 +721,7 @@ static struct page *follow_pmd_mask(struct vm_area_struct *vma,
>   		} else {
>   			spin_unlock(ptl);
>   			split_huge_pmd(vma, pmd, address);
> -			ret = pte_alloc(mm, pmd) ? -ENOMEM : 0;
> +			ret = pte_alloc_get(mm, pmd) ? -ENOMEM : 0;
>   		}
>   
>   		return ret ? ERR_PTR(ret) :
> diff --git a/mm/hmm.c b/mm/hmm.c
> index fad6be2bf072..29bb379510cc 100644
> --- a/mm/hmm.c
> +++ b/mm/hmm.c
> @@ -380,6 +380,9 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
>   		return hmm_pfns_fill(start, end, range, HMM_PFN_ERROR);
>   	}
>   
> +	if (!pte_try_get(walk->mm, pmdp))
> +		goto again;
> +
>   	ptep = pte_offset_map(pmdp, addr);
>   	for (; addr < end; addr += PAGE_SIZE, ptep++, hmm_pfns++) {
>   		int r;
> @@ -391,6 +394,7 @@ static int hmm_vma_walk_pmd(pmd_t *pmdp,
>   		}
>   	}
>   	pte_unmap(ptep - 1);
> +	pte_put(walk->mm, pmdp, start);
>   	return 0;
>   }
>   
> diff --git a/mm/internal.h b/mm/internal.h
> index 31ff935b2547..642a7e0af740 100644
> --- a/mm/internal.h
> +++ b/mm/internal.h
> @@ -11,6 +11,7 @@
>   #include <linux/mm.h>
>   #include <linux/pagemap.h>
>   #include <linux/tracepoint-defs.h>
> +#include <linux/pte_ref.h>
>   
>   /*
>    * The set of flags that only affect watermark checking and reclaim
> @@ -441,6 +442,7 @@ static inline struct file *maybe_unlock_mmap_for_io(struct vm_fault *vmf,
>   	if (fault_flag_allow_retry_first(flags) &&
>   	    !(flags & FAULT_FLAG_RETRY_NOWAIT)) {
>   		fpin = get_file(vmf->vma->vm_file);
> +		pte_put_vmf(vmf);
>   		mmap_read_unlock(vmf->vma->vm_mm);
>   	}
>   	return fpin;
> diff --git a/mm/khugepaged.c b/mm/khugepaged.c
> index b0412be08fa2..e6c4d1b7a12a 100644
> --- a/mm/khugepaged.c
> +++ b/mm/khugepaged.c
> @@ -741,6 +741,7 @@ static void __collapse_huge_page_copy(pte_t *pte, struct page *page,
>   {
>   	struct page *src_page, *tmp;
>   	pte_t *_pte;
> +
>   	for (_pte = pte; _pte < pte + HPAGE_PMD_NR;
>   				_pte++, page++, address += PAGE_SIZE) {
>   		pte_t pteval = *_pte;
> @@ -1239,6 +1240,10 @@ static int khugepaged_scan_pmd(struct mm_struct *mm,
>   		goto out;
>   	}
>   
> +	if (!pte_try_get(mm, pmd)) {
> +		result = SCAN_PMD_NULL;
> +		goto out;
> +	}
>   	memset(khugepaged_node_load, 0, sizeof(khugepaged_node_load));
>   	pte = pte_offset_map_lock(mm, pmd, address, &ptl);
>   	for (_address = address, _pte = pte; _pte < pte+HPAGE_PMD_NR;
> @@ -1361,6 +1366,7 @@ static int khugepaged_scan_pmd(struct mm_struct *mm,
>   	}
>   out_unmap:
>   	pte_unmap_unlock(pte, ptl);
> +	pte_put(mm, pmd, address);
>   	if (ret) {
>   		node = khugepaged_find_target_node();
>   		/* collapse_huge_page will return with the mmap_lock released */
> @@ -1463,6 +1469,8 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
>   	if (!pmd)
>   		goto drop_hpage;
>   
> +	if (!pte_try_get(mm, pmd))
> +		goto drop_hpage;
>   	start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
>   
>   	/* step 1: check all mapped PTEs are to the right huge page */
> @@ -1501,6 +1509,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
>   	}
>   
>   	pte_unmap_unlock(start_pte, ptl);
> +	pte_put(mm, pmd, haddr);
>   
>   	/* step 3: set proper refcount and mm_counters. */
>   	if (count) {
> @@ -1522,6 +1531,7 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
>   
>   abort:
>   	pte_unmap_unlock(start_pte, ptl);
> +	pte_put(mm, pmd, haddr);
>   	goto drop_hpage;
>   }
>   
> diff --git a/mm/ksm.c b/mm/ksm.c
> index 3fa9bc8a67cf..2e106f58dad0 100644
> --- a/mm/ksm.c
> +++ b/mm/ksm.c
> @@ -1133,6 +1133,9 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
>   	if (!pmd)
>   		goto out;
>   
> +	if (!pte_try_get(mm, pmd))
> +		goto out;
> +
>   	mmu_notifier_range_init(&range, MMU_NOTIFY_CLEAR, 0, vma, mm, addr,
>   				addr + PAGE_SIZE);
>   	mmu_notifier_invalidate_range_start(&range);
> @@ -1182,6 +1185,7 @@ static int replace_page(struct vm_area_struct *vma, struct page *page,
>   	err = 0;
>   out_mn:
>   	mmu_notifier_invalidate_range_end(&range);
> +	pte_put(mm, pmd, addr);
>   out:
>   	return err;
>   }
> diff --git a/mm/madvise.c b/mm/madvise.c
> index 012129fbfaf8..4c4b35292212 100644
> --- a/mm/madvise.c
> +++ b/mm/madvise.c
> @@ -191,7 +191,9 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
>   	struct vm_area_struct *vma = walk->private;
>   	unsigned long index;
>   
> -	if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +	    pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> +	    !pte_try_get(vma->vm_mm, pmd))
>   		return 0;
>   
>   	for (index = start; index != end; index += PAGE_SIZE) {
> @@ -215,6 +217,7 @@ static int swapin_walk_pmd_entry(pmd_t *pmd, unsigned long start,
>   		if (page)
>   			put_page(page);
>   	}
> +	pte_put(vma->vm_mm, pmd, start);
>   
>   	return 0;
>   }
> @@ -318,6 +321,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
>   	spinlock_t *ptl;
>   	struct page *page = NULL;
>   	LIST_HEAD(page_list);
> +	unsigned long start = addr;
>   
>   	if (fatal_signal_pending(current))
>   		return -EINTR;
> @@ -389,9 +393,11 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
>   	}
>   
>   regular_page:
> -	if (pmd_trans_unstable(pmd))
> +	if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
>   		return 0;
>   #endif
> +	if (!pte_try_get(vma->vm_mm, pmd))
> +		return 0;
>   	tlb_change_page_size(tlb, PAGE_SIZE);
>   	orig_pte = pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
>   	flush_tlb_batched_pending(mm);
> @@ -471,6 +477,7 @@ static int madvise_cold_or_pageout_pte_range(pmd_t *pmd,
>   
>   	arch_leave_lazy_mmu_mode();
>   	pte_unmap_unlock(orig_pte, ptl);
> +	pte_put(vma->vm_mm, pmd, start);
>   	if (pageout)
>   		reclaim_pages(&page_list);
>   	cond_resched();
> @@ -580,14 +587,18 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
>   	struct page *page;
>   	int nr_swap = 0;
>   	unsigned long next;
> +	unsigned int nr_put = 0;
> +	unsigned long start = addr;
>   
>   	next = pmd_addr_end(addr, end);
>   	if (pmd_trans_huge(*pmd))
>   		if (madvise_free_huge_pmd(tlb, vma, pmd, addr, next))
>   			goto next;
>   
> -	if (pmd_trans_unstable(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +	     pmd_trans_unstable(pmd)) || !pte_try_get(mm, pmd))
>   		return 0;
> +	nr_put++;
>   
>   	tlb_change_page_size(tlb, PAGE_SIZE);
>   	orig_pte = pte = pte_offset_map_lock(mm, pmd, addr, &ptl);
> @@ -612,6 +623,7 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
>   			nr_swap--;
>   			free_swap_and_cache(entry);
>   			pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> +			nr_put++;
>   			continue;
>   		}
>   
> @@ -696,6 +708,8 @@ static int madvise_free_pte_range(pmd_t *pmd, unsigned long addr,
>   	}
>   	arch_leave_lazy_mmu_mode();
>   	pte_unmap_unlock(orig_pte, ptl);
> +	if (nr_put)
> +		pte_put_many(mm, pmd, start, nr_put);
>   	cond_resched();
>   next:
>   	return 0;
> diff --git a/mm/memcontrol.c b/mm/memcontrol.c
> index ae1f5d0cb581..4f19e5f2cd18 100644
> --- a/mm/memcontrol.c
> +++ b/mm/memcontrol.c
> @@ -5819,6 +5819,7 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
>   	struct vm_area_struct *vma = walk->vma;
>   	pte_t *pte;
>   	spinlock_t *ptl;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -5833,13 +5834,15 @@ static int mem_cgroup_count_precharge_pte_range(pmd_t *pmd,
>   		return 0;
>   	}
>   
> -	if (pmd_trans_unstable(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> +	    !pte_try_get(vma->vm_mm, pmd))
>   		return 0;
>   	pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
>   	for (; addr != end; pte++, addr += PAGE_SIZE)
>   		if (get_mctgt_type(vma, addr, *pte, NULL))
>   			mc.precharge++;	/* increment precharge temporarily */
>   	pte_unmap_unlock(pte - 1, ptl);
> +	pte_put(vma->vm_mm, pmd, start);
>   	cond_resched();
>   
>   	return 0;
> @@ -6019,6 +6022,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
>   	enum mc_target_type target_type;
>   	union mc_target target;
>   	struct page *page;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -6051,9 +6055,11 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
>   		return 0;
>   	}
>   
> -	if (pmd_trans_unstable(pmd))
> +	if (!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd))
>   		return 0;
>   retry:
> +	if (!pte_try_get(vma->vm_mm, pmd))
> +		return 0;
>   	pte = pte_offset_map_lock(vma->vm_mm, pmd, addr, &ptl);
>   	for (; addr != end; addr += PAGE_SIZE) {
>   		pte_t ptent = *(pte++);
> @@ -6104,6 +6110,7 @@ static int mem_cgroup_move_charge_pte_range(pmd_t *pmd,
>   		}
>   	}
>   	pte_unmap_unlock(pte - 1, ptl);
> +	pte_put(vma->vm_mm, pmd, start);
>   	cond_resched();
>   
>   	if (addr != end) {
> diff --git a/mm/memory.c b/mm/memory.c
> index 3bf2636413ee..242ed135bde4 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -219,6 +219,17 @@ static void check_sync_rss_stat(struct task_struct *task)
>   
>   #endif /* SPLIT_RSS_COUNTING */
>   
> +#ifdef CONFIG_FREE_USER_PTE
> +static inline void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
> +				  unsigned long addr)
> +{
> +	/*
> +	 * We should never reach here since the PTE page tables are
> +	 * dynamically freed.
> +	 */
> +	BUG();
> +}
> +#else
>   /*
>    * Note: this doesn't free the actual pages themselves. That
>    * has been handled earlier when unmapping all the memory regions.
> @@ -231,6 +242,7 @@ static void free_pte_range(struct mmu_gather *tlb, pmd_t *pmd,
>   	pte_free_tlb(tlb, token, addr);
>   	mm_dec_nr_ptes(tlb->mm);
>   }
> +#endif /* CONFIG_FREE_USER_PTE */
>   
>   static inline void free_pmd_range(struct mmu_gather *tlb, pud_t *pud,
>   				unsigned long addr, unsigned long end,
> @@ -433,44 +445,6 @@ void free_pgtables(struct mmu_gather *tlb, struct vm_area_struct *vma,
>   	}
>   }
>   
> -void pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> -{
> -	spinlock_t *ptl = pmd_lock(mm, pmd);
> -
> -	if (likely(pmd_none(*pmd))) {	/* Has another populated it ? */
> -		mm_inc_nr_ptes(mm);
> -		/*
> -		 * Ensure all pte setup (eg. pte page lock and page clearing) are
> -		 * visible before the pte is made visible to other CPUs by being
> -		 * put into page tables.
> -		 *
> -		 * The other side of the story is the pointer chasing in the page
> -		 * table walking code (when walking the page table without locking;
> -		 * ie. most of the time). Fortunately, these data accesses consist
> -		 * of a chain of data-dependent loads, meaning most CPUs (alpha
> -		 * being the notable exception) will already guarantee loads are
> -		 * seen in-order. See the alpha page table accessors for the
> -		 * smp_rmb() barriers in page table walking code.
> -		 */
> -		smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
> -		pmd_populate(mm, pmd, *pte);
> -		*pte = NULL;
> -	}
> -	spin_unlock(ptl);
> -}
> -
> -int __pte_alloc(struct mm_struct *mm, pmd_t *pmd)
> -{
> -	pgtable_t new = pte_alloc_one(mm);
> -	if (!new)
> -		return -ENOMEM;
> -
> -	pte_install(mm, pmd, &new);
> -	if (new)
> -		pte_free(mm, new);
> -	return 0;
> -}
> -
>   int __pte_alloc_kernel(pmd_t *pmd)
>   {
>   	pte_t *new = pte_alloc_one_kernel(&init_mm);
> @@ -479,7 +453,7 @@ int __pte_alloc_kernel(pmd_t *pmd)
>   
>   	spin_lock(&init_mm.page_table_lock);
>   	if (likely(pmd_none(*pmd))) {	/* Has another populated it ? */
> -		smp_wmb(); /* See comment in pte_install() */
> +		smp_wmb(); /* See comment in __pte_install() */
>   		pmd_populate_kernel(&init_mm, pmd, new);
>   		new = NULL;
>   	}
> @@ -860,6 +834,7 @@ copy_nonpresent_pte(struct mm_struct *dst_mm, struct mm_struct *src_mm,
>   	if (!userfaultfd_wp(dst_vma))
>   		pte = pte_swp_clear_uffd_wp(pte);
>   	set_pte_at(dst_mm, addr, dst_pte, pte);
> +	pte_get(pte_to_pmd(dst_pte));
>   	return 0;
>   }
>   
> @@ -928,6 +903,7 @@ copy_present_page(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma
>   		/* Uffd-wp needs to be delivered to dest pte as well */
>   		pte = pte_wrprotect(pte_mkuffd_wp(pte));
>   	set_pte_at(dst_vma->vm_mm, addr, dst_pte, pte);
> +	pte_get(pte_to_pmd(dst_pte));
>   	return 0;
>   }
>   
> @@ -980,6 +956,7 @@ copy_present_pte(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   		pte = pte_clear_uffd_wp(pte);
>   
>   	set_pte_at(dst_vma->vm_mm, addr, dst_pte, pte);
> +	pte_get(pte_to_pmd(dst_pte));
>   	return 0;
>   }
>   
> @@ -1021,7 +998,7 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   	progress = 0;
>   	init_rss_vec(rss);
>   
> -	dst_pte = pte_alloc_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
> +	dst_pte = pte_alloc_get_map_lock(dst_mm, dst_pmd, addr, &dst_ptl);
>   	if (!dst_pte) {
>   		ret = -ENOMEM;
>   		goto out;
> @@ -1109,8 +1086,10 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   		goto out;
>   	} else if (ret ==  -EAGAIN) {
>   		prealloc = page_copy_prealloc(src_mm, src_vma, addr);
> -		if (!prealloc)
> -			return -ENOMEM;
> +		if (!prealloc) {
> +			ret = -ENOMEM;
> +			goto out;
> +		}
>   	} else if (ret) {
>   		VM_WARN_ON_ONCE(1);
>   	}
> @@ -1118,11 +1097,14 @@ copy_pte_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   	/* We've captured and resolved the error. Reset, try again. */
>   	ret = 0;
>   
> -	if (addr != end)
> +	if (addr != end) {
> +		pte_put(dst_mm, dst_pmd, addr);
>   		goto again;
> +	}
>   out:
>   	if (unlikely(prealloc))
>   		put_page(prealloc);
> +	pte_put(dst_mm, dst_pmd, addr);
>   	return ret;
>   }
>   
> @@ -1141,9 +1123,13 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   		return -ENOMEM;
>   	src_pmd = pmd_offset(src_pud, addr);
>   	do {
> +		pmd_t pmdval;
> +
>   		next = pmd_addr_end(addr, end);
> -		if (is_swap_pmd(*src_pmd) || pmd_trans_huge(*src_pmd)
> -			|| pmd_devmap(*src_pmd)) {
> +retry:
> +		pmdval = READ_ONCE(*src_pmd);
> +		if (is_swap_pmd(pmdval) || pmd_trans_huge(pmdval)
> +			|| pmd_devmap(pmdval)) {
>   			int err;
>   			VM_BUG_ON_VMA(next-addr != HPAGE_PMD_SIZE, src_vma);
>   			err = copy_huge_pmd(dst_mm, src_mm, dst_pmd, src_pmd,
> @@ -1156,9 +1142,15 @@ copy_pmd_range(struct vm_area_struct *dst_vma, struct vm_area_struct *src_vma,
>   		}
>   		if (pmd_none_or_clear_bad(src_pmd))
>   			continue;
> +
> +		if (!pte_try_get(src_mm, src_pmd))
> +			goto retry;
>   		if (copy_pte_range(dst_vma, src_vma, dst_pmd, src_pmd,
> -				   addr, next))
> +				   addr, next)) {
> +			pte_put(src_mm, src_pmd, addr);
>   			return -ENOMEM;
> +		}
> +		pte_put(src_mm, src_pmd, addr);
>   	} while (dst_pmd++, src_pmd++, addr = next, addr != end);
>   	return 0;
>   }
> @@ -1316,6 +1308,8 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>   	pte_t *start_pte;
>   	pte_t *pte;
>   	swp_entry_t entry;
> +	unsigned int nr_put = 0;
> +	unsigned long start = addr;
>   
>   	tlb_change_page_size(tlb, PAGE_SIZE);
>   again:
> @@ -1348,6 +1342,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>   			}
>   			ptent = ptep_get_and_clear_full(mm, addr, pte,
>   							tlb->fullmm);
> +			nr_put++;
>   			tlb_remove_tlb_entry(tlb, pte, addr);
>   			if (unlikely(!page))
>   				continue;
> @@ -1390,6 +1385,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>   			}
>   
>   			pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> +			nr_put++;
>   			rss[mm_counter(page)]--;
>   
>   			if (is_device_private_entry(entry))
> @@ -1414,6 +1410,7 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>   		if (unlikely(!free_swap_and_cache(entry)))
>   			print_bad_pte(vma, addr, ptent, NULL);
>   		pte_clear_not_present_full(mm, addr, pte, tlb->fullmm);
> +		nr_put++;
>   	} while (pte++, addr += PAGE_SIZE, addr != end);
>   
>   	add_mm_rss_vec(mm, rss);
> @@ -1440,6 +1437,9 @@ static unsigned long zap_pte_range(struct mmu_gather *tlb,
>   		goto again;
>   	}
>   
> +	if (nr_put)
> +		pte_put_many(mm, pmd, start, nr_put);
> +
>   	return addr;
>   }
>   
> @@ -1479,9 +1479,13 @@ static inline unsigned long zap_pmd_range(struct mmu_gather *tlb,
>   		 * because MADV_DONTNEED holds the mmap_lock in read
>   		 * mode.
>   		 */
> -		if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> +		if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +		     pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> +		     !pte_try_get(tlb->mm, pmd))
>   			goto next;
> +
>   		next = zap_pte_range(tlb, vma, pmd, addr, next, details);
> +		pte_put(tlb->mm, pmd, addr);
>   next:
>   		cond_resched();
>   	} while (pmd++, addr = next, addr != end);
> @@ -1736,7 +1740,7 @@ pte_t *__get_locked_pte(struct mm_struct *mm, unsigned long addr,
>   
>   	if (!pmd)
>   		return NULL;
> -	return pte_alloc_map_lock(mm, pmd, addr, ptl);
> +	return pte_alloc_get_map_lock(mm, pmd, addr, ptl);
>   }
>   
>   static int validate_page_before_insert(struct page *page)
> @@ -1757,6 +1761,7 @@ static int insert_page_into_pte_locked(struct mm_struct *mm, pte_t *pte,
>   	inc_mm_counter_fast(mm, mm_counter_file(page));
>   	page_add_file_rmap(page, false);
>   	set_pte_at(mm, addr, pte, mk_pte(page, prot));
> +	pte_get(pte_to_pmd(pte));
>   	return 0;
>   }
>   
> @@ -1784,6 +1789,7 @@ static int insert_page(struct vm_area_struct *vma, unsigned long addr,
>   		goto out;
>   	retval = insert_page_into_pte_locked(mm, pte, addr, page, prot);
>   	pte_unmap_unlock(pte, ptl);
> +	pte_put(mm, pte_to_pmd(pte), addr);
>   out:
>   	return retval;
>   }
> @@ -1827,7 +1833,7 @@ static int insert_pages(struct vm_area_struct *vma, unsigned long addr,
>   
>   	/* Allocate the PTE if necessary; takes PMD lock once only. */
>   	ret = -ENOMEM;
> -	if (pte_alloc(mm, pmd))
> +	if (pte_alloc_try_get(mm, pmd) < 0)
>   		goto out;
>   
>   	while (pages_to_write_in_pmd) {
> @@ -1854,6 +1860,7 @@ static int insert_pages(struct vm_area_struct *vma, unsigned long addr,
>   	if (remaining_pages_total)
>   		goto more;
>   	ret = 0;
> +	pte_put(mm, pmd, addr);
>   out:
>   	*num = remaining_pages_total;
>   	return ret;
> @@ -2077,10 +2084,12 @@ static vm_fault_t insert_pfn(struct vm_area_struct *vma, unsigned long addr,
>   	}
>   
>   	set_pte_at(mm, addr, pte, entry);
> +	pte_get(pte_to_pmd(pte));
>   	update_mmu_cache(vma, addr, pte); /* XXX: why not for insert_page? */
>   
>   out_unlock:
>   	pte_unmap_unlock(pte, ptl);
> +	pte_put(mm, pte_to_pmd(pte), addr);
>   	return VM_FAULT_NOPAGE;
>   }
>   
> @@ -2284,8 +2293,10 @@ static int remap_pte_range(struct mm_struct *mm, pmd_t *pmd,
>   	pte_t *pte, *mapped_pte;
>   	spinlock_t *ptl;
>   	int err = 0;
> +	unsigned int nr_get = 0;
> +	unsigned long start_addr = addr;
>   
> -	mapped_pte = pte = pte_alloc_map_lock(mm, pmd, addr, &ptl);
> +	mapped_pte = pte = pte_alloc_get_map_lock(mm, pmd, addr, &ptl);
>   	if (!pte)
>   		return -ENOMEM;
>   	arch_enter_lazy_mmu_mode();
> @@ -2296,10 +2307,13 @@ static int remap_pte_range(struct mm_struct *mm, pmd_t *pmd,
>   			break;
>   		}
>   		set_pte_at(mm, addr, pte, pte_mkspecial(pfn_pte(pfn, prot)));
> +		nr_get++;
>   		pfn++;
>   	} while (pte++, addr += PAGE_SIZE, addr != end);
> +	pte_get_many(pmd, nr_get);
>   	arch_leave_lazy_mmu_mode();
>   	pte_unmap_unlock(mapped_pte, ptl);
> +	pte_put(mm, pmd, start_addr);
>   	return err;
>   }
>   
> @@ -2512,13 +2526,17 @@ static int apply_to_pte_range(struct mm_struct *mm, pmd_t *pmd,
>   	pte_t *pte, *mapped_pte;
>   	int err = 0;
>   	spinlock_t *ptl;
> +	unsigned int nr_put = 0;
> +	unsigned int nr_get = 0;
> +	unsigned long start = addr;
>   
>   	if (create) {
>   		mapped_pte = pte = (mm == &init_mm) ?
>   			pte_alloc_kernel_track(pmd, addr, mask) :
> -			pte_alloc_map_lock(mm, pmd, addr, &ptl);
> +			pte_alloc_get_map_lock(mm, pmd, addr, &ptl);
>   		if (!pte)
>   			return -ENOMEM;
> +		nr_put++;
>   	} else {
>   		mapped_pte = pte = (mm == &init_mm) ?
>   			pte_offset_kernel(pmd, addr) :
> @@ -2531,19 +2549,32 @@ static int apply_to_pte_range(struct mm_struct *mm, pmd_t *pmd,
>   
>   	if (fn) {
>   		do {
> -			if (create || !pte_none(*pte)) {
> +			if (create) {
>   				err = fn(pte++, addr, data);
> -				if (err)
> -					break;
> +				if (IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +					mm != &init_mm && !pte_none(*(pte-1)))
> +					nr_get++;
> +			} else if (!pte_none(*pte)) {
> +				err = fn(pte++, addr, data);
> +				if (IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +					mm != &init_mm && pte_none(*(pte-1)))
> +					nr_put++;
>   			}
> +			if (err)
> +				break;
>   		} while (addr += PAGE_SIZE, addr != end);
>   	}
>   	*mask |= PGTBL_PTE_MODIFIED;
>   
>   	arch_leave_lazy_mmu_mode();
>   
> -	if (mm != &init_mm)
> +	if (mm != &init_mm) {
>   		pte_unmap_unlock(mapped_pte, ptl);
> +		pte_get_many(pmd, nr_get);
> +		if (nr_put)
> +			pte_put_many(mm, pmd, start, nr_put);
> +	}
> +
>   	return err;
>   }
>   
> @@ -2567,6 +2598,7 @@ static int apply_to_pmd_range(struct mm_struct *mm, pud_t *pud,
>   	}
>   	do {
>   		next = pmd_addr_end(addr, end);
> +retry:
>   		if (pmd_none(*pmd) && !create)
>   			continue;
>   		if (WARN_ON_ONCE(pmd_leaf(*pmd)))
> @@ -2576,8 +2608,12 @@ static int apply_to_pmd_range(struct mm_struct *mm, pud_t *pud,
>   				continue;
>   			pmd_clear_bad(pmd);
>   		}
> +		if (!create && !pte_try_get(mm, pmd))
> +			goto retry;
>   		err = apply_to_pte_range(mm, pmd, addr, next,
>   					 fn, data, create, mask);
> +		if (!create)
> +			pte_put(mm, pmd, addr);
>   		if (err)
>   			break;
>   	} while (pmd++, addr = next, addr != end);
> @@ -3726,21 +3762,19 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
>   		return VM_FAULT_SIGBUS;
>   
>   	/*
> -	 * Use pte_alloc() instead of pte_alloc_map().  We can't run
> +	 * Use pte_alloc_try_get() instead of pte_alloc_get_map().  We can't run
>   	 * pte_offset_map() on pmds where a huge pmd might be created
>   	 * from a different thread.
>   	 *
> -	 * pte_alloc_map() is safe to use under mmap_write_lock(mm) or when
> +	 * pte_alloc_get_map() is safe to use under mmap_write_lock(mm) or when
>   	 * parallel threads are excluded by other means.
>   	 *
>   	 * Here we only have mmap_read_lock(mm).
>   	 */
> -	if (pte_alloc(vma->vm_mm, vmf->pmd))
> -		return VM_FAULT_OOM;
> -
> -	/* See comment in handle_pte_fault() */
> -	if (unlikely(pmd_trans_unstable(vmf->pmd)))
> -		return 0;
> +	ret = pte_alloc_try_get(vma->vm_mm, vmf->pmd);
> +	if (ret <= 0)
> +		return ret < 0 ? VM_FAULT_OOM : 0;
> +	ret = 0;
>   
>   	/* Use the zero-page for reads */
>   	if (!(vmf->flags & FAULT_FLAG_WRITE) &&
> @@ -3759,7 +3793,8 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
>   		/* Deliver the page fault to userland, check inside PT lock */
>   		if (userfaultfd_missing(vma)) {
>   			pte_unmap_unlock(vmf->pte, vmf->ptl);
> -			return handle_userfault(vmf, VM_UFFD_MISSING);
> +			ret = handle_userfault(vmf, VM_UFFD_MISSING);
> +			goto put;
>   		}
>   		goto setpte;
>   	}
> @@ -3802,7 +3837,8 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
>   	if (userfaultfd_missing(vma)) {
>   		pte_unmap_unlock(vmf->pte, vmf->ptl);
>   		put_page(page);
> -		return handle_userfault(vmf, VM_UFFD_MISSING);
> +		ret = handle_userfault(vmf, VM_UFFD_MISSING);
> +		goto put;
>   	}
>   
>   	inc_mm_counter_fast(vma->vm_mm, MM_ANONPAGES);
> @@ -3810,19 +3846,23 @@ static vm_fault_t do_anonymous_page(struct vm_fault *vmf)
>   	lru_cache_add_inactive_or_unevictable(page, vma);
>   setpte:
>   	set_pte_at(vma->vm_mm, vmf->address, vmf->pte, entry);
> +	pte_get(vmf->pmd);
>   
>   	/* No need to invalidate - it was non-present before */
>   	update_mmu_cache(vma, vmf->address, vmf->pte);
>   unlock:
>   	pte_unmap_unlock(vmf->pte, vmf->ptl);
> -	return ret;
> +	goto put;
>   release:
>   	put_page(page);
>   	goto unlock;
>   oom_free_page:
>   	put_page(page);
>   oom:
> -	return VM_FAULT_OOM;
> +	ret = VM_FAULT_OOM;
> +put:
> +	pte_put(vma->vm_mm, vmf->pmd, vmf->address);
> +	return ret;
>   }
>   
>   /*
> @@ -3850,7 +3890,7 @@ static vm_fault_t __do_fault(struct vm_fault *vmf)
>   	 *				unlock_page(B)
>   	 *				# flush A, B to clear the writeback
>   	 */
> -	if (pmd_none(*vmf->pmd) && !vmf->prealloc_pte) {
> +	if (!vmf->prealloc_pte) {
>   		vmf->prealloc_pte = pte_alloc_one(vma->vm_mm);
>   		if (!vmf->prealloc_pte)
>   			return VM_FAULT_OOM;
> @@ -4020,6 +4060,7 @@ vm_fault_t finish_fault(struct vm_fault *vmf)
>   			return ret;
>   	}
>   
> +retry:
>   	if (pmd_none(*vmf->pmd)) {
>   		if (PageTransCompound(page)) {
>   			ret = do_set_pmd(vmf, page);
> @@ -4027,27 +4068,33 @@ vm_fault_t finish_fault(struct vm_fault *vmf)
>   				return ret;
>   		}
>   
> -		if (vmf->prealloc_pte)
> -			pte_install(vma->vm_mm, vmf->pmd, &vmf->prealloc_pte);
> -		else if (unlikely(pte_alloc(vma->vm_mm, vmf->pmd)))
> -			return VM_FAULT_OOM;
> -	}
> -
> -	/* See comment in handle_pte_fault() */
> -	if (pmd_devmap_trans_unstable(vmf->pmd))
> +		if (vmf->prealloc_pte) {
> +			if (!pte_install_try_get(vma->vm_mm, vmf->pmd, &vmf->prealloc_pte))
> +				return 0;
> +		} else {
> +			ret = pte_alloc_try_get(vma->vm_mm, vmf->pmd);
> +			if (ret <= 0)
> +				return ret < 0 ? VM_FAULT_OOM : 0;
> +		}
> +	} else if (pmd_devmap_trans_unstable(vmf->pmd)) { /* See comment in handle_pte_fault() */
>   		return 0;
> +	} else if (!pte_try_get(vma->vm_mm, vmf->pmd)) {
> +		goto retry;
> +	}
>   
>   	vmf->pte = pte_offset_map_lock(vma->vm_mm, vmf->pmd,
>   				      vmf->address, &vmf->ptl);
>   	ret = 0;
>   	/* Re-check under ptl */
> -	if (likely(pte_none(*vmf->pte)))
> +	if (likely(pte_none(*vmf->pte))) {
>   		do_set_pte(vmf, page, vmf->address);
> -	else
> +		pte_get(vmf->pmd);
> +	} else
>   		ret = VM_FAULT_NOPAGE;
>   
>   	update_mmu_tlb(vma, vmf->address, vmf->pte);
>   	pte_unmap_unlock(vmf->pte, vmf->ptl);
> +	pte_put(vma->vm_mm, vmf->pmd, vmf->address);
>   	return ret;
>   }
>   
> @@ -4268,9 +4315,14 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
>   		 * If we find a migration pmd entry or a none pmd entry, which
>   		 * should never happen, return SIGBUS
>   		 */
> -		if (unlikely(!pmd_present(*vmf->pmd)))
> +		if (unlikely(!pmd_present(*vmf->pmd))) {
>   			ret = VM_FAULT_SIGBUS;
> -		else {
> +			goto out;
> +		} else {
> +			if (!pte_try_get(vma->vm_mm, vmf->pmd)) {
> +				ret = VM_FAULT_SIGBUS;
> +				goto out;
> +			}
>   			vmf->pte = pte_offset_map_lock(vmf->vma->vm_mm,
>   						       vmf->pmd,
>   						       vmf->address,
> @@ -4288,6 +4340,7 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
>   				ret = VM_FAULT_NOPAGE;
>   
>   			pte_unmap_unlock(vmf->pte, vmf->ptl);
> +			pte_put(vma->vm_mm, vmf->pmd, vmf->address);
>   		}
>   	} else if (!(vmf->flags & FAULT_FLAG_WRITE))
>   		ret = do_read_fault(vmf);
> @@ -4301,6 +4354,7 @@ static vm_fault_t do_fault(struct vm_fault *vmf)
>   		pte_free(vm_mm, vmf->prealloc_pte);
>   		vmf->prealloc_pte = NULL;
>   	}
> +out:
>   	return ret;
>   }
>   
> @@ -4496,11 +4550,13 @@ static vm_fault_t wp_huge_pud(struct vm_fault *vmf, pud_t orig_pud)
>   static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   {
>   	pte_t entry;
> +	vm_fault_t ret;
>   
> -	if (unlikely(pmd_none(*vmf->pmd))) {
> +retry:
> +	if (unlikely(pmd_none(READ_ONCE(*vmf->pmd)))) {
>   		/*
> -		 * Leave __pte_alloc() until later: because vm_ops->fault may
> -		 * want to allocate huge page, and if we expose page table
> +		 * Leave __pte_alloc_try_get() until later: because vm_ops->fault
> +		 * may want to allocate huge page, and if we expose page table
>   		 * for an instant, it will be difficult to retract from
>   		 * concurrent faults and from rmap lookups.
>   		 */
> @@ -4517,9 +4573,18 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   		 * that it is a regular pmd that we can walk with
>   		 * pte_offset_map() and we can do that through an atomic read
>   		 * in C, which is what pmd_trans_unstable() provides.
> +		 *
> +		 * Note: we do this in pte_try_get() when CONFIG_FREE_USER_PTE
>   		 */
>   		if (pmd_devmap_trans_unstable(vmf->pmd))
>   			return 0;
> +
> +		if (!pte_try_get(vmf->vma->vm_mm, vmf->pmd))
> +			goto retry;
> +
> +		if (IS_ENABLED(CONFIG_FREE_USER_PTE))
> +			vmf->flags |= FAULT_FLAG_PTE_GET;
> +
>   		/*
>   		 * A regular pmd is established and it can't morph into a huge
>   		 * pmd from under us anymore at this point because we hold the
> @@ -4541,6 +4606,7 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   		if (pte_none(vmf->orig_pte)) {
>   			pte_unmap(vmf->pte);
>   			vmf->pte = NULL;
> +			pte_put_vmf(vmf);
>   		}
>   	}
>   
> @@ -4551,11 +4617,15 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   			return do_fault(vmf);
>   	}
>   
> -	if (!pte_present(vmf->orig_pte))
> -		return do_swap_page(vmf);
> +	if (!pte_present(vmf->orig_pte)) {
> +		ret = do_swap_page(vmf);
> +		goto put;
> +	}
>   
> -	if (pte_protnone(vmf->orig_pte) && vma_is_accessible(vmf->vma))
> -		return do_numa_page(vmf);
> +	if (pte_protnone(vmf->orig_pte) && vma_is_accessible(vmf->vma)) {
> +		ret = do_numa_page(vmf);
> +		goto put;
> +	}
>   
>   	vmf->ptl = pte_lockptr(vmf->vma->vm_mm, vmf->pmd);
>   	spin_lock(vmf->ptl);
> @@ -4565,8 +4635,10 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   		goto unlock;
>   	}
>   	if (vmf->flags & FAULT_FLAG_WRITE) {
> -		if (!pte_write(entry))
> -			return do_wp_page(vmf);
> +		if (!pte_write(entry)) {
> +			ret = do_wp_page(vmf);
> +			goto put;
> +		}
>   		entry = pte_mkdirty(entry);
>   	}
>   	entry = pte_mkyoung(entry);
> @@ -4588,7 +4660,10 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>   	}
>   unlock:
>   	pte_unmap_unlock(vmf->pte, vmf->ptl);
> -	return 0;
> +	ret = 0;
> +put:
> +	pte_put_vmf(vmf);
> +	return ret;
>   }
>   
>   /*
> @@ -4816,7 +4891,7 @@ int __p4d_alloc(struct mm_struct *mm, pgd_t *pgd, unsigned long address)
>   	if (pgd_present(*pgd))		/* Another has populated it */
>   		p4d_free(mm, new);
>   	else {
> -		smp_wmb(); /* See comment in pte_install() */
> +		smp_wmb(); /* See comment in __pte_install() */
>   		pgd_populate(mm, pgd, new);
>   	}
>   	spin_unlock(&mm->page_table_lock);
> @@ -4838,7 +4913,7 @@ int __pud_alloc(struct mm_struct *mm, p4d_t *p4d, unsigned long address)
>   	spin_lock(&mm->page_table_lock);
>   	if (!p4d_present(*p4d)) {
>   		mm_inc_nr_puds(mm);
> -		smp_wmb(); /* See comment in pte_install() */
> +		smp_wmb(); /* See comment in __pte_install() */
>   		p4d_populate(mm, p4d, new);
>   	} else	/* Another has populated it */
>   		pud_free(mm, new);
> @@ -4862,7 +4937,7 @@ int __pmd_alloc(struct mm_struct *mm, pud_t *pud, unsigned long address)
>   	ptl = pud_lock(mm, pud);
>   	if (!pud_present(*pud)) {
>   		mm_inc_nr_pmds(mm);
> -		smp_wmb(); /* See comment in pte_install() */
> +		smp_wmb(); /* See comment in __pte_install() */
>   		pud_populate(mm, pud, new);
>   	} else	/* Another has populated it */
>   		pmd_free(mm, new);
> @@ -4925,13 +5000,22 @@ int follow_invalidate_pte(struct mm_struct *mm, unsigned long address,
>   					(address & PAGE_MASK) + PAGE_SIZE);
>   		mmu_notifier_invalidate_range_start(range);
>   	}
> +	if (!pte_try_get(mm, pmd))
> +		goto out;
>   	ptep = pte_offset_map_lock(mm, pmd, address, ptlp);
>   	if (!pte_present(*ptep))
>   		goto unlock;
> +	/*
> +	 * when we reach here, it means that the ->pte_refcount is at least
> +	 * one and the contents of the PTE page table are stable until @ptlp is
> +	 * released, so we can put pte safely.
> +	 */
> +	pte_put(mm, pmd, address);
>   	*ptepp = ptep;
>   	return 0;
>   unlock:
>   	pte_unmap_unlock(ptep, *ptlp);
> +	pte_put(mm, pmd, address);
>   	if (range)
>   		mmu_notifier_invalidate_range_end(range);
>   out:
> @@ -5058,6 +5142,7 @@ int generic_access_phys(struct vm_area_struct *vma, unsigned long addr,
>   		return -EINVAL;
>   	pte = *ptep;
>   	pte_unmap_unlock(ptep, ptl);
> +	pte_put(vma->vm_mm, pte_to_pmd(ptep), addr);
>   
>   	prot = pgprot_val(pte_pgprot(pte));
>   	phys_addr = (resource_size_t)pte_pfn(pte) << PAGE_SHIFT;
> diff --git a/mm/mempolicy.c b/mm/mempolicy.c
> index e32360e90274..cbb3640717ff 100644
> --- a/mm/mempolicy.c
> +++ b/mm/mempolicy.c
> @@ -509,6 +509,7 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
>   	bool has_unmovable = false;
>   	pte_t *pte, *mapped_pte;
>   	spinlock_t *ptl;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -518,7 +519,8 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
>   	}
>   	/* THP was split, fall through to pte walk */
>   
> -	if (pmd_trans_unstable(pmd))
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> +	    !pte_try_get(walk->mm, pmd))
>   		return 0;
>   
>   	mapped_pte = pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
> @@ -554,6 +556,7 @@ static int queue_pages_pte_range(pmd_t *pmd, unsigned long addr,
>   			break;
>   	}
>   	pte_unmap_unlock(mapped_pte, ptl);
> +	pte_put(walk->mm, pmd, start);
>   	cond_resched();
>   
>   	if (has_unmovable)
> diff --git a/mm/migrate.c b/mm/migrate.c
> index 23cbd9de030b..6a94e8558b2c 100644
> --- a/mm/migrate.c
> +++ b/mm/migrate.c
> @@ -2265,6 +2265,8 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp,
>   	if (unlikely(pmd_bad(*pmdp)))
>   		return migrate_vma_collect_skip(start, end, walk);
>   
> +	if (!pte_try_get(mm, pmdp))
> +		goto again;
>   	ptep = pte_offset_map_lock(mm, pmdp, addr, &ptl);
>   	arch_enter_lazy_mmu_mode();
>   
> @@ -2386,6 +2388,7 @@ static int migrate_vma_collect_pmd(pmd_t *pmdp,
>   	}
>   	arch_leave_lazy_mmu_mode();
>   	pte_unmap_unlock(ptep - 1, ptl);
> +	pte_put(mm, pmdp, start);
>   
>   	/* Only flush the TLB if we actually modified any entries */
>   	if (unmapped)
> @@ -2793,26 +2796,22 @@ static void migrate_vma_insert_page(struct migrate_vma *migrate,
>   		goto abort;
>   
>   	/*
> -	 * Use pte_alloc() instead of pte_alloc_map().  We can't run
> +	 * Use pte_alloc_try_get() instead of pte_alloc_get_map().  We can't run
>   	 * pte_offset_map() on pmds where a huge pmd might be created
>   	 * from a different thread.
>   	 *
> -	 * pte_alloc_map() is safe to use under mmap_write_lock(mm) or when
> +	 * pte_alloc_get_map() is safe to use under mmap_write_lock(mm) or when
>   	 * parallel threads are excluded by other means.
>   	 *
>   	 * Here we only have mmap_read_lock(mm).
>   	 */
> -	if (pte_alloc(mm, pmdp))
> -		goto abort;
> -
> -	/* See the comment in pte_alloc_one_map() */
> -	if (unlikely(pmd_trans_unstable(pmdp)))
> +	if (pte_alloc_try_get(mm, pmdp) <= 0)
>   		goto abort;
>   
>   	if (unlikely(anon_vma_prepare(vma)))
> -		goto abort;
> +		goto put;
>   	if (mem_cgroup_charge(page, vma->vm_mm, GFP_KERNEL))
> -		goto abort;
> +		goto put;
>   
>   	/*
>   	 * The memory barrier inside __SetPageUptodate makes sure that
> @@ -2881,15 +2880,19 @@ static void migrate_vma_insert_page(struct migrate_vma *migrate,
>   	} else {
>   		/* No need to invalidate - it was non-present before */
>   		set_pte_at(mm, addr, ptep, entry);
> +		pte_get(pmdp);
>   		update_mmu_cache(vma, addr, ptep);
>   	}
>   
>   	pte_unmap_unlock(ptep, ptl);
> +	pte_put(mm, pmdp, addr);
>   	*src = MIGRATE_PFN_MIGRATE;
>   	return;
>   
>   unlock_abort:
>   	pte_unmap_unlock(ptep, ptl);
> +put:
> +	pte_put(mm, pmdp, addr);
>   abort:
>   	*src &= ~MIGRATE_PFN_MIGRATE;
>   }
> diff --git a/mm/mincore.c b/mm/mincore.c
> index 9122676b54d6..e21e271a7657 100644
> --- a/mm/mincore.c
> +++ b/mm/mincore.c
> @@ -18,6 +18,7 @@
>   #include <linux/shmem_fs.h>
>   #include <linux/hugetlb.h>
>   #include <linux/pgtable.h>
> +#include <linux/pte_ref.h>
>   
>   #include <linux/uaccess.h>
>   
> @@ -104,6 +105,7 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   	pte_t *ptep;
>   	unsigned char *vec = walk->private;
>   	int nr = (end - addr) >> PAGE_SHIFT;
> +	unsigned long start = addr;
>   
>   	ptl = pmd_trans_huge_lock(pmd, vma);
>   	if (ptl) {
> @@ -112,7 +114,8 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   		goto out;
>   	}
>   
> -	if (pmd_trans_unstable(pmd)) {
> +	if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) && pmd_trans_unstable(pmd)) ||
> +	    !pte_try_get(walk->mm, pmd)) {
>   		__mincore_unmapped_range(addr, end, vma, vec);
>   		goto out;
>   	}
> @@ -148,6 +151,7 @@ static int mincore_pte_range(pmd_t *pmd, unsigned long addr, unsigned long end,
>   		vec++;
>   	}
>   	pte_unmap_unlock(ptep - 1, ptl);
> +	pte_put(walk->mm, pmd, start);
>   out:
>   	walk->private += nr;
>   	cond_resched();
> diff --git a/mm/mlock.c b/mm/mlock.c
> index 16d2ee160d43..ead14abb016a 100644
> --- a/mm/mlock.c
> +++ b/mm/mlock.c
> @@ -397,6 +397,7 @@ static unsigned long __munlock_pagevec_fill(struct pagevec *pvec,
>   			break;
>   	}
>   	pte_unmap_unlock(pte, ptl);
> +	pte_put(vma->vm_mm, pte_to_pmd(pte), start);
>   	return start;
>   }
>   
> diff --git a/mm/mprotect.c b/mm/mprotect.c
> index 4cb240fd9936..9cbd0848c5c5 100644
> --- a/mm/mprotect.c
> +++ b/mm/mprotect.c
> @@ -274,9 +274,12 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
>   	pmd = pmd_offset(pud, addr);
>   	do {
>   		unsigned long this_pages;
> +		pmd_t pmdval;
>   
>   		next = pmd_addr_end(addr, end);
>   
> +retry:
> +		pmdval = READ_ONCE(*pmd);
>   		/*
>   		 * Automatic NUMA balancing walks the tables with mmap_lock
>   		 * held for read. It's possible a parallel update to occur
> @@ -285,7 +288,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
>   		 * Hence, it's necessary to atomically read the PMD value
>   		 * for all the checks.
>   		 */
> -		if (!is_swap_pmd(*pmd) && !pmd_devmap(*pmd) &&
> +		if (!is_swap_pmd(pmdval) && !pmd_devmap(pmdval) &&
>   		     pmd_none_or_clear_bad_unless_trans_huge(pmd))
>   			goto next;
>   
> @@ -297,7 +300,7 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
>   			mmu_notifier_invalidate_range_start(&range);
>   		}
>   
> -		if (is_swap_pmd(*pmd) || pmd_trans_huge(*pmd) || pmd_devmap(*pmd)) {
> +		if (is_swap_pmd(pmdval) || pmd_trans_huge(pmdval) || pmd_devmap(pmdval)) {
>   			if (next - addr != HPAGE_PMD_SIZE) {
>   				__split_huge_pmd(vma, pmd, addr, false, NULL);
>   			} else {
> @@ -316,8 +319,11 @@ static inline unsigned long change_pmd_range(struct vm_area_struct *vma,
>   			}
>   			/* fall through, the trans huge pmd just split */
>   		}
> +		if (!pte_try_get(vma->vm_mm, pmd))
> +			goto retry;
>   		this_pages = change_pte_range(vma, pmd, addr, next, newprot,
>   					      cp_flags);
> +		pte_put(vma->vm_mm, pmd, addr);
>   		pages += this_pages;
>   next:
>   		cond_resched();
> diff --git a/mm/mremap.c b/mm/mremap.c
> index 5989d3990020..776c6ea7bd06 100644
> --- a/mm/mremap.c
> +++ b/mm/mremap.c
> @@ -141,6 +141,9 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
>   	spinlock_t *old_ptl, *new_ptl;
>   	bool force_flush = false;
>   	unsigned long len = old_end - old_addr;
> +	unsigned long old_start = old_addr;
> +	unsigned int nr_put = 0;
> +	unsigned int nr_get = 0;
>   
>   	/*
>   	 * When need_rmap_locks is true, we take the i_mmap_rwsem and anon_vma
> @@ -181,6 +184,7 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
>   			continue;
>   
>   		pte = ptep_get_and_clear(mm, old_addr, old_pte);
> +		nr_put++;
>   		/*
>   		 * If we are remapping a valid PTE, make sure
>   		 * to flush TLB before we drop the PTL for the
> @@ -197,7 +201,9 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
>   		pte = move_pte(pte, new_vma->vm_page_prot, old_addr, new_addr);
>   		pte = move_soft_dirty_pte(pte);
>   		set_pte_at(mm, new_addr, new_pte, pte);
> +		nr_get++;
>   	}
> +	pte_get_many(new_pmd, nr_get);
>   
>   	arch_leave_lazy_mmu_mode();
>   	if (force_flush)
> @@ -206,6 +212,8 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
>   		spin_unlock(new_ptl);
>   	pte_unmap(new_pte - 1);
>   	pte_unmap_unlock(old_pte - 1, old_ptl);
> +	if (nr_put)
> +		pte_put_many(mm, old_pmd, old_start, nr_put);
>   	if (need_rmap_locks)
>   		drop_rmap_locks(vma);
>   }
> @@ -271,6 +279,7 @@ static bool move_normal_pmd(struct vm_area_struct *vma, unsigned long old_addr,
>   	VM_BUG_ON(!pmd_none(*new_pmd));
>   
>   	pmd_populate(mm, new_pmd, pmd_pgtable(pmd));
> +	pte_migrate_pmd(pmd, new_pmd);
>   	flush_tlb_range(vma, old_addr, old_addr + PMD_SIZE);
>   	if (new_ptl != old_ptl)
>   		spin_unlock(new_ptl);
> @@ -548,10 +557,11 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
>   				continue;
>   		}
>   
> -		if (pte_alloc(new_vma->vm_mm, new_pmd))
> +		if (pte_alloc_get(new_vma->vm_mm, new_pmd))
>   			break;
>   		move_ptes(vma, old_pmd, old_addr, old_addr + extent, new_vma,
>   			  new_pmd, new_addr, need_rmap_locks);
> +		pte_put(new_vma->vm_mm, new_pmd, new_addr);
>   	}
>   
>   	mmu_notifier_invalidate_range_end(&range);
> diff --git a/mm/page_vma_mapped.c b/mm/page_vma_mapped.c
> index f7b331081791..eb84fa5825c0 100644
> --- a/mm/page_vma_mapped.c
> +++ b/mm/page_vma_mapped.c
> @@ -211,6 +211,7 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
>   		}
>   
>   		pvmw->pmd = pmd_offset(pud, pvmw->address);
> +retry:
>   		/*
>   		 * Make sure the pmd value isn't cached in a register by the
>   		 * compiler and used as a stale value after we've observed a
> @@ -258,6 +259,8 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
>   			step_forward(pvmw, PMD_SIZE);
>   			continue;
>   		}
> +		if (!pte_try_get(pvmw->vma->vm_mm, pvmw->pmd))
> +			goto retry;
>   		if (!map_pte(pvmw))
>   			goto next_pte;
>   this_pte:
> @@ -275,6 +278,7 @@ bool page_vma_mapped_walk(struct page_vma_mapped_walk *pvmw)
>   					pvmw->ptl = NULL;
>   				}
>   				pte_unmap(pvmw->pte);
> +				pte_put(pvmw->vma->vm_mm, pvmw->pmd, pvmw->address);
>   				pvmw->pte = NULL;
>   				goto restart;
>   			}
> diff --git a/mm/pagewalk.c b/mm/pagewalk.c
> index 9b3db11a4d1d..4080a88d7852 100644
> --- a/mm/pagewalk.c
> +++ b/mm/pagewalk.c
> @@ -3,6 +3,7 @@
>   #include <linux/highmem.h>
>   #include <linux/sched.h>
>   #include <linux/hugetlb.h>
> +#include <linux/pte_ref.h>
>   
>   /*
>    * We want to know the real level where a entry is located ignoring any
> @@ -108,9 +109,9 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
>   
>   	pmd = pmd_offset(pud, addr);
>   	do {
> -again:
>   		next = pmd_addr_end(addr, end);
> -		if (pmd_none(*pmd) || (!walk->vma && !walk->no_vma)) {
> +again:
> +		if (pmd_none(READ_ONCE(*pmd)) || (!walk->vma && !walk->no_vma)) {
>   			if (ops->pte_hole)
>   				err = ops->pte_hole(addr, next, depth, walk);
>   			if (err)
> @@ -147,10 +148,18 @@ static int walk_pmd_range(pud_t *pud, unsigned long addr, unsigned long end,
>   				goto again;
>   		}
>   
> -		if (is_hugepd(__hugepd(pmd_val(*pmd))))
> +		if (is_hugepd(__hugepd(pmd_val(*pmd)))) {
>   			err = walk_hugepd_range((hugepd_t *)pmd, addr, next, walk, PMD_SHIFT);
> -		else
> -			err = walk_pte_range(pmd, addr, next, walk);
> +		} else {
> +			if (!walk->no_vma) {
> +				if (!pte_try_get(walk->mm, pmd))
> +					goto again;
> +				err = walk_pte_range(pmd, addr, next, walk);
> +				pte_put(walk->mm, pmd, addr);
> +			} else {
> +				err = walk_pte_range(pmd, addr, next, walk);
> +			}
> +		}
>   		if (err)
>   			break;
>   	} while (pmd++, addr = next, addr != end);
> diff --git a/mm/pgtable-generic.c b/mm/pgtable-generic.c
> index 4e640baf9794..f935779a0967 100644
> --- a/mm/pgtable-generic.c
> +++ b/mm/pgtable-generic.c
> @@ -11,6 +11,7 @@
>   #include <linux/hugetlb.h>
>   #include <linux/pgtable.h>
>   #include <asm/tlb.h>
> +#include <linux/pte_ref.h>
>   
>   /*
>    * If a p?d_bad entry is found while walking page tables, report
> @@ -186,6 +187,7 @@ pgtable_t pgtable_trans_huge_withdraw(struct mm_struct *mm, pmd_t *pmdp)
>   							  struct page, lru);
>   	if (pmd_huge_pte(mm, pmdp))
>   		list_del(&pgtable->lru);
> +	pte_ref_init(pgtable, pmdp, HPAGE_PMD_NR);
>   	return pgtable;
>   }
>   #endif
> diff --git a/mm/pte_ref.c b/mm/pte_ref.c
> new file mode 100644
> index 000000000000..1b8d9828d513
> --- /dev/null
> +++ b/mm/pte_ref.c
> @@ -0,0 +1,132 @@
> +// SPDX-License-Identifier: GPL-2.0
> +/*
> + * Free user PTE page table pages
> + *
> + * Copyright (c) 2021, ByteDance. All rights reserved.
> + *
> + * 	Author: Qi Zheng <zhengqi.arch@...edance.com>
> + */
> +
> +#include <linux/pte_ref.h>
> +#include <linux/hugetlb.h>
> +#include <asm/tlbflush.h>
> +
> +#ifdef CONFIG_DEBUG_VM
> +static void pte_free_debug(pmd_t pmd)
> +{
> +	pte_t *ptep = (pte_t *)pmd_page_vaddr(pmd);
> +	int i = 0;
> +
> +	for (i = 0; i < PTRS_PER_PTE; i++)
> +		BUG_ON(!pte_none(*ptep++));
> +}
> +#else
> +static inline void pte_free_debug(pmd_t pmd)
> +{
> +}
> +#endif
> +
> +void free_pte_table(struct mm_struct *mm, pmd_t *pmdp, unsigned long addr)
> +{
> +	struct vm_area_struct vma = TLB_FLUSH_VMA(mm, 0);
> +	spinlock_t *ptl;
> +	pmd_t pmd;
> +
> +	ptl = pmd_lock(mm, pmdp);
> +	pmd = pmdp_huge_get_and_clear(mm, addr, pmdp);
> +	spin_unlock(ptl);
> +
> +	pte_free_debug(pmd);
> +	flush_tlb_range(&vma, addr, addr + PMD_SIZE);
> +	mm_dec_nr_ptes(mm);
> +	pte_free(mm, pmd_pgtable(pmd));
> +}
> +
> +static inline void __pte_install(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> +	mm_inc_nr_ptes(mm);
> +	/*
> +	 * Ensure all pte setup (eg. pte page lock and page clearing) are
> +	 * visible before the pte is made visible to other CPUs by being
> +	 * put into page tables.
> +	 *
> +	 * The other side of the story is the pointer chasing in the page
> +	 * table walking code (when walking the page table without locking;
> +	 * ie. most of the time). Fortunately, these data accesses consist
> +	 * of a chain of data-dependent loads, meaning most CPUs (alpha
> +	 * being the notable exception) will already guarantee loads are
> +	 * seen in-order. See the alpha page table accessors for the
> +	 * smp_rmb() barriers in page table walking code.
> +	 */
> +	smp_wmb(); /* Could be smp_wmb__xxx(before|after)_spin_lock */
> +	pmd_populate(mm, pmd, *pte);
> +	pte_ref_init(*pte, pmd, 1);
> +	*pte = NULL;
> +}
> +
> +/*
> + * returns true if the pmd has been populated with PTE page table,
> + * or false for all other cases.
> + */
> +bool pte_install_try_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> +	spinlock_t *ptl;
> +	bool retval = true;
> +
> +retry:
> +	ptl = pmd_lock(mm, pmd);
> +	if (likely(pmd_none(*pmd))) {
> +		__pte_install(mm, pmd, pte);
> +	} else if (pmd_leaf(*pmd) || !pmd_present(*pmd)) {
> +		retval = false;
> +	} else if (!pte_get_unless_zero(pmd)) {
> +		spin_unlock(ptl);
> +		goto retry;
> +	}
> +	spin_unlock(ptl);
> +	return retval;
> +}
> +

Can pte_get_unless_zero() return true above? Can the pmd have been by populated by others? In that case the ref count is wrongly incremented.



> +static void pte_install_get(struct mm_struct *mm, pmd_t *pmd, pgtable_t *pte)
> +{
> +	spinlock_t *ptl;
> +
> +retry:
> +	ptl = pmd_lock(mm, pmd);
> +	if (likely(pmd_none(*pmd))) {
> +		__pte_install(mm, pmd, pte);
> +	} else if (!pte_get_unless_zero(pmd)) {
> +		spin_unlock(ptl);
> +		goto retry;
> +	}
> +	spin_unlock(ptl);
> +}
> +
> +/*
> + * returns -ENOMEM if memory allocation failed, or 1 if the pmd
> + * has been populated with PTE page table, or 0 for all other cases.
> + */
> +int __pte_alloc_try_get(struct mm_struct *mm, pmd_t *pmd)
> +{
> +	int retval;
> +	pgtable_t new = pte_alloc_one(mm);
> +	if (!new)
> +		return -ENOMEM;
> +
> +	retval = pte_install_try_get(mm, pmd, &new);
> +	if (new)
> +		pte_free(mm, new);
> +	return retval;
> +}
> +
> +int __pte_alloc_get(struct mm_struct *mm, pmd_t *pmd)
> +{
> +	pgtable_t new = pte_alloc_one(mm);
> +	if (!new)
> +		return -ENOMEM;
> +
> +	pte_install_get(mm, pmd, &new);
> +	if (new)
> +		pte_free(mm, new);
> +	return 0;
> +}
> diff --git a/mm/rmap.c b/mm/rmap.c
> index fed7c4df25f2..8c10dbca02d4 100644
> --- a/mm/rmap.c
> +++ b/mm/rmap.c
> @@ -1402,6 +1402,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   	bool ret = true;
>   	struct mmu_notifier_range range;
>   	enum ttu_flags flags = (enum ttu_flags)(long)arg;
> +	unsigned int nr_put = 0;
>   
>   	/*
>   	 * When racing against e.g. zap_pte_range() on another cpu,
> @@ -1551,6 +1552,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   			/* We have to invalidate as we cleared the pte */
>   			mmu_notifier_invalidate_range(mm, address,
>   						      address + PAGE_SIZE);
> +			nr_put++;
>   		} else if (PageAnon(page)) {
>   			swp_entry_t entry = { .val = page_private(subpage) };
>   			pte_t swp_pte;
> @@ -1564,6 +1566,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   				/* We have to invalidate as we cleared the pte */
>   				mmu_notifier_invalidate_range(mm, address,
>   							address + PAGE_SIZE);
> +				nr_put++;
>   				page_vma_mapped_walk_done(&pvmw);
>   				break;
>   			}
> @@ -1575,6 +1578,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   					mmu_notifier_invalidate_range(mm,
>   						address, address + PAGE_SIZE);
>   					dec_mm_counter(mm, MM_ANONPAGES);
> +					nr_put++;
>   					goto discard;
>   				}
>   
> @@ -1630,6 +1634,7 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   			 * See Documentation/vm/mmu_notifier.rst
>   			 */
>   			dec_mm_counter(mm, mm_counter_file(page));
> +			nr_put++;
>   		}
>   discard:
>   		/*
> @@ -1644,7 +1649,8 @@ static bool try_to_unmap_one(struct page *page, struct vm_area_struct *vma,
>   	}
>   
>   	mmu_notifier_invalidate_range_end(&range);
> -
> +	if (nr_put)
> +		pte_put_many(mm, pvmw.pmd, address, nr_put);
>   	return ret;
>   }
>   
> @@ -1705,6 +1711,7 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
>   	bool ret = true;
>   	struct mmu_notifier_range range;
>   	enum ttu_flags flags = (enum ttu_flags)(long)arg;
> +	unsigned int nr_put = 0;
>   
>   	if (is_zone_device_page(page) && !is_device_private_page(page))
>   		return true;
> @@ -1871,6 +1878,7 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
>   			/* We have to invalidate as we cleared the pte */
>   			mmu_notifier_invalidate_range(mm, address,
>   						      address + PAGE_SIZE);
> +			nr_put++;
>   		} else {
>   			swp_entry_t entry;
>   			pte_t swp_pte;
> @@ -1919,6 +1927,9 @@ static bool try_to_migrate_one(struct page *page, struct vm_area_struct *vma,
>   
>   	mmu_notifier_invalidate_range_end(&range);
>   
> +	if (nr_put)
> +		pte_put_many(mm, pvmw.pmd, address, nr_put);
> +
>   	return ret;
>   }
>   
> diff --git a/mm/swapfile.c b/mm/swapfile.c
> index 1e07d1c776f2..6153283be500 100644
> --- a/mm/swapfile.c
> +++ b/mm/swapfile.c
> @@ -40,6 +40,7 @@
>   #include <linux/swap_slots.h>
>   #include <linux/sort.h>
>   #include <linux/completion.h>
> +#include <linux/pte_ref.h>
>   
>   #include <asm/tlbflush.h>
>   #include <linux/swapops.h>
> @@ -2021,10 +2022,13 @@ static inline int unuse_pmd_range(struct vm_area_struct *vma, pud_t *pud,
>   	do {
>   		cond_resched();
>   		next = pmd_addr_end(addr, end);
> -		if (pmd_none_or_trans_huge_or_clear_bad(pmd))
> +		if ((!IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +		    pmd_none_or_trans_huge_or_clear_bad(pmd)) ||
> +		    !pte_try_get(vma->vm_mm, pmd))
>   			continue;
>   		ret = unuse_pte_range(vma, pmd, addr, next, type,
>   				      frontswap, fs_pages_to_unuse);
> +		pte_put(vma->vm_mm, pmd, addr);
>   		if (ret)
>   			return ret;
>   	} while (pmd++, addr = next, addr != end);
> diff --git a/mm/userfaultfd.c b/mm/userfaultfd.c
> index 0e2132834bc7..7ebf4fb09a85 100644
> --- a/mm/userfaultfd.c
> +++ b/mm/userfaultfd.c
> @@ -111,6 +111,7 @@ int mfill_atomic_install_pte(struct mm_struct *dst_mm, pmd_t *dst_pmd,
>   		lru_cache_add_inactive_or_unevictable(page, dst_vma);
>   
>   	set_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> +	pte_get(dst_pmd);
>   
>   	/* No need to invalidate - it was non-present before */
>   	update_mmu_cache(dst_vma, dst_addr, dst_pte);
> @@ -205,6 +206,7 @@ static int mfill_zeropage_pte(struct mm_struct *dst_mm,
>   	if (!pte_none(*dst_pte))
>   		goto out_unlock;
>   	set_pte_at(dst_mm, dst_addr, dst_pte, _dst_pte);
> +	pte_get(dst_pmd);
>   	/* No need to invalidate - it was non-present before */
>   	update_mmu_cache(dst_vma, dst_addr, dst_pte);
>   	ret = 0;
> @@ -570,6 +572,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>   
>   	while (src_addr < src_start + len) {
>   		pmd_t dst_pmdval;
> +		int ret = 1;
>   
>   		BUG_ON(dst_addr >= dst_start + len);
>   
> @@ -588,13 +591,14 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>   			err = -EEXIST;
>   			break;
>   		}
> -		if (unlikely(pmd_none(dst_pmdval)) &&
> -		    unlikely(__pte_alloc(dst_mm, dst_pmd))) {
> +
> +		if ((IS_ENABLED(CONFIG_FREE_USER_PTE) &&
> +		     unlikely((ret = pte_alloc_try_get(dst_mm, dst_pmd)) < 0)) ||
> +		    (unlikely(pmd_none(dst_pmdval)) &&
> +		     unlikely((ret = __pte_alloc_try_get(dst_mm, dst_pmd)) < 0))) {
>   			err = -ENOMEM;
>   			break;
> -		}
> -		/* If an huge pmd materialized from under us fail */
> -		if (unlikely(pmd_trans_huge(*dst_pmd))) {
> +		} else if (!ret || unlikely(pmd_trans_huge(*dst_pmd))) {
>   			err = -EFAULT;
>   			break;
>   		}
> @@ -604,6 +608,7 @@ static __always_inline ssize_t __mcopy_atomic(struct mm_struct *dst_mm,
>   
>   		err = mfill_atomic_pte(dst_mm, dst_pmd, dst_vma, dst_addr,
>   				       src_addr, &page, mcopy_mode, wp_copy);
> +		pte_put(dst_mm, dst_pmd, dst_addr);
>   		cond_resched();
>   
>   		if (unlikely(err == -ENOENT)) {

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ