diff --git a/net/netlink/diag.c b/net/netlink/diag.c index 7301850..2d94c49 100644 --- a/net/netlink/diag.c +++ b/net/netlink/diag.c @@ -103,7 +103,7 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, { struct netlink_table *tbl = &nl_table[protocol]; struct rhashtable *ht = &tbl->hash; - const struct bucket_table *htbl = rht_dereference(ht->tbl, ht); + const struct bucket_table *htbl; struct net *net = sock_net(skb->sk); struct netlink_diag_req *req; struct netlink_sock *nlsk; @@ -112,8 +112,11 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, req = nlmsg_data(cb->nlh); + rcu_read_lock(); + htbl = rht_dereference_rcu(ht->tbl, ht); + for (i = 0; i < htbl->size; i++) { - rht_for_each_entry(nlsk, htbl->buckets[i], ht, node) { + rht_for_each_entry_rcu(nlsk, htbl->buckets[i], node) { sk = (struct sock *)nlsk; if (!net_eq(sock_net(sk), net)) @@ -129,12 +132,14 @@ static int __netlink_diag_dump(struct sk_buff *skb, struct netlink_callback *cb, NLM_F_MULTI, sock_i_ino(sk)) < 0) { ret = 1; + rcu_read_unlock(); goto done; } num++; } } + rcu_read_unlock(); sk_for_each_bound(sk, &tbl->mc_list) { if (sk_hashed(sk))