--- b/arch/x86/include/asm/pgtable.h | 6 ++ b/include/linux/sched.h | 1 b/kernel/fork.c | 81 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) diff -puN kernel/fork.c~get_task_mm-locking kernel/fork.c --- a/kernel/fork.c~get_task_mm-locking 2023-05-15 08:55:59.168918971 -0700 +++ b/kernel/fork.c 2023-05-15 14:17:34.434547231 -0700 @@ -661,6 +661,12 @@ static __latent_entropy int dup_mmap(str retval = -EINTR; goto fail_uprobe_end; } + + if ((atomic_read(&oldmm->mm_users) == 1) && + (atomic_read(&oldmm->mm_count) == 1)) + current->task_doing_fast_fork = 1; + + flush_cache_dup_mm(oldmm); uprobe_dup_mmap(oldmm, mm); /* @@ -774,6 +780,7 @@ loop_out: out: mmap_write_unlock(mm); flush_tlb_mm(oldmm); + current->task_doing_fast_fork = 0; mmap_write_unlock(oldmm); dup_userfaultfd_complete(&uf); fail_uprobe_end: @@ -1527,6 +1534,46 @@ struct file *get_task_exe_file(struct ta return exe_file; } +/* + * A "fast fork()" task might be taking shortcuts that make + * the mm's address space unstable for multithreaded access. + * The mm can't be considered stable until holding + * mmap_write_lock() to ensure that the fork() is finished. + * + * This function is a "maybe" because it can "return true" + * for many reasons other than an mm that is doing a fast + * fork. But it should be right enough of the time to keep + * callers using their fast paths the majority of the time. + */ +static bool mm_maybe_fast_fork(struct mm_struct *mm) +{ + /* Fast fork() won't start if ->mm_users is elevated */ + if (atomic_read(&mm->mm_users) != 1) + return false; + + /* + * Some users, like proc_mem_operations want to avoid pinning + * the address space with ->mm_users. They instead elevate + * ->mm_count and then (temporarily) "upgrade" that ->mm_count + * ref to an ->mm_users ref with mmget_not_zero(). + * + * Check for ->mm_count==1. This ensures that no one will + * later upgrade that ref with mmget_not_zero() and access + * the page tables without the "slow" mmap_lock path. + */ + if (atomic_read(&mm->mm_count) != 1) + return false; + + /* + * A task can not not be doing a fork() if the lock + * is not held. + */ + if (!rwsem_is_locked(&mm->mmap_lock)) + return false; + + return true; +} + /** * get_task_mm - acquire a reference to the task's mm * @@ -1543,6 +1590,8 @@ struct mm_struct *get_task_mm(struct tas task_lock(task); mm = task->mm; if (mm) { + if (mm_maybe_fast_fork(mm)) + goto slow; if (task->flags & PF_KTHREAD) mm = NULL; else @@ -1550,6 +1599,31 @@ struct mm_struct *get_task_mm(struct tas } task_unlock(task); return mm; +slow: { + struct mm_struct *ret_mm = mm; + + mmgrab(mm); + task_unlock(task); + + /* + * Thanks to the mmgrab(), 'mm' itself is now stable. + * 'task' might exit but can not free the mm. + */ + + /* If a fork() was happening, wait for it to complete: */ + mmap_write_lock(mm); + if (!mmget_not_zero(mm)) { + /* + * The mm's address space has gone away. Tell + * the caller that the task's mm was unavailable: + */ + ret_mm = NULL; + } + mmap_write_unlock(mm); + mmdrop(mm); + + return ret_mm; + } } EXPORT_SYMBOL_GPL(get_task_mm); @@ -3561,3 +3635,10 @@ int sysctl_max_threads(struct ctl_table return 0; } + +bool current_doing_fast_fork(void) +{ + WARN_ON(current->task_doing_fast_fork && (atomic_read(¤t->mm->mm_users) > 1)); + return current->task_doing_fast_fork; +} + diff -puN include/linux/sched.h~get_task_mm-locking include/linux/sched.h --- a/include/linux/sched.h~get_task_mm-locking 2023-05-15 09:32:57.393721962 -0700 +++ b/include/linux/sched.h 2023-05-15 09:33:31.240599499 -0700 @@ -762,6 +762,7 @@ struct task_struct { /* Per task flags (PF_*), defined further below: */ unsigned int flags; unsigned int ptrace; + bool task_doing_fast_fork; #ifdef CONFIG_SMP int on_cpu; diff -puN arch/x86/include/asm/pgtable.h~get_task_mm-locking arch/x86/include/asm/pgtable.h --- a/arch/x86/include/asm/pgtable.h~get_task_mm-locking 2023-05-15 09:36:13.030361316 -0700 +++ b/arch/x86/include/asm/pgtable.h 2023-05-15 10:27:45.112652940 -0700 @@ -1094,7 +1094,11 @@ static inline pte_t ptep_get_and_clear_f static inline void ptep_set_wrprotect(struct mm_struct *mm, unsigned long addr, pte_t *ptep) { - clear_bit(_PAGE_BIT_RW, (unsigned long *)&ptep->pte); + extern bool current_doing_fast_fork(void); + if (current_doing_fast_fork()) + __clear_bit(_PAGE_BIT_RW, (unsigned long *)&ptep->pte); + else + clear_bit(_PAGE_BIT_RW, (unsigned long *)&ptep->pte); } #define flush_tlb_fix_spurious_fault(vma, address, ptep) do { } while (0) _