[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-Id: <20250310132222.58378-4-luxu.kernel@bytedance.com>
Date: Mon, 10 Mar 2025 21:22:21 +0800
From: Xu Lu <luxu.kernel@...edance.com>
To: akpm@...ux-foundation.org,
tjeznach@...osinc.com,
joro@...tes.org,
will@...nel.org,
robin.murphy@....com
Cc: lihangjing@...edance.com,
xieyongji@...edance.com,
linux-riscv@...ts.infradead.org,
linux-kernel@...r.kernel.org,
Xu Lu <luxu.kernel@...edance.com>
Subject: [PATCH 3/4] iommu/riscv: Introduce IOMMU page table lock
Introduce page table lock to address competition issues when modifying
multiple PTEs, for example, when applying Svnapot. We use fine-grained
page table locks to minimize lock contention.
Signed-off-by: Xu Lu <luxu.kernel@...edance.com>
---
drivers/iommu/riscv/iommu.c | 126 ++++++++++++++++++++++++++++++------
1 file changed, 108 insertions(+), 18 deletions(-)
diff --git a/drivers/iommu/riscv/iommu.c b/drivers/iommu/riscv/iommu.c
index f752096989a79..ffc474987a075 100644
--- a/drivers/iommu/riscv/iommu.c
+++ b/drivers/iommu/riscv/iommu.c
@@ -808,6 +808,7 @@ struct riscv_iommu_domain {
struct iommu_domain domain;
struct list_head bonds;
spinlock_t lock; /* protect bonds list updates. */
+ spinlock_t page_table_lock; /* protect page table updates. */
int pscid;
bool amo_enabled;
int numa_node;
@@ -1086,8 +1087,80 @@ static void riscv_iommu_iotlb_sync(struct iommu_domain *iommu_domain,
#define _io_pte_none(pte) (pte_val(pte) == 0)
#define _io_pte_entry(pn, prot) (__pte((_PAGE_PFN_MASK & ((pn) << _PAGE_PFN_SHIFT)) | (prot)))
+#define RISCV_IOMMU_PMD_LEVEL 1
+
+static bool riscv_iommu_ptlock_init(struct ptdesc *ptdesc, int level)
+{
+ if (level <= RISCV_IOMMU_PMD_LEVEL)
+ return ptlock_init(ptdesc);
+ return true;
+}
+
+static void riscv_iommu_ptlock_free(struct ptdesc *ptdesc, int level)
+{
+ if (level <= RISCV_IOMMU_PMD_LEVEL)
+ ptlock_free(ptdesc);
+}
+
+static spinlock_t *riscv_iommu_ptlock(struct riscv_iommu_domain *domain,
+ pte_t *pte, int level)
+{
+ spinlock_t *ptl;
+
+#ifdef CONFIG_SPLIT_PTE_PTLOCKS
+ if (level <= RISCV_IOMMU_PMD_LEVEL)
+ ptl = ptlock_ptr(page_ptdesc(virt_to_page(pte)));
+ else
+#endif
+ ptl = &domain->page_table_lock;
+ spin_lock(ptl);
+
+ return ptl;
+}
+
+static void *riscv_iommu_alloc_pagetable_node(int numa_node, gfp_t gfp, int level)
+{
+ struct ptdesc *ptdesc;
+ void *addr;
+
+ addr = iommu_alloc_page_node(numa_node, gfp);
+ if (!addr)
+ return NULL;
+
+ ptdesc = page_ptdesc(virt_to_page(addr));
+ if (!riscv_iommu_ptlock_init(ptdesc, level)) {
+ iommu_free_page(addr);
+ addr = NULL;
+ }
+
+ return addr;
+}
+
+static void riscv_iommu_free_pagetable(void *addr, int level)
+{
+ struct ptdesc *ptdesc = page_ptdesc(virt_to_page(addr));
+
+ riscv_iommu_ptlock_free(ptdesc, level);
+ iommu_free_page(addr);
+}
+
+static int pgsize_to_level(size_t pgsize)
+{
+ int level = RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV57 -
+ RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
+ int shift = PAGE_SHIFT + PT_SHIFT * level;
+
+ while (pgsize < ((size_t)1 << shift)) {
+ shift -= PT_SHIFT;
+ level--;
+ }
+
+ return level;
+}
+
static void riscv_iommu_pte_free(struct riscv_iommu_domain *domain,
- pte_t pte, struct list_head *freelist)
+ pte_t pte, int level,
+ struct list_head *freelist)
{
pte_t *ptr;
int i;
@@ -1102,10 +1175,11 @@ static void riscv_iommu_pte_free(struct riscv_iommu_domain *domain,
pte = ptr[i];
if (!_io_pte_none(pte)) {
ptr[i] = __pte(0);
- riscv_iommu_pte_free(domain, pte, freelist);
+ riscv_iommu_pte_free(domain, pte, level - 1, freelist);
}
}
+ riscv_iommu_ptlock_free(page_ptdesc(virt_to_page(ptr)), level);
if (freelist)
list_add_tail(&virt_to_page(ptr)->lru, freelist);
else
@@ -1117,8 +1191,9 @@ static pte_t *riscv_iommu_pte_alloc(struct riscv_iommu_domain *domain,
gfp_t gfp)
{
pte_t *ptr = domain->pgd_root;
- pte_t pte, old;
+ pte_t pte;
int level = domain->pgd_mode - RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
+ spinlock_t *ptl;
void *addr;
do {
@@ -1146,15 +1221,21 @@ static pte_t *riscv_iommu_pte_alloc(struct riscv_iommu_domain *domain,
* page table. This might race with other mappings, retry.
*/
if (_io_pte_none(pte)) {
- addr = iommu_alloc_page_node(domain->numa_node, gfp);
+ addr = riscv_iommu_alloc_pagetable_node(domain->numa_node, gfp,
+ level - 1);
if (!addr)
return NULL;
- old = pte;
- pte = _io_pte_entry(virt_to_pfn(addr), _PAGE_TABLE);
- if (cmpxchg_relaxed(ptr, old, pte) != old) {
- iommu_free_page(addr);
+
+ ptl = riscv_iommu_ptlock(domain, ptr, level);
+ pte = ptep_get(ptr);
+ if (!_io_pte_none(pte)) {
+ spin_unlock(ptl);
+ riscv_iommu_free_pagetable(addr, level - 1);
goto pte_retry;
}
+ pte = _io_pte_entry(virt_to_pfn(addr), _PAGE_TABLE);
+ set_pte(ptr, pte);
+ spin_unlock(ptl);
}
ptr = (pte_t *)pfn_to_virt(pte_pfn(pte));
} while (level-- > 0);
@@ -1194,9 +1275,10 @@ static int riscv_iommu_map_pages(struct iommu_domain *iommu_domain,
struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
size_t size = 0;
pte_t *ptr;
- pte_t pte, old;
+ pte_t pte;
unsigned long pte_prot;
- int rc = 0;
+ int rc = 0, level;
+ spinlock_t *ptl;
LIST_HEAD(freelist);
if (!(prot & IOMMU_WRITE))
@@ -1213,11 +1295,12 @@ static int riscv_iommu_map_pages(struct iommu_domain *iommu_domain,
break;
}
- old = ptep_get(ptr);
+ level = pgsize_to_level(pgsize);
+ ptl = riscv_iommu_ptlock(domain, ptr, level);
+ riscv_iommu_pte_free(domain, ptep_get(ptr), level, &freelist);
pte = _io_pte_entry(phys_to_pfn(phys), pte_prot);
set_pte(ptr, pte);
-
- riscv_iommu_pte_free(domain, old, &freelist);
+ spin_unlock(ptl);
size += pgsize;
iova += pgsize;
@@ -1252,6 +1335,7 @@ static size_t riscv_iommu_unmap_pages(struct iommu_domain *iommu_domain,
pte_t *ptr;
size_t unmapped = 0;
size_t pte_size;
+ spinlock_t *ptl;
while (unmapped < size) {
ptr = riscv_iommu_pte_fetch(domain, iova, &pte_size);
@@ -1262,7 +1346,9 @@ static size_t riscv_iommu_unmap_pages(struct iommu_domain *iommu_domain,
if (iova & (pte_size - 1))
return unmapped;
+ ptl = riscv_iommu_ptlock(domain, ptr, pgsize_to_level(pte_size));
set_pte(ptr, __pte(0));
+ spin_unlock(ptl);
iommu_iotlb_gather_add_page(&domain->domain, gather, iova,
pte_size);
@@ -1292,13 +1378,14 @@ static void riscv_iommu_free_paging_domain(struct iommu_domain *iommu_domain)
{
struct riscv_iommu_domain *domain = iommu_domain_to_riscv(iommu_domain);
const unsigned long pfn = virt_to_pfn(domain->pgd_root);
+ int level = domain->pgd_mode - RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
WARN_ON(!list_empty(&domain->bonds));
if ((int)domain->pscid > 0)
ida_free(&riscv_iommu_pscids, domain->pscid);
- riscv_iommu_pte_free(domain, _io_pte_entry(pfn, _PAGE_TABLE), NULL);
+ riscv_iommu_pte_free(domain, _io_pte_entry(pfn, _PAGE_TABLE), level, NULL);
kfree(domain);
}
@@ -1359,7 +1446,7 @@ static struct iommu_domain *riscv_iommu_alloc_paging_domain(struct device *dev)
struct riscv_iommu_device *iommu;
unsigned int pgd_mode;
dma_addr_t va_mask;
- int va_bits;
+ int va_bits, level;
iommu = dev_to_iommu(dev);
if (iommu->caps & RISCV_IOMMU_CAPABILITIES_SV57) {
@@ -1382,11 +1469,14 @@ static struct iommu_domain *riscv_iommu_alloc_paging_domain(struct device *dev)
INIT_LIST_HEAD_RCU(&domain->bonds);
spin_lock_init(&domain->lock);
+ spin_lock_init(&domain->page_table_lock);
domain->numa_node = dev_to_node(iommu->dev);
domain->amo_enabled = !!(iommu->caps & RISCV_IOMMU_CAPABILITIES_AMO_HWAD);
domain->pgd_mode = pgd_mode;
- domain->pgd_root = iommu_alloc_page_node(domain->numa_node,
- GFP_KERNEL_ACCOUNT);
+ level = domain->pgd_mode - RISCV_IOMMU_DC_FSC_IOSATP_MODE_SV39 + 2;
+ domain->pgd_root = riscv_iommu_alloc_pagetable_node(domain->numa_node,
+ GFP_KERNEL_ACCOUNT,
+ level);
if (!domain->pgd_root) {
kfree(domain);
return ERR_PTR(-ENOMEM);
@@ -1395,7 +1485,7 @@ static struct iommu_domain *riscv_iommu_alloc_paging_domain(struct device *dev)
domain->pscid = ida_alloc_range(&riscv_iommu_pscids, 1,
RISCV_IOMMU_MAX_PSCID, GFP_KERNEL);
if (domain->pscid < 0) {
- iommu_free_page(domain->pgd_root);
+ riscv_iommu_free_pagetable(domain->pgd_root, level);
kfree(domain);
return ERR_PTR(-ENOMEM);
}
--
2.20.1
Powered by blists - more mailing lists