[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <20210318120930.5723-1-alobakin@pm.me>
Date: Thu, 18 Mar 2021 12:09:41 +0000
From: Alexander Lobakin <alobakin@...me>
To: Cong Wang <xiyou.wangcong@...il.com>
Cc: Alexander Lobakin <alobakin@...me>, bpf@...r.kernel.org,
duanxiongchun@...edance.com, wangdongdong.6@...edance.com,
jiang.wang@...edance.com, Cong Wang <cong.wang@...edance.com>,
John Fastabend <john.fastabend@...il.com>,
Daniel Borkmann <daniel@...earbox.net>,
Jakub Sitnicki <jakub@...udflare.com>,
Lorenz Bauer <lmb@...udflare.com>, netdev@...r.kernel.org
Subject: Re: [Patch bpf-next v5 06/11] sock: introduce sk->sk_prot->psock_update_sk_prot()
From: Cong Wang <xiyou.wangcong@...il.com>
Date: Tue, 16 Mar 2021 19:22:14 -0700
Hi,
> From: Cong Wang <cong.wang@...edance.com>
>
> Currently sockmap calls into each protocol to update the struct
> proto and replace it. This certainly won't work when the protocol
> is implemented as a module, for example, AF_UNIX.
>
> Introduce a new ops sk->sk_prot->psock_update_sk_prot(), so each
> protocol can implement its own way to replace the struct proto.
> This also helps get rid of symbol dependencies on CONFIG_INET.
>
> Cc: John Fastabend <john.fastabend@...il.com>
> Cc: Daniel Borkmann <daniel@...earbox.net>
> Cc: Jakub Sitnicki <jakub@...udflare.com>
> Cc: Lorenz Bauer <lmb@...udflare.com>
> Signed-off-by: Cong Wang <cong.wang@...edance.com>
> ---
> include/linux/skmsg.h | 18 +++---------------
> include/net/sock.h | 3 +++
> include/net/tcp.h | 1 +
> include/net/udp.h | 1 +
> net/core/skmsg.c | 5 -----
> net/core/sock_map.c | 24 ++++--------------------
> net/ipv4/tcp_bpf.c | 24 +++++++++++++++++++++---
> net/ipv4/tcp_ipv4.c | 3 +++
> net/ipv4/udp.c | 3 +++
> net/ipv4/udp_bpf.c | 15 +++++++++++++--
> net/ipv6/tcp_ipv6.c | 3 +++
> net/ipv6/udp.c | 3 +++
> 12 files changed, 58 insertions(+), 45 deletions(-)
>
> diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
> index 77e5d890ec4b..eb2757c0295d 100644
> --- a/include/linux/skmsg.h
> +++ b/include/linux/skmsg.h
> @@ -99,6 +99,7 @@ struct sk_psock {
> void (*saved_close)(struct sock *sk, long timeout);
> void (*saved_write_space)(struct sock *sk);
> void (*saved_data_ready)(struct sock *sk);
> + int (*psock_update_sk_prot)(struct sock *sk, bool restore);
> struct proto *sk_proto;
> struct sk_psock_work_state work_state;
> struct work_struct work;
> @@ -397,25 +398,12 @@ static inline void sk_psock_cork_free(struct sk_psock *psock)
> }
> }
>
> -static inline void sk_psock_update_proto(struct sock *sk,
> - struct sk_psock *psock,
> - struct proto *ops)
> -{
> - /* Pairs with lockless read in sk_clone_lock() */
> - WRITE_ONCE(sk->sk_prot, ops);
> -}
> -
> static inline void sk_psock_restore_proto(struct sock *sk,
> struct sk_psock *psock)
> {
> sk->sk_prot->unhash = psock->saved_unhash;
> - if (inet_csk_has_ulp(sk)) {
> - tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
> - } else {
> - sk->sk_write_space = psock->saved_write_space;
> - /* Pairs with lockless read in sk_clone_lock() */
> - WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> - }
> + if (psock->psock_update_sk_prot)
> + psock->psock_update_sk_prot(sk, true);
> }
>
> static inline void sk_psock_set_state(struct sk_psock *psock,
> diff --git a/include/net/sock.h b/include/net/sock.h
> index 636810ddcd9b..eda64fbd5e3d 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -1184,6 +1184,9 @@ struct proto {
> void (*unhash)(struct sock *sk);
> void (*rehash)(struct sock *sk);
> int (*get_port)(struct sock *sk, unsigned short snum);
> +#ifdef CONFIG_BPF_SYSCALL
> + int (*psock_update_sk_prot)(struct sock *sk, bool restore);
> +#endif
>
> /* Keeping track of sockets in use */
> #ifdef CONFIG_PROC_FS
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index 075de26f449d..2efa4e5ea23d 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -2203,6 +2203,7 @@ struct sk_psock;
>
> #ifdef CONFIG_BPF_SYSCALL
> struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
> +int tcp_bpf_update_proto(struct sock *sk, bool restore);
> void tcp_bpf_clone(const struct sock *sk, struct sock *newsk);
> #endif /* CONFIG_BPF_SYSCALL */
>
> diff --git a/include/net/udp.h b/include/net/udp.h
> index d4d064c59232..df7cc1edc200 100644
> --- a/include/net/udp.h
> +++ b/include/net/udp.h
> @@ -518,6 +518,7 @@ static inline struct sk_buff *udp_rcv_segment(struct sock *sk,
> #ifdef CONFIG_BPF_SYSCALL
> struct sk_psock;
> struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock);
> +int udp_bpf_update_proto(struct sock *sk, bool restore);
> #endif
>
> #endif /* _UDP_H */
> diff --git a/net/core/skmsg.c b/net/core/skmsg.c
> index 5cba52862334..e93683a287a0 100644
> --- a/net/core/skmsg.c
> +++ b/net/core/skmsg.c
> @@ -559,11 +559,6 @@ struct sk_psock *sk_psock_init(struct sock *sk, int node)
>
> write_lock_bh(&sk->sk_callback_lock);
>
> - if (inet_csk_has_ulp(sk)) {
> - psock = ERR_PTR(-EINVAL);
> - goto out;
> - }
> -
> if (sk->sk_user_data) {
> psock = ERR_PTR(-EBUSY);
> goto out;
> diff --git a/net/core/sock_map.c b/net/core/sock_map.c
> index 33f8c854db4f..596cbac24091 100644
> --- a/net/core/sock_map.c
> +++ b/net/core/sock_map.c
> @@ -184,26 +184,10 @@ static void sock_map_unref(struct sock *sk, void *link_raw)
>
> static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
> {
> - struct proto *prot;
> -
> - switch (sk->sk_type) {
> - case SOCK_STREAM:
> - prot = tcp_bpf_get_proto(sk, psock);
> - break;
> -
> - case SOCK_DGRAM:
> - prot = udp_bpf_get_proto(sk, psock);
> - break;
> -
> - default:
> + if (!sk->sk_prot->psock_update_sk_prot)
> return -EINVAL;
> - }
> -
> - if (IS_ERR(prot))
> - return PTR_ERR(prot);
> -
> - sk_psock_update_proto(sk, psock, prot);
> - return 0;
> + psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot;
> + return sk->sk_prot->psock_update_sk_prot(sk, false);
Regarding that both {tcp,udp}_bpf_update_proto() is global and
for now they are the only two implemented callbacks, wouldn't it
be worthy to straighten the calls here? Like
return INDIRECT_CALL_2(sk->sk_prot->psock_update_sk_prot,
tcp_bpf_update_proto,
udp_bpf_update_proto,
sk, false);
(the same in sk_psock_restore_proto() then)
Or this code path is not performance-critical?
> }
>
> static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
> @@ -570,7 +554,7 @@ static bool sock_map_redirect_allowed(const struct sock *sk)
>
> static bool sock_map_sk_is_suitable(const struct sock *sk)
> {
> - return sk_is_tcp(sk) || sk_is_udp(sk);
> + return !!sk->sk_prot->psock_update_sk_prot;
> }
>
> static bool sock_map_sk_state_allowed(const struct sock *sk)
> diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
> index ae980716d896..ac8cfbaeacd2 100644
> --- a/net/ipv4/tcp_bpf.c
> +++ b/net/ipv4/tcp_bpf.c
> @@ -595,20 +595,38 @@ static int tcp_bpf_assert_proto_ops(struct proto *ops)
> ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
> }
>
> -struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> +int tcp_bpf_update_proto(struct sock *sk, bool restore)
> {
> + struct sk_psock *psock = sk_psock(sk);
> int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
> int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE;
>
> + if (restore) {
> + if (inet_csk_has_ulp(sk)) {
> + tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space);
> + } else {
> + sk->sk_write_space = psock->saved_write_space;
> + /* Pairs with lockless read in sk_clone_lock() */
> + WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> + }
> + return 0;
> + }
> +
> + if (inet_csk_has_ulp(sk))
> + return -EINVAL;
> +
> if (sk->sk_family == AF_INET6) {
> if (tcp_bpf_assert_proto_ops(psock->sk_proto))
> - return ERR_PTR(-EINVAL);
> + return -EINVAL;
>
> tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
> }
>
> - return &tcp_bpf_prots[family][config];
> + /* Pairs with lockless read in sk_clone_lock() */
> + WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]);
> + return 0;
> }
> +EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
>
> /* If a child got cloned from a listening socket that had tcp_bpf
> * protocol callbacks installed, we need to restore the callbacks to
> diff --git a/net/ipv4/tcp_ipv4.c b/net/ipv4/tcp_ipv4.c
> index daad4f99db32..dfc6d1c0e710 100644
> --- a/net/ipv4/tcp_ipv4.c
> +++ b/net/ipv4/tcp_ipv4.c
> @@ -2806,6 +2806,9 @@ struct proto tcp_prot = {
> .hash = inet_hash,
> .unhash = inet_unhash,
> .get_port = inet_csk_get_port,
> +#ifdef CONFIG_BPF_SYSCALL
> + .psock_update_sk_prot = tcp_bpf_update_proto,
> +#endif
> .enter_memory_pressure = tcp_enter_memory_pressure,
> .leave_memory_pressure = tcp_leave_memory_pressure,
> .stream_memory_free = tcp_stream_memory_free,
> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index 4a0478b17243..38952aaee3a1 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -2849,6 +2849,9 @@ struct proto udp_prot = {
> .unhash = udp_lib_unhash,
> .rehash = udp_v4_rehash,
> .get_port = udp_v4_get_port,
> +#ifdef CONFIG_BPF_SYSCALL
> + .psock_update_sk_prot = udp_bpf_update_proto,
> +#endif
> .memory_allocated = &udp_memory_allocated,
> .sysctl_mem = sysctl_udp_mem,
> .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
> diff --git a/net/ipv4/udp_bpf.c b/net/ipv4/udp_bpf.c
> index 7a94791efc1a..6001f93cd3a0 100644
> --- a/net/ipv4/udp_bpf.c
> +++ b/net/ipv4/udp_bpf.c
> @@ -41,12 +41,23 @@ static int __init udp_bpf_v4_build_proto(void)
> }
> core_initcall(udp_bpf_v4_build_proto);
>
> -struct proto *udp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
> +int udp_bpf_update_proto(struct sock *sk, bool restore)
> {
> int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;
> + struct sk_psock *psock = sk_psock(sk);
> +
> + if (restore) {
> + sk->sk_write_space = psock->saved_write_space;
> + /* Pairs with lockless read in sk_clone_lock() */
> + WRITE_ONCE(sk->sk_prot, psock->sk_proto);
> + return 0;
> + }
>
> if (sk->sk_family == AF_INET6)
> udp_bpf_check_v6_needs_rebuild(psock->sk_proto);
>
> - return &udp_bpf_prots[family];
> + /* Pairs with lockless read in sk_clone_lock() */
> + WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
> + return 0;
> }
> +EXPORT_SYMBOL_GPL(udp_bpf_update_proto);
> diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c
> index bd44ded7e50c..4fdc58a9e19e 100644
> --- a/net/ipv6/tcp_ipv6.c
> +++ b/net/ipv6/tcp_ipv6.c
> @@ -2134,6 +2134,9 @@ struct proto tcpv6_prot = {
> .hash = inet6_hash,
> .unhash = inet_unhash,
> .get_port = inet_csk_get_port,
> +#ifdef CONFIG_BPF_SYSCALL
> + .psock_update_sk_prot = tcp_bpf_update_proto,
> +#endif
> .enter_memory_pressure = tcp_enter_memory_pressure,
> .leave_memory_pressure = tcp_leave_memory_pressure,
> .stream_memory_free = tcp_stream_memory_free,
> diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
> index d25e5a9252fd..ef2c75bb4771 100644
> --- a/net/ipv6/udp.c
> +++ b/net/ipv6/udp.c
> @@ -1713,6 +1713,9 @@ struct proto udpv6_prot = {
> .unhash = udp_lib_unhash,
> .rehash = udp_v6_rehash,
> .get_port = udp_v6_get_port,
> +#ifdef CONFIG_BPF_SYSCALL
> + .psock_update_sk_prot = udp_bpf_update_proto,
> +#endif
> .memory_allocated = &udp_memory_allocated,
> .sysctl_mem = sysctl_udp_mem,
> .sysctl_wmem_offset = offsetof(struct net, ipv4.sysctl_udp_wmem_min),
> --
> 2.25.1
Thanks,
Al
Powered by blists - more mailing lists