lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [<thread-prev] [day] [month] [year] [list]
Message-ID: <87h6rhw4i0.fsf@nvidia.com>
Date:   Fri, 09 Jun 2023 16:05:05 +1000
From:   Alistair Popple <apopple@...dia.com>
To:     Jason Gunthorpe <jgg@...dia.com>
Cc:     Sean Christopherson <seanjc@...gle.com>,
        Robin Murphy <robin.murphy@....com>,
        Andrew Morton <akpm@...ux-foundation.org>, will@...nel.org,
        catalin.marinas@....com, linux-mm@...ck.org,
        linux-kernel@...r.kernel.org, nicolinc@...dia.com,
        linux-arm-kernel@...ts.infradead.org, kvm@...r.kernel.org,
        John Hubbard <jhubbard@...dia.com>, zhi.wang.linux@...il.com
Subject: Re: [PATCH 2/2] arm64: Notify on pte permission upgrades


Alistair Popple <apopple@...dia.com> writes:

> Alistair Popple <apopple@...dia.com> writes:
>
>> Alistair Popple <apopple@...dia.com> writes:
>>
>>>> On Tue, May 30, 2023 at 02:44:40PM -0700, Sean Christopherson wrote:
>>>>> > KVM already has locking for invalidate_start/end - it has to check
>>>>> > mmu_notifier_retry_cache() with the sequence numbers/etc around when
>>>>> > it does does hva_to_pfn()
>>>>> > 
>>>>> > The bug is that the kvm_vcpu_reload_apic_access_page() path is
>>>>> > ignoring this locking so it ignores in-progress range
>>>>> > invalidations. It should spin until the invalidation clears like other
>>>>> > places in KVM.
>>>>> > 
>>>>> > The comment is kind of misleading because drivers shouldn't be abusing
>>>>> > the iommu centric invalidate_range() thing to fix missing locking in
>>>>> > start/end users. :\
>>>>> > 
>>>>> > So if KVM could be fixed up we could make invalidate_range defined to
>>>>> > be an arch specific callback to synchronize the iommu TLB.
>>>>> 
>>>>> And maybe rename invalidate_range() and/or invalidate_range_{start,end}() to make
>>>>> it super obvious that they are intended for two different purposes?  E.g. instead
>>>>> of invalidate_range(), something like invalidate_secondary_tlbs().
>>>>
>>>> Yeah, I think I would call it invalidate_arch_secondary_tlb() and
>>>> document it as being an arch specific set of invalidations that match
>>>> the architected TLB maintenance requrements. And maybe we can check it
>>>> more carefully to make it be called in less places. Like I'm not sure
>>>> it is right to call it from invalidate_range_end under this new
>>>> definition..
>>>
>>> I'd be happy to look at that, although it sounds like Sean already is.
>
> Thanks Sean for getting the KVM fix posted so quickly. I'm looking into
> doing the rename now.
>
> Do we want to do more than a simple rename and tidy up of callers
> though? What I'm thinking is introducing something like an IOMMU/TLB
> specific variant of notifiers (eg. struct tlb_notifier) which has the
> invalidate_secondary_tlbs() callback in say struct tlb_notifier_ops
> rather than leaving that in the mmu_notifier_ops.
>
> Implementation wise we'd reuse most of the mmu_notifier code, but it
> would help make the two different uses of notifiers clearer. Thoughts?

So something like the below incomplete patch (against v6.2) which would
introduce a new struct tlb_notifier and associated ops. The change isn't
huge, but does result in some churn and another layer of indirection in
mmu_notifier.c. Otherwise we can just rename the callback.

--- 

diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
index a5a63b1c947e..c300cd435609 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-sva.c
@@ -14,7 +14,7 @@
 #include "../../io-pgtable-arm.h"
 
 struct arm_smmu_mmu_notifier {
-	struct mmu_notifier		mn;
+	struct tlb_notifier		mn;
 	struct arm_smmu_ctx_desc	*cd;
 	bool				cleared;
 	refcount_t			refs;
@@ -186,7 +186,7 @@ static void arm_smmu_free_shared_cd(struct arm_smmu_ctx_desc *cd)
 	}
 }
 
-static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
+static void arm_smmu_mm_invalidate_range(struct tlb_notifier *mn,
 					 struct mm_struct *mm,
 					 unsigned long start, unsigned long end)
 {
@@ -207,7 +207,7 @@ static void arm_smmu_mm_invalidate_range(struct mmu_notifier *mn,
 	arm_smmu_atc_inv_domain(smmu_domain, mm->pasid, start, size);
 }
 
-static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
+static void arm_smmu_mm_release(struct tlb_notifier *mn, struct mm_struct *mm)
 {
 	struct arm_smmu_mmu_notifier *smmu_mn = mn_to_smmu(mn);
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
@@ -231,15 +231,15 @@ static void arm_smmu_mm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 	mutex_unlock(&sva_lock);
 }
 
-static void arm_smmu_mmu_notifier_free(struct mmu_notifier *mn)
+static void arm_smmu_mmu_notifier_free(struct tlb_notifier *mn)
 {
 	kfree(mn_to_smmu(mn));
 }
 
-static const struct mmu_notifier_ops arm_smmu_mmu_notifier_ops = {
-	.invalidate_range	= arm_smmu_mm_invalidate_range,
-	.release		= arm_smmu_mm_release,
-	.free_notifier		= arm_smmu_mmu_notifier_free,
+static const struct tlb_notifier_ops arm_smmu_tlb_notifier_ops = {
+	.invalidate_secondary_tlbs = arm_smmu_mm_invalidate_range,
+	.release		   = arm_smmu_mm_release,
+	.free_notifier		   = arm_smmu_mmu_notifier_free,
 };
 
 /* Allocate or get existing MMU notifier for this {domain, mm} pair */
@@ -252,7 +252,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	struct arm_smmu_mmu_notifier *smmu_mn;
 
 	list_for_each_entry(smmu_mn, &smmu_domain->mmu_notifiers, list) {
-		if (smmu_mn->mn.mm == mm) {
+		if (smmu_mn->mn.mm_notifier_chain.mm == mm) {
 			refcount_inc(&smmu_mn->refs);
 			return smmu_mn;
 		}
@@ -271,9 +271,9 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 	refcount_set(&smmu_mn->refs, 1);
 	smmu_mn->cd = cd;
 	smmu_mn->domain = smmu_domain;
-	smmu_mn->mn.ops = &arm_smmu_mmu_notifier_ops;
+	smmu_mn->mn.ops = &arm_smmu_tlb_notifier_ops;
 
-	ret = mmu_notifier_register(&smmu_mn->mn, mm);
+	ret = tlb_notifier_register(&smmu_mn->mn, mm);
 	if (ret) {
 		kfree(smmu_mn);
 		goto err_free_cd;
@@ -288,7 +288,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 err_put_notifier:
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 err_free_cd:
 	arm_smmu_free_shared_cd(cd);
 	return ERR_PTR(ret);
@@ -296,7 +296,7 @@ arm_smmu_mmu_notifier_get(struct arm_smmu_domain *smmu_domain,
 
 static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 {
-	struct mm_struct *mm = smmu_mn->mn.mm;
+	struct mm_struct *mm = smmu_mn->mn.mm_notifier_chain.mm;
 	struct arm_smmu_ctx_desc *cd = smmu_mn->cd;
 	struct arm_smmu_domain *smmu_domain = smmu_mn->domain;
 
@@ -316,7 +316,7 @@ static void arm_smmu_mmu_notifier_put(struct arm_smmu_mmu_notifier *smmu_mn)
 	}
 
 	/* Frees smmu_mn */
-	mmu_notifier_put(&smmu_mn->mn);
+	tlb_notifier_put(&smmu_mn->mn);
 	arm_smmu_free_shared_cd(cd);
 }
 
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index d6c06e140277..157571497b28 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -13,6 +13,7 @@ struct mmu_notifier_subscriptions;
 struct mmu_notifier;
 struct mmu_notifier_range;
 struct mmu_interval_notifier;
+struct mm_notifier_chain;
 
 /**
  * enum mmu_notifier_event - reason for the mmu notifier callback
@@ -61,6 +62,18 @@ enum mmu_notifier_event {
 
 #define MMU_NOTIFIER_RANGE_BLOCKABLE (1 << 0)
 
+struct tlb_notifier;
+struct tlb_notifier_ops {
+	void (*invalidate_secondary_tlbs)(struct tlb_notifier *subscription,
+					struct mm_struct *mm,
+					unsigned long start,
+					unsigned long end);
+
+	void (*free_notifier)(struct tlb_notifier *subscription);
+	void (*release)(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+};
+
 struct mmu_notifier_ops {
 	/*
 	 * Called either by mmu_notifier_unregister or when the mm is
@@ -186,29 +199,6 @@ struct mmu_notifier_ops {
 	void (*invalidate_range_end)(struct mmu_notifier *subscription,
 				     const struct mmu_notifier_range *range);
 
-	/*
-	 * invalidate_range() is either called between
-	 * invalidate_range_start() and invalidate_range_end() when the
-	 * VM has to free pages that where unmapped, but before the
-	 * pages are actually freed, or outside of _start()/_end() when
-	 * a (remote) TLB is necessary.
-	 *
-	 * If invalidate_range() is used to manage a non-CPU TLB with
-	 * shared page-tables, it not necessary to implement the
-	 * invalidate_range_start()/end() notifiers, as
-	 * invalidate_range() already catches the points in time when an
-	 * external TLB range needs to be flushed. For more in depth
-	 * discussion on this see Documentation/mm/mmu_notifier.rst
-	 *
-	 * Note that this function might be called with just a sub-range
-	 * of what was passed to invalidate_range_start()/end(), if
-	 * called between those functions.
-	 */
-	void (*invalidate_range)(struct mmu_notifier *subscription,
-				 struct mm_struct *mm,
-				 unsigned long start,
-				 unsigned long end);
-
 	/*
 	 * These callbacks are used with the get/put interface to manage the
 	 * lifetime of the mmu_notifier memory. alloc_notifier() returns a new
@@ -234,14 +224,23 @@ struct mmu_notifier_ops {
  * 2. One of the reverse map locks is held (i_mmap_rwsem or anon_vma->rwsem).
  * 3. No other concurrent thread can access the list (release)
  */
-struct mmu_notifier {
+struct mm_notifier_chain {
 	struct hlist_node hlist;
-	const struct mmu_notifier_ops *ops;
 	struct mm_struct *mm;
 	struct rcu_head rcu;
 	unsigned int users;
 };
 
+struct mmu_notifier {
+	const struct mmu_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
+struct tlb_notifier {
+	const struct tlb_notifier_ops *ops;
+	struct mm_notifier_chain mm_notifier_chain;
+};
+
 /**
  * struct mmu_interval_notifier_ops
  * @invalidate: Upon return the caller must stop using any SPTEs within this
@@ -283,6 +282,10 @@ static inline int mm_has_notifiers(struct mm_struct *mm)
 	return unlikely(mm->notifier_subscriptions);
 }
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm);
+void tlb_notifier_put(struct tlb_notifier *subscription);
+
 struct mmu_notifier *mmu_notifier_get_locked(const struct mmu_notifier_ops *ops,
 					     struct mm_struct *mm);
 static inline struct mmu_notifier *
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index f45ff1b7626a..cdc3a373a225 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -47,6 +47,16 @@ struct mmu_notifier_subscriptions {
 	struct hlist_head deferred_list;
 };
 
+struct mmu_notifier *mmu_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct mmu_notifier, mm_notifier_chain);
+}
+
+struct tlb_notifier *tlb_notifier_from_chain(struct mm_notifier_chain *chain)
+{
+	return container_of(chain, struct tlb_notifier, mm_notifier_chain);
+}
+
 /*
  * This is a collision-retry read-side/write-side 'lock', a lot like a
  * seqcount, however this allows multiple write-sides to hold it at
@@ -299,7 +309,7 @@ static void mn_itree_release(struct mmu_notifier_subscriptions *subscriptions,
 static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 			     struct mm_struct *mm)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	/*
@@ -307,8 +317,10 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 	 * ->release returns.
 	 */
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
-				 srcu_read_lock_held(&srcu))
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
+				srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		/*
 		 * If ->release runs before mmu_notifier_unregister it must be
 		 * handled, as it's the only way for the driver to flush all
@@ -317,18 +329,19 @@ static void mn_hlist_release(struct mmu_notifier_subscriptions *subscriptions,
 		 */
 		if (subscription->ops->release)
 			subscription->ops->release(subscription, mm);
+	}
 
 	spin_lock(&subscriptions->lock);
 	while (unlikely(!hlist_empty(&subscriptions->list))) {
-		subscription = hlist_entry(subscriptions->list.first,
-					   struct mmu_notifier, hlist);
+		chain = hlist_entry(subscriptions->list.first,
+					   struct mm_notifier_chain, hlist);
 		/*
 		 * We arrived before mmu_notifier_unregister so
 		 * mmu_notifier_unregister will do nothing other than to wait
 		 * for ->release to finish and for mmu_notifier_unregister to
 		 * return.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&chain->hlist);
 	}
 	spin_unlock(&subscriptions->lock);
 	srcu_read_unlock(&srcu, id);
@@ -366,13 +379,15 @@ int __mmu_notifier_clear_flush_young(struct mm_struct *mm,
 					unsigned long start,
 					unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_flush_young)
 			young |= subscription->ops->clear_flush_young(
 				subscription, mm, start, end);
@@ -386,13 +401,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 			       unsigned long start,
 			       unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->clear_young)
 			young |= subscription->ops->clear_young(subscription,
 								mm, start, end);
@@ -405,13 +422,15 @@ int __mmu_notifier_clear_young(struct mm_struct *mm,
 int __mmu_notifier_test_young(struct mm_struct *mm,
 			      unsigned long address)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int young = 0, id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->test_young) {
 			young = subscription->ops->test_young(subscription, mm,
 							      address);
@@ -427,13 +446,15 @@ int __mmu_notifier_test_young(struct mm_struct *mm,
 void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
 			       pte_t pte)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops->change_pte)
 			subscription->ops->change_pte(subscription, mm, address,
 						      pte);
@@ -476,13 +497,15 @@ static int mn_hlist_invalidate_range_start(
 	struct mmu_notifier_subscriptions *subscriptions,
 	struct mmu_notifier_range *range)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int ret = 0;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		const struct mmu_notifier_ops *ops = subscription->ops;
 
 		if (ops->invalidate_range_start) {
@@ -519,8 +542,10 @@ static int mn_hlist_invalidate_range_start(
 		 * notifiers and one or more failed start, any that succeeded
 		 * start are expecting their end to be called.  Do so now.
 		 */
-		hlist_for_each_entry_rcu(subscription, &subscriptions->list,
+		hlist_for_each_entry_rcu(chain, &subscriptions->list,
 					 hlist, srcu_read_lock_held(&srcu)) {
+			struct mmu_notifier *subscription =
+				mmu_notifier_from_chain(chain);
 			if (!subscription->ops->invalidate_range_end)
 				continue;
 
@@ -553,35 +578,20 @@ static void
 mn_hlist_invalidate_end(struct mmu_notifier_subscriptions *subscriptions,
 			struct mmu_notifier_range *range, bool only_end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription, &subscriptions->list, hlist,
+	hlist_for_each_entry_rcu(chain, &subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		/*
-		 * Call invalidate_range here too to avoid the need for the
-		 * subsystem of having to register an invalidate_range_end
-		 * call-back when there is invalidate_range already. Usually a
-		 * subsystem registers either invalidate_range_start()/end() or
-		 * invalidate_range(), so this will be no additional overhead
-		 * (besides the pointer check).
-		 *
-		 * We skip call to invalidate_range() if we know it is safe ie
-		 * call site use mmu_notifier_invalidate_range_only_end() which
-		 * is safe to do when we know that a call to invalidate_range()
-		 * already happen under page table lock.
-		 */
-		if (!only_end && subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription,
-							    range->mm,
-							    range->start,
-							    range->end);
-		if (subscription->ops->invalidate_range_end) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
+		const struct mmu_notifier_ops *ops = subscription->ops;
+
+		if (ops->invalidate_range_end) {
 			if (!mmu_notifier_range_blockable(range))
 				non_block_start();
-			subscription->ops->invalidate_range_end(subscription,
-								range);
+			ops->invalidate_range_end(subscription, range);
 			if (!mmu_notifier_range_blockable(range))
 				non_block_end();
 		}
@@ -607,27 +617,24 @@ void __mmu_notifier_invalidate_range_end(struct mmu_notifier_range *range,
 void __mmu_notifier_invalidate_range(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 	int id;
 
 	id = srcu_read_lock(&srcu);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 srcu_read_lock_held(&srcu)) {
-		if (subscription->ops->invalidate_range)
-			subscription->ops->invalidate_range(subscription, mm,
-							    start, end);
+		struct tlb_notifier *subscription =
+			tlb_notifier_from_chain(chain);
+		if (subscription->ops->invalidate_secondary_tlbs)
+			subscription->ops->invalidate_secondary_tlbs(subscription,
+								mm, start, end);
 	}
 	srcu_read_unlock(&srcu, id);
 }
 
-/*
- * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
- * write mode. A NULL mn signals the notifier is being registered for itree
- * mode.
- */
-int __mmu_notifier_register(struct mmu_notifier *subscription,
-			    struct mm_struct *mm)
+int __mm_notifier_chain_register(struct mm_notifier_chain *chain,
+				struct mm_struct *mm)
 {
 	struct mmu_notifier_subscriptions *subscriptions = NULL;
 	int ret;
@@ -677,14 +684,14 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	if (subscriptions)
 		smp_store_release(&mm->notifier_subscriptions, subscriptions);
 
-	if (subscription) {
+	if (chain) {
 		/* Pairs with the mmdrop in mmu_notifier_unregister_* */
 		mmgrab(mm);
-		subscription->mm = mm;
-		subscription->users = 1;
+		chain->mm = mm;
+		chain->users = 1;
 
 		spin_lock(&mm->notifier_subscriptions->lock);
-		hlist_add_head_rcu(&subscription->hlist,
+		hlist_add_head_rcu(&chain->hlist,
 				   &mm->notifier_subscriptions->list);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	} else
@@ -698,6 +705,18 @@ int __mmu_notifier_register(struct mmu_notifier *subscription,
 	kfree(subscriptions);
 	return ret;
 }
+
+/*
+ * Same as mmu_notifier_register but here the caller must hold the mmap_lock in
+ * write mode. A NULL mn signals the notifier is being registered for itree
+ * mode.
+ */
+int __mmu_notifier_register(struct mmu_notifier *subscription,
+			    struct mm_struct *mm)
+{
+	return __mm_notifier_chain_register(&subscription->mm_notifier_chain,
+					mm);
+}
 EXPORT_SYMBOL_GPL(__mmu_notifier_register);
 
 /**
@@ -731,20 +750,34 @@ int mmu_notifier_register(struct mmu_notifier *subscription,
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_register);
 
+int tlb_notifier_register(struct tlb_notifier *subscription,
+			struct mm_struct *mm)
+{
+	int ret;
+
+	mmap_write_lock(mm);
+	ret = __mm_notifier_chain_register(&subscription->mm_notifier_chain, mm);
+	mmap_write_unlock(mm);
+	return ret;
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_register);
+
 static struct mmu_notifier *
 find_get_mmu_notifier(struct mm_struct *mm, const struct mmu_notifier_ops *ops)
 {
-	struct mmu_notifier *subscription;
+	struct mm_notifier_chain *chain;
 
 	spin_lock(&mm->notifier_subscriptions->lock);
-	hlist_for_each_entry_rcu(subscription,
+	hlist_for_each_entry_rcu(chain,
 				 &mm->notifier_subscriptions->list, hlist,
 				 lockdep_is_held(&mm->notifier_subscriptions->lock)) {
+		struct mmu_notifier *subscription =
+			mmu_notifier_from_chain(chain);
 		if (subscription->ops != ops)
 			continue;
 
-		if (likely(subscription->users != UINT_MAX))
-			subscription->users++;
+		if (likely(chain->users != UINT_MAX))
+			chain->users++;
 		else
 			subscription = ERR_PTR(-EOVERFLOW);
 		spin_unlock(&mm->notifier_subscriptions->lock);
@@ -822,7 +855,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 {
 	BUG_ON(atomic_read(&mm->mm_count) <= 0);
 
-	if (!hlist_unhashed(&subscription->hlist)) {
+	if (!hlist_unhashed(&subscription->mm_notifier_chain.hlist)) {
 		/*
 		 * SRCU here will force exit_mmap to wait for ->release to
 		 * finish before freeing the pages.
@@ -843,7 +876,7 @@ void mmu_notifier_unregister(struct mmu_notifier *subscription,
 		 * Can not use list_del_rcu() since __mmu_notifier_release
 		 * can delete it before we hold the lock.
 		 */
-		hlist_del_init_rcu(&subscription->hlist);
+		hlist_del_init_rcu(&subscription->mm_notifier_chain.hlist);
 		spin_unlock(&mm->notifier_subscriptions->lock);
 	}
 
@@ -861,15 +894,34 @@ EXPORT_SYMBOL_GPL(mmu_notifier_unregister);
 
 static void mmu_notifier_free_rcu(struct rcu_head *rcu)
 {
-	struct mmu_notifier *subscription =
-		container_of(rcu, struct mmu_notifier, rcu);
-	struct mm_struct *mm = subscription->mm;
+	struct mm_notifier_chain *chain =
+		container_of(rcu, struct mm_notifier_chain, rcu);
+	struct mm_struct *mm = chain->mm;
+	struct mmu_notifier *subscription = mmu_notifier_from_chain(chain);
 
 	subscription->ops->free_notifier(subscription);
 	/* Pairs with the get in __mmu_notifier_register() */
 	mmdrop(mm);
 }
 
+void mm_notifier_chain_put(struct mm_notifier_chain *chain)
+{
+	struct mm_struct *mm = chain->mm;
+
+	spin_lock(&mm->notifier_subscriptions->lock);
+	if (WARN_ON(!chain->users) ||
+		--chain->users)
+		goto out_unlock;
+	hlist_del_init_rcu(&chain->hlist);
+	spin_unlock(&mm->notifier_subscriptions->lock);
+
+	call_srcu(&srcu, &chain->rcu, mmu_notifier_free_rcu);
+	return;
+
+out_unlock:
+	spin_unlock(&mm->notifier_subscriptions->lock);
+}
+
 /**
  * mmu_notifier_put - Release the reference on the notifier
  * @subscription: The notifier to act on
@@ -894,22 +946,16 @@ static void mmu_notifier_free_rcu(struct rcu_head *rcu)
  */
 void mmu_notifier_put(struct mmu_notifier *subscription)
 {
-	struct mm_struct *mm = subscription->mm;
-
-	spin_lock(&mm->notifier_subscriptions->lock);
-	if (WARN_ON(!subscription->users) || --subscription->users)
-		goto out_unlock;
-	hlist_del_init_rcu(&subscription->hlist);
-	spin_unlock(&mm->notifier_subscriptions->lock);
-
-	call_srcu(&srcu, &subscription->rcu, mmu_notifier_free_rcu);
-	return;
-
-out_unlock:
-	spin_unlock(&mm->notifier_subscriptions->lock);
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
 }
 EXPORT_SYMBOL_GPL(mmu_notifier_put);
 
+void tlb_notifier_put(struct tlb_notifier *subscription)
+{
+	mm_notifier_chain_put(&subscription->mm_notifier_chain);
+}
+EXPORT_SYMBOL_GPL(tlb_notifier_put);
+
 static int __mmu_interval_notifier_insert(
 	struct mmu_interval_notifier *interval_sub, struct mm_struct *mm,
 	struct mmu_notifier_subscriptions *subscriptions, unsigned long start,

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ