bpf: sockmap redirect ingress support
authorJohn Fastabend <john.fastabend@gmail.com>
Wed, 28 Mar 2018 19:49:15 +0000 (12:49 -0700)
committerDaniel Borkmann <daniel@iogearbox.net>
Thu, 29 Mar 2018 22:09:43 +0000 (00:09 +0200)
Add support for the BPF_F_INGRESS flag in sk_msg redirect helper.
To do this add a scatterlist ring for receiving socks to check
before calling into regular recvmsg call path. Additionally, because
the poll wakeup logic only checked the skb recv queue we need to
add a hook in TCP stack (similar to write side) so that we have
a way to wake up polling socks when a scatterlist is redirected
to that sock.

After this all that is needed is for the redirect helper to
push the scatterlist into the psock receive queue.

Signed-off-by: John Fastabend <john.fastabend@gmail.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
include/linux/filter.h
include/net/sock.h
kernel/bpf/sockmap.c
net/core/filter.c
net/ipv4/tcp.c

index c2f167db8bd5ccace957e97c49c9a6f048859743..961cc5d53956d76560c94ef4e7137b7f936a6624 100644 (file)
@@ -521,6 +521,7 @@ struct sk_msg_buff {
        __u32 key;
        __u32 flags;
        struct bpf_map *map;
+       struct list_head list;
 };
 
 /* Compute the linear packet data range [data, data_end) which
index 709311132d4c1d575abfe82542429ce016fdaef7..b8ff435fa96e4f8f8228322883cb0394811f817e 100644 (file)
@@ -1085,6 +1085,7 @@ struct proto {
 #endif
 
        bool                    (*stream_memory_free)(const struct sock *sk);
+       bool                    (*stream_memory_read)(const struct sock *sk);
        /* Memory pressure */
        void                    (*enter_memory_pressure)(struct sock *sk);
        void                    (*leave_memory_pressure)(struct sock *sk);
index 69c5bccabd229f801537f6137e311d075b79db72..402e15466e9f4aae6aec9338c236e87cec1a4259 100644 (file)
@@ -41,6 +41,8 @@
 #include <linux/mm.h>
 #include <net/strparser.h>
 #include <net/tcp.h>
+#include <linux/ptr_ring.h>
+#include <net/inet_common.h>
 
 #define SOCK_CREATE_FLAG_MASK \
        (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
@@ -82,6 +84,7 @@ struct smap_psock {
        int sg_size;
        int eval;
        struct sk_msg_buff *cork;
+       struct list_head ingress;
 
        struct strparser strp;
        struct bpf_prog *bpf_tx_msg;
@@ -103,6 +106,8 @@ struct smap_psock {
 };
 
 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+                          int nonblock, int flags, int *addr_len);
 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
                            int offset, size_t size, int flags);
@@ -112,6 +117,21 @@ static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
        return rcu_dereference_sk_user_data(sk);
 }
 
+static bool bpf_tcp_stream_read(const struct sock *sk)
+{
+       struct smap_psock *psock;
+       bool empty = true;
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto out;
+       empty = list_empty(&psock->ingress);
+out:
+       rcu_read_unlock();
+       return !empty;
+}
+
 static struct proto tcp_bpf_proto;
 static int bpf_tcp_init(struct sock *sk)
 {
@@ -135,6 +155,8 @@ static int bpf_tcp_init(struct sock *sk)
        if (psock->bpf_tx_msg) {
                tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
                tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
+               tcp_bpf_proto.recvmsg = bpf_tcp_recvmsg;
+               tcp_bpf_proto.stream_memory_read = bpf_tcp_stream_read;
        }
 
        sk->sk_prot = &tcp_bpf_proto;
@@ -170,6 +192,7 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
 {
        void (*close_fun)(struct sock *sk, long timeout);
        struct smap_psock_map_entry *e, *tmp;
+       struct sk_msg_buff *md, *mtmp;
        struct smap_psock *psock;
        struct sock *osk;
 
@@ -188,6 +211,12 @@ static void bpf_tcp_close(struct sock *sk, long timeout)
        close_fun = psock->save_close;
 
        write_lock_bh(&sk->sk_callback_lock);
+       list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
+               list_del(&md->list);
+               free_start_sg(psock->sock, md);
+               kfree(md);
+       }
+
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
                osk = cmpxchg(e->entry, sk, NULL);
                if (osk == sk) {
@@ -468,6 +497,72 @@ verdict:
        return _rc;
 }
 
+static int bpf_tcp_ingress(struct sock *sk, int apply_bytes,
+                          struct smap_psock *psock,
+                          struct sk_msg_buff *md, int flags)
+{
+       bool apply = apply_bytes;
+       size_t size, copied = 0;
+       struct sk_msg_buff *r;
+       int err = 0, i;
+
+       r = kzalloc(sizeof(struct sk_msg_buff), __GFP_NOWARN | GFP_KERNEL);
+       if (unlikely(!r))
+               return -ENOMEM;
+
+       lock_sock(sk);
+       r->sg_start = md->sg_start;
+       i = md->sg_start;
+
+       do {
+               r->sg_data[i] = md->sg_data[i];
+
+               size = (apply && apply_bytes < md->sg_data[i].length) ?
+                       apply_bytes : md->sg_data[i].length;
+
+               if (!sk_wmem_schedule(sk, size)) {
+                       if (!copied)
+                               err = -ENOMEM;
+                       break;
+               }
+
+               sk_mem_charge(sk, size);
+               r->sg_data[i].length = size;
+               md->sg_data[i].length -= size;
+               md->sg_data[i].offset += size;
+               copied += size;
+
+               if (md->sg_data[i].length) {
+                       get_page(sg_page(&r->sg_data[i]));
+                       r->sg_end = (i + 1) == MAX_SKB_FRAGS ? 0 : i + 1;
+               } else {
+                       i++;
+                       if (i == MAX_SKB_FRAGS)
+                               i = 0;
+                       r->sg_end = i;
+               }
+
+               if (apply) {
+                       apply_bytes -= size;
+                       if (!apply_bytes)
+                               break;
+               }
+       } while (i != md->sg_end);
+
+       md->sg_start = i;
+
+       if (!err) {
+               list_add_tail(&r->list, &psock->ingress);
+               sk->sk_data_ready(sk);
+       } else {
+               free_start_sg(sk, r);
+               kfree(r);
+       }
+
+       release_sock(sk);
+       return err;
+}
+
 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
                                       struct sk_msg_buff *md,
                                       int flags)
@@ -475,6 +570,7 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
        struct smap_psock *psock;
        struct scatterlist *sg;
        int i, err, free = 0;
+       bool ingress = !!(md->flags & BPF_F_INGRESS);
 
        sg = md->sg_data;
 
@@ -487,9 +583,14 @@ static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
                goto out_rcu;
 
        rcu_read_unlock();
-       lock_sock(sk);
-       err = bpf_tcp_push(sk, send, md, flags, false);
-       release_sock(sk);
+
+       if (ingress) {
+               err = bpf_tcp_ingress(sk, send, psock, md, flags);
+       } else {
+               lock_sock(sk);
+               err = bpf_tcp_push(sk, send, md, flags, false);
+               release_sock(sk);
+       }
        smap_release_sock(psock, sk);
        if (unlikely(err))
                goto out;
@@ -623,6 +724,89 @@ out_err:
        return err;
 }
 
+static int bpf_tcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+                          int nonblock, int flags, int *addr_len)
+{
+       struct iov_iter *iter = &msg->msg_iter;
+       struct smap_psock *psock;
+       int copied = 0;
+
+       if (unlikely(flags & MSG_ERRQUEUE))
+               return inet_recv_error(sk, msg, len, addr_len);
+
+       rcu_read_lock();
+       psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto out;
+
+       if (unlikely(!refcount_inc_not_zero(&psock->refcnt)))
+               goto out;
+       rcu_read_unlock();
+
+       if (!skb_queue_empty(&sk->sk_receive_queue))
+               return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+
+       lock_sock(sk);
+       while (copied != len) {
+               struct scatterlist *sg;
+               struct sk_msg_buff *md;
+               int i;
+
+               md = list_first_entry_or_null(&psock->ingress,
+                                             struct sk_msg_buff, list);
+               if (unlikely(!md))
+                       break;
+               i = md->sg_start;
+               do {
+                       struct page *page;
+                       int n, copy;
+
+                       sg = &md->sg_data[i];
+                       copy = sg->length;
+                       page = sg_page(sg);
+
+                       if (copied + copy > len)
+                               copy = len - copied;
+
+                       n = copy_page_to_iter(page, sg->offset, copy, iter);
+                       if (n != copy) {
+                               md->sg_start = i;
+                               release_sock(sk);
+                               smap_release_sock(psock, sk);
+                               return -EFAULT;
+                       }
+
+                       copied += copy;
+                       sg->offset += copy;
+                       sg->length -= copy;
+                       sk_mem_uncharge(sk, copy);
+
+                       if (!sg->length) {
+                               i++;
+                               if (i == MAX_SKB_FRAGS)
+                                       i = 0;
+                               put_page(page);
+                       }
+                       if (copied == len)
+                               break;
+               } while (i != md->sg_end);
+               md->sg_start = i;
+
+               if (!sg->length && md->sg_start == md->sg_end) {
+                       list_del(&md->list);
+                       kfree(md);
+               }
+       }
+
+       release_sock(sk);
+       smap_release_sock(psock, sk);
+       return copied;
+out:
+       rcu_read_unlock();
+       return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+}
+
+
 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
        int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
@@ -1107,6 +1291,7 @@ static void sock_map_remove_complete(struct bpf_stab *stab)
 static void smap_gc_work(struct work_struct *w)
 {
        struct smap_psock_map_entry *e, *tmp;
+       struct sk_msg_buff *md, *mtmp;
        struct smap_psock *psock;
 
        psock = container_of(w, struct smap_psock, gc_work);
@@ -1131,6 +1316,12 @@ static void smap_gc_work(struct work_struct *w)
                kfree(psock->cork);
        }
 
+       list_for_each_entry_safe(md, mtmp, &psock->ingress, list) {
+               list_del(&md->list);
+               free_start_sg(psock->sock, md);
+               kfree(md);
+       }
+
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
                list_del(&e->list);
                kfree(e);
@@ -1160,6 +1351,7 @@ static struct smap_psock *smap_init_psock(struct sock *sock,
        INIT_WORK(&psock->tx_work, smap_tx_work);
        INIT_WORK(&psock->gc_work, smap_gc_work);
        INIT_LIST_HEAD(&psock->maps);
+       INIT_LIST_HEAD(&psock->ingress);
        refcount_set(&psock->refcnt, 1);
 
        rcu_assign_sk_user_data(sock, psock);
index afd825534ac465f09023c2324740093a98b48609..a5a995e5b38032e90d2dbecdc7a4a6d8d38ae036 100644 (file)
@@ -1894,7 +1894,7 @@ BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg_buff *, msg,
           struct bpf_map *, map, u32, key, u64, flags)
 {
        /* If user passes invalid input drop the packet. */
-       if (unlikely(flags))
+       if (unlikely(flags & ~(BPF_F_INGRESS)))
                return SK_DROP;
 
        msg->key = key;
index 0c31be306572acdecaf45cdb0357bb0f7f9eca8b..bccc4c2700870b8c7ff592a6bd27acebd9bc6471 100644 (file)
@@ -485,6 +485,14 @@ static void tcp_tx_timestamp(struct sock *sk, u16 tsflags)
        }
 }
 
+static inline bool tcp_stream_is_readable(const struct tcp_sock *tp,
+                                         int target, struct sock *sk)
+{
+       return (tp->rcv_nxt - tp->copied_seq >= target) ||
+               (sk->sk_prot->stream_memory_read ?
+               sk->sk_prot->stream_memory_read(sk) : false);
+}
+
 /*
  *     Wait for a TCP event.
  *
@@ -554,7 +562,7 @@ __poll_t tcp_poll(struct file *file, struct socket *sock, poll_table *wait)
                    tp->urg_data)
                        target++;
 
-               if (tp->rcv_nxt - tp->copied_seq >= target)
+               if (tcp_stream_is_readable(tp, target, sk))
                        mask |= EPOLLIN | EPOLLRDNORM;
 
                if (!(sk->sk_shutdown & SEND_SHUTDOWN)) {