Commit d2f77c53 authored by Paolo Abeni's avatar Paolo Abeni Committed by David S. Miller
Browse files

mptcp: check for plain TCP sock at accept time



This cleanup the code a bit and avoid corrupted states
on weird syscall sequence (accept(), connect()).

Signed-off-by: default avatarPaolo Abeni <pabeni@redhat.com>
Signed-off-by: default avatarDavide Caratti <dcaratti@redhat.com>
Reviewed-by: default avatarMat Martineau <mathew.j.martineau@linux.intel.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 8fd73804
Loading
Loading
Loading
Loading
+7 −62
Original line number Diff line number Diff line
@@ -52,13 +52,10 @@ static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
	return msk->subflow;
}

static struct socket *mptcp_is_tcpsk(struct sock *sk)
static bool mptcp_is_tcpsk(struct sock *sk)
{
	struct socket *sock = sk->sk_socket;

	if (sock->sk != sk)
		return NULL;

	if (unlikely(sk->sk_prot == &tcp_prot)) {
		/* we are being invoked after mptcp_accept() has
		 * accepted a non-mp-capable flow: sk is a tcp_sk,
@@ -68,27 +65,21 @@ static struct socket *mptcp_is_tcpsk(struct sock *sk)
		 * bypass mptcp.
		 */
		sock->ops = &inet_stream_ops;
		return sock;
		return true;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
	} else if (unlikely(sk->sk_prot == &tcpv6_prot)) {
		sock->ops = &inet6_stream_ops;
		return sock;
		return true;
#endif
	}

	return NULL;
	return false;
}

static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk)
{
	struct socket *sock;

	sock_owned_by_me((const struct sock *)msk);

	sock = mptcp_is_tcpsk((struct sock *)msk);
	if (unlikely(sock))
		return sock;

	if (likely(!__mptcp_check_fallback(msk)))
		return NULL;

@@ -1466,7 +1457,6 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
		return NULL;

	pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));

	if (sk_is_mptcp(newsk)) {
		struct mptcp_subflow_context *subflow;
		struct sock *new_mptcp_sock;
@@ -1821,42 +1811,6 @@ unlock:
	return err;
}

static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
			    int peer)
{
	if (sock->sk->sk_prot == &tcp_prot) {
		/* we are being invoked from __sys_accept4, after
		 * mptcp_accept() has just accepted a non-mp-capable
		 * flow: sk is a tcp_sk, not an mptcp one.
		 *
		 * Hand the socket over to tcp so all further socket ops
		 * bypass mptcp.
		 */
		sock->ops = &inet_stream_ops;
	}

	return inet_getname(sock, uaddr, peer);
}

#if IS_ENABLED(CONFIG_MPTCP_IPV6)
static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
			    int peer)
{
	if (sock->sk->sk_prot == &tcpv6_prot) {
		/* we are being invoked from __sys_accept4 after
		 * mptcp_accept() has accepted a non-mp-capable
		 * subflow: sk is a tcp_sk, not mptcp.
		 *
		 * Hand the socket over to tcp so all further
		 * socket ops bypass mptcp.
		 */
		sock->ops = &inet6_stream_ops;
	}

	return inet6_getname(sock, uaddr, peer);
}
#endif

static int mptcp_listen(struct socket *sock, int backlog)
{
	struct mptcp_sock *msk = mptcp_sk(sock->sk);
@@ -1885,15 +1839,6 @@ unlock:
	return err;
}

static bool is_tcp_proto(const struct proto *p)
{
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
	return p == &tcp_prot || p == &tcpv6_prot;
#else
	return p == &tcp_prot;
#endif
}

static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
			       int flags, bool kern)
{
@@ -1915,7 +1860,7 @@ static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
	release_sock(sock->sk);

	err = ssock->ops->accept(sock, newsock, flags, kern);
	if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
	if (err == 0 && !mptcp_is_tcpsk(newsock->sk)) {
		struct mptcp_sock *msk = mptcp_sk(newsock->sk);
		struct mptcp_subflow_context *subflow;

@@ -2011,7 +1956,7 @@ static const struct proto_ops mptcp_stream_ops = {
	.connect	   = mptcp_stream_connect,
	.socketpair	   = sock_no_socketpair,
	.accept		   = mptcp_stream_accept,
	.getname	   = mptcp_v4_getname,
	.getname	   = inet_getname,
	.poll		   = mptcp_poll,
	.ioctl		   = inet_ioctl,
	.gettstamp	   = sock_gettstamp,
@@ -2065,7 +2010,7 @@ static const struct proto_ops mptcp_v6_stream_ops = {
	.connect	   = mptcp_stream_connect,
	.socketpair	   = sock_no_socketpair,
	.accept		   = mptcp_stream_accept,
	.getname	   = mptcp_v6_getname,
	.getname	   = inet6_getname,
	.poll		   = mptcp_poll,
	.ioctl		   = inet6_ioctl,
	.gettstamp	   = sock_gettstamp,