[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <ZNeKHZmSYDIejGKx@vergenet.net>
Date: Sat, 12 Aug 2023 15:33:17 +0200
From: Simon Horman <horms@...nel.org>
To: Eric Dumazet <edumazet@...gle.com>
Cc: "David S. Miller" <davem@...emloft.net>,
Jakub Kicinski <kuba@...nel.org>, Paolo Abeni <pabeni@...hat.com>,
netdev@...r.kernel.org, eric.dumazet@...il.com,
syzbot <syzkaller@...glegroups.com>,
Anjali Kulkarni <anjali.k.kulkarni@...cle.com>,
Florian Westphal <fw@...len.de>,
Kuniyuki Iwashima <kuniyu@...zon.com>,
Liam Howlett <liam.howlett@...cle.com>
Subject: Re: [PATCH net-next] netlink: convert nlk->flags to atomic flags
On Fri, Aug 11, 2023 at 07:22:26AM +0000, Eric Dumazet wrote:
+ Anjali Kulkarni <anjali.k.kulkarni@...cle.com>
Florian Westphal <fw@...len.de>
Kuniyuki Iwashima <kuniyu@...zon.com>
Liam Howlett <liam.howlett@...cle.com>
> sk_diag_put_flags(), netlink_setsockopt(), netlink_getsockopt()
> and others use nlk->flags without correct locking.
>
> Use set_bit(), clear_bit(), test_bit(), assign_bit() to remove
> data-races.
>
> Reported-by: syzbot <syzkaller@...glegroups.com>
> Signed-off-by: Eric Dumazet <edumazet@...gle.com>
Reviewed-by: Simon Horman <horms@...nel.org>
> ---
> net/netlink/af_netlink.c | 90 ++++++++++++++--------------------------
> net/netlink/af_netlink.h | 22 ++++++----
> net/netlink/diag.c | 10 ++---
> 3 files changed, 48 insertions(+), 74 deletions(-)
>
> diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
> index 96c605e45235815a273784302d45fa7ff88e6d62..642b9d382fb46ddbc3523584c98e07da6860951a 100644
> --- a/net/netlink/af_netlink.c
> +++ b/net/netlink/af_netlink.c
> @@ -84,7 +84,7 @@ struct listeners {
>
> static inline int netlink_is_kernel(struct sock *sk)
> {
> - return nlk_sk(sk)->flags & NETLINK_F_KERNEL_SOCKET;
> + return nlk_test_bit(KERNEL_SOCKET, sk);
> }
>
> struct netlink_table *nl_table __read_mostly;
> @@ -349,9 +349,7 @@ static void netlink_deliver_tap_kernel(struct sock *dst, struct sock *src,
>
> static void netlink_overrun(struct sock *sk)
> {
> - struct netlink_sock *nlk = nlk_sk(sk);
> -
> - if (!(nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)) {
> + if (!nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
> if (!test_and_set_bit(NETLINK_S_CONGESTED,
> &nlk_sk(sk)->state)) {
> sk->sk_err = ENOBUFS;
> @@ -1407,9 +1405,7 @@ EXPORT_SYMBOL_GPL(netlink_has_listeners);
>
> bool netlink_strict_get_check(struct sk_buff *skb)
> {
> - const struct netlink_sock *nlk = nlk_sk(NETLINK_CB(skb).sk);
> -
> - return nlk->flags & NETLINK_F_STRICT_CHK;
> + return nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
> }
> EXPORT_SYMBOL_GPL(netlink_strict_get_check);
>
> @@ -1455,7 +1451,7 @@ static void do_one_broadcast(struct sock *sk,
> return;
>
> if (!net_eq(sock_net(sk), p->net)) {
> - if (!(nlk->flags & NETLINK_F_LISTEN_ALL_NSID))
> + if (!nlk_test_bit(LISTEN_ALL_NSID, sk))
> return;
>
> if (!peernet_has_id(sock_net(sk), p->net))
> @@ -1488,7 +1484,7 @@ static void do_one_broadcast(struct sock *sk,
> netlink_overrun(sk);
> /* Clone failed. Notify ALL listeners. */
> p->failure = 1;
> - if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
> + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
> p->delivery_failure = 1;
> goto out;
> }
> @@ -1510,7 +1506,7 @@ static void do_one_broadcast(struct sock *sk,
> val = netlink_broadcast_deliver(sk, p->skb2);
> if (val < 0) {
> netlink_overrun(sk);
> - if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
> + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
> p->delivery_failure = 1;
> } else {
> p->congested |= val;
> @@ -1604,7 +1600,7 @@ static int do_one_set_err(struct sock *sk, struct netlink_set_err_data *p)
> !test_bit(p->group - 1, nlk->groups))
> goto out;
>
> - if (p->code == ENOBUFS && nlk->flags & NETLINK_F_RECV_NO_ENOBUFS) {
> + if (p->code == ENOBUFS && nlk_test_bit(RECV_NO_ENOBUFS, sk)) {
> ret = 1;
> goto out;
> }
> @@ -1668,7 +1664,7 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
> struct sock *sk = sock->sk;
> struct netlink_sock *nlk = nlk_sk(sk);
> unsigned int val = 0;
> - int err;
> + int nr = -1;
>
> if (level != SOL_NETLINK)
> return -ENOPROTOOPT;
> @@ -1679,14 +1675,12 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
>
> switch (optname) {
> case NETLINK_PKTINFO:
> - if (val)
> - nlk->flags |= NETLINK_F_RECV_PKTINFO;
> - else
> - nlk->flags &= ~NETLINK_F_RECV_PKTINFO;
> - err = 0;
> + nr = NETLINK_F_RECV_PKTINFO;
> break;
> case NETLINK_ADD_MEMBERSHIP:
> case NETLINK_DROP_MEMBERSHIP: {
> + int err;
> +
> if (!netlink_allowed(sock, NL_CFG_F_NONROOT_RECV))
> return -EPERM;
> err = netlink_realloc_groups(sk);
> @@ -1706,61 +1700,38 @@ static int netlink_setsockopt(struct socket *sock, int level, int optname,
> if (optname == NETLINK_DROP_MEMBERSHIP && nlk->netlink_unbind)
> nlk->netlink_unbind(sock_net(sk), val);
>
> - err = 0;
> break;
> }
> case NETLINK_BROADCAST_ERROR:
> - if (val)
> - nlk->flags |= NETLINK_F_BROADCAST_SEND_ERROR;
> - else
> - nlk->flags &= ~NETLINK_F_BROADCAST_SEND_ERROR;
> - err = 0;
> + nr = NETLINK_F_BROADCAST_SEND_ERROR;
> break;
> case NETLINK_NO_ENOBUFS:
> + assign_bit(NETLINK_F_RECV_NO_ENOBUFS, &nlk->flags, val);
> if (val) {
> - nlk->flags |= NETLINK_F_RECV_NO_ENOBUFS;
> clear_bit(NETLINK_S_CONGESTED, &nlk->state);
> wake_up_interruptible(&nlk->wait);
> - } else {
> - nlk->flags &= ~NETLINK_F_RECV_NO_ENOBUFS;
> }
> - err = 0;
> break;
> case NETLINK_LISTEN_ALL_NSID:
> if (!ns_capable(sock_net(sk)->user_ns, CAP_NET_BROADCAST))
> return -EPERM;
> -
> - if (val)
> - nlk->flags |= NETLINK_F_LISTEN_ALL_NSID;
> - else
> - nlk->flags &= ~NETLINK_F_LISTEN_ALL_NSID;
> - err = 0;
> + nr = NETLINK_F_LISTEN_ALL_NSID;
> break;
> case NETLINK_CAP_ACK:
> - if (val)
> - nlk->flags |= NETLINK_F_CAP_ACK;
> - else
> - nlk->flags &= ~NETLINK_F_CAP_ACK;
> - err = 0;
> + nr = NETLINK_F_CAP_ACK;
> break;
> case NETLINK_EXT_ACK:
> - if (val)
> - nlk->flags |= NETLINK_F_EXT_ACK;
> - else
> - nlk->flags &= ~NETLINK_F_EXT_ACK;
> - err = 0;
> + nr = NETLINK_F_EXT_ACK;
> break;
> case NETLINK_GET_STRICT_CHK:
> - if (val)
> - nlk->flags |= NETLINK_F_STRICT_CHK;
> - else
> - nlk->flags &= ~NETLINK_F_STRICT_CHK;
> - err = 0;
> + nr = NETLINK_F_STRICT_CHK;
> break;
> default:
> - err = -ENOPROTOOPT;
> + return -ENOPROTOOPT;
> }
> - return err;
> + if (nr >= 0)
> + assign_bit(nr, &nlk->flags, val);
> + return 0;
> }
>
> static int netlink_getsockopt(struct socket *sock, int level, int optname,
> @@ -1827,7 +1798,7 @@ static int netlink_getsockopt(struct socket *sock, int level, int optname,
> return -EINVAL;
>
> len = sizeof(int);
> - val = nlk->flags & flag ? 1 : 0;
> + val = test_bit(flag, &nlk->flags);
>
> if (put_user(len, optlen) ||
> copy_to_user(optval, &val, len))
> @@ -2004,9 +1975,9 @@ static int netlink_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
> msg->msg_namelen = sizeof(*addr);
> }
>
> - if (nlk->flags & NETLINK_F_RECV_PKTINFO)
> + if (nlk_test_bit(RECV_PKTINFO, sk))
> netlink_cmsg_recv_pktinfo(msg, skb);
> - if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID)
> + if (nlk_test_bit(LISTEN_ALL_NSID, sk))
> netlink_cmsg_listen_all_nsid(sk, msg, skb);
>
> memset(&scm, 0, sizeof(scm));
> @@ -2083,7 +2054,7 @@ __netlink_kernel_create(struct net *net, int unit, struct module *module,
> goto out_sock_release;
>
> nlk = nlk_sk(sk);
> - nlk->flags |= NETLINK_F_KERNEL_SOCKET;
> + set_bit(NETLINK_F_KERNEL_SOCKET, &nlk->flags);
>
> netlink_table_grab();
> if (!nl_table[unit].registered) {
> @@ -2218,7 +2189,7 @@ static int netlink_dump_done(struct netlink_sock *nlk, struct sk_buff *skb,
> nl_dump_check_consistent(cb, nlh);
> memcpy(nlmsg_data(nlh), &nlk->dump_done_errno, sizeof(nlk->dump_done_errno));
>
> - if (extack->_msg && nlk->flags & NETLINK_F_EXT_ACK) {
> + if (extack->_msg && test_bit(NETLINK_F_EXT_ACK, &nlk->flags)) {
> nlh->nlmsg_flags |= NLM_F_ACK_TLVS;
> if (!nla_put_string(skb, NLMSGERR_ATTR_MSG, extack->_msg))
> nlmsg_end(skb, nlh);
> @@ -2347,8 +2318,8 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
> const struct nlmsghdr *nlh,
> struct netlink_dump_control *control)
> {
> - struct netlink_sock *nlk, *nlk2;
> struct netlink_callback *cb;
> + struct netlink_sock *nlk;
> struct sock *sk;
> int ret;
>
> @@ -2383,8 +2354,7 @@ int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
> cb->min_dump_alloc = control->min_dump_alloc;
> cb->skb = skb;
>
> - nlk2 = nlk_sk(NETLINK_CB(skb).sk);
> - cb->strict_check = !!(nlk2->flags & NETLINK_F_STRICT_CHK);
> + cb->strict_check = nlk_test_bit(STRICT_CHK, NETLINK_CB(skb).sk);
>
> if (control->start) {
> cb->extack = control->extack;
> @@ -2428,7 +2398,7 @@ netlink_ack_tlv_len(struct netlink_sock *nlk, int err,
> {
> size_t tlvlen;
>
> - if (!extack || !(nlk->flags & NETLINK_F_EXT_ACK))
> + if (!extack || !test_bit(NETLINK_F_EXT_ACK, &nlk->flags))
> return 0;
>
> tlvlen = 0;
> @@ -2500,7 +2470,7 @@ void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err,
> * requests to cap the error message, and get extra error data if
> * requested.
> */
> - if (err && !(nlk->flags & NETLINK_F_CAP_ACK))
> + if (err && !test_bit(NETLINK_F_CAP_ACK, &nlk->flags))
> payload += nlmsg_len(nlh);
> else
> flags |= NLM_F_CAPPED;
> diff --git a/net/netlink/af_netlink.h b/net/netlink/af_netlink.h
> index fd424cd63f31cf09b00398a1ca92c0e0600ac7bb..2145979b9986a0331b34b6ba2fda867f23d0d71c 100644
> --- a/net/netlink/af_netlink.h
> +++ b/net/netlink/af_netlink.h
> @@ -8,14 +8,16 @@
> #include <net/sock.h>
>
> /* flags */
> -#define NETLINK_F_KERNEL_SOCKET 0x1
> -#define NETLINK_F_RECV_PKTINFO 0x2
> -#define NETLINK_F_BROADCAST_SEND_ERROR 0x4
> -#define NETLINK_F_RECV_NO_ENOBUFS 0x8
> -#define NETLINK_F_LISTEN_ALL_NSID 0x10
> -#define NETLINK_F_CAP_ACK 0x20
> -#define NETLINK_F_EXT_ACK 0x40
> -#define NETLINK_F_STRICT_CHK 0x80
> +enum {
> + NETLINK_F_KERNEL_SOCKET,
> + NETLINK_F_RECV_PKTINFO,
> + NETLINK_F_BROADCAST_SEND_ERROR,
> + NETLINK_F_RECV_NO_ENOBUFS,
> + NETLINK_F_LISTEN_ALL_NSID,
> + NETLINK_F_CAP_ACK,
> + NETLINK_F_EXT_ACK,
> + NETLINK_F_STRICT_CHK,
> +};
>
> #define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8)
> #define NLGRPLONGS(x) (NLGRPSZ(x)/sizeof(unsigned long))
> @@ -23,10 +25,10 @@
> struct netlink_sock {
> /* struct sock has to be the first member of netlink_sock */
> struct sock sk;
> + unsigned long flags;
> u32 portid;
> u32 dst_portid;
> u32 dst_group;
> - u32 flags;
> u32 subscriptions;
> u32 ngroups;
> unsigned long *groups;
> @@ -56,6 +58,8 @@ static inline struct netlink_sock *nlk_sk(struct sock *sk)
> return container_of(sk, struct netlink_sock, sk);
> }
>
> +#define nlk_test_bit(nr, sk) test_bit(NETLINK_F_##nr, &nlk_sk(sk)->flags)
> +
> struct netlink_table {
> struct rhashtable hash;
> struct hlist_head mc_list;
> diff --git a/net/netlink/diag.c b/net/netlink/diag.c
> index e4f21b1067bccacc86811bd240056de65c470ad9..9c4f231be27572f9d889248c33a04868f22e44de 100644
> --- a/net/netlink/diag.c
> +++ b/net/netlink/diag.c
> @@ -27,15 +27,15 @@ static int sk_diag_put_flags(struct sock *sk, struct sk_buff *skb)
>
> if (nlk->cb_running)
> flags |= NDIAG_FLAG_CB_RUNNING;
> - if (nlk->flags & NETLINK_F_RECV_PKTINFO)
> + if (nlk_test_bit(RECV_PKTINFO, sk))
> flags |= NDIAG_FLAG_PKTINFO;
> - if (nlk->flags & NETLINK_F_BROADCAST_SEND_ERROR)
> + if (nlk_test_bit(BROADCAST_SEND_ERROR, sk))
> flags |= NDIAG_FLAG_BROADCAST_ERROR;
> - if (nlk->flags & NETLINK_F_RECV_NO_ENOBUFS)
> + if (nlk_test_bit(RECV_NO_ENOBUFS, sk))
> flags |= NDIAG_FLAG_NO_ENOBUFS;
> - if (nlk->flags & NETLINK_F_LISTEN_ALL_NSID)
> + if (nlk_test_bit(LISTEN_ALL_NSID, sk))
> flags |= NDIAG_FLAG_LISTEN_ALL_NSID;
> - if (nlk->flags & NETLINK_F_CAP_ACK)
> + if (nlk_test_bit(CAP_ACK, sk))
> flags |= NDIAG_FLAG_CAP_ACK;
>
> return nla_put_u32(skb, NETLINK_DIAG_FLAGS, flags);
> --
> 2.41.0.640.ga95def55d0-goog
>
>
Powered by blists - more mailing lists