On Wed, 24 Apr 2019 12:21:03 -0700, John Fastabend wrote:
> It is possible (via shutdown()) for TCP socks to go through TCP_CLOSE
> state via tcp_disconnect() without calling into close callback. This
> would allow a kTLS enabled socket to exist outside of ESTABLISHED
> state which is not supported.
> 
> Solve this the same way we solved the sock{map|hash} case by adding
> an unhash hook to remove tear down the TLS state.
> 
> In the process we also make the close hook more robust. We add a put
> call into the close path, also in the unhash path, to remove the
> reference to ulp data after free. Its no longer valid and may confuse
> things later if the socket (re)enters kTLS code paths. Second we add
> an 'if(ctx)' check to ensure the ctx is still valid and not released
> from a previous unhash/close path.
> 
> Fixes: d91c3e17f75f2 ("net/tls: Only attach to sockets in ESTABLISHED state")
> Reported-by: Eric Dumazet <[email protected]>
> Signed-off-by: John Fastabend <[email protected]>

Ah, EDOESNTBUILD, now I get to nitpick too? :)

> diff --git a/include/net/tls.h b/include/net/tls.h
> index d9d0ac66f040..ae13ea19b375 100644
> --- a/include/net/tls.h
> +++ b/include/net/tls.h
> @@ -266,6 +266,8 @@ struct tls_context {
>       void (*sk_write_space)(struct sock *sk);
>       void (*sk_destruct)(struct sock *sk);
>       void (*sk_proto_close)(struct sock *sk, long timeout);
> +     void (*sk_proto_unhash)(struct sock *sk);
> +     struct proto *sk_proto;
>  
>       int  (*setsockopt)(struct sock *sk, int level,
>                          int optname, char __user *optval,
> @@ -303,7 +305,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, 
> size_t size);
>  int tls_sw_sendpage(struct sock *sk, struct page *page,
>                   int offset, size_t size, int flags);
>  void tls_sw_close(struct sock *sk, long timeout);
> -void tls_sw_free_resources_tx(struct sock *sk);
> +void tls_sw_free_resources_tx(struct sock *sk, bool locked);
>  void tls_sw_free_resources_rx(struct sock *sk);
>  void tls_sw_release_resources_rx(struct sock *sk);
>  int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
> @@ -504,6 +506,16 @@ static inline void xor_iv_with_seq(int version, char 
> *iv, char *seq)
>       }
>  }
>  
> +static inline void tls_put_ctx(struct sock *sk)
> +{
> +     struct inet_connection_sock *icsk = inet_csk(sk);
> +     struct tls_context *ctx = icsk->icsk_ulp_data;
> +
> +     if (!ctx)
> +             return;
> +     sk->sk_prot = ctx->sk_proto;
> +     icsk->icsk_ulp_data = NULL;
> +}
>  
>  static inline struct tls_sw_context_rx *tls_sw_ctx_rx(
>               const struct tls_context *tls_ctx)
> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
> index 7e546b8ec000..2973048957bd 100644
> --- a/net/tls/tls_main.c
> +++ b/net/tls/tls_main.c
> @@ -261,23 +261,16 @@ static void tls_ctx_free(struct tls_context *ctx)
>       kfree(ctx);
>  }
>  
> -static void tls_sk_proto_close(struct sock *sk, long timeout)
> +static bool tls_sk_proto_destroy(struct sock *sk,
> +                              struct tls_context *ctx, bool destroy)

perhaps this destroy should rather be called locked?  It doesn't really
control destroying AFACT..

>  {
> -     struct tls_context *ctx = tls_get_ctx(sk);
>       long timeo = sock_sndtimeo(sk, 0);
> -     void (*sk_proto_close)(struct sock *sk, long timeout);
> -     bool free_ctx = false;
> -
> -     lock_sock(sk);
> -     sk_proto_close = ctx->sk_proto_close;
>  
>       if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD)
> -             goto skip_tx_cleanup;
> +             return false;
>  
> -     if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) {
> -             free_ctx = true;
> -             goto skip_tx_cleanup;
> -     }
> +     if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)
> +             return true;
>  
>       if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
>               tls_handle_open_record(sk, 0);
> @@ -286,10 +279,10 @@ static void tls_sk_proto_close(struct sock *sk, long 
> timeout)
>       if (ctx->tx_conf == TLS_SW) {
>               kfree(ctx->tx.rec_seq);
>               kfree(ctx->tx.iv);
> -             tls_sw_free_resources_tx(sk);
> +             tls_sw_free_resources_tx(sk, destroy);
>  #ifdef CONFIG_TLS_DEVICE
>       } else if (ctx->tx_conf == TLS_HW) {
> -             tls_device_free_resources_tx(sk);
> +             tls_device_free_resources_tx(sk, destroy);

this part breaks the build tls_device_free_resources_tx() doesn't need
changes.  tls_device_offload_cleanup_rx() will though, cause it sleeps.

>  #endif
>       }
>  
> @@ -310,8 +303,39 @@ static void tls_sk_proto_close(struct sock *sk, long 
> timeout)
>               tls_ctx_free(ctx);
>               ctx = NULL;
>       }
> +     return false;
> +}
> +
> +static void tls_sk_proto_unhash(struct sock *sk)
> +{
> +     struct tls_context *ctx = tls_get_ctx(sk);
> +     void (*sk_proto_unhash)(struct sock *sk);
> +     bool free_ctx;
> +
> +     if (!ctx)
> +             return sk->sk_prot->unhash(sk);
> +     sk_proto_unhash = ctx->sk_proto_unhash;
> +     free_ctx = tls_sk_proto_destroy(sk, ctx, false);
> +     tls_put_ctx(sk);
> +     if (sk_proto_unhash)
> +             sk_proto_unhash(sk);
> +     if (free_ctx)
> +             tls_ctx_free(ctx);
> +}
>  
> -skip_tx_cleanup:
> +static void tls_sk_proto_close(struct sock *sk, long timeout)
> +{
> +     struct tls_context *ctx = tls_get_ctx(sk);
> +     void (*sk_proto_close)(struct sock *sk, long timeout);

reverse xmas tree

> +     bool free_ctx;
> +
> +     if (!ctx)
> +             return sk->sk_prot->destroy(sk);
> +
> +     lock_sock(sk);
> +     sk_proto_close = ctx->sk_proto_close;
> +     free_ctx = tls_sk_proto_destroy(sk, ctx, true);
> +     tls_put_ctx(sk);
>       release_sock(sk);
>       sk_proto_close(sk, timeout);
>       /* free ctx for TLS_HW_RECORD, used by tcp_set_state
> @@ -609,6 +633,8 @@ static struct tls_context *create_ctx(struct sock *sk)
>       ctx->setsockopt = sk->sk_prot->setsockopt;
>       ctx->getsockopt = sk->sk_prot->getsockopt;
>       ctx->sk_proto_close = sk->sk_prot->close;
> +     ctx->sk_proto_unhash = sk->sk_prot->unhash;
> +     ctx->sk_proto = sk->sk_prot;
>       return ctx;
>  }
>  
> @@ -732,6 +758,7 @@ static void build_protos(struct proto 
> prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
>       prot[TLS_BASE][TLS_BASE].setsockopt     = tls_setsockopt;
>       prot[TLS_BASE][TLS_BASE].getsockopt     = tls_getsockopt;
>       prot[TLS_BASE][TLS_BASE].close          = tls_sk_proto_close;
> +     prot[TLS_BASE][TLS_BASE].unhash         = tls_sk_proto_unhash;
>  
>       prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
>       prot[TLS_SW][TLS_BASE].sendmsg          = tls_sw_sendmsg;
> diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
> index f780b473827b..0577633c319b 100644
> --- a/net/tls/tls_sw.c
> +++ b/net/tls/tls_sw.c
> @@ -2044,7 +2044,7 @@ static void tls_data_ready(struct sock *sk)
>       }
>  }
>  
> -void tls_sw_free_resources_tx(struct sock *sk)
> +void tls_sw_free_resources_tx(struct sock *sk, bool locked)
>  {
>       struct tls_context *tls_ctx = tls_get_ctx(sk);
>       struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
> @@ -2055,9 +2055,11 @@ void tls_sw_free_resources_tx(struct sock *sk)
>       if (atomic_read(&ctx->encrypt_pending))
>               crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
>  
> -     release_sock(sk);
> +     if (locked)
> +             release_sock(sk);
>       cancel_delayed_work_sync(&ctx->tx_work.work);
> -     lock_sock(sk);
> +     if (locked)
> +             lock_sock(sk);
>  
>       /* Tx whatever records we can transmit and abandon the rest */
>       tls_tx_records(sk, -1);
> @@ -2080,7 +2082,10 @@ void tls_sw_free_resources_tx(struct sock *sk)
>               kfree(rec);
>       }
>  
> -     crypto_free_aead(ctx->aead_send);
> +     if (ctx->aead_send) {
> +             crypto_free_aead(ctx->aead_send);
> +             ctx->aead_send = NULL;
> +     }
>       tls_free_open_rec(sk);
>  
>       kfree(ctx);
> 

Reply via email to