tls: Fix recvmsg() to be able to peek across multiple records
authorVakul Garg <vakul.garg@nxp.com>
Wed, 16 Jan 2019 10:40:16 +0000 (10:40 +0000)
committerDavid S. Miller <davem@davemloft.net>
Thu, 17 Jan 2019 22:20:40 +0000 (14:20 -0800)
This fixes recvmsg() to be able to peek across multiple tls records.
Without this patch, the tls's selftests test case
'recv_peek_large_buf_mult_recs' fails. Each tls receive context now
maintains a 'rx_list' to retain incoming skb carrying tls records. If a
tls record needs to be retained e.g. for peek case or for the case when
the buffer passed to recvmsg() has a length smaller than decrypted
record length, then it is added to 'rx_list'. Additionally, records are
added in 'rx_list' if the crypto operation runs in async mode. The
records are dequeued from 'rx_list' after the decrypted data is consumed
by copying into the buffer passed to recvmsg(). In case, the MSG_PEEK
flag is used in recvmsg(), then records are not consumed or removed
from the 'rx_list'.

Signed-off-by: Vakul Garg <vakul.garg@nxp.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/tls.h
net/tls/tls_sw.c

index 2a6ac8d642afa0d358d6d3633eb099f47647626f..90bf52db573e9c3ea702147f64adce1cd5143732 100644 (file)
@@ -145,12 +145,13 @@ struct tls_sw_context_tx {
 struct tls_sw_context_rx {
        struct crypto_aead *aead_recv;
        struct crypto_wait async_wait;
-
        struct strparser strp;
+       struct sk_buff_head rx_list;    /* list of decrypted 'data' records */
        void (*saved_data_ready)(struct sock *sk);
 
        struct sk_buff *recv_pkt;
        u8 control;
+       int async_capable;
        bool decrypted;
        atomic_t decrypt_pending;
        bool async_notify;
index b8e50e22b777200ae9b556f112007d3e812d24bc..86b9527c4826b04ad9c0853e90556e38d5ea8291 100644 (file)
@@ -124,6 +124,7 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
 {
        struct aead_request *aead_req = (struct aead_request *)req;
        struct scatterlist *sgout = aead_req->dst;
+       struct scatterlist *sgin = aead_req->src;
        struct tls_sw_context_rx *ctx;
        struct tls_context *tls_ctx;
        struct scatterlist *sg;
@@ -134,12 +135,16 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
        skb = (struct sk_buff *)req->data;
        tls_ctx = tls_get_ctx(skb->sk);
        ctx = tls_sw_ctx_rx(tls_ctx);
-       pending = atomic_dec_return(&ctx->decrypt_pending);
 
        /* Propagate if there was an err */
        if (err) {
                ctx->async_wait.err = err;
                tls_err_abort(skb->sk, err);
+       } else {
+               struct strp_msg *rxm = strp_msg(skb);
+
+               rxm->offset += tls_ctx->rx.prepend_size;
+               rxm->full_len -= tls_ctx->rx.overhead_size;
        }
 
        /* After using skb->sk to propagate sk through crypto async callback
@@ -147,18 +152,21 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
         */
        skb->sk = NULL;
 
-       /* Release the skb, pages and memory allocated for crypto req */
-       kfree_skb(skb);
 
-       /* Skip the first S/G entry as it points to AAD */
-       for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
-               if (!sg)
-                       break;
-               put_page(sg_page(sg));
+       /* Free the destination pages if skb was not decrypted inplace */
+       if (sgout != sgin) {
+               /* Skip the first S/G entry as it points to AAD */
+               for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+                       if (!sg)
+                               break;
+                       put_page(sg_page(sg));
+               }
        }
 
        kfree(aead_req);
 
+       pending = atomic_dec_return(&ctx->decrypt_pending);
+
        if (!pending && READ_ONCE(ctx->async_notify))
                complete(&ctx->async_wait.completion);
 }
@@ -1271,7 +1279,7 @@ out:
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
                            struct iov_iter *out_iov,
                            struct scatterlist *out_sg,
-                           int *chunk, bool *zc)
+                           int *chunk, bool *zc, bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1371,13 +1379,13 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 fallback_to_reg_recv:
                sgout = sgin;
                pages = 0;
-               *chunk = 0;
+               *chunk = data_len;
                *zc = false;
        }
 
        /* Prepare and submit AEAD request */
        err = tls_do_decryption(sk, skb, sgin, sgout, iv,
-                               data_len, aead_req, *zc);
+                               data_len, aead_req, async);
        if (err == -EINPROGRESS)
                return err;
 
@@ -1390,7 +1398,8 @@ fallback_to_reg_recv:
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-                             struct iov_iter *dest, int *chunk, bool *zc)
+                             struct iov_iter *dest, int *chunk, bool *zc,
+                             bool async)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -1403,7 +1412,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
                return err;
 #endif
        if (!ctx->decrypted) {
-               err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
+               err = decrypt_internal(sk, skb, dest, NULL, chunk, zc, async);
                if (err < 0) {
                        if (err == -EINPROGRESS)
                                tls_advance_record_sn(sk, &tls_ctx->rx);
@@ -1429,7 +1438,7 @@ int decrypt_skb(struct sock *sk, struct sk_buff *skb,
        bool zc = true;
        int chunk;
 
-       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+       return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -1456,6 +1465,77 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
        return true;
 }
 
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted data records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+                          struct msghdr *msg,
+                          size_t skip,
+                          size_t len,
+                          bool zc,
+                          bool is_peek)
+{
+       struct sk_buff *skb = skb_peek(&ctx->rx_list);
+       ssize_t copied = 0;
+
+       while (skip && skb) {
+               struct strp_msg *rxm = strp_msg(skb);
+
+               if (skip < rxm->full_len)
+                       break;
+
+               skip = skip - rxm->full_len;
+               skb = skb_peek_next(skb, &ctx->rx_list);
+       }
+
+       while (len && skb) {
+               struct sk_buff *next_skb;
+               struct strp_msg *rxm = strp_msg(skb);
+               int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+               if (!zc || (rxm->full_len - skip) > len) {
+                       int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+                                                   msg, chunk);
+                       if (err < 0)
+                               return err;
+               }
+
+               len = len - chunk;
+               copied = copied + chunk;
+
+               /* Consume the data from record if it is non-peek case*/
+               if (!is_peek) {
+                       rxm->offset = rxm->offset + chunk;
+                       rxm->full_len = rxm->full_len - chunk;
+
+                       /* Return if there is unconsumed data in the record */
+                       if (rxm->full_len - skip)
+                               break;
+               }
+
+               /* The remaining skip-bytes must lie in 1st record in rx_list.
+                * So from the 2nd record, 'skip' should be 0.
+                */
+               skip = 0;
+
+               if (msg)
+                       msg->msg_flags |= MSG_EOR;
+
+               next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+               if (!is_peek) {
+                       skb_unlink(skb, &ctx->rx_list);
+                       kfree_skb(skb);
+               }
+
+               skb = next_skb;
+       }
+
+       return copied;
+}
+
 int tls_sw_recvmsg(struct sock *sk,
                   struct msghdr *msg,
                   size_t len,
@@ -1466,7 +1546,8 @@ int tls_sw_recvmsg(struct sock *sk,
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
        struct sk_psock *psock;
-       unsigned char control;
+       unsigned char control = 0;
+       ssize_t decrypted = 0;
        struct strp_msg *rxm;
        struct sk_buff *skb;
        ssize_t copied = 0;
@@ -1474,6 +1555,7 @@ int tls_sw_recvmsg(struct sock *sk,
        int target, err = 0;
        long timeo;
        bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+       bool is_peek = flags & MSG_PEEK;
        int num_async = 0;
 
        flags |= nonblock;
@@ -1484,11 +1566,28 @@ int tls_sw_recvmsg(struct sock *sk,
        psock = sk_psock_get(sk);
        lock_sock(sk);
 
-       target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
-       timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+       /* Process pending decrypted records. It must be non-zero-copy */
+       err = process_rx_list(ctx, msg, 0, len, false, is_peek);
+       if (err < 0) {
+               tls_err_abort(sk, err);
+               goto end;
+       } else {
+               copied = err;
+       }
+
+       len = len - copied;
+       if (len) {
+               target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+               timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+       } else {
+               goto recv_end;
+       }
+
        do {
-               bool zc = false;
+               bool retain_skb = false;
                bool async = false;
+               bool zc = false;
+               int to_decrypt;
                int chunk = 0;
 
                skb = tls_wait_data(sk, psock, flags, timeo, &err);
@@ -1498,7 +1597,7 @@ int tls_sw_recvmsg(struct sock *sk,
                                                            msg, len, flags);
 
                                if (ret > 0) {
-                                       copied += ret;
+                                       decrypted += ret;
                                        len -= ret;
                                        continue;
                                }
@@ -1525,70 +1624,70 @@ int tls_sw_recvmsg(struct sock *sk,
                        goto recv_end;
                }
 
-               if (!ctx->decrypted) {
-                       int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+               to_decrypt = rxm->full_len - tls_ctx->rx.overhead_size;
 
-                       if (!is_kvec && to_copy <= len &&
-                           likely(!(flags & MSG_PEEK)))
-                               zc = true;
+               if (to_decrypt <= len && !is_kvec && !is_peek)
+                       zc = true;
 
-                       err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-                                                &chunk, &zc);
-                       if (err < 0 && err != -EINPROGRESS) {
-                               tls_err_abort(sk, EBADMSG);
-                               goto recv_end;
-                       }
-
-                       if (err == -EINPROGRESS) {
-                               async = true;
-                               num_async++;
-                               goto pick_next_record;
-                       }
-
-                       ctx->decrypted = true;
+               err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+                                        &chunk, &zc, ctx->async_capable);
+               if (err < 0 && err != -EINPROGRESS) {
+                       tls_err_abort(sk, EBADMSG);
+                       goto recv_end;
                }
 
-               if (!zc) {
-                       chunk = min_t(unsigned int, rxm->full_len, len);
+               if (err == -EINPROGRESS) {
+                       async = true;
+                       num_async++;
+                       goto pick_next_record;
+               } else {
+                       if (!zc) {
+                               if (rxm->full_len > len) {
+                                       retain_skb = true;
+                                       chunk = len;
+                               } else {
+                                       chunk = rxm->full_len;
+                               }
+
+                               err = skb_copy_datagram_msg(skb, rxm->offset,
+                                                           msg, chunk);
+                               if (err < 0)
+                                       goto recv_end;
 
-                       err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-                                                   chunk);
-                       if (err < 0)
-                               goto recv_end;
+                               if (!is_peek) {
+                                       rxm->offset = rxm->offset + chunk;
+                                       rxm->full_len = rxm->full_len - chunk;
+                               }
+                       }
                }
 
 pick_next_record:
-               copied += chunk;
+               if (chunk > len)
+                       chunk = len;
+
+               decrypted += chunk;
                len -= chunk;
-               if (likely(!(flags & MSG_PEEK))) {
-                       u8 control = ctx->control;
-
-                       /* For async, drop current skb reference */
-                       if (async)
-                               skb = NULL;
-
-                       if (tls_sw_advance_skb(sk, skb, chunk)) {
-                               /* Return full control message to
-                                * userspace before trying to parse
-                                * another message type
-                                */
-                               msg->msg_flags |= MSG_EOR;
-                               if (control != TLS_RECORD_TYPE_DATA)
-                                       goto recv_end;
-                       } else {
-                               break;
-                       }
-               } else {
-                       /* MSG_PEEK right now cannot look beyond current skb
-                        * from strparser, meaning we cannot advance skb here
-                        * and thus unpause strparser since we'd loose original
-                        * one.
+
+               /* For async or peek case, queue the current skb */
+               if (async || is_peek || retain_skb) {
+                       skb_queue_tail(&ctx->rx_list, skb);
+                       skb = NULL;
+               }
+
+               if (tls_sw_advance_skb(sk, skb, chunk)) {
+                       /* Return full control message to
+                        * userspace before trying to parse
+                        * another message type
                         */
+                       msg->msg_flags |= MSG_EOR;
+                       if (ctx->control != TLS_RECORD_TYPE_DATA)
+                               goto recv_end;
+               } else {
                        break;
                }
 
                /* If we have a new message from strparser, continue now. */
-               if (copied >= target && !ctx->recv_pkt)
+               if (decrypted >= target && !ctx->recv_pkt)
                        break;
        } while (len);
 
@@ -1602,13 +1701,33 @@ recv_end:
                                /* one of async decrypt failed */
                                tls_err_abort(sk, err);
                                copied = 0;
+                               decrypted = 0;
+                               goto end;
                        }
                } else {
                        reinit_completion(&ctx->async_wait.completion);
                }
                WRITE_ONCE(ctx->async_notify, false);
+
+               /* Drain records from the rx_list & copy if required */
+               if (is_peek || is_kvec)
+                       err = process_rx_list(ctx, msg, copied,
+                                             decrypted, false, is_peek);
+               else
+                       err = process_rx_list(ctx, msg, 0,
+                                             decrypted, true, is_peek);
+               if (err < 0) {
+                       tls_err_abort(sk, err);
+                       copied = 0;
+                       goto end;
+               }
+
+               WARN_ON(decrypted != err);
        }
 
+       copied += decrypted;
+
+end:
        release_sock(sk);
        if (psock)
                sk_psock_put(sk, psock);
@@ -1645,7 +1764,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        }
 
        if (!ctx->decrypted) {
-               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+               err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
 
                if (err < 0) {
                        tls_err_abort(sk, EBADMSG);
@@ -1832,6 +1951,7 @@ void tls_sw_release_resources_rx(struct sock *sk)
        if (ctx->aead_recv) {
                kfree_skb(ctx->recv_pkt);
                ctx->recv_pkt = NULL;
+               skb_queue_purge(&ctx->rx_list);
                crypto_free_aead(ctx->aead_recv);
                strp_stop(&ctx->strp);
                write_lock_bh(&sk->sk_callback_lock);
@@ -1881,6 +2001,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
        struct crypto_aead **aead;
        struct strp_callbacks cb;
        u16 nonce_size, tag_size, iv_size, rec_seq_size;
+       struct crypto_tfm *tfm;
        char *iv, *rec_seq;
        int rc = 0;
 
@@ -1927,6 +2048,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
                crypto_init_wait(&sw_ctx_rx->async_wait);
                crypto_info = &ctx->crypto_recv.info;
                cctx = &ctx->rx;
+               skb_queue_head_init(&sw_ctx_rx->rx_list);
                aead = &sw_ctx_rx->aead_recv;
        }
 
@@ -1994,6 +2116,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
                goto free_aead;
 
        if (sw_ctx_rx) {
+               tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+               sw_ctx_rx->async_capable =
+                       tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
                /* Set up strparser */
                memset(&cb, 0, sizeof(cb));
                cb.rcv_msg = tls_queue;