--- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -155,8 +155,10 @@ static struct sock *l2tp_tunnel_sock_lookup(struct l2tp_tunnel *tunnel) * consistency. */ sock = sockfd_lookup(tunnel->fd, &err); - if (sock) + if (sock) { sk = sock->sk; + sock_hold(sk); + } } else { /* Socket is owned by kernelspace */ sk = tunnel->sock; @@ -202,6 +204,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { if (tunnel->tunnel_id == tunnel_id) { + if (test_bit(0, &tunnel->dead)) { + rcu_read_unlock_bh(); + return NULL; + } l2tp_tunnel_inc_refcount(tunnel); rcu_read_unlock_bh(); @@ -378,6 +384,10 @@ struct l2tp_tunnel *l2tp_tunnel_find(const struct net *net, u32 tunnel_id) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { if (tunnel->tunnel_id == tunnel_id) { + if (test_bit(0, &tunnel->dead)) { + rcu_read_unlock_bh(); + return NULL; + } rcu_read_unlock_bh(); return tunnel; } @@ -396,6 +406,8 @@ struct l2tp_tunnel *l2tp_tunnel_find_nth(const struct net *net, int nth) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { + if (test_bit(0, &tunnel->dead)) + continue; if (++count > nth) { rcu_read_unlock_bh(); return tunnel; @@ -1214,7 +1226,6 @@ static void l2tp_tunnel_destruct(struct sock *sk) l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: closing...\n", tunnel->name); - /* Disable udp encapsulation */ switch (tunnel->encap) { case L2TP_ENCAPTYPE_UDP: @@ -1235,6 +1246,7 @@ static void l2tp_tunnel_destruct(struct sock *sk) pn = l2tp_pernet(tunnel->l2tp_net); spin_lock_bh(&pn->l2tp_tunnel_list_lock); list_del_rcu(&tunnel->list); + set_bit(0, &tunnel->dead); spin_unlock_bh(&pn->l2tp_tunnel_list_lock); tunnel->sock = NULL; @@ -1327,17 +1339,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work) sock = sk->sk_socket; - /* If the tunnel socket was created by userspace, then go through the - * inet layer to shut the socket down, and let userspace close it. - * Otherwise, if we created the socket directly within the kernel, use + /* If the tunnel socket was created within the kernel, use * the sk API to release it here. - * In either case the tunnel resources are freed in the socket - * destructor when the tunnel socket goes away. */ - if (tunnel->fd >= 0) { - if (sock) - inet_shutdown(sock, 2); - } else { + if (tunnel->fd < 0) { if (sock) { kernel_sock_shutdown(sock, SHUT_RDWR); sock_release(sock); diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index 59f246d7b290..b46ecdf1dd93 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -139,6 +139,8 @@ struct pppol2tp_session { static const struct proto_ops pppol2tp_ops; +static DEFINE_MUTEX(pppol2tp_lock); + /* Retrieves the pppol2tp socket associated to a session. * A reference is held on the returned socket, so this function must be paired * with sock_put(). @@ -416,22 +418,6 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) * Session (and tunnel control) socket create/destroy. *****************************************************************************/ -/* Called by l2tp_core when a session socket is being closed. - */ -static void pppol2tp_session_close(struct l2tp_session *session) -{ - struct sock *sk; - - BUG_ON(session->magic != L2TP_SESSION_MAGIC); - - sk = pppol2tp_session_get_sock(session); - if (sk) { - if (sk->sk_socket) - inet_shutdown(sk->sk_socket, SEND_SHUTDOWN); - sock_put(sk); - } -} - /* Really kill the session socket. (Called from sock_put() if * refcnt == 0.) */ @@ -468,6 +454,8 @@ static int pppol2tp_release(struct socket *sock) if (!sk) return 0; + mutex_lock(&pppol2tp_lock); + error = -EBADF; lock_sock(sk); if (sock_flag(sk, SOCK_DEAD) != 0) @@ -508,10 +496,12 @@ static int pppol2tp_release(struct socket *sock) */ sock_put(sk); + mutex_unlock(&pppol2tp_lock); return 0; error: release_sock(sk); + mutex_unlock(&pppol2tp_lock); return error; } @@ -583,7 +573,6 @@ static void pppol2tp_session_init(struct l2tp_session *session) struct dst_entry *dst; session->recv_skb = pppol2tp_recv; - session->session_close = pppol2tp_session_close; #if IS_ENABLED(CONFIG_L2TP_DEBUGFS) session->show = pppol2tp_show; #endif @@ -610,6 +599,7 @@ static void pppol2tp_session_init(struct l2tp_session *session) static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, int sockaddr_len, int flags) { + struct socket *tsock = NULL; struct sock *sk = sock->sk; struct sockaddr_pppol2tp *sp = (struct sockaddr_pppol2tp *) uservaddr; struct pppox_sock *po = pppox_sk(sk); @@ -625,6 +615,8 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, int ver = 2; int fd; + mutex_lock(&pppol2tp_lock); + lock_sock(sk); error = -EINVAL; @@ -690,6 +682,13 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, if (tunnel_id == 0) goto end; + /* Check that the fd is a valid socket and prevent it closing + * while we are processing this connect. + */ + tsock = sockfd_lookup(fd, &error); + if (!tsock) + goto end; + tunnel = l2tp_tunnel_get(sock_net(sk), tunnel_id); if (tunnel) drop_tunnel = true; @@ -719,6 +718,14 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, goto end; } + read_lock_bh(&tunnel->hlist_lock); + if (!tunnel->acpt_newsess) { + error = -EBUSY; + read_unlock_bh(&tunnel->hlist_lock); + goto end; + } + read_unlock_bh(&tunnel->hlist_lock); + if (tunnel->recv_payload_hook == NULL) tunnel->recv_payload_hook = pppol2tp_recv_payload_hook; @@ -815,7 +822,11 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, l2tp_session_dec_refcount(session); if (drop_tunnel) l2tp_tunnel_dec_refcount(tunnel); + if (tsock) + sockfd_put(tsock); + release_sock(sk); + mutex_unlock(&pppol2tp_lock); return error; } @@ -1215,6 +1226,8 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd, if (!sk) return 0; + mutex_lock(&pppol2tp_lock); + err = -EBADF; if (sock_flag(sk, SOCK_DEAD) != 0) goto end; @@ -1245,6 +1258,7 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd, end_put_sess: sock_put(sk); end: + mutex_unlock(&pppol2tp_lock); return err; } @@ -1374,6 +1388,8 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, if (get_user(val, (int __user *)optval)) return -EFAULT; + mutex_lock(&pppol2tp_lock); + err = -ENOTCONN; if (sk->sk_user_data == NULL) goto end; @@ -1396,6 +1412,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, sock_put(sk); end: + mutex_unlock(&pppol2tp_lock); return err; }