--- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -136,51 +136,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) } -/* Lookup the tunnel socket, possibly involving the fs code if the socket is - * owned by userspace. A struct sock returned from this function must be - * released using l2tp_tunnel_sock_put once you're done with it. - */ -static struct sock *l2tp_tunnel_sock_lookup(struct l2tp_tunnel *tunnel) -{ - int err = 0; - struct socket *sock = NULL; - struct sock *sk = NULL; - - if (!tunnel) - goto out; - - if (tunnel->fd >= 0) { - /* Socket is owned by userspace, who might be in the process - * of closing it. Look the socket up using the fd to ensure - * consistency. - */ - sock = sockfd_lookup(tunnel->fd, &err); - if (sock) - sk = sock->sk; - } else { - /* Socket is owned by kernelspace */ - sk = tunnel->sock; - sock_hold(sk); - } - -out: - return sk; -} - -/* Drop a reference to a tunnel socket obtained via. l2tp_tunnel_sock_put */ -static void l2tp_tunnel_sock_put(struct sock *sk) -{ - struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk); - if (tunnel) { - if (tunnel->fd >= 0) { - /* Socket is owned by userspace */ - sockfd_put(sk->sk_socket); - } - sock_put(sk); - } - sock_put(sk); -} - /* Session hash list. * The session_id SHOULD be random according to RFC2661, but several * L2TP implementations (Cisco and Microsoft) use incrementing @@ -193,6 +148,12 @@ static void l2tp_tunnel_sock_put(struct sock *sk) return &tunnel->session_hlist[hash_32(session_id, L2TP_HASH_BITS)]; } +void l2tp_tunnel_free(struct l2tp_tunnel *tunnel) +{ + sock_put(tunnel->sock); + /* the tunnel is freed in the socket destructor */ +} + /* Lookup a tunnel. A new reference is held on the returned tunnel. */ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) { @@ -202,6 +163,13 @@ struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id) rcu_read_lock_bh(); list_for_each_entry_rcu(tunnel, &pn->l2tp_tunnel_list, list) { if (tunnel->tunnel_id == tunnel_id) { + spin_lock_bh(&tunnel->lock); + if (tunnel->closing) { + spin_unlock_bh(&tunnel->lock); + rcu_read_unlock_bh(); + return NULL; + } + spin_unlock_bh(&tunnel->lock); l2tp_tunnel_inc_refcount(tunnel); rcu_read_unlock_bh(); @@ -230,7 +198,14 @@ struct l2tp_session *l2tp_session_get(const struct net *net, rcu_read_lock_bh(); hlist_for_each_entry_rcu(session, session_list, global_hlist) { if (session->session_id == session_id) { + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + rcu_read_unlock_bh(); + return NULL; + } l2tp_session_inc_refcount(session); + spin_unlock_bh(&session->lock); rcu_read_unlock_bh(); return session; @@ -245,7 +220,14 @@ struct l2tp_session *l2tp_session_get(const struct net *net, read_lock_bh(&tunnel->hlist_lock); hlist_for_each_entry(session, session_list, hlist) { if (session->session_id == session_id) { + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + read_unlock_bh(&tunnel->hlist_lock); + return NULL; + } l2tp_session_inc_refcount(session); + spin_unlock_bh(&session->lock); read_unlock_bh(&tunnel->hlist_lock); return session; @@ -266,6 +248,12 @@ struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) read_lock_bh(&tunnel->hlist_lock); for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { hlist_for_each_entry(session, &tunnel->session_hlist[hash], hlist) { + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + continue; + } + spin_unlock_bh(&session->lock); if (++count > nth) { l2tp_session_inc_refcount(session); read_unlock_bh(&tunnel->hlist_lock); @@ -293,6 +281,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, rcu_read_lock_bh(); for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) { hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + continue; + } + spin_unlock_bh(&session->lock); if (!strcmp(session->ifname, ifname)) { l2tp_session_inc_refcount(session); rcu_read_unlock_bh(); @@ -317,13 +311,17 @@ int l2tp_session_register(struct l2tp_session *session, struct l2tp_net *pn; int err; + spin_lock_bh(&tunnel->lock); + if (tunnel->closing) { + spin_unlock_bh(&tunnel->lock); + return -ENODEV; + } + l2tp_tunnel_inc_refcount(tunnel); + spin_unlock_bh(&tunnel->lock); + head = l2tp_session_id_hash(tunnel, session->session_id); write_lock_bh(&tunnel->hlist_lock); - if (!tunnel->acpt_newsess) { - err = -ENODEV; - goto err_tlock; - } hlist_for_each_entry(session_walk, head, hlist) if (session_walk->session_id == session->session_id) { @@ -344,14 +342,9 @@ int l2tp_session_register(struct l2tp_session *session, goto err_tlock_pnlock; } - l2tp_tunnel_inc_refcount(tunnel); - sock_hold(tunnel->sock); hlist_add_head_rcu(&session->global_hlist, g_head); spin_unlock_bh(&pn->l2tp_session_hlist_lock); - } else { - l2tp_tunnel_inc_refcount(tunnel); - sock_hold(tunnel->sock); } hlist_add_head(&session->hlist, head); @@ -363,6 +356,7 @@ int l2tp_session_register(struct l2tp_session *session, spin_unlock_bh(&pn->l2tp_session_hlist_lock); err_tlock: write_unlock_bh(&tunnel->hlist_lock); + l2tp_tunnel_dec_refcount(tunnel); return err; } @@ -969,7 +963,7 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb) { struct l2tp_tunnel *tunnel; - tunnel = l2tp_sock_to_tunnel(sk); + tunnel = l2tp_tunnel(sk); if (tunnel == NULL) goto pass_up; @@ -977,13 +971,10 @@ int l2tp_udp_encap_recv(struct sock *sk, struct sk_buff *skb) tunnel->name, skb->len); if (l2tp_udp_recv_core(tunnel, skb, tunnel->recv_payload_hook)) - goto pass_up_put; + goto pass_up; - sock_put(sk); return 0; -pass_up_put: - sock_put(sk); pass_up: return 1; } @@ -1214,8 +1205,8 @@ static void l2tp_tunnel_destruct(struct sock *sk) l2tp_info(tunnel, L2TP_MSG_CONTROL, "%s: closing...\n", tunnel->name); - /* Disable udp encapsulation */ + write_lock_bh(&sk->sk_callback_lock); switch (tunnel->encap) { case L2TP_ENCAPTYPE_UDP: /* No longer an encapsulation socket. See net/ipv4/udp.c */ @@ -1229,7 +1220,8 @@ static void l2tp_tunnel_destruct(struct sock *sk) /* Remove hooks into tunnel socket */ sk->sk_destruct = tunnel->old_sk_destruct; - sk->sk_user_data = NULL; + rcu_assign_sk_user_data(sk, NULL); + write_unlock_bh(&sk->sk_callback_lock); /* Remove the tunnel struct from the tunnel list */ pn = l2tp_pernet(tunnel->l2tp_net); @@ -1237,12 +1229,11 @@ static void l2tp_tunnel_destruct(struct sock *sk) list_del_rcu(&tunnel->list); spin_unlock_bh(&pn->l2tp_tunnel_list_lock); - tunnel->sock = NULL; - l2tp_tunnel_dec_refcount(tunnel); - /* Call the original destructor */ if (sk->sk_destruct) (*sk->sk_destruct)(sk); + + kfree_rcu(tunnel, rcu); end: return; } @@ -1262,38 +1253,10 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) tunnel->name); write_lock_bh(&tunnel->hlist_lock); - tunnel->acpt_newsess = false; for (hash = 0; hash < L2TP_HASH_SIZE; hash++) { -again: hlist_for_each_safe(walk, tmp, &tunnel->session_hlist[hash]) { session = hlist_entry(walk, struct l2tp_session, hlist); - - l2tp_info(session, L2TP_MSG_CONTROL, - "%s: closing session\n", session->name); - - hlist_del_init(&session->hlist); - - if (test_and_set_bit(0, &session->dead)) - goto again; - - write_unlock_bh(&tunnel->hlist_lock); - - __l2tp_session_unhash(session); - l2tp_session_queue_purge(session); - - if (session->session_close != NULL) - (*session->session_close)(session); - - l2tp_session_dec_refcount(session); - - write_lock_bh(&tunnel->hlist_lock); - - /* Now restart from the beginning of this hash - * chain. We always remove a session from the - * list so we are guaranteed to make forward - * progress. - */ - goto again; + l2tp_session_delete(session); } } write_unlock_bh(&tunnel->hlist_lock); @@ -1303,30 +1266,21 @@ void l2tp_tunnel_closeall(struct l2tp_tunnel *tunnel) /* Tunnel socket destroy hook for UDP encapsulation */ static void l2tp_udp_encap_destroy(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk); + struct l2tp_tunnel *tunnel = l2tp_tunnel(sk); if (tunnel) { - l2tp_tunnel_closeall(tunnel); - sock_put(sk); + l2tp_tunnel_delete(tunnel); } } /* Workqueue tunnel deletion function */ static void l2tp_tunnel_del_work(struct work_struct *work) { - struct l2tp_tunnel *tunnel = NULL; - struct socket *sock = NULL; - struct sock *sk = NULL; - - tunnel = container_of(work, struct l2tp_tunnel, del_work); + struct l2tp_tunnel *tunnel = container_of(work, struct l2tp_tunnel, del_work); + struct sock *sk = tunnel->sock; + struct socket *sock = sk->sk_socket; l2tp_tunnel_closeall(tunnel); - sk = l2tp_tunnel_sock_lookup(tunnel); - if (!sk) - goto out; - - sock = sk->sk_socket; - /* If the tunnel socket was created by userspace, then go through the * inet layer to shut the socket down, and let userspace close it. * Otherwise, if we created the socket directly within the kernel, use @@ -1335,7 +1289,7 @@ static void l2tp_tunnel_del_work(struct work_struct *work) * destructor when the tunnel socket goes away. */ if (tunnel->fd >= 0) { - if (sock) + if (sock && sock->sk) inet_shutdown(sock, 2); } else { if (sock) { @@ -1344,8 +1298,10 @@ static void l2tp_tunnel_del_work(struct work_struct *work) } } - l2tp_tunnel_sock_put(sk); -out: + /* drop initial ref */ + l2tp_tunnel_dec_refcount(tunnel); + + /* drop workqueue ref */ l2tp_tunnel_dec_refcount(tunnel); } @@ -1495,8 +1451,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 } else { sock = sockfd_lookup(fd, &err); if (!sock) { - pr_err("tunl %u: sockfd_lookup(fd=%d) returned %d\n", - tunnel_id, fd, err); err = -EBADF; goto err; } @@ -1534,14 +1488,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 break; } - /* Check if this socket has already been prepped */ - tunnel = l2tp_tunnel(sk); - if (tunnel != NULL) { - /* This socket has already been prepped */ - err = -EBUSY; - goto err; - } - tunnel = kzalloc(sizeof(struct l2tp_tunnel), GFP_KERNEL); if (tunnel == NULL) { err = -ENOMEM; @@ -1555,8 +1501,8 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 tunnel->magic = L2TP_TUNNEL_MAGIC; sprintf(&tunnel->name[0], "tunl %u", tunnel_id); + spin_lock_init(&tunnel->lock); rwlock_init(&tunnel->hlist_lock); - tunnel->acpt_newsess = true; /* The net we belong to */ tunnel->l2tp_net = net; @@ -1583,6 +1529,20 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 } #endif + /* Assign socket sk_user_data. Must be done with + * sk_callback_lock. Bail if sk_user_data is already assigned. + */ + write_lock_bh(&sk->sk_callback_lock); + if (sk->sk_user_data) { + err = -EALREADY; + write_unlock_bh(&sk->sk_callback_lock); + kfree(tunnel); + tunnel = NULL; + goto err; + } + rcu_assign_sk_user_data(sk, tunnel); + write_unlock_bh(&sk->sk_callback_lock); + /* Mark socket as an encapsulation socket. See net/ipv4/udp.c */ tunnel->encap = encap; if (encap == L2TP_ENCAPTYPE_UDP) { @@ -1594,8 +1554,6 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 udp_cfg.encap_destroy = l2tp_udp_encap_destroy; setup_udp_tunnel_sock(net, sock, &udp_cfg); - } else { - sk->sk_user_data = tunnel; } /* Hook on the tunnel socket destructor so that we can cleanup @@ -1603,6 +1561,7 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 */ tunnel->old_sk_destruct = sk->sk_destruct; sk->sk_destruct = &l2tp_tunnel_destruct; + tunnel->sock = sk; tunnel->fd = fd; lockdep_set_class_and_name(&sk->sk_lock.slock, &l2tp_socket_class, "l2tp_sock"); @@ -1616,9 +1575,12 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 INIT_LIST_HEAD(&tunnel->list); /* Bump the reference count. The tunnel context is deleted - * only when this drops to zero. Must be done before list insertion + * only when this drops to zero. A reference is also held on + * the tunnel socket to ensure that it is not released while + * the tunnel is extant. Must be done before list insertion */ refcount_set(&tunnel->ref_count, 1); + sock_hold(sk); spin_lock_bh(&pn->l2tp_tunnel_list_lock); list_add_rcu(&tunnel->list, &pn->l2tp_tunnel_list); spin_unlock_bh(&pn->l2tp_tunnel_list_lock); @@ -1642,10 +1604,17 @@ int l2tp_tunnel_create(struct net *net, int fd, int version, u32 tunnel_id, u32 */ void l2tp_tunnel_delete(struct l2tp_tunnel *tunnel) { - if (!test_and_set_bit(0, &tunnel->dead)) { - l2tp_tunnel_inc_refcount(tunnel); - queue_work(l2tp_wq, &tunnel->del_work); + spin_lock_bh(&tunnel->lock); + if (tunnel->closing) { + spin_unlock_bh(&tunnel->lock); + return; } + tunnel->closing = true; + spin_unlock_bh(&tunnel->lock); + + /* Hold tunnel ref while queued work item is pending */ + l2tp_tunnel_inc_refcount(tunnel); + queue_work(l2tp_wq, &tunnel->del_work); } EXPORT_SYMBOL_GPL(l2tp_tunnel_delete); @@ -1657,14 +1626,15 @@ void l2tp_session_free(struct l2tp_session *session) BUG_ON(refcount_read(&session->ref_count) != 0); + if (session->session_free) + session->session_free(session); + else + kfree(session); + if (tunnel) { BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC); - sock_put(tunnel->sock); - session->tunnel = NULL; l2tp_tunnel_dec_refcount(tunnel); } - - kfree(session); } EXPORT_SYMBOL_GPL(l2tp_session_free); @@ -1673,7 +1643,7 @@ void l2tp_session_free(struct l2tp_session *session) * shutdown via. l2tp_session_delete and a pseudowire-specific session_close * callback. */ -void __l2tp_session_unhash(struct l2tp_session *session) +static void l2tp_session_unhash(struct l2tp_session *session) { struct l2tp_tunnel *tunnel = session->tunnel; @@ -1694,23 +1664,43 @@ void __l2tp_session_unhash(struct l2tp_session *session) } } } -EXPORT_SYMBOL_GPL(__l2tp_session_unhash); -/* This function is used by the netlink SESSION_DELETE command and by - pseudowire modules. - */ -int l2tp_session_delete(struct l2tp_session *session) +/* Workqueue session deletion function */ +static void l2tp_session_del_work(struct work_struct *work) { - if (test_and_set_bit(0, &session->dead)) - return 0; + struct l2tp_session *session = container_of(work, struct l2tp_session, del_work); - __l2tp_session_unhash(session); + l2tp_info(session, L2TP_MSG_CONTROL, + "%s: closing session\n", session->name); + + l2tp_session_unhash(session); l2tp_session_queue_purge(session); if (session->session_close != NULL) (*session->session_close)(session); + /* drop initial ref */ + l2tp_session_dec_refcount(session); + + /* drop workqueue ref */ l2tp_session_dec_refcount(session); +} + +/* This function is used by the netlink SESSION_DELETE command and by + pseudowire modules. + */ +int l2tp_session_delete(struct l2tp_session *session) +{ + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + return 0; + } + session->closing = true; + spin_unlock_bh(&session->lock); + /* Hold session ref while queued work item is pending */ + l2tp_session_inc_refcount(session); + queue_work(l2tp_wq, &session->del_work); return 0; } EXPORT_SYMBOL_GPL(l2tp_session_delete); @@ -1738,6 +1728,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn { struct l2tp_session *session; + spin_lock_bh(&tunnel->lock); + if (tunnel->closing) { + spin_unlock_bh(&tunnel->lock); + return ERR_PTR(-ENODEV); + } + spin_unlock_bh(&tunnel->lock); + session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL); if (session != NULL) { session->magic = L2TP_SESSION_MAGIC; @@ -1763,6 +1760,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn INIT_HLIST_NODE(&session->hlist); INIT_HLIST_NODE(&session->global_hlist); + spin_lock_init(&session->lock); + + INIT_WORK(&session->del_work, l2tp_session_del_work); /* Inherit debug options from tunnel */ session->debug = tunnel->debug; diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index 9bbee90e9963..9a194f316751 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -74,7 +74,8 @@ struct l2tp_session_cfg { struct l2tp_session { int magic; /* should be * L2TP_SESSION_MAGIC */ - long dead; + bool closing; + spinlock_t lock; struct l2tp_tunnel *tunnel; /* back pointer to tunnel * context */ @@ -121,9 +122,12 @@ struct l2tp_session { struct l2tp_stats stats; struct hlist_node global_hlist; /* Global hash list node */ + struct work_struct del_work; + int (*build_header)(struct l2tp_session *session, void *buf); void (*recv_skb)(struct l2tp_session *session, struct sk_buff *skb, int data_len); void (*session_close)(struct l2tp_session *session); + void (*session_free)(struct l2tp_session *session); #if IS_ENABLED(CONFIG_L2TP_DEBUGFS) void (*show)(struct seq_file *m, void *priv); #endif @@ -155,14 +159,11 @@ struct l2tp_tunnel_cfg { struct l2tp_tunnel { int magic; /* Should be L2TP_TUNNEL_MAGIC */ - unsigned long dead; + bool closing; + spinlock_t lock; struct rcu_head rcu; rwlock_t hlist_lock; /* protect session_hlist */ - bool acpt_newsess; /* Indicates whether this - * tunnel accepts new sessions. - * Protected by hlist_lock. - */ struct hlist_head session_hlist[L2TP_HASH_SIZE]; /* hashed list of sessions, * hashed by id */ @@ -214,27 +215,8 @@ static inline void *l2tp_session_priv(struct l2tp_session *session) return &session->priv[0]; } -static inline struct l2tp_tunnel *l2tp_sock_to_tunnel(struct sock *sk) -{ - struct l2tp_tunnel *tunnel; - - if (sk == NULL) - return NULL; - - sock_hold(sk); - tunnel = (struct l2tp_tunnel *)(sk->sk_user_data); - if (tunnel == NULL) { - sock_put(sk); - goto out; - } - - BUG_ON(tunnel->magic != L2TP_TUNNEL_MAGIC); - -out: - return tunnel; -} - struct l2tp_tunnel *l2tp_tunnel_get(const struct net *net, u32 tunnel_id); +void l2tp_tunnel_free(struct l2tp_tunnel *tunnel); struct l2tp_session *l2tp_session_get(const struct net *net, struct l2tp_tunnel *tunnel, @@ -257,7 +239,6 @@ struct l2tp_session *l2tp_session_create(int priv_size, int l2tp_session_register(struct l2tp_session *session, struct l2tp_tunnel *tunnel); -void __l2tp_session_unhash(struct l2tp_session *session); int l2tp_session_delete(struct l2tp_session *session); void l2tp_session_free(struct l2tp_session *session); void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb, @@ -283,7 +264,7 @@ static inline void l2tp_tunnel_inc_refcount(struct l2tp_tunnel *tunnel) static inline void l2tp_tunnel_dec_refcount(struct l2tp_tunnel *tunnel) { if (refcount_dec_and_test(&tunnel->ref_count)) - kfree_rcu(tunnel, rcu); + l2tp_tunnel_free(tunnel); } /* Session reference counts. Incremented when code obtains a reference diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c index ff61124fdf59..a5591bd2fa24 100644 --- a/net/l2tp/l2tp_ip.c +++ b/net/l2tp/l2tp_ip.c @@ -234,17 +234,17 @@ static void l2tp_ip_close(struct sock *sk, long timeout) static void l2tp_ip_destroy_sock(struct sock *sk) { struct sk_buff *skb; - struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk); + struct l2tp_tunnel *tunnel; while ((skb = __skb_dequeue_tail(&sk->sk_write_queue)) != NULL) kfree_skb(skb); + rcu_read_lock(); + tunnel = rcu_dereference_sk_user_data(sk); if (tunnel) { - l2tp_tunnel_closeall(tunnel); - sock_put(sk); + l2tp_tunnel_delete(tunnel); } - - sk_refcnt_debug_dec(sk); + rcu_read_unlock(); } static int l2tp_ip_bind(struct sock *sk, struct sockaddr *uaddr, int addr_len) diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index 192344688c06..de8e7eb7a638 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -248,16 +248,18 @@ static void l2tp_ip6_close(struct sock *sk, long timeout) static void l2tp_ip6_destroy_sock(struct sock *sk) { - struct l2tp_tunnel *tunnel = l2tp_sock_to_tunnel(sk); + struct l2tp_tunnel *tunnel; lock_sock(sk); ip6_flush_pending_frames(sk); release_sock(sk); + rcu_read_lock(); + tunnel = rcu_dereference_sk_user_data(sk); if (tunnel) { - l2tp_tunnel_closeall(tunnel); - sock_put(sk); + l2tp_tunnel_delete(tunnel); } + rcu_read_unlock(); inet6_destroy_sock(sk); } diff --git a/net/l2tp/l2tp_ppp.c b/net/l2tp/l2tp_ppp.c index 59f246d7b290..1fb4f5264917 100644 --- a/net/l2tp/l2tp_ppp.c +++ b/net/l2tp/l2tp_ppp.c @@ -166,16 +166,25 @@ static inline struct l2tp_session *pppol2tp_sock_to_session(struct sock *sk) if (sk == NULL) return NULL; - sock_hold(sk); - session = (struct l2tp_session *)(sk->sk_user_data); + rcu_read_lock_bh(); + session = rcu_dereference_bh(__sk_user_data((sk))); if (session == NULL) { - sock_put(sk); - goto out; + rcu_read_unlock_bh(); + return NULL; } + spin_lock_bh(&session->lock); + if (session->closing) { + spin_unlock_bh(&session->lock); + rcu_read_unlock_bh(); + return NULL; + } + l2tp_session_inc_refcount(session); + spin_unlock_bh(&session->lock); + rcu_read_unlock_bh(); + BUG_ON(session->magic != L2TP_SESSION_MAGIC); -out: return session; } @@ -243,8 +252,8 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int /* If the socket is bound, send it in to PPP's input queue. Otherwise * queue it on the session socket. */ - rcu_read_lock(); - sk = rcu_dereference(ps->sk); + rcu_read_lock_bh(); + sk = rcu_dereference_bh(ps->sk); if (sk == NULL) goto no_sock; @@ -267,12 +276,12 @@ static void pppol2tp_recv(struct l2tp_session *session, struct sk_buff *skb, int kfree_skb(skb); } } - rcu_read_unlock(); + rcu_read_unlock_bh(); return; no_sock: - rcu_read_unlock(); + rcu_read_unlock_bh(); l2tp_info(session, L2TP_MSG_DATA, "%s: no socket\n", session->name); kfree_skb(skb); } @@ -341,12 +350,12 @@ static int pppol2tp_sendmsg(struct socket *sock, struct msghdr *m, l2tp_xmit_skb(session, skb, session->hdr_len); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return total_len; error_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); error: return error; } @@ -400,12 +409,12 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) l2tp_xmit_skb(session, skb, session->hdr_len); local_bh_enable(); - sock_put(sk); + l2tp_session_dec_refcount(session); return 1; abort_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); abort: /* Free the original skb */ kfree_skb(skb); @@ -416,18 +425,73 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb) * Session (and tunnel control) socket create/destroy. *****************************************************************************/ -/* Called by l2tp_core when a session socket is being closed. +/* called with ps->sk_lock */ +static void pppol2tp_attach(struct l2tp_session *session, struct sock *sk) +{ + struct pppol2tp_session *ps = l2tp_session_priv(session); + + write_lock_bh(&sk->sk_callback_lock); + rcu_assign_sk_user_data(sk, session); + write_unlock_bh(&sk->sk_callback_lock); + rcu_assign_pointer(ps->sk, sk); +} + +/* called with ps->sk_lock */ +static void pppol2tp_detach(struct l2tp_session *session, struct sock *sk) +{ + struct pppol2tp_session *ps = l2tp_session_priv(session); + + rcu_assign_pointer(ps->sk, NULL); + write_lock_bh(&sk->sk_callback_lock); + rcu_assign_sk_user_data(sk, NULL); + write_unlock_bh(&sk->sk_callback_lock); +} + +static void pppol2tp_put_sk(struct rcu_head *head) +{ + struct pppol2tp_session *ps = container_of(head, typeof(*ps), rcu); + struct l2tp_session *session = container_of((void *)ps, typeof(*session), priv); + + BUG_ON(session->magic != L2TP_SESSION_MAGIC); + sock_put(ps->__sk); + kfree(session); +} + +/* Called by l2tp_core when a session is being freed. + */ +static void pppol2tp_session_free(struct l2tp_session *session) +{ + struct pppol2tp_session *ps = l2tp_session_priv(session); + struct sock *sk = ps->__sk; + BUG_ON(session->magic != L2TP_SESSION_MAGIC); + + if (sk) { + struct socket *sock = sk->sk_socket; + if (sock && sock->sk) + inet_shutdown(sock, SEND_SHUTDOWN); + call_rcu(&ps->rcu, pppol2tp_put_sk); + } else { + synchronize_rcu(); + kfree(session); + } +} + +/* Called by l2tp_core when a session is being closed. */ static void pppol2tp_session_close(struct l2tp_session *session) { struct sock *sk; BUG_ON(session->magic != L2TP_SESSION_MAGIC); - sk = pppol2tp_session_get_sock(session); if (sk) { - if (sk->sk_socket) - inet_shutdown(sk->sk_socket, SEND_SHUTDOWN); + struct pppol2tp_session *ps = l2tp_session_priv(session); + mutex_lock(&ps->sk_lock); + ps->__sk = rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock)); + RCU_INIT_POINTER(ps->sk, NULL); + pppol2tp_detach(session, sk); + mutex_unlock(&ps->sk_lock); sock_put(sk); } } @@ -437,24 +501,8 @@ static void pppol2tp_session_close(struct l2tp_session *session) */ static void pppol2tp_session_destruct(struct sock *sk) { - struct l2tp_session *session = sk->sk_user_data; - skb_queue_purge(&sk->sk_receive_queue); skb_queue_purge(&sk->sk_write_queue); - - if (session) { - sk->sk_user_data = NULL; - BUG_ON(session->magic != L2TP_SESSION_MAGIC); - l2tp_session_dec_refcount(session); - } -} - -static void pppol2tp_put_sk(struct rcu_head *head) -{ - struct pppol2tp_session *ps; - - ps = container_of(head, typeof(*ps), rcu); - sock_put(ps->__sk); } /* Called when the PPPoX socket (session) is closed. @@ -479,28 +527,14 @@ static int pppol2tp_release(struct socket *sock) sk->sk_state = PPPOX_DEAD; sock_orphan(sk); sock->sk = NULL; + release_sock(sk); - session = pppol2tp_sock_to_session(sk); - - if (session != NULL) { - struct pppol2tp_session *ps; - + rcu_read_lock_bh(); + session = rcu_dereference_bh(__sk_user_data((sk))); + if (session) { l2tp_session_delete(session); - - ps = l2tp_session_priv(session); - mutex_lock(&ps->sk_lock); - ps->__sk = rcu_dereference_protected(ps->sk, - lockdep_is_held(&ps->sk_lock)); - RCU_INIT_POINTER(ps->sk, NULL); - mutex_unlock(&ps->sk_lock); - call_rcu(&ps->rcu, pppol2tp_put_sk); - - /* Rely on the sock_put() call at the end of the function for - * dropping the reference held by pppol2tp_sock_to_session(). - * The last reference will be dropped by pppol2tp_put_sk(). - */ } - release_sock(sk); + rcu_read_unlock_bh(); /* This will delete the session context via * pppol2tp_session_destruct() if the socket's refcnt drops to @@ -584,6 +618,7 @@ static void pppol2tp_session_init(struct l2tp_session *session) session->recv_skb = pppol2tp_recv; session->session_close = pppol2tp_session_close; + session->session_free = pppol2tp_session_free; #if IS_ENABLED(CONFIG_L2TP_DEBUGFS) session->show = pppol2tp_show; #endif @@ -605,25 +640,142 @@ static void pppol2tp_session_init(struct l2tp_session *session) } } -/* connect() handler. Attach a PPPoX socket to a tunnel UDP socket +/* Prepare a tunnel. If a tunnel instance doesn't already exist, + * optionally create it. Return with a ref on the tunnel instance. + */ +static int pppol2tp_tunnel_prep(struct net *net, int fd, int ver, u32 tunnel_id, u32 peer_tunnel_id, bool can_create, struct l2tp_tunnel **tunnelp) +{ + struct l2tp_tunnel *tunnel; + int error; + + tunnel = l2tp_tunnel_get(net, tunnel_id); + if (!tunnel && can_create) { + struct l2tp_tunnel_cfg tcfg = { + .encap = L2TP_ENCAPTYPE_UDP, + .debug = 0, + }; + error = l2tp_tunnel_create(net, fd, ver, tunnel_id, peer_tunnel_id, &tcfg, &tunnel); + if (error < 0) + return error; + + l2tp_tunnel_inc_refcount(tunnel); + } + + /* Error if we can't find the tunnel */ + if (tunnel == NULL) + return -ENOENT; + + if (tunnel->recv_payload_hook == NULL) + tunnel->recv_payload_hook = pppol2tp_recv_payload_hook; + + if (tunnel->peer_tunnel_id == 0) + tunnel->peer_tunnel_id = peer_tunnel_id; + + *tunnelp = tunnel; + return 0; + + l2tp_tunnel_dec_refcount(tunnel); + return error; +} + +/* Prepare a session in a tunnel. If the session doesn't already + * exist, create it and add it to the tunnel's session list. Return + * with a ref on the session instance and its sk_lock held. + */ +static int pppol2tp_session_prep(struct sock *sk, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session **sessionp) +{ + struct l2tp_session *session; + struct pppol2tp_session *ps; + int error; + struct l2tp_session_cfg cfg = {}; + + session = l2tp_session_get(sock_net(sk), tunnel, session_id); + if (session) { + ps = l2tp_session_priv(session); + + /* Using a pre-existing session is fine as long as it hasn't + * been connected yet. + */ + mutex_lock(&ps->sk_lock); + if (rcu_dereference_protected(ps->sk, + lockdep_is_held(&ps->sk_lock))) { + mutex_unlock(&ps->sk_lock); + l2tp_session_dec_refcount(session); + return -EEXIST; + } + } else { + /* Default MTU must allow space for UDP/L2TP/PPP headers */ + cfg.mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD; + cfg.mru = cfg.mtu; + + session = l2tp_session_create(sizeof(struct pppol2tp_session), + tunnel, session_id, + peer_session_id, &cfg); + if (IS_ERR(session)) { + error = PTR_ERR(session); + return error; + } + + pppol2tp_session_init(session); + ps = l2tp_session_priv(session); + + mutex_lock(&ps->sk_lock); + error = l2tp_session_register(session, tunnel); + if (error < 0) { + mutex_unlock(&ps->sk_lock); + kfree(session); + return error; + } + l2tp_session_inc_refcount(session); + } + + *sessionp = session; + return 0; +} + +static int pppol2tp_setup_ppp(struct l2tp_session *session, struct sock *sk) +{ + struct pppox_sock *po = pppox_sk(sk); + + /* The only header we need to worry about is the L2TP + * header. This size is different depending on whether + * sequence numbers are enabled for the data channel. + */ + po->chan.hdrlen = PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; + + po->chan.private = sk; + po->chan.ops = &pppol2tp_chan_ops; + po->chan.mtu = session->mtu; + + return ppp_register_net_channel(sock_net(sk), &po->chan); +} + +/* connect() handler. Attach a PPPoX socket to a tunnel socket. + * The PPPoX socket is associated with an l2tp_session and the tunnel + * socket is associated with an l2tp_tunnel. The l2tp_tunnel and + * l2tp_session are usually created by netlink before the PPPoX socket + * is connected. However, for L2TPv2 we support a legacy mode where + * netlink is not used and we create the l2tp_tunnel and l2tp_session + * when the PPPoX sockets are connected. In legacy mode, a per-tunnel + * PPPoX socket is used as a control socket for the tunnel and is + * identified by session_id 0. An l2tp_session is created to manage + * the control socket and an l2tp_tunnel is created for the tunnel if + * it doesn't already exist. */ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, int sockaddr_len, int flags) { struct sock *sk = sock->sk; struct sockaddr_pppol2tp *sp = (struct sockaddr_pppol2tp *) uservaddr; - struct pppox_sock *po = pppox_sk(sk); struct l2tp_session *session = NULL; - struct l2tp_tunnel *tunnel; + struct l2tp_tunnel *tunnel = NULL; struct pppol2tp_session *ps; - struct l2tp_session_cfg cfg = { 0, }; int error = 0; u32 tunnel_id, peer_tunnel_id; u32 session_id, peer_session_id; - bool drop_refcnt = false; - bool drop_tunnel = false; int ver = 2; int fd; + bool is_ctrl_skt; lock_sock(sk); @@ -685,135 +837,54 @@ static int pppol2tp_connect(struct socket *sock, struct sockaddr *uservaddr, goto end; /* bad socket address */ } - /* Don't bind if tunnel_id is 0 */ error = -EINVAL; - if (tunnel_id == 0) + if (tunnel_id == 0 || peer_tunnel_id == 0) goto end; - tunnel = l2tp_tunnel_get(sock_net(sk), tunnel_id); - if (tunnel) - drop_tunnel = true; - - /* Special case: create tunnel context if session_id and - * peer_session_id is 0. Otherwise look up tunnel using supplied - * tunnel id. + /* The socket is a control socket if session_id is 0. There is + * one control socket per tunnel. Control sockets do not have ppp. */ - if ((session_id == 0) && (peer_session_id == 0)) { - if (tunnel == NULL) { - struct l2tp_tunnel_cfg tcfg = { - .encap = L2TP_ENCAPTYPE_UDP, - .debug = 0, - }; - error = l2tp_tunnel_create(sock_net(sk), fd, ver, tunnel_id, peer_tunnel_id, &tcfg, &tunnel); - if (error < 0) - goto end; - } - } else { - /* Error if we can't find the tunnel */ - error = -ENOENT; - if (tunnel == NULL) - goto end; - - /* Error if socket is not prepped */ - if (tunnel->sock == NULL) - goto end; - } - - if (tunnel->recv_payload_hook == NULL) - tunnel->recv_payload_hook = pppol2tp_recv_payload_hook; - - if (tunnel->peer_tunnel_id == 0) - tunnel->peer_tunnel_id = peer_tunnel_id; - - session = l2tp_session_get(sock_net(sk), tunnel, session_id); - if (session) { - drop_refcnt = true; - ps = l2tp_session_priv(session); - - /* Using a pre-existing session is fine as long as it hasn't - * been connected yet. - */ - mutex_lock(&ps->sk_lock); - if (rcu_dereference_protected(ps->sk, - lockdep_is_held(&ps->sk_lock))) { - mutex_unlock(&ps->sk_lock); - error = -EEXIST; - goto end; - } - } else { - /* Default MTU must allow space for UDP/L2TP/PPP headers */ - cfg.mtu = 1500 - PPPOL2TP_HEADER_OVERHEAD; - cfg.mru = cfg.mtu; + is_ctrl_skt = (session_id == 0 && peer_session_id == 0); - session = l2tp_session_create(sizeof(struct pppol2tp_session), - tunnel, session_id, - peer_session_id, &cfg); - if (IS_ERR(session)) { - error = PTR_ERR(session); - goto end; - } + /* prep and possibly create the l2tp tunnel instance */ + error = pppol2tp_tunnel_prep(sock_net(sk), fd, ver, tunnel_id, + peer_tunnel_id, is_ctrl_skt, &tunnel); + if (error) + goto end; - pppol2tp_session_init(session); - ps = l2tp_session_priv(session); - l2tp_session_inc_refcount(session); + /* prep and possibly create the l2tp session instance */ + error = pppol2tp_session_prep(sk, tunnel, session_id, + peer_session_id, &session); + if (error) + goto end; - mutex_lock(&ps->sk_lock); - error = l2tp_session_register(session, tunnel); - if (error < 0) { + /* setup ppp unless it's a control socket */ + ps = l2tp_session_priv(session); + if (!is_ctrl_skt) { + error = pppol2tp_setup_ppp(session, sk); + if (error) { mutex_unlock(&ps->sk_lock); - kfree(session); goto end; } - drop_refcnt = true; } - /* Special case: if source & dest session_id == 0x0000, this - * socket is being created to manage the tunnel. Just set up - * the internal context for use by ioctl() and sockopt() - * handlers. + /* The session has now been added to the tunnel. Hold the + * socket to prevent it going away until the session is + * destroyed and attach it to the session such that we can get + * the session instance from the socket and vice versa. */ - if ((session->session_id == 0) && - (session->peer_session_id == 0)) { - error = 0; - goto out_no_ppp; - } - - /* The only header we need to worry about is the L2TP - * header. This size is different depending on whether - * sequence numbers are enabled for the data channel. - */ - po->chan.hdrlen = PPPOL2TP_L2TP_HDR_SIZE_NOSEQ; - - po->chan.private = sk; - po->chan.ops = &pppol2tp_chan_ops; - po->chan.mtu = session->mtu; - - error = ppp_register_net_channel(sock_net(sk), &po->chan); - if (error) { - mutex_unlock(&ps->sk_lock); - goto end; - } - -out_no_ppp: - /* This is how we get the session context from the socket. */ - sk->sk_user_data = session; - rcu_assign_pointer(ps->sk, sk); + sock_hold(sk); + pppol2tp_attach(session, sk); mutex_unlock(&ps->sk_lock); - /* Keep the reference we've grabbed on the session: sk doesn't expect - * the session to disappear. pppol2tp_session_destruct() is responsible - * for dropping it. - */ - drop_refcnt = false; - sk->sk_state = PPPOX_CONNECTED; l2tp_info(session, L2TP_MSG_CONTROL, "%s: created\n", session->name); end: - if (drop_refcnt) + if (session) l2tp_session_dec_refcount(session); - if (drop_tunnel) + if (tunnel) l2tp_tunnel_dec_refcount(tunnel); release_sock(sk); @@ -829,6 +900,7 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel, { int error; struct l2tp_session *session; + struct pppol2tp_session *ps; /* Error if tunnel socket is not prepped */ if (!tunnel->sock) { @@ -852,10 +924,14 @@ static int pppol2tp_session_create(struct net *net, struct l2tp_tunnel *tunnel, } pppol2tp_session_init(session); - + ps = l2tp_session_priv(session); + mutex_lock(&ps->sk_lock); error = l2tp_session_register(session, tunnel); - if (error < 0) + if (error < 0) { + mutex_unlock(&ps->sk_lock); goto err_sess; + } + mutex_unlock(&ps->sk_lock); return 0; @@ -972,7 +1048,7 @@ static int pppol2tp_getname(struct socket *sock, struct sockaddr *uaddr, *usockaddr_len = len; error = 0; - sock_put(sk); + l2tp_session_dec_refcount(session); end: return error; } @@ -1243,7 +1319,7 @@ static int pppol2tp_ioctl(struct socket *sock, unsigned int cmd, err = pppol2tp_session_ioctl(session, cmd, arg); end_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1394,7 +1470,7 @@ static int pppol2tp_setsockopt(struct socket *sock, int level, int optname, err = pppol2tp_session_setsockopt(sk, session, optname, val); } - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; } @@ -1526,7 +1602,7 @@ static int pppol2tp_getsockopt(struct socket *sock, int level, int optname, err = 0; end_put_sess: - sock_put(sk); + l2tp_session_dec_refcount(session); end: return err; }