--- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -202,6 +202,10 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { if (tunnel->tunnel_id == tunnel_id) { + if (test_bit(0, &tunnel->dead)) { + rcu_read_unlock_bh(); + return NULL; + } l2tp_tunnel_inc_refcount(tunnel); rcu_read_unlock_bh(); @@ -378,6 +382,10 @@ struct l2tp_tunnel *l2tp_tunnel_find(const struct net *net, u32 tunnel_id) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { if (tunnel->tunnel_id == tunnel_id) { + if (test_bit(0, &tunnel->dead)) { + rcu_read_unlock_bh(); + return NULL; + } rcu_read_unlock_bh(); return tunnel; } @@ -396,6 +404,8 @@ struct l2tp_tunnel *l2tp_tunnel_find_nth(const struct net *net, int nth) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { + if (test_bit(0, &tunnel->dead)) + continue; if (++count > nth) { rcu_read_unlock_bh(); return tunnel; @@ -1214,7 +1224,6 @@ static void l2tp_tunnel_destruct(struct sock *sk) l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: closing...\n", tunnel->name); - /* Disable udp encapsulation */ switch (tunnel->encap) { case L2TP_ENCAPTYPE_UDP: @@ -1235,6 +1244,7 @@ static void l2tp_tunnel_destruct(struct sock *sk) pn = l2tp_pernet(tunnel->l2tp_net); spin_lock_bh(&pn->l2tp_tunnel_list_lock); list_del_rcu(&tunnel->list); + set_bit(0, &tunnel->dead); spin_unlock_bh(&pn->l2tp_tunnel_list_lock); tunnel->sock = NULL; @@ -1327,17 +1337,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work) sock = sk->sk_socket; - /* If the tunnel socket was created by userspace, then go through the - * inet layer to shut the socket down, and let userspace close it. - * Otherwise, if we created the socket directly within the kernel, use + /* If the tunnel socket was created within the kernel, use * the sk API to release it here. - * In either case the tunnel resources are freed in the socket - * destructor when the tunnel socket goes away. */ - if (tunnel->fd >= 0) { - if (sock) - inet_shutdown(sock, 2); - } else { + if (tunnel->fd < 0) { if (sock) { kernel_sock_shutdown(sock, SHUT_RDWR); sock_release(sock); diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index 59f246d7b290..6c71f360828d 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -416,22 +416,6 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) * Session (and tunnel control) socket create/destroy. *****************************************************************************/ -/* Called by l2tp_core when a session socket is being closed. - */ -static void pppol2tp_session_close(struct l2tp_session *session) -{ - struct sock *sk; - - BUG_ON(session->magic != L2TP_SESSION_MAGIC); - - sk = pppol2tp_session_get_sock(session); - if (sk) { - if (sk->sk_socket) - inet_shutdown(sk->sk_socket, SEND_SHUTDOWN); - sock_put(sk); - } -} - /* Really kill the session socket. (Called from sock_put() if * refcnt == 0.) */ @@ -489,16 +473,19 @@ static int pppol2tp_release(struct socket *sock) ps = l2tp_session_priv(session); mutex_lock(&ps->sk_lock); - ps->__sk = rcu_dereference_protected(ps->sk, - lockdep_is_held(&ps->sk_lock)); - RCU_INIT_POINTER(ps->sk, NULL); - mutex_unlock(&ps->sk_lock); - call_rcu(&ps->rcu, pppol2tp_put_sk); + if (!ps->__sk) { + ps->__sk = rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)); + RCU_INIT_POINTER(ps->sk, NULL); + mutex_unlock(&ps->sk_lock); + call_rcu(&ps->rcu, pppol2tp_put_sk); - /* Rely on the sock_put() call at the end of the function for - * dropping the reference held by pppol2tp_sock_to_session(). - * The last reference will be dropped by pppol2tp_put_sk(). - */ + /* Rely on the sock_put() call at the end of the function for + * dropping the reference held by pppol2tp_sock_to_session(). + * The last reference will be dropped by pppol2tp_put_sk(). + */ + } else + mutex_unlock(&ps->sk_lock); } release_sock(sk); @@ -583,7 +570,6 @@ static void pppol2tp_session_init(struct l2tp_session *session) struct dst_entry *dst; session->recv_skb = pppol2tp_recv; - session->session_close = pppol2tp_session_close; #if IS_ENABLED(CONFIG_L2TP_DEBUGFS) session->show = pppol2tp_show; #endif @@ -610,6 +596,7 @@ static void pppol2tp_session_init(struct l2tp_session *session) static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, int sockaddr_len, int flags) { + struct socket *tsock = NULL; struct sock *sk = sock->sk; struct sockaddr_pppol2tp *sp = (struct sockaddr_pppol2tp *) uservaddr; struct pppox_sock *po = pppox_sk(sk); @@ -690,6 +677,14 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, if (tunnel_id == 0) goto end; + /* Check that the fd is a valid socket and prevent it closing + * while we are processing this connect. + */ + tsock = sockfd_lookup(fd, &error); + if (!tsock) + goto end; + sock_hold(tsock->sk); + tunnel = l2tp_tunnel_get(sock_net(sk), tunnel_id); if (tunnel) drop_tunnel = true; @@ -719,6 +714,14 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, goto end; } + read_lock_bh(&tunnel->hlist_lock); + if (!tunnel->acpt_newsess) { + error = -EBUSY; + read_unlock_bh(&tunnel->hlist_lock); + goto end; + } + read_unlock_bh(&tunnel->hlist_lock); + if (tunnel->recv_payload_hook == NULL) tunnel->recv_payload_hook = pppol2tp_recv_payload_hook; @@ -815,6 +818,11 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, l2tp_session_dec_refcount(session); if (drop_tunnel) l2tp_tunnel_dec_refcount(tunnel); + if (tsock) { + sock_put(tsock->sk); + sockfd_put(tsock); + } + release_sock(sk); return error;