Commit 15a7dea7 authored by Jakub Kicinski's avatar Jakub Kicinski Committed by David S. Miller
Browse files

net/tls: use RCU protection on icsk->icsk_ulp_data



We need to make sure context does not get freed while diag
code is interrogating it. Free struct tls_context with
kfree_rcu().

We add the __rcu annotation directly in icsk, and cast it
away in the datapath accessor. Presumably all ULPs will
do a similar thing.

Signed-off-by: default avatarJakub Kicinski <jakub.kicinski@netronome.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent ed6e8103
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -97,7 +97,7 @@ struct inet_connection_sock {
	const struct tcp_congestion_ops *icsk_ca_ops;
	const struct inet_connection_sock_af_ops *icsk_af_ops;
	const struct tcp_ulp_ops  *icsk_ulp_ops;
	void			  *icsk_ulp_data;
	void __rcu		  *icsk_ulp_data;
	void (*icsk_clean_acked)(struct sock *sk, u32 acked_seq);
	struct hlist_node         icsk_listen_portaddr_node;
	unsigned int		  (*icsk_sync_mss)(struct sock *sk, u32 pmtu);
+7 −2
Original line number Diff line number Diff line
@@ -41,6 +41,7 @@
#include <linux/tcp.h>
#include <linux/skmsg.h>
#include <linux/netdevice.h>
#include <linux/rcupdate.h>

#include <net/tcp.h>
#include <net/strparser.h>
@@ -290,6 +291,7 @@ struct tls_context {

	struct list_head list;
	refcount_t refcount;
	struct rcu_head rcu;
};

enum tls_offload_ctx_dir {
@@ -348,7 +350,7 @@ struct tls_offload_context_rx {
#define TLS_OFFLOAD_CONTEXT_SIZE_RX					\
	(sizeof(struct tls_offload_context_rx) + TLS_DRIVER_STATE_SIZE_RX)

void tls_ctx_free(struct tls_context *ctx);
void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
int wait_on_pending_writer(struct sock *sk, long *timeo);
int tls_sk_query(struct sock *sk, int optname, char __user *optval,
		int __user *optlen);
@@ -467,7 +469,10 @@ static inline struct tls_context *tls_get_ctx(const struct sock *sk)
{
	struct inet_connection_sock *icsk = inet_csk(sk);

	return icsk->icsk_ulp_data;
	/* Use RCU on icsk_ulp_data only for sock diag code,
	 * TLS data path doesn't need rcu_dereference().
	 */
	return (__force void *)icsk->icsk_ulp_data;
}

static inline void tls_advance_record_sn(struct sock *sk,
+1 −1
Original line number Diff line number Diff line
@@ -345,7 +345,7 @@ static int sock_map_update_common(struct bpf_map *map, u32 idx,
		return -EINVAL;
	if (unlikely(idx >= map->max_entries))
		return -E2BIG;
	if (unlikely(icsk->icsk_ulp_data))
	if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data)))
		return -EINVAL;

	link = sk_psock_init_link();
+1 −1
Original line number Diff line number Diff line
@@ -61,7 +61,7 @@ static void tls_device_free_ctx(struct tls_context *ctx)
	if (ctx->rx_conf == TLS_HW)
		kfree(tls_offload_ctx_rx(ctx));

	tls_ctx_free(ctx);
	tls_ctx_free(NULL, ctx);
}

static void tls_device_gc_task(struct work_struct *work)
+19 −7
Original line number Diff line number Diff line
@@ -251,13 +251,25 @@ static void tls_write_space(struct sock *sk)
	ctx->sk_write_space(sk);
}

void tls_ctx_free(struct tls_context *ctx)
/**
 * tls_ctx_free() - free TLS ULP context
 * @sk:  socket to with @ctx is attached
 * @ctx: TLS context structure
 *
 * Free TLS context. If @sk is %NULL caller guarantees that the socket
 * to which @ctx was attached has no outstanding references.
 */
void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
{
	if (!ctx)
		return;

	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));

	if (sk)
		kfree_rcu(ctx, rcu);
	else
		kfree(ctx);
}

@@ -306,7 +318,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)

	write_lock_bh(&sk->sk_callback_lock);
	if (free_ctx)
		icsk->icsk_ulp_data = NULL;
		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
	sk->sk_prot = ctx->sk_proto;
	if (sk->sk_write_space == tls_write_space)
		sk->sk_write_space = ctx->sk_write_space;
@@ -321,7 +333,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
	ctx->sk_proto_close(sk, timeout);

	if (free_ctx)
		tls_ctx_free(ctx);
		tls_ctx_free(sk, ctx);
}

static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
@@ -610,7 +622,7 @@ static struct tls_context *create_ctx(struct sock *sk)
	if (!ctx)
		return NULL;

	icsk->icsk_ulp_data = ctx;
	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
	ctx->setsockopt = sk->sk_prot->setsockopt;
	ctx->getsockopt = sk->sk_prot->getsockopt;
	ctx->sk_proto_close = sk->sk_prot->close;
@@ -651,8 +663,8 @@ static void tls_hw_sk_destruct(struct sock *sk)

	ctx->sk_destruct(sk);
	/* Free ctx */
	tls_ctx_free(ctx);
	icsk->icsk_ulp_data = NULL;
	rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
	tls_ctx_free(sk, ctx);
}

static int tls_hw_prot(struct sock *sk)