[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <5f5f5fc791d23c83211e48f2e74676b14a0a0c71.camel@oracle.com>
Date: Fri, 13 Dec 2024 23:29:49 +0000
From: Allison Henderson <allison.henderson@...cle.com>
To: "horms@...nel.org" <horms@...nel.org>,
"edumazet@...gle.com"
<edumazet@...gle.com>,
"kuniyu@...zon.com" <kuniyu@...zon.com>,
"davem@...emloft.net" <davem@...emloft.net>,
"pabeni@...hat.com"
<pabeni@...hat.com>,
"kuba@...nel.org" <kuba@...nel.org>
CC: Chuck Lever III <chuck.lever@...cle.com>,
"kuni1840@...il.com"
<kuni1840@...il.com>,
"wenjia@...ux.ibm.com" <wenjia@...ux.ibm.com>,
"jaka@...ux.ibm.com" <jaka@...ux.ibm.com>,
"sfrench@...ba.org"
<sfrench@...ba.org>,
"jlayton@...nel.org" <jlayton@...nel.org>,
"netdev@...r.kernel.org" <netdev@...r.kernel.org>,
"matttbe@...nel.org"
<matttbe@...nel.org>
Subject: Re: [PATCH v3 net-next 11/15] socket: Remove kernel socket
conversion.
On Fri, 2024-12-13 at 18:21 +0900, Kuniyuki Iwashima wrote:
> Since commit 26abe14379f8 ("net: Modify sk_alloc to not reference count
> the netns of kernel sockets."), TCP kernel socket has caused many UAF.
>
> We have converted such sockets to hold netns refcnt, and we have the
> same pattern in cifs, mptcp, rds, smc, and sunrpc.
>
> Let's drop the conversion and use sock_create_net() instead.
>
> The changes for cifs, mptcp, and smc are straightforward.
>
> For rds, we need to move maybe_get_net() before sock_create_net() and
> sock->ops->accept().
>
> For sunrpc, we call sock_create_net() for IPPROTO_TCP only and still
> call sock_create_kern() for others.
>
> Signed-off-by: Kuniyuki Iwashima <kuniyu@...zon.com>
> Acked-by: Matthieu Baerts (NGI0) <matttbe@...nel.org>
> Acked-by: Allison Henderson <allison.henderson@...cle.com>
> ---
> v3: Add missing mutex_unlock in rds_tcp_conn_path_connect().
> v2: Collect Acked-by from MPTCP and RDS maintainers
>
> Cc: Steve French <sfrench@...ba.org>
> Cc: Wenjia Zhang <wenjia@...ux.ibm.com>
> Cc: Jan Karcher <jaka@...ux.ibm.com>
> Cc: Chuck Lever <chuck.lever@...cle.com>
> Cc: Jeff Layton <jlayton@...nel.org>
> ---
> fs/smb/client/connect.c | 13 ++-----------
> net/mptcp/subflow.c | 10 +---------
> net/rds/tcp.c | 14 --------------
> net/rds/tcp_connect.c | 21 +++++++++++++++------
> net/rds/tcp_listen.c | 14 ++++++++++++--
> net/smc/af_smc.c | 21 ++-------------------
> net/sunrpc/svcsock.c | 12 ++++++------
> net/sunrpc/xprtsock.c | 12 ++++--------
> 8 files changed, 42 insertions(+), 75 deletions(-)
>
> diff --git a/fs/smb/client/connect.c b/fs/smb/client/connect.c
> index c36c1b4ffe6e..7a67b86c0423 100644
> --- a/fs/smb/client/connect.c
> +++ b/fs/smb/client/connect.c
> @@ -3130,22 +3130,13 @@ generic_ip_connect(struct TCP_Server_Info *server)
> if (server->ssocket) {
> socket = server->ssocket;
> } else {
> - struct net *net = cifs_net_ns(server);
> - struct sock *sk;
> -
> - rc = sock_create_kern(net, sfamily, SOCK_STREAM,
> - IPPROTO_TCP, &server->ssocket);
> + rc = sock_create_net(cifs_net_ns(server), sfamily, SOCK_STREAM,
> + IPPROTO_TCP, &server->ssocket);
> if (rc < 0) {
> cifs_server_dbg(VFS, "Error %d creating socket\n", rc);
> return rc;
> }
>
> - sk = server->ssocket->sk;
> - __netns_tracker_free(net, &sk->ns_tracker, false);
> - sk->sk_net_refcnt = 1;
> - get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(net, 1);
> -
> /* BB other socket options to set KEEPALIVE, NODELAY? */
> cifs_dbg(FYI, "Socket created\n");
> socket = server->ssocket;
> diff --git a/net/mptcp/subflow.c b/net/mptcp/subflow.c
> index fd021cf8286e..e7e8972bdfca 100644
> --- a/net/mptcp/subflow.c
> +++ b/net/mptcp/subflow.c
> @@ -1755,7 +1755,7 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
> if (unlikely(!sk->sk_socket))
> return -EINVAL;
>
> - err = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
> + err = sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP, &sf);
> if (err)
> return err;
>
> @@ -1768,14 +1768,6 @@ int mptcp_subflow_create_socket(struct sock *sk, unsigned short family,
> /* the newly created socket has to be in the same cgroup as its parent */
> mptcp_attach_cgroup(sk, sf->sk);
>
> - /* kernel sockets do not by default acquire net ref, but TCP timer
> - * needs it.
> - * Update ns_tracker to current stack trace and refcounted tracker.
> - */
> - __netns_tracker_free(net, &sf->sk->ns_tracker, false);
> - sf->sk->sk_net_refcnt = 1;
> - get_net_track(net, &sf->sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(net, 1);
> err = tcp_set_ulp(sf->sk, "mptcp");
> if (err)
> goto err_free;
> diff --git a/net/rds/tcp.c b/net/rds/tcp.c
> index 351ac1747224..4509900476f7 100644
> --- a/net/rds/tcp.c
> +++ b/net/rds/tcp.c
> @@ -494,21 +494,7 @@ bool rds_tcp_tune(struct socket *sock)
>
> tcp_sock_set_nodelay(sock->sk);
> lock_sock(sk);
> - /* TCP timer functions might access net namespace even after
> - * a process which created this net namespace terminated.
> - */
> - if (!sk->sk_net_refcnt) {
> - if (!maybe_get_net(net)) {
> - release_sock(sk);
> - return false;
> - }
> - /* Update ns_tracker to current stack trace and refcounted tracker */
> - __netns_tracker_free(net, &sk->ns_tracker, false);
>
> - sk->sk_net_refcnt = 1;
> - netns_tracker_alloc(net, &sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(net, 1);
> - }
> rtn = net_generic(net, rds_tcp_netid);
> if (rtn->sndbuf_size > 0) {
> sk->sk_sndbuf = rtn->sndbuf_size;
> diff --git a/net/rds/tcp_connect.c b/net/rds/tcp_connect.c
> index a0046e99d6df..c9449780f952 100644
> --- a/net/rds/tcp_connect.c
> +++ b/net/rds/tcp_connect.c
> @@ -93,6 +93,7 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
> struct sockaddr_in6 sin6;
> struct sockaddr_in sin;
> struct sockaddr *addr;
> + struct net *net;
> int addrlen;
> bool isv6;
> int ret;
> @@ -107,20 +108,28 @@ int rds_tcp_conn_path_connect(struct rds_conn_path *cp)
>
> mutex_lock(&tc->t_conn_path_lock);
>
> + net = rds_conn_net(conn);
> +
> if (rds_conn_path_up(cp)) {
> - mutex_unlock(&tc->t_conn_path_lock);
> - return 0;
> + ret = 0;
> + goto out;
> }
> +
> + if (!maybe_get_net(net)) {
> + ret = -EINVAL;
> + goto out;
> + }
Ok, this looks much better. Thank you!
Allison
> +
> if (ipv6_addr_v4mapped(&conn->c_laddr)) {
> - ret = sock_create_kern(rds_conn_net(conn), PF_INET,
> - SOCK_STREAM, IPPROTO_TCP, &sock);
> + ret = sock_create_net(net, PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
> isv6 = false;
> } else {
> - ret = sock_create_kern(rds_conn_net(conn), PF_INET6,
> - SOCK_STREAM, IPPROTO_TCP, &sock);
> + ret = sock_create_net(net, PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
> isv6 = true;
> }
>
> + put_net(net);
> +
> if (ret < 0)
> goto out;
>
> diff --git a/net/rds/tcp_listen.c b/net/rds/tcp_listen.c
> index 69aaf03ab93e..440ac9057148 100644
> --- a/net/rds/tcp_listen.c
> +++ b/net/rds/tcp_listen.c
> @@ -101,6 +101,7 @@ int rds_tcp_accept_one(struct socket *sock)
> struct rds_connection *conn;
> int ret;
> struct inet_sock *inet;
> + struct net *net;
> struct rds_tcp_connection *rs_tcp = NULL;
> int conn_state;
> struct rds_conn_path *cp;
> @@ -108,7 +109,7 @@ int rds_tcp_accept_one(struct socket *sock)
> struct proto_accept_arg arg = {
> .flags = O_NONBLOCK,
> .kern = true,
> - .hold_net = false,
> + .hold_net = true,
> };
> #if !IS_ENABLED(CONFIG_IPV6)
> struct in6_addr saddr, daddr;
> @@ -118,13 +119,22 @@ int rds_tcp_accept_one(struct socket *sock)
> if (!sock) /* module unload or netns delete in progress */
> return -ENETUNREACH;
>
> + net = sock_net(sock->sk);
> +
> + if (!maybe_get_net(net))
> + return -EINVAL;
> +
> ret = sock_create_lite(sock->sk->sk_family,
> sock->sk->sk_type, sock->sk->sk_protocol,
> &new_sock);
> - if (ret)
> + if (ret) {
> + put_net(net);
> goto out;
> + }
>
> ret = sock->ops->accept(sock, new_sock, &arg);
> + put_net(net);
> +
> if (ret < 0)
> goto out;
>
> diff --git a/net/smc/af_smc.c b/net/smc/af_smc.c
> index 6e93f188a908..7b0de80b3aca 100644
> --- a/net/smc/af_smc.c
> +++ b/net/smc/af_smc.c
> @@ -3310,25 +3310,8 @@ static const struct proto_ops smc_sock_ops = {
>
> int smc_create_clcsk(struct net *net, struct sock *sk, int family)
> {
> - struct smc_sock *smc = smc_sk(sk);
> - int rc;
> -
> - rc = sock_create_kern(net, family, SOCK_STREAM, IPPROTO_TCP,
> - &smc->clcsock);
> - if (rc)
> - return rc;
> -
> - /* smc_clcsock_release() does not wait smc->clcsock->sk's
> - * destruction; its sk_state might not be TCP_CLOSE after
> - * smc->sk is close()d, and TCP timers can be fired later,
> - * which need net ref.
> - */
> - sk = smc->clcsock->sk;
> - __netns_tracker_free(net, &sk->ns_tracker, false);
> - sk->sk_net_refcnt = 1;
> - get_net_track(net, &sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(net, 1);
> - return 0;
> + return sock_create_net(net, family, SOCK_STREAM, IPPROTO_TCP,
> + &smc_sk(sk)->clcsock);
> }
>
> static int __smc_create(struct net *net, struct socket *sock, int protocol,
> diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c
> index 9583bad3d150..cde5765f6f81 100644
> --- a/net/sunrpc/svcsock.c
> +++ b/net/sunrpc/svcsock.c
> @@ -1526,7 +1526,10 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
> return ERR_PTR(-EINVAL);
> }
>
> - error = sock_create_kern(net, family, type, protocol, &sock);
> + if (protocol == IPPROTO_TCP)
> + error = sock_create_net(net, family, type, protocol, &sock);
> + else
> + error = sock_create_kern(net, family, type, protocol, &sock);
> if (error < 0)
> return ERR_PTR(error);
>
> @@ -1551,11 +1554,8 @@ static struct svc_xprt *svc_create_socket(struct svc_serv *serv,
> newlen = error;
>
> if (protocol == IPPROTO_TCP) {
> - __netns_tracker_free(net, &sock->sk->ns_tracker, false);
> - sock->sk->sk_net_refcnt = 1;
> - get_net_track(net, &sock->sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(net, 1);
> - if ((error = kernel_listen(sock, 64)) < 0)
> + error = kernel_listen(sock, 64);
> + if (error < 0)
> goto bummer;
> }
>
> diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c
> index feb1768e8a57..f3e139c30442 100644
> --- a/net/sunrpc/xprtsock.c
> +++ b/net/sunrpc/xprtsock.c
> @@ -1924,7 +1924,10 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
> struct socket *sock;
> int err;
>
> - err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
> + if (protocol == IPPROTO_TCP)
> + err = sock_create_net(xprt->xprt_net, family, type, protocol, &sock);
> + else
> + err = sock_create_kern(xprt->xprt_net, family, type, protocol, &sock);
> if (err < 0) {
> dprintk("RPC: can't create %d transport socket (%d).\n",
> protocol, -err);
> @@ -1941,13 +1944,6 @@ static struct socket *xs_create_sock(struct rpc_xprt *xprt,
> goto out;
> }
>
> - if (protocol == IPPROTO_TCP) {
> - __netns_tracker_free(xprt->xprt_net, &sock->sk->ns_tracker, false);
> - sock->sk->sk_net_refcnt = 1;
> - get_net_track(xprt->xprt_net, &sock->sk->ns_tracker, GFP_KERNEL);
> - sock_inuse_add(xprt->xprt_net, 1);
> - }
> -
> filp = sock_alloc_file(sock, O_NONBLOCK, NULL);
> if (IS_ERR(filp))
> return ERR_CAST(filp);
Powered by blists - more mailing lists