const struct tcp_congestion_ops *icsk_ca_ops;
const struct inet_connection_sock_af_ops *icsk_af_ops;
const struct tcp_ulp_ops *icsk_ulp_ops;
- void *icsk_ulp_data;
+ void __rcu *icsk_ulp_data;
void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
struct hlist_node icsk_listen_portaddr_node;
unsigned int (*icsk_sync_mss)(struct sock *sk, u32 pmtu);
#include <linux/tcp.h>
#include <linux/skmsg.h>
#include <linux/netdevice.h>
+#include <linux/rcupdate.h>
#include <net/tcp.h>
#include <net/strparser.h>
struct list_head list;
refcount_t refcount;
+ struct rcu_head rcu;
};
enum tls_offload_ctx_dir {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX \
(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)
-void tls_ctx_free(struct tls_context *ctx);
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
int __user *optlen);
{
struct inet_connection_sock *icsk = inet_csk(sk);
- return icsk->icsk_ulp_data;
+ /* Use RCU on icsk_ulp_data only for sock diag code,
+ * TLS data path doesn't need rcu_dereference().
+ */
+ return (__force void *)icsk->icsk_ulp_data;
}
static inline void tls_advance_record_sn(struct sock *sk,
return -EINVAL;
if (unlikely(idx >= map->max_entries))
return -E2BIG;
- if (unlikely(icsk->icsk_ulp_data))
+ if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
return -EINVAL;
link = sk_psock_init_link();
if (ctx->rx_conf == TLS_HW)
kfree(tls_offload_ctx_rx(ctx));
- tls_ctx_free(ctx);
+ tls_ctx_free(NULL, ctx);
}
static void tls_device_gc_task(struct work_struct *work)
ctx->sk_write_space(sk);
}
-void tls_ctx_free(struct tls_context *ctx)
+/**
+ * tls_ctx_free() - free TLS ULP context
+ * @sk: socket to with @ctx is attached
+ * @ctx: TLS context structure
+ *
+ * Free TLS context. If @sk is %NULL caller guarantees that the socket
+ * to which @ctx was attached has no outstanding references.
+ */
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{
if (!ctx)
return;
memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
- kfree(ctx);
+
+ if (sk)
+ kfree_rcu(ctx, rcu);
+ else
+ kfree(ctx);
}
static void tls_sk_proto_cleanup(struct sock *sk,
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
- icsk->icsk_ulp_data = NULL;
+ rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
sk->sk_prot = ctx->sk_proto;
if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space;
ctx->sk_proto_close(sk, timeout);
if (free_ctx)
- tls_ctx_free(ctx);
+ tls_ctx_free(sk, ctx);
}
static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
if (!ctx)
return NULL;
- icsk->icsk_ulp_data = ctx;
+ rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
ctx->setsockopt = sk->sk_prot->setsockopt;
ctx->getsockopt = sk->sk_prot->getsockopt;
ctx->sk_proto_close = sk->sk_prot->close;
ctx->sk_destruct(sk);
/* Free ctx */
- tls_ctx_free(ctx);
- icsk->icsk_ulp_data = NULL;
+ rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
+ tls_ctx_free(sk, ctx);
}
static int tls_hw_prot(struct sock *sk)