Hi Vakul,

Only minor comments, mostly looks good to me.  Thanks

> +/* This function decrypts the input skb into either out_iov or in out_sg
> + * or in skb buffers itself. The input parameter 'zc' indicates if
> + * zero-copy mode needs to be tried or not. With zero-copy mode, either
> + * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
> + * NULL, then the decryption happens inside skb buffers itself, i.e.
> + * zero-copy gets disabled and 'zc' is updated.
> + */
> +
> +static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
> +                         struct iov_iter *out_iov,
> +                         struct scatterlist *out_sg,
> +                         int *chunk, bool *zc)
> +{
> +     struct tls_context *tls_ctx = tls_get_ctx(sk);
> +     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
> +     struct strp_msg *rxm = strp_msg(skb);
> +     int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
> +     struct aead_request *aead_req;
> +     struct sk_buff *unused;
> +     u8 *aad, *iv, *mem = NULL;
> +     struct scatterlist *sgin = NULL;
> +     struct scatterlist *sgout = NULL;
> +     const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
> +
> +     if (*zc && (out_iov || out_sg)) {
> +             if (out_iov)
> +                     n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
> +             else if (out_sg)
> +                     n_sgout = sg_nents(out_sg);
> +             else
> +                     goto no_zerocopy;

Is the last else necessary?  It looks like the if already checks for
out_iov || out_sg.

>               struct scatterlist *sgout)
>  {
> -     struct tls_context *tls_ctx = tls_get_ctx(sk);
> -     struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
> -     char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
> -     struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
> -     struct scatterlist *sgin = &sgin_arr[0];
> -     struct strp_msg *rxm = strp_msg(skb);
> -     int ret, nsg;
> -     struct sk_buff *unused;
> -
> -     ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
> -                         iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
> -                         tls_ctx->rx.iv_size);
> -     if (ret < 0)
> -             return ret;
> -
> -     memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
> -     if (!sgout) {
> -             nsg = skb_cow_data(skb, 0, &unused);
> -     } else {
> -             nsg = skb_nsg(skb,
> -                           rxm->offset + tls_ctx->rx.prepend_size,
> -                           rxm->full_len - tls_ctx->rx.prepend_size);
> -             if (nsg <= 0)
> -                     return nsg;
> -     }
> -
> -     // We need one extra for ctx->rx_aad_ciphertext
> -     nsg++;
> -
> -     if (nsg > ARRAY_SIZE(sgin_arr))
> -             sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
> -
> -     if (!sgout)
> -             sgout = sgin;
> -
> -     sg_init_table(sgin, nsg);
> -     sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);
> -
> -     nsg = skb_to_sgvec(skb, &sgin[1],
> -                        rxm->offset + tls_ctx->rx.prepend_size,
> -                        rxm->full_len - tls_ctx->rx.prepend_size);
> -     if (nsg < 0) {
> -             ret = nsg;
> -             goto out;
> -     }
> -
> -     tls_make_aad(ctx->rx_aad_ciphertext,
> -                  rxm->full_len - tls_ctx->rx.overhead_size,
> -                  tls_ctx->rx.rec_seq,
> -                  tls_ctx->rx.rec_seq_size,
> -                  ctx->control);
> -
> -     ret = tls_do_decryption(sk, sgin, sgout, iv,
> -                             rxm->full_len - tls_ctx->rx.overhead_size,
> -                             skb, sk->sk_allocation);
> -
> -out:
> -     if (sgin != &sgin_arr[0])
> -             kfree(sgin);
> +     bool zc = true;
> +     int chunk;
>  
> -     return ret;
> +     return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
>  }

Can we merge this function to callsites?  It's pretty useless now.

>  
>  static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
> @@ -899,43 +964,17 @@ int tls_sw_recvmsg(struct sock *sk,
>               }
>  
>               if (!ctx->decrypted) {
> -                     int page_count;
> -                     int to_copy;

Reply via email to