[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <20180117220434.uee6seg665loc535@kafai-mbp.dhcp.thefacebook.com>
Date: Wed, 17 Jan 2018 14:04:34 -0800
From: Martin KaFai Lau <kafai@...com>
To: John Fastabend <john.fastabend@...il.com>
CC: <borkmann@...earbox.net>, <ast@...nel.org>,
<netdev@...r.kernel.org>
Subject: Re: [bpf-next PATCH 5/7] bpf: create tcp_bpf_ulp allowing BPF to
monitor socket TX/RX data
On Fri, Jan 12, 2018 at 10:11:11AM -0800, John Fastabend wrote:
> This implements a BPF ULP layer to allow policy enforcement and
> monitoring at the socket layer. In order to support this a new
> program type BPF_PROG_TYPE_SK_MSG is used to run the policy at
> the sendmsg/sendpage hook. To attach the policy to sockets a
> sockmap is used with a new program attach type BPF_SK_MSG_VERDICT.
>
> Similar to previous sockmap usages when a sock is added to a
> sockmap, via a map update, if the map contains a BPF_SK_MSG_VERDICT
> program type attached then the BPF ULP layer is created on the
> socket and the attached BPF_PROG_TYPE_SK_MSG program is run for
> every msg in sendmsg case and page/offset in sendpage case.
>
> BPF_PROG_TYPE_SK_MSG Semantics/API:
>
> BPF_PROG_TYPE_SK_MSG supports only two return codes SK_PASS and
> SK_DROP. Returning SK_DROP free's the copied data in the sendmsg
> case and in the sendpage case leaves the data untouched. Both cases
> return -EACESS to the user. Returning SK_PASS will allow the msg to
> be sent.
>
> In the sendmsg case data is copied into kernel space buffers before
> running the BPF program. In the sendpage case data is never copied.
> The implication being users may change data after BPF programs run in
> the sendpage case. (A flag will be added to always copy shortly
> if the copy must always be performed).
>
> The verdict from the BPF_PROG_TYPE_SK_MSG applies to the entire msg
> in the sendmsg() case and the entire page/offset in the sendpage case.
> This avoid ambiguity on how to handle mixed return codes in the
> sendmsg case. The readable/writeable data provided to the program
> in the sendmsg case may not be the entire message, in fact for
> large sends this is likely the case. The data range that can be
> read is part of the sk_msg_md structure. This is because similar
> to the tc bpf_cls case the data is stored in a scatter gather list.
> Future work will address this short-coming to allow users to pull
> in more data if needed (similar to TC BPF).
>
> The helper msg_redirect_map() can be used to select the socket to
> send the data on. This is used similar to existing redirect use
> cases. This allows policy to redirect msgs.
>
> Pseudo code simple example:
>
> The basic logic to attach a program to a socket is as follows,
>
> // load the programs
> bpf_prog_load(SOCKMAP_TCP_MSG_PROG, BPF_PROG_TYPE_SK_MSG,
> &obj, &msg_prog);
>
> // lookup the sockmap
> bpf_map_msg = bpf_object__find_map_by_name(obj, "my_sock_map");
>
> // get fd for sockmap
> map_fd_msg = bpf_map__fd(bpf_map_msg);
>
> // attach program to sockmap
> bpf_prog_attach(msg_prog, map_fd_msg, BPF_SK_MSG_VERDICT, 0);
>
> Adding sockets to the map is done in the normal way,
>
> // Add a socket 'fd' to sockmap at location 'i'
> bpf_map_update_elem(map_fd_msg, &i, fd, BPF_ANY);
>
> After the above any socket attached to "my_sock_map", in this case
> 'fd', will run the BPF msg verdict program (msg_prog) on every
> sendmsg and sendpage system call.
>
> For a complete example see BPF selftests bpf/sockmap_tcp_msg_*.c and
> test_maps.c
>
> Implementation notes:
>
> It seemed the simplest, to me at least, to use a refcnt to ensure
> psock is not lost across the sendmsg copy into the sg, the bpf program
> running on the data in sg_data, and the final pass to the TCP stack.
> Some performance testing may show a better method to do this and avoid
> the refcnt cost, but for now use the simpler method.
>
> Another item that will come after basic support is in place is
> supporting MSG_MORE flag. At the moment we call sendpages even if
> the MSG_MORE flag is set. An enhancement would be to collect the
> pages into a larger scatterlist and pass down the stack. Notice that
> bpf_tcp_sendmsg() could support this with some additional state saved
> across sendmsg calls. I built the code to support this without having
> to do refactoring work. Other flags TBD include ZEROCOPY flag.
>
> Yet another detail that needs some thought is the size of scatterlist.
> Currently, we use MAX_SKB_FRAGS simply because this was being used
> already in the TLS case. Future work to improve the kernel sk APIs to
> tune this depending on workload may be useful. This is a trade-off
> between memory usage and B/s performance.
Some minor comments/nits below:
>
> Signed-off-by: John Fastabend <john.fastabend@...il.com>
> ---
> include/linux/bpf.h | 1
> include/linux/bpf_types.h | 1
> include/linux/filter.h | 10 +
> include/net/tcp.h | 2
> include/uapi/linux/bpf.h | 28 +++
> kernel/bpf/sockmap.c | 485 ++++++++++++++++++++++++++++++++++++++++++++-
> kernel/bpf/syscall.c | 14 +
> kernel/bpf/verifier.c | 5
> net/core/filter.c | 106 ++++++++++
> 9 files changed, 638 insertions(+), 14 deletions(-)
>
> diff --git a/include/linux/bpf.h b/include/linux/bpf.h
> index 9e03046..14cdb4d 100644
> --- a/include/linux/bpf.h
> +++ b/include/linux/bpf.h
> @@ -21,6 +21,7 @@
> struct perf_event;
> struct bpf_prog;
> struct bpf_map;
> +struct sock;
>
> /* map is generic key/value storage optionally accesible by eBPF programs */
> struct bpf_map_ops {
> diff --git a/include/linux/bpf_types.h b/include/linux/bpf_types.h
> index 19b8349..5e2e8a4 100644
> --- a/include/linux/bpf_types.h
> +++ b/include/linux/bpf_types.h
> @@ -13,6 +13,7 @@
> BPF_PROG_TYPE(BPF_PROG_TYPE_LWT_XMIT, lwt_xmit)
> BPF_PROG_TYPE(BPF_PROG_TYPE_SOCK_OPS, sock_ops)
> BPF_PROG_TYPE(BPF_PROG_TYPE_SK_SKB, sk_skb)
> +BPF_PROG_TYPE(BPF_PROG_TYPE_SK_MSG, sk_msg)
> #endif
> #ifdef CONFIG_BPF_EVENTS
> BPF_PROG_TYPE(BPF_PROG_TYPE_KPROBE, kprobe)
> diff --git a/include/linux/filter.h b/include/linux/filter.h
> index 425056c..f1e9833 100644
> --- a/include/linux/filter.h
> +++ b/include/linux/filter.h
> @@ -507,6 +507,15 @@ struct xdp_buff {
> struct xdp_rxq_info *rxq;
> };
>
> +struct sk_msg_buff {
> + void *data;
> + void *data_end;
> + struct scatterlist sg_data[MAX_SKB_FRAGS];
> + __u32 key;
> + __u32 flags;
> + struct bpf_map *map;
> +};
> +
> /* Compute the linear packet data range [data, data_end) which
> * will be accessed by various program types (cls_bpf, act_bpf,
> * lwt, ...). Subsystems allowing direct data access must (!)
> @@ -769,6 +778,7 @@ int xdp_do_redirect(struct net_device *dev,
> void bpf_warn_invalid_xdp_action(u32 act);
>
> struct sock *do_sk_redirect_map(struct sk_buff *skb);
> +struct sock *do_msg_redirect_map(struct sk_msg_buff *md);
>
> #ifdef CONFIG_BPF_JIT
> extern int bpf_jit_enable;
> diff --git a/include/net/tcp.h b/include/net/tcp.h
> index a99ceb8..7f56c3c 100644
> --- a/include/net/tcp.h
> +++ b/include/net/tcp.h
> @@ -1984,6 +1984,7 @@ static inline void tcp_listendrop(const struct sock *sk)
>
> enum {
> TCP_ULP_TLS,
> + TCP_ULP_BPF,
> };
>
> struct tcp_ulp_ops {
> @@ -2001,6 +2002,7 @@ struct tcp_ulp_ops {
> int tcp_register_ulp(struct tcp_ulp_ops *type);
> void tcp_unregister_ulp(struct tcp_ulp_ops *type);
> int tcp_set_ulp(struct sock *sk, const char *name);
> +int tcp_set_ulp_id(struct sock *sk, const int ulp);
> void tcp_get_available_ulp(char *buf, size_t len);
> void tcp_cleanup_ulp(struct sock *sk);
>
> diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
> index 405317f..bf649ae 100644
> --- a/include/uapi/linux/bpf.h
> +++ b/include/uapi/linux/bpf.h
> @@ -133,6 +133,7 @@ enum bpf_prog_type {
> BPF_PROG_TYPE_SOCK_OPS,
> BPF_PROG_TYPE_SK_SKB,
> BPF_PROG_TYPE_CGROUP_DEVICE,
> + BPF_PROG_TYPE_SK_MSG,
> };
>
> enum bpf_attach_type {
> @@ -143,6 +144,7 @@ enum bpf_attach_type {
> BPF_SK_SKB_STREAM_PARSER,
> BPF_SK_SKB_STREAM_VERDICT,
> BPF_CGROUP_DEVICE,
> + BPF_SK_MSG_VERDICT,
> __MAX_BPF_ATTACH_TYPE
> };
>
> @@ -687,6 +689,15 @@ enum bpf_attach_type {
> * int bpf_override_return(pt_regs, rc)
> * @pt_regs: pointer to struct pt_regs
> * @rc: the return value to set
> + *
> + * int bpf_msg_redirect_map(map, key, flags)
> + * Redirect msg to a sock in map using key as a lookup key for the
> + * sock in map.
> + * @map: pointer to sockmap
> + * @key: key to lookup sock in map
> + * @flags: reserved for future use
> + * Return: SK_PASS
> + *
> */
> #define __BPF_FUNC_MAPPER(FN) \
> FN(unspec), \
> @@ -747,7 +758,8 @@ enum bpf_attach_type {
> FN(perf_event_read_value), \
> FN(perf_prog_read_value), \
> FN(getsockopt), \
> - FN(override_return),
> + FN(override_return), \
> + FN(msg_redirect_map),
>
> /* integer value in 'imm' field of BPF_CALL instruction selects which helper
> * function eBPF program intends to call
> @@ -909,6 +921,20 @@ enum sk_action {
> SK_PASS,
> };
>
> +/* User return codes for SK_MSG prog type. */
> +enum sk_msg_action {
> + SK_MSG_DROP = 0,
> + SK_MSG_PASS,
> +};
> +
> +/* user accessible metadata for SK_MSG packet hook, new fields must
> + * be added to the end of this structure
> + */
> +struct sk_msg_md {
> + __u32 data;
> + __u32 data_end;
> +};
> +
> #define BPF_TAG_SIZE 8
>
> struct bpf_prog_info {
> diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
> index 972608f..5793f3a 100644
> --- a/kernel/bpf/sockmap.c
> +++ b/kernel/bpf/sockmap.c
> @@ -38,6 +38,7 @@
> #include <linux/skbuff.h>
> #include <linux/workqueue.h>
> #include <linux/list.h>
> +#include <linux/mm.h>
> #include <net/strparser.h>
> #include <net/tcp.h>
>
> @@ -47,6 +48,7 @@
> struct bpf_stab {
> struct bpf_map map;
> struct sock **sock_map;
> + struct bpf_prog *bpf_tx_msg;
> struct bpf_prog *bpf_parse;
> struct bpf_prog *bpf_verdict;
> };
> @@ -74,6 +76,7 @@ struct smap_psock {
> struct sk_buff *save_skb;
>
> struct strparser strp;
> + struct bpf_prog *bpf_tx_msg;
> struct bpf_prog *bpf_parse;
> struct bpf_prog *bpf_verdict;
> struct list_head maps;
> @@ -90,6 +93,8 @@ struct smap_psock {
> void (*save_state_change)(struct sock *sk);
> };
>
> +static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> +
> static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
> {
> return rcu_dereference_sk_user_data(sk);
> @@ -99,8 +104,439 @@ enum __sk_action {
> __SK_DROP = 0,
> __SK_PASS,
> __SK_REDIRECT,
> + __SK_NONE,
> };
>
> +static int memcopy_from_iter(struct sock *sk, struct scatterlist *sg,
> + int sg_num, struct iov_iter *from, int bytes)
> +{
> + int i, rc = 0;
> +
> + for (i = 0; i < sg_num; ++i) {
> + int copy = sg[i].length;
> + char *to = sg_virt(&sg[i]);
> +
> + if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
> + rc = copy_from_iter_nocache(to, copy, from);
> + else
> + rc = copy_from_iter(to, copy, from);
> +
> + if (rc != copy) {
> + rc = -EFAULT;
> + goto out;
> + }
> +
> + bytes -= copy;
> + if (!bytes)
> + break;
> + }
> +out:
> + return rc;
> +}
> +
> +static int bpf_tcp_push(struct sock *sk, struct scatterlist *sg,
> + int *sg_end, int flags, bool charge)
> +{
> + int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST;
> + int offset, ret = 0;
> + struct page *p;
> + size_t size;
> +
> + size = sg->length;
> + offset = sg->offset;
> +
> + while (1) {
> + if (sg_is_last(sg))
> + sendpage_flags = flags;
> +
> + tcp_rate_check_app_limited(sk);
> + p = sg_page(sg);
> +retry:
> + ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags);
> + if (ret != size) {
> + if (ret > 0) {
> + offset += ret;
> + size -= ret;
> + goto retry;
> + }
> +
> + if (charge)
> + sk_mem_uncharge(sk,
> + sg->length - size - sg->offset);
> +
> + sg->offset = offset;
> + sg->length = size;
> + return ret;
> + }
> +
> + put_page(p);
> + if (charge)
> + sk_mem_uncharge(sk, sg->length);
> + *sg_end += 1;
> + sg = sg_next(sg);
> + if (!sg)
> + break;
> +
> + offset = sg->offset;
> + size = sg->length;
> + }
> +
> + return 0;
> +}
> +
> +static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
> +{
> + md->data = sg_virt(md->sg_data);
> + md->data_end = md->data + md->sg_data->length;
> +}
> +
> +static void return_mem_sg(struct sock *sk, struct scatterlist *sg, int end)
> +{
> + int i;
> +
> + for (i = 0; i < end; ++i)
> + sk_mem_uncharge(sk, sg[i].length);
> +}
> +
> +static int free_sg(struct sock *sk, struct scatterlist *sg, int start, int len)
> +{
> + int i, free = 0;
> +
> + for (i = start; i < len; ++i) {
> + free += sg[i].length;
> + sk_mem_uncharge(sk, sg[i].length);
> + put_page(sg_page(&sg[i]));
> + }
> +
> + return free;
> +}
> +
> +static unsigned int smap_do_tx_msg(struct sock *sk,
> + struct smap_psock *psock,
> + struct sk_msg_buff *md)
> +{
> + struct bpf_prog *prog;
> + unsigned int rc, _rc;
> +
> + preempt_disable();
Why preempt_disable() is needed?
> + rcu_read_lock();
> +
> + /* If the policy was removed mid-send then default to 'accept' */
> + prog = READ_ONCE(psock->bpf_tx_msg);
> + if (unlikely(!prog)) {
> + _rc = SK_PASS;
> + goto verdict;
> + }
> +
> + bpf_compute_data_pointers_sg(md);
> + _rc = (*prog->bpf_func)(md, prog->insnsi);
> +
> +verdict:
> + rcu_read_unlock();
> + preempt_enable();
> +
> + /* Moving return codes from UAPI namespace into internal namespace */
> + rc = ((_rc == SK_PASS) ?
> + (md->map ? __SK_REDIRECT : __SK_PASS) :
> + __SK_DROP);
> +
> + return rc;
> +}
> +
> +static int bpf_tcp_sendmsg_do_redirect(struct scatterlist *sg, int sg_num,
> + struct sk_msg_buff *md, int flags)
> +{
> + int i, sg_curr = 0, err, free;
> + struct smap_psock *psock;
> + struct sock *sk;
> +
> + rcu_read_lock();
> + sk = do_msg_redirect_map(md);
> + if (unlikely(!sk))
> + goto out_rcu;
> +
> + psock = smap_psock_sk(sk);
> + if (unlikely(!psock))
> + goto out_rcu;
> +
> + if (!refcount_inc_not_zero(&psock->refcnt))
> + goto out_rcu;
> +
> + rcu_read_unlock();
> + lock_sock(sk);
> + err = bpf_tcp_push(sk, sg, &sg_curr, flags, false);
> + if (unlikely(err))
> + goto out;
> + release_sock(sk);
> + smap_release_sock(psock, sk);
> + return 0;
> +out_rcu:
> + rcu_read_unlock();
> +out:
> + for (i = sg_curr; i < sg_num; ++i) {
> + free += sg[i].length;
free is not init.
> + put_page(sg_page(&sg[i]));
> + }
> + return free;
> +}
> +
> +static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
> +{
> + int err = 0, eval = __SK_NONE, sg_size = 0, sg_num = 0;
> + int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
> + struct sk_msg_buff md = {0};
> + struct smap_psock *psock;
> + size_t copy, copied = 0;
> + struct scatterlist *sg;
> + long timeo;
> +
> + sg = md.sg_data;
> + sg_init_table(sg, MAX_SKB_FRAGS);
> +
> + /* Its possible a sock event or user removed the psock _but_ the ops
> + * have not been reprogrammed yet so we get here. In this case fallback
> + * to tcp_sendmsg. Note this only works because we _only_ ever allow
> + * a single ULP there is no hierarchy here.
> + */
> + rcu_read_lock();
> + psock = smap_psock_sk(sk);
> + if (unlikely(!psock)) {
> + rcu_read_unlock();
> + return tcp_sendmsg(sk, msg, size);
> + }
> +
> + /* Increment the psock refcnt to ensure its not released while sending a
> + * message. Required because sk lookup and bpf programs are used in
> + * separate rcu critical sections. Its OK if we lose the map entry
> + * but we can't lose the sock reference, possible when the refcnt hits
> + * zero and garbage collection calls sock_put().
> + */
> + if (!refcount_inc_not_zero(&psock->refcnt)) {
> + rcu_read_unlock();
> + return tcp_sendmsg(sk, msg, size);
> + }
> +
> + rcu_read_unlock();
> +
> + lock_sock(sk);
> + timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
> +
> + while (msg_data_left(msg)) {
> + int sg_curr;
> +
> + if (sk->sk_err) {
> + err = sk->sk_err;
> + goto out_err;
> + }
> +
> + copy = msg_data_left(msg);
> + if (!sk_stream_memory_free(sk))
> + goto wait_for_sndbuf;
> +
> + /* sg_size indicates bytes already allocated and sg_num
> + * is last sg element used. This is used when alloc_sg
s/alloc_sg/sk_alloc_sg/
> + * partially allocates a scatterlist and then is sent
> + * to wait for memory. In normal case (no memory pressure)
> + * both sg_nun and sg_size are zero.
s/sg_nun/sg_num/
> + */
> + copy = copy - sg_size;
> + err = sk_alloc_sg(sk, copy, sg, &sg_num, &sg_size, 0);
> + if (err) {
> + if (err != -ENOSPC)
> + goto wait_for_memory;
> + copy = sg_size;
> + }
> +
> + err = memcopy_from_iter(sk, sg, sg_num, &msg->msg_iter, copy);
> + if (err < 0) {
> + free_sg(sk, sg, 0, sg_num);
> + goto out_err;
> + }
> +
> + copied += copy;
> +
> + /* If msg is larger than MAX_SKB_FRAGS we can send multiple
> + * scatterlists per msg. However BPF decisions apply to the
> + * entire msg.
> + */
> + if (eval == __SK_NONE)
> + eval = smap_do_tx_msg(sk, psock, &md);
> +
> + switch (eval) {
> + case __SK_PASS:
> + sg_mark_end(sg + sg_num - 1);
> + err = bpf_tcp_push(sk, sg, &sg_curr, flags, true);
> + if (unlikely(err)) {
> + copied -= free_sg(sk, sg, sg_curr, sg_num);
> + goto out_err;
> + }
> + break;
> + case __SK_REDIRECT:
> + sg_mark_end(sg + sg_num - 1);
> + goto do_redir;
> + case __SK_DROP:
> + default:
> + copied -= free_sg(sk, sg, 0, sg_num);
> + goto out_err;
> + }
> +
> + sg_num = 0;
> + sg_size = 0;
> + continue;
> +wait_for_sndbuf:
> + set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
> +wait_for_memory:
> + err = sk_stream_wait_memory(sk, &timeo);
> + if (err)
> + goto out_err;
> + }
> +out_err:
> + if (err < 0)
> + err = sk_stream_error(sk, msg->msg_flags, err);
> + release_sock(sk);
> + smap_release_sock(psock, sk);
> + return copied ? copied : err;
> +
> +do_redir:
> + /* To avoid deadlock with multiple socks all doing redirects to
> + * each other we must first drop the current sock lock and release
> + * the psock. Then get the redirect socket (assuming it still
> + * exists), take it's lock, and finally do the send here. If the
> + * redirect fails there is nothing to do, we don't want to blame
> + * the sender for remote socket failures. Instead we simply
> + * continue making forward progress.
> + */
> + return_mem_sg(sk, sg, sg_num);
> + release_sock(sk);
> + smap_release_sock(psock, sk);
> + copied -= bpf_tcp_sendmsg_do_redirect(sg, sg_num, &md, flags);
> + return copied;
For __SK_REDIRECT case, before returning, should 'msg_data_left(msg)' be checked
first? Or msg_data_left(msg) will always be 0 here?
> +}
> +
> +static int bpf_tcp_sendpage_do_redirect(struct page *page, int offset,
> + size_t size, int flags,
> + struct sk_msg_buff *md)
> +{
> + struct smap_psock *psock;
> + struct sock *sk;
> + int rc;
> +
> + rcu_read_lock();
> + sk = do_msg_redirect_map(md);
> + if (unlikely(!sk))
> + goto out_rcu;
> +
> + psock = smap_psock_sk(sk);
> + if (unlikely(!psock))
> + goto out_rcu;
> +
> + if (!refcount_inc_not_zero(&psock->refcnt))
> + goto out_rcu;
> +
> + rcu_read_unlock();
> +
> + lock_sock(sk);
> + rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> + release_sock(sk);
> +
> + smap_release_sock(psock, sk);
> + return rc;
> +out_rcu:
> + rcu_read_unlock();
> + return -EINVAL;
> +}
> +
> +static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
> + int offset, size_t size, int flags)
> +{
> + struct smap_psock *psock;
> + int rc, _rc = __SK_PASS;
> + struct bpf_prog *prog;
> + struct sk_msg_buff md;
> +
> + preempt_disable();
> + rcu_read_lock();
> + psock = smap_psock_sk(sk);
> + if (unlikely(!psock))
> + goto verdict;
> +
> + /* If the policy was removed mid-send then default to 'accept' */
> + prog = READ_ONCE(psock->bpf_tx_msg);
> + if (unlikely(!prog))
> + goto verdict;
> +
> + /* Calculate pkt data pointers and run BPF program */
> + md.data = page_address(page) + offset;
> + md.data_end = md.data + size;
> + _rc = (*prog->bpf_func)(&md, prog->insnsi);
> +
> +verdict:
> + rcu_read_unlock();
> + preempt_enable();
> +
> + /* Moving return codes from UAPI namespace into internal namespace */
> + rc = ((_rc == SK_PASS) ? __SK_PASS : __SK_DROP);
> +
> + switch (rc) {
> + case __SK_PASS:
> + lock_sock(sk);
> + rc = tcp_sendpage_locked(sk, page, offset, size, flags);
> + release_sock(sk);
> + break;
> + case __SK_REDIRECT:
> + smap_release_sock(psock, sk);
smap_release_sock() is only needed in __SK_REDIRECT case?
> + rc = bpf_tcp_sendpage_do_redirect(page, offset, size, flags,
> + &md);
> + break;
> + case __SK_DROP:
> + default:
> + rc = -EACCES;
> + }
> +
> + return rc;
> +}
> +
> +static int bpf_tcp_msg_add(struct smap_psock *psock,
> + struct sock *sk,
> + struct bpf_prog *tx_msg)
> +{
> + struct bpf_prog *orig_tx_msg;
> +
> + orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
> + if (orig_tx_msg)
> + bpf_prog_put(orig_tx_msg);
> +
> + return tcp_set_ulp_id(sk, TCP_ULP_BPF);
> +}
> +
> +struct proto tcp_bpf_proto;
> +static int bpf_tcp_init(struct sock *sk)
> +{
> + sk->sk_prot = &tcp_bpf_proto;
> + return 0;
> +}
> +
> +static void bpf_tcp_release(struct sock *sk)
> +{
> + sk->sk_prot = &tcp_prot;
> +}
> +
> +static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
> + .name = "bpf_tcp",
> + .uid = TCP_ULP_BPF,
> + .owner = NULL,
> + .init = bpf_tcp_init,
> + .release = bpf_tcp_release,
> +};
> +
> +static int bpf_tcp_ulp_register(void)
> +{
> + tcp_bpf_proto = tcp_prot;
> + tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
> + tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
> + return tcp_register_ulp(&bpf_tcp_ulp_ops);
> +}
> +
> static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
> {
> struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
> @@ -165,8 +601,6 @@ static void smap_report_sk_error(struct smap_psock *psock, int err)
> sk->sk_error_report(sk);
> }
>
> -static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
> -
> /* Called with lock_sock(sk) held */
> static void smap_state_change(struct sock *sk)
> {
> @@ -317,6 +751,7 @@ static void smap_write_space(struct sock *sk)
>
> static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
> {
> + tcp_cleanup_ulp(sk);
> if (!psock->strp_enabled)
> return;
> sk->sk_data_ready = psock->save_data_ready;
> @@ -384,7 +819,6 @@ static int smap_parse_func_strparser(struct strparser *strp,
> return rc;
> }
>
> -
> static int smap_read_sock_done(struct strparser *strp, int err)
> {
> return err;
> @@ -456,6 +890,8 @@ static void smap_gc_work(struct work_struct *w)
> bpf_prog_put(psock->bpf_parse);
> if (psock->bpf_verdict)
> bpf_prog_put(psock->bpf_verdict);
> + if (psock->bpf_tx_msg)
> + bpf_prog_put(psock->bpf_tx_msg);
>
> list_for_each_entry_safe(e, tmp, &psock->maps, list) {
> list_del(&e->list);
> @@ -491,8 +927,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
>
> static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
> {
> - struct bpf_stab *stab;
> - int err = -EINVAL;
> + struct bpf_stab *stab; int err = -EINVAL;
> u64 cost;
>
> if (!capable(CAP_NET_ADMIN))
> @@ -506,6 +941,10 @@ static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
> if (attr->value_size > KMALLOC_MAX_SIZE)
> return ERR_PTR(-E2BIG);
>
> + err = bpf_tcp_ulp_register();
> + if (err && err != -EEXIST)
> + return ERR_PTR(err);
> +
> stab = kzalloc(sizeof(*stab), GFP_USER);
> if (!stab)
> return ERR_PTR(-ENOMEM);
> @@ -590,6 +1029,8 @@ static void sock_map_free(struct bpf_map *map)
> bpf_prog_put(stab->bpf_verdict);
> if (stab->bpf_parse)
> bpf_prog_put(stab->bpf_parse);
> + if (stab->bpf_tx_msg)
> + bpf_prog_put(stab->bpf_tx_msg);
>
> sock_map_remove_complete(stab);
> }
> @@ -684,7 +1125,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> {
> struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
> struct smap_psock_map_entry *e = NULL;
> - struct bpf_prog *verdict, *parse;
> + struct bpf_prog *verdict, *parse, *tx_msg;
> struct sock *osock, *sock;
> struct smap_psock *psock;
> u32 i = *(u32 *)key;
> @@ -710,6 +1151,7 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> */
> verdict = READ_ONCE(stab->bpf_verdict);
> parse = READ_ONCE(stab->bpf_parse);
> + tx_msg = READ_ONCE(stab->bpf_tx_msg);
>
> if (parse && verdict) {
> /* bpf prog refcnt may be zero if a concurrent attach operation
> @@ -728,6 +1170,17 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> }
> }
>
> + if (tx_msg) {
> + tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
> + if (IS_ERR(tx_msg)) {
> + if (verdict)
> + bpf_prog_put(verdict);
> + if (parse)
> + bpf_prog_put(parse);
> + return PTR_ERR(tx_msg);
> + }
> + }
> +
> write_lock_bh(&sock->sk_callback_lock);
> psock = smap_psock_sk(sock);
>
> @@ -742,7 +1195,14 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> err = -EBUSY;
> goto out_progs;
> }
> - refcount_inc(&psock->refcnt);
> + if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
> + err = -EBUSY;
> + goto out_progs;
> + }
> + if (!refcount_inc_not_zero(&psock->refcnt)) {
> + err = -EAGAIN;
> + goto out_progs;
> + }
> } else {
> psock = smap_init_psock(sock, stab);
> if (IS_ERR(psock)) {
> @@ -763,6 +1223,12 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> /* 3. At this point we have a reference to a valid psock that is
> * running. Attach any BPF programs needed.
> */
> + if (tx_msg) {
> + err = bpf_tcp_msg_add(psock, sock, tx_msg);
> + if (err)
> + goto out_free;
> + }
> +
> if (parse && verdict && !psock->strp_enabled) {
> err = smap_init_sock(psock, sock);
> if (err)
> @@ -798,6 +1264,8 @@ static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
> bpf_prog_put(verdict);
> if (parse)
> bpf_prog_put(parse);
> + if (tx_msg)
> + bpf_prog_put(tx_msg);
> write_unlock_bh(&sock->sk_callback_lock);
> kfree(e);
> return err;
> @@ -812,6 +1280,9 @@ int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
> return -EINVAL;
>
> switch (type) {
> + case BPF_SK_MSG_VERDICT:
> + orig = xchg(&stab->bpf_tx_msg, prog);
> + break;
> case BPF_SK_SKB_STREAM_PARSER:
> orig = xchg(&stab->bpf_parse, prog);
> break;
> diff --git a/kernel/bpf/syscall.c b/kernel/bpf/syscall.c
> index ebf0fb2..d32f093 100644
> --- a/kernel/bpf/syscall.c
> +++ b/kernel/bpf/syscall.c
> @@ -1267,7 +1267,8 @@ static int bpf_obj_get(const union bpf_attr *attr)
>
> #define BPF_PROG_ATTACH_LAST_FIELD attach_flags
>
> -static int sockmap_get_from_fd(const union bpf_attr *attr, bool attach)
> +static int sockmap_get_from_fd(const union bpf_attr *attr,
> + int type, bool attach)
> {
> struct bpf_prog *prog = NULL;
> int ufd = attr->target_fd;
> @@ -1281,8 +1282,7 @@ static int sockmap_get_from_fd(const union bpf_attr *attr, bool attach)
> return PTR_ERR(map);
>
> if (attach) {
> - prog = bpf_prog_get_type(attr->attach_bpf_fd,
> - BPF_PROG_TYPE_SK_SKB);
> + prog = bpf_prog_get_type(attr->attach_bpf_fd, type);
> if (IS_ERR(prog)) {
> fdput(f);
> return PTR_ERR(prog);
> @@ -1334,9 +1334,11 @@ static int bpf_prog_attach(const union bpf_attr *attr)
> case BPF_CGROUP_DEVICE:
> ptype = BPF_PROG_TYPE_CGROUP_DEVICE;
> break;
> + case BPF_SK_MSG_VERDICT:
> + return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_MSG, true);
> case BPF_SK_SKB_STREAM_PARSER:
> case BPF_SK_SKB_STREAM_VERDICT:
> - return sockmap_get_from_fd(attr, true);
> + return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_SKB, true);
> default:
> return -EINVAL;
> }
> @@ -1389,9 +1391,11 @@ static int bpf_prog_detach(const union bpf_attr *attr)
> case BPF_CGROUP_DEVICE:
> ptype = BPF_PROG_TYPE_CGROUP_DEVICE;
> break;
> + case BPF_SK_MSG_VERDICT:
> + return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_MSG, false);
> case BPF_SK_SKB_STREAM_PARSER:
> case BPF_SK_SKB_STREAM_VERDICT:
> - return sockmap_get_from_fd(attr, false);
> + return sockmap_get_from_fd(attr, BPF_PROG_TYPE_SK_SKB, false);
> default:
> return -EINVAL;
> }
> diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
> index a2b2112..15c5c2a 100644
> --- a/kernel/bpf/verifier.c
> +++ b/kernel/bpf/verifier.c
> @@ -1240,6 +1240,7 @@ static bool may_access_direct_pkt_data(struct bpf_verifier_env *env,
> case BPF_PROG_TYPE_XDP:
> case BPF_PROG_TYPE_LWT_XMIT:
> case BPF_PROG_TYPE_SK_SKB:
> + case BPF_PROG_TYPE_SK_MSG:
> if (meta)
> return meta->pkt_access;
>
> @@ -2041,7 +2042,8 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
> case BPF_MAP_TYPE_SOCKMAP:
> if (func_id != BPF_FUNC_sk_redirect_map &&
> func_id != BPF_FUNC_sock_map_update &&
> - func_id != BPF_FUNC_map_delete_elem)
> + func_id != BPF_FUNC_map_delete_elem &&
> + func_id != BPF_FUNC_msg_redirect_map)
> goto error;
> break;
> default:
> @@ -2079,6 +2081,7 @@ static int check_map_func_compatibility(struct bpf_verifier_env *env,
> goto error;
> break;
> case BPF_FUNC_sk_redirect_map:
> + case BPF_FUNC_msg_redirect_map:
> if (map->map_type != BPF_MAP_TYPE_SOCKMAP)
> goto error;
> break;
> diff --git a/net/core/filter.c b/net/core/filter.c
> index acdb94c..ca87b8d 100644
> --- a/net/core/filter.c
> +++ b/net/core/filter.c
> @@ -1881,6 +1881,44 @@ struct sock *do_sk_redirect_map(struct sk_buff *skb)
> .arg4_type = ARG_ANYTHING,
> };
>
> +BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
> + struct bpf_map *, map, u32, key, u64, flags)
> +{
> + /* If user passes invalid input drop the packet. */
> + if (unlikely(flags))
> + return SK_DROP;
> +
> + msg->key = key;
> + msg->flags = flags;
> + msg->map = map;
> +
> + return SK_PASS;
> +}
> +
> +struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
> +{
> + struct sock *sk = NULL;
> +
> + if (msg->map) {
> + sk = __sock_map_lookup_elem(msg->map, msg->key);
> +
> + msg->key = 0;
> + msg->map = NULL;
> + }
> +
> + return sk;
> +}
> +
> +static const struct bpf_func_proto bpf_msg_redirect_map_proto = {
> + .func = bpf_msg_redirect_map,
> + .gpl_only = false,
> + .ret_type = RET_INTEGER,
> + .arg1_type = ARG_PTR_TO_CTX,
> + .arg2_type = ARG_CONST_MAP_PTR,
> + .arg3_type = ARG_ANYTHING,
> + .arg4_type = ARG_ANYTHING,
> +};
> +
> BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
> {
> return task_get_classid(skb);
> @@ -3513,6 +3551,16 @@ static unsigned long bpf_xdp_copy(void *dst_buff, const void *src_buff,
> }
> }
>
> +static const struct bpf_func_proto *sk_msg_func_proto(enum bpf_func_id func_id)
> +{
> + switch (func_id) {
> + case BPF_FUNC_msg_redirect_map:
> + return &bpf_msg_redirect_map_proto;
> + default:
> + return bpf_base_func_proto(func_id);
> + }
> +}
> +
> static const struct bpf_func_proto *sk_skb_func_proto(enum bpf_func_id func_id)
> {
> switch (func_id) {
> @@ -3892,6 +3940,32 @@ static bool sk_skb_is_valid_access(int off, int size,
> return bpf_skb_is_valid_access(off, size, type, info);
> }
>
> +static bool sk_msg_is_valid_access(int off, int size,
> + enum bpf_access_type type,
> + struct bpf_insn_access_aux *info)
> +{
> + if (type == BPF_WRITE)
> + return false;
> +
> + switch (off) {
> + case offsetof(struct sk_msg_md, data):
> + info->reg_type = PTR_TO_PACKET;
> + break;
> + case offsetof(struct sk_msg_md, data_end):
> + info->reg_type = PTR_TO_PACKET_END;
> + break;
> + }
> +
> + if (off < 0 || off >= sizeof(struct sk_msg_md))
> + return false;
> + if (off % size != 0)
> + return false;
> + if (size != sizeof(__u32))
> + return false;
> +
> + return true;
> +}
> +
> static u32 bpf_convert_ctx_access(enum bpf_access_type type,
> const struct bpf_insn *si,
> struct bpf_insn *insn_buf,
> @@ -4522,6 +4596,29 @@ static u32 sk_skb_convert_ctx_access(enum bpf_access_type type,
> return insn - insn_buf;
> }
>
> +static u32 sk_msg_convert_ctx_access(enum bpf_access_type type,
> + const struct bpf_insn *si,
> + struct bpf_insn *insn_buf,
> + struct bpf_prog *prog, u32 *target_size)
> +{
> + struct bpf_insn *insn = insn_buf;
> +
> + switch (si->off) {
> + case offsetof(struct sk_msg_md, data):
> + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data),
> + si->dst_reg, si->src_reg,
> + offsetof(struct sk_msg_buff, data));
> + break;
> + case offsetof(struct sk_msg_md, data_end):
> + *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_msg_buff, data_end),
> + si->dst_reg, si->src_reg,
> + offsetof(struct sk_msg_buff, data_end));
> + break;
> + }
> +
> + return insn - insn_buf;
> +}
> +
> const struct bpf_verifier_ops sk_filter_verifier_ops = {
> .get_func_proto = sk_filter_func_proto,
> .is_valid_access = sk_filter_is_valid_access,
> @@ -4611,6 +4708,15 @@ static u32 sk_skb_convert_ctx_access(enum bpf_access_type type,
> const struct bpf_prog_ops sk_skb_prog_ops = {
> };
>
> +const struct bpf_verifier_ops sk_msg_verifier_ops = {
> + .get_func_proto = sk_msg_func_proto,
> + .is_valid_access = sk_msg_is_valid_access,
> + .convert_ctx_access = sk_msg_convert_ctx_access,
> +};
> +
> +const struct bpf_prog_ops sk_msg_prog_ops = {
> +};
> +
> int sk_detach_filter(struct sock *sk)
> {
> int ret = -ENOENT;
>
Powered by blists - more mailing lists