lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [day] [month] [year] [list]
Date:   Tue, 29 Dec 2020 14:07:17 +0300
From:   Arseny Krasnov <arseny.krasnov@...persky.com>
To:     "David S. Miller" <davem@...emloft.net>,
        Jakub Kicinski <kuba@...nel.org>
CC:     <arseny.krasnov@...persky.com>, <netdev@...r.kernel.org>,
        <linux-kernel@...r.kernel.org>
Subject: [PATCH 3/3] vsock: support for SOCK_SEQPACKET socket.

	1) Add socket ops for SOCK_SEQPACKET type.
	2) For receive, create another loop. It looks like
	   stream receive loop, but it doesn't call notify
	   callbacks, it doesn't care about 'SO_SNDLOWAT' and
	   'SO_RCVLOWAT' values, it waits until whole record is
	   received or error in found during receiving and it takes
	   care about 'MSG_TRUNC' flag.
	3) Update some comments('stream' -> 'connect oriented').

Signed-off-by: Arseny Krasnov <arseny.krasnov@...persky.com>
---
 net/vmw_vsock/af_vsock.c | 460 +++++++++++++++++++++++++++++++--------
 1 file changed, 366 insertions(+), 94 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b12d3a322242..be488d1a1fc7 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -415,8 +415,8 @@ static void vsock_deassign_transport(struct vsock_sock *vsk)
 
 /* Assign a transport to a socket and call the .init transport callback.
  *
- * Note: for stream socket this must be called when vsk->remote_addr is set
- * (e.g. during the connect() or when a connection request on a listener
+ * Note: for connect oriented socket this must be called when vsk->remote_addr
+ * is set (e.g. during the connect() or when a connection request on a listener
  * socket is received).
  * The vsk->remote_addr is used to decide which transport to use:
  *  - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if
@@ -452,6 +452,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 		new_transport = transport_dgram;
 		break;
 	case SOCK_STREAM:
+	case SOCK_SEQPACKET:
 		if (vsock_use_local_transport(remote_cid))
 			new_transport = transport_local;
 		else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g ||
@@ -459,6 +460,12 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 			new_transport = transport_g2h;
 		else
 			new_transport = transport_h2g;
+
+		if (sk->sk_type == SOCK_SEQPACKET) {
+			if (!new_transport->seqpacket_seq_send_len ||
+			    !new_transport->seqpacket_seq_get_len)
+				return -ENODEV;
+		}
 		break;
 	default:
 		return -ESOCKTNOSUPPORT;
@@ -469,10 +476,10 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
 			return 0;
 
 		/* transport->release() must be called with sock lock acquired.
-		 * This path can only be taken during vsock_stream_connect(),
-		 * where we have already held the sock lock.
-		 * In the other cases, this function is called on a new socket
-		 * which is not assigned to any transport.
+		 * This path can only be taken during vsock_connect(), where we
+		 * have already held the sock lock. In the other cases, this
+		 * function is called on a new socket which is not assigned to
+		 * any transport.
 		 */
 		vsk->transport->release(vsk);
 		vsock_deassign_transport(vsk);
@@ -604,8 +611,8 @@ static void vsock_pending_work(struct work_struct *work)
 
 /**** SOCKET OPERATIONS ****/
 
-static int __vsock_bind_stream(struct vsock_sock *vsk,
-			       struct sockaddr_vm *addr)
+static int __vsock_bind_connectible(struct vsock_sock *vsk,
+				    struct sockaddr_vm *addr)
 {
 	static u32 port;
 	struct sockaddr_vm new_addr;
@@ -649,9 +656,10 @@ static int __vsock_bind_stream(struct vsock_sock *vsk,
 
 	vsock_addr_init(&vsk->local_addr, new_addr.svm_cid, new_addr.svm_port);
 
-	/* Remove stream sockets from the unbound list and add them to the hash
-	 * table for easy lookup by its address.  The unbound list is simply an
-	 * extra entry at the end of the hash table, a trick used by AF_UNIX.
+	/* Remove connect oriented sockets from the unbound list and add them
+	 * to the hash table for easy lookup by its address.  The unbound list
+	 * is simply an extra entry at the end of the hash table, a trick used
+	 * by AF_UNIX.
 	 */
 	__vsock_remove_bound(vsk);
 	__vsock_insert_bound(vsock_bound_sockets(&vsk->local_addr), vsk);
@@ -684,8 +692,9 @@ static int __vsock_bind(struct sock *sk, struct sockaddr_vm *addr)
 
 	switch (sk->sk_socket->type) {
 	case SOCK_STREAM:
+	case SOCK_SEQPACKET:
 		spin_lock_bh(&vsock_table_lock);
-		retval = __vsock_bind_stream(vsk, addr);
+		retval = __vsock_bind_connectible(vsk, addr);
 		spin_unlock_bh(&vsock_table_lock);
 		break;
 
@@ -767,6 +776,11 @@ static struct sock *__vsock_create(struct net *net,
 	return sk;
 }
 
+static bool sock_type_connectible(u16 type)
+{
+	return (type == SOCK_STREAM || type == SOCK_SEQPACKET);
+}
+
 static void __vsock_release(struct sock *sk, int level)
 {
 	if (sk) {
@@ -785,7 +799,7 @@ static void __vsock_release(struct sock *sk, int level)
 
 		if (vsk->transport)
 			vsk->transport->release(vsk);
-		else if (sk->sk_type == SOCK_STREAM)
+		else if (sock_type_connectible(sk->sk_type))
 			vsock_remove_sock(vsk);
 
 		sock_orphan(sk);
@@ -936,16 +950,16 @@ static int vsock_shutdown(struct socket *sock, int mode)
 	if ((mode & ~SHUTDOWN_MASK) || !mode)
 		return -EINVAL;
 
-	/* If this is a STREAM socket and it is not connected then bail out
-	 * immediately.  If it is a DGRAM socket then we must first kick the
-	 * socket so that it wakes up from any sleeping calls, for example
-	 * recv(), and then afterwards return the error.
+	/* If this is a connect oriented socket and it is not connected then
+	 * bail out immediately.  If it is a DGRAM socket then we must first
+	 * kick the socket so that it wakes up from any sleeping calls, for
+	 * example recv(), and then afterwards return the error.
 	 */
 
 	sk = sock->sk;
 	if (sock->state == SS_UNCONNECTED) {
 		err = -ENOTCONN;
-		if (sk->sk_type == SOCK_STREAM)
+		if (sock_type_connectible(sk->sk_type))
 			return err;
 	} else {
 		sock->state = SS_DISCONNECTING;
@@ -960,7 +974,7 @@ static int vsock_shutdown(struct socket *sock, int mode)
 		sk->sk_state_change(sk);
 		release_sock(sk);
 
-		if (sk->sk_type == SOCK_STREAM) {
+		if (sock_type_connectible(sk->sk_type)) {
 			sock_reset_flag(sk, SOCK_DONE);
 			vsock_send_shutdown(sk, mode);
 		}
@@ -1013,7 +1027,7 @@ static __poll_t vsock_poll(struct file *file, struct socket *sock,
 		if (!(sk->sk_shutdown & SEND_SHUTDOWN))
 			mask |= EPOLLOUT | EPOLLWRNORM | EPOLLWRBAND;
 
-	} else if (sock->type == SOCK_STREAM) {
+	} else if (sock_type_connectible(sk->sk_type)) {
 		const struct vsock_transport *transport = vsk->transport;
 		lock_sock(sk);
 
@@ -1259,8 +1273,8 @@ static void vsock_connect_timeout(struct work_struct *work)
 	sock_put(sk);
 }
 
-static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
-				int addr_len, int flags)
+static int vsock_connect(struct socket *sock, struct sockaddr *addr,
+			 int addr_len, int flags)
 {
 	int err;
 	struct sock *sk;
@@ -1410,7 +1424,7 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
 
 	lock_sock(listener);
 
-	if (sock->type != SOCK_STREAM) {
+	if (!sock_type_connectible(sock->type)) {
 		err = -EOPNOTSUPP;
 		goto out;
 	}
@@ -1477,6 +1491,18 @@ static int vsock_accept(struct socket *sock, struct socket *newsock, int flags,
 	return err;
 }
 
+static int vsock_stream_connect(struct socket *sock, struct sockaddr *addr,
+				int addr_len, int flags)
+{
+	return vsock_connect(sock, addr, addr_len, flags);
+}
+
+static int vsock_seqpacket_connect(struct socket *sock, struct sockaddr *addr,
+				   int addr_len, int flags)
+{
+	return vsock_connect(sock, addr, addr_len, flags);
+}
+
 static int vsock_listen(struct socket *sock, int backlog)
 {
 	int err;
@@ -1487,7 +1513,7 @@ static int vsock_listen(struct socket *sock, int backlog)
 
 	lock_sock(sk);
 
-	if (sock->type != SOCK_STREAM) {
+	if (!sock_type_connectible(sk->sk_type)) {
 		err = -EOPNOTSUPP;
 		goto out;
 	}
@@ -1531,11 +1557,11 @@ static void vsock_update_buffer_size(struct vsock_sock *vsk,
 	vsk->buffer_size = val;
 }
 
-static int vsock_stream_setsockopt(struct socket *sock,
-				   int level,
-				   int optname,
-				   sockptr_t optval,
-				   unsigned int optlen)
+static int vsock_setsockopt(struct socket *sock,
+			    int level,
+			    int optname,
+			    sockptr_t optval,
+			    unsigned int optlen)
 {
 	int err;
 	struct sock *sk;
@@ -1612,6 +1638,24 @@ static int vsock_stream_setsockopt(struct socket *sock,
 	return err;
 }
 
+static int vsock_seqpacket_setsockopt(struct socket *sock,
+				      int level,
+				      int optname,
+				      sockptr_t optval,
+				      unsigned int optlen)
+{
+	return vsock_setsockopt(sock, level, optname, optval, optlen);
+}
+
+static int vsock_stream_setsockopt(struct socket *sock,
+				   int level,
+				   int optname,
+				   sockptr_t optval,
+				   unsigned int optlen)
+{
+	return vsock_setsockopt(sock, level, optname, optval, optlen);
+}
+
 static int vsock_stream_getsockopt(struct socket *sock,
 				   int level, int optname,
 				   char __user *optval,
@@ -1683,8 +1727,16 @@ static int vsock_stream_getsockopt(struct socket *sock,
 	return 0;
 }
 
-static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
-				size_t len)
+static int vsock_seqpacket_getsockopt(struct socket *sock,
+				      int level, int optname,
+				      char __user *optval,
+				      int __user *optlen)
+{
+	return vsock_stream_getsockopt(sock, level, optname, optval, optlen);
+}
+
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+				     size_t len)
 {
 	struct sock *sk;
 	struct vsock_sock *vsk;
@@ -1706,7 +1758,9 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 
 	lock_sock(sk);
 
-	/* Callers should not provide a destination with stream sockets. */
+	/* Callers should not provide a destination with connect oriented
+	 * sockets.
+	 */
 	if (msg->msg_namelen) {
 		err = sk->sk_state == TCP_ESTABLISHED ? -EISCONN : -EOPNOTSUPP;
 		goto out;
@@ -1737,6 +1791,12 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 	if (err < 0)
 		goto out;
 
+	if (sk->sk_type == SOCK_SEQPACKET) {
+		err = transport->seqpacket_seq_send_len(vsk, len);
+		if (err < 0)
+			goto out;
+	}
+
 	while (total_written < len) {
 		ssize_t written;
 
@@ -1796,10 +1856,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 		 * smaller than the queue size.  It is the caller's
 		 * responsibility to check how many bytes we were able to send.
 		 */
-
-		written = transport->stream_enqueue(
-				vsk, msg,
-				len - total_written);
+		written = transport->stream_enqueue(vsk, msg,
+						    len - total_written);
 		if (written < 0) {
 			err = -ENOMEM;
 			goto out_err;
@@ -1815,36 +1873,96 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 	}
 
 out_err:
-	if (total_written > 0)
-		err = total_written;
+	if (total_written > 0) {
+		/* Return number of written bytes only if:
+		 * 1) SOCK_STREAM socket.
+		 * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+		 */
+		if (sk->sk_type == SOCK_STREAM || total_written == len)
+			err = total_written;
+	}
 out:
 	release_sock(sk);
 	return err;
 }
 
+static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
+				size_t len)
+{
+	return vsock_connectible_sendmsg(sock, msg, len);
+}
 
-static int
-vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
-		     int flags)
+static int vsock_seqpacket_sendmsg(struct socket *sock, struct msghdr *msg,
+				   size_t len)
 {
-	struct sock *sk;
+	return vsock_connectible_sendmsg(sock, msg, len);
+}
+
+static int vsock_wait_data(struct sock *sk, struct wait_queue_entry *wait,
+			   long timeout,
+			   struct vsock_transport_recv_notify_data *recv_data,
+			   size_t target)
+{
+	int err = 0;
 	struct vsock_sock *vsk;
 	const struct vsock_transport *transport;
-	int err;
-	size_t target;
-	ssize_t copied;
-	long timeout;
-	struct vsock_transport_recv_notify_data recv_data;
 
-	DEFINE_WAIT(wait);
-
-	sk = sock->sk;
 	vsk = vsock_sk(sk);
 	transport = vsk->transport;
-	err = 0;
 
+	if (sk->sk_err != 0 ||
+	    (sk->sk_shutdown & RCV_SHUTDOWN) ||
+	    (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+		finish_wait(sk_sleep(sk), wait);
+		return -1;
+	}
+	/* Don't wait for non-blocking sockets. */
+	if (timeout == 0) {
+		err = -EAGAIN;
+		finish_wait(sk_sleep(sk), wait);
+		return err;
+	}
+
+	if (sk->sk_type == SOCK_STREAM) {
+		err = transport->notify_recv_pre_block(vsk, target, recv_data);
+		if (err < 0) {
+			finish_wait(sk_sleep(sk), wait);
+			return err;
+		}
+	}
+
+	release_sock(sk);
+	timeout = schedule_timeout(timeout);
 	lock_sock(sk);
 
+	if (signal_pending(current)) {
+		err = sock_intr_errno(timeout);
+		finish_wait(sk_sleep(sk), wait);
+	} else if (timeout == 0) {
+		err = -EAGAIN;
+		finish_wait(sk_sleep(sk), wait);
+	}
+
+	return err;
+}
+
+static int vsock_wait_data_seqpacket(struct sock *sk, struct wait_queue_entry *wait,
+				     long timeout)
+{
+	return vsock_wait_data(sk, wait, timeout, NULL, 0);
+}
+
+static int vsock_pre_recv_check(struct socket *sock,
+				int flags, size_t len, int *err)
+{
+	struct sock *sk;
+	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
+
+	sk = sock->sk;
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
+
 	if (!transport || sk->sk_state != TCP_ESTABLISHED) {
 		/* Recvmsg is supposed to return 0 if a peer performs an
 		 * orderly shutdown. Differentiate between that case and when a
@@ -1852,16 +1970,16 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 		 * SOCK_DONE flag.
 		 */
 		if (sock_flag(sk, SOCK_DONE))
-			err = 0;
+			*err = 0;
 		else
-			err = -ENOTCONN;
+			*err = -ENOTCONN;
 
-		goto out;
+		return false;
 	}
 
 	if (flags & MSG_OOB) {
-		err = -EOPNOTSUPP;
-		goto out;
+		*err = -EOPNOTSUPP;
+		return false;
 	}
 
 	/* We don't check peer_shutdown flag here since peer may actually shut
@@ -1869,17 +1987,143 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	 * receive.
 	 */
 	if (sk->sk_shutdown & RCV_SHUTDOWN) {
-		err = 0;
-		goto out;
+		*err = 0;
+		return false;
 	}
 
 	/* It is valid on Linux to pass in a zero-length receive buffer.  This
 	 * is not an error.  We may as well bail out now.
 	 */
 	if (!len) {
+		*err = 0;
+		return false;
+	}
+
+	return true;
+}
+
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+				     size_t len, int flags)
+{
+	int err = 0;
+	size_t record_len;
+	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
+	long timeout;
+	ssize_t dequeued_total = 0;
+	unsigned long orig_nr_segs;
+	const struct iovec *orig_iov;
+	DEFINE_WAIT(wait);
+
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
+
+	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	msg->msg_flags &= ~MSG_EOR;
+	orig_nr_segs = msg->msg_iter.nr_segs;
+	orig_iov = msg->msg_iter.iov;
+
+	while (1) {
+		s64 ready;
+
+		prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+		ready = vsock_stream_has_data(vsk);
+
+		if (ready == 0) {
+			if (vsock_wait_data_seqpacket(sk, &wait, timeout)) {
+				/* In case of any loop break(timeout, signal
+				 * interrupt or shutdown), we report user that
+				 * nothing was copied.
+				 */
+				dequeued_total = 0;
+				break;
+			}
+		} else {
+			ssize_t dequeued;
+
+			finish_wait(sk_sleep(sk), &wait);
+
+			if (ready < 0) {
+				err = -ENOMEM;
+				goto out;
+			}
+
+			if (dequeued_total == 0) {
+				record_len =
+					transport->seqpacket_seq_get_len(vsk);
+
+				if (record_len == 0)
+					continue;
+			}
+
+			/* 'msg_iter.count' is number of unused bytes in iov.
+			 * On every copy to iov iterator it is decremented at
+			 * size of data.
+			 */
+			dequeued = transport->stream_dequeue(vsk, msg,
+						msg->msg_iter.count, flags);
+
+			if (dequeued < 0) {
+				dequeued_total = 0;
+
+				if (dequeued == -EAGAIN) {
+					iov_iter_init(&msg->msg_iter, READ,
+						      orig_iov, orig_nr_segs,
+						      len);
+					msg->msg_flags &= ~MSG_EOR;
+					continue;
+				}
+
+				err = -ENOMEM;
+				break;
+			}
+
+			dequeued_total += dequeued;
+
+			if (dequeued_total >= record_len)
+				break;
+		}
+	}
+
+	if (sk->sk_err)
+		err = -sk->sk_err;
+	else if (sk->sk_shutdown & RCV_SHUTDOWN)
 		err = 0;
-		goto out;
+
+	if (dequeued_total > 0) {
+		/* User sets MSG_TRUNC, so return real length of
+		 * packet.
+		 */
+		if (flags & MSG_TRUNC)
+			err = record_len;
+		else
+			err = len - msg->msg_iter.count;
+
+		/* Always set MSG_TRUNC if real length of packet is
+		 * bigger that user buffer.
+		 */
+		if (record_len > len)
+			msg->msg_flags |= MSG_TRUNC;
 	}
+out:
+	return err;
+}
+
+static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+				  size_t len, int flags)
+{
+	int err;
+	const struct vsock_transport *transport;
+	struct vsock_sock *vsk;
+	size_t target;
+	struct vsock_transport_recv_notify_data recv_data;
+	long timeout;
+	ssize_t copied;
+
+	DEFINE_WAIT(wait);
+
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
 
 	/* We must not copy less than target bytes into the user's buffer
 	 * before returning successfully, so we wait for the consume queue to
@@ -1888,10 +2132,12 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	 * queue size.
 	 */
 	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+
 	if (target >= transport->stream_rcvhiwat(vsk)) {
 		err = -ENOMEM;
 		goto out;
 	}
+
 	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 	copied = 0;
 
@@ -1899,7 +2145,6 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	if (err < 0)
 		goto out;
 
-
 	while (1) {
 		s64 ready;
 
@@ -1907,38 +2152,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 		ready = vsock_stream_has_data(vsk);
 
 		if (ready == 0) {
-			if (sk->sk_err != 0 ||
-			    (sk->sk_shutdown & RCV_SHUTDOWN) ||
-			    (vsk->peer_shutdown & SEND_SHUTDOWN)) {
-				finish_wait(sk_sleep(sk), &wait);
+			if (vsock_wait_data(sk, &wait, timeout, &recv_data, target))
 				break;
-			}
-			/* Don't wait for non-blocking sockets. */
-			if (timeout == 0) {
-				err = -EAGAIN;
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
-
-			err = transport->notify_recv_pre_block(
-					vsk, target, &recv_data);
-			if (err < 0) {
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
-			release_sock(sk);
-			timeout = schedule_timeout(timeout);
-			lock_sock(sk);
-
-			if (signal_pending(current)) {
-				err = sock_intr_errno(timeout);
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			} else if (timeout == 0) {
-				err = -EAGAIN;
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
 		} else {
 			ssize_t read;
 
@@ -1959,9 +2174,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 			if (err < 0)
 				break;
 
-			read = transport->stream_dequeue(
-					vsk, msg,
-					len - copied, flags);
+			read = transport->stream_dequeue(vsk, msg, len - copied, flags);
+
 			if (read < 0) {
 				err = -ENOMEM;
 				break;
@@ -1990,11 +2204,45 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	if (copied > 0)
 		err = copied;
 
+out:
+	return err;
+}
+
+static int vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg,
+				     size_t len, int flags)
+{
+	struct sock *sk;
+	int err;
+
+	sk = sock->sk;
+
+	lock_sock(sk);
+
+	if (!vsock_pre_recv_check(sock, flags,  len, &err))
+		goto out;
+
+	if (sk->sk_type == SOCK_STREAM)
+		err = __vsock_stream_recvmsg(sk, msg, len, flags);
+	else
+		err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
+
 out:
 	release_sock(sk);
 	return err;
 }
 
+static int vsock_seqpacket_recvmsg(struct socket *sock, struct msghdr *msg,
+				   size_t len, int flags)
+{
+	return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
+static int vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg,
+				size_t len, int flags)
+{
+	return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
 static const struct proto_ops vsock_stream_ops = {
 	.family = PF_VSOCK,
 	.owner = THIS_MODULE,
@@ -2016,6 +2264,27 @@ static const struct proto_ops vsock_stream_ops = {
 	.sendpage = sock_no_sendpage,
 };
 
+static const struct proto_ops vsock_seqpacket_ops = {
+	.family = PF_VSOCK,
+	.owner = THIS_MODULE,
+	.release = vsock_release,
+	.bind = vsock_bind,
+	.connect = vsock_seqpacket_connect,
+	.socketpair = sock_no_socketpair,
+	.accept = vsock_accept,
+	.getname = vsock_getname,
+	.poll = vsock_poll,
+	.ioctl = sock_no_ioctl,
+	.listen = vsock_listen,
+	.shutdown = vsock_shutdown,
+	.setsockopt = vsock_seqpacket_setsockopt,
+	.getsockopt = vsock_seqpacket_getsockopt,
+	.sendmsg = vsock_seqpacket_sendmsg,
+	.recvmsg = vsock_seqpacket_recvmsg,
+	.mmap = sock_no_mmap,
+	.sendpage = sock_no_sendpage,
+};
+
 static int vsock_create(struct net *net, struct socket *sock,
 			int protocol, int kern)
 {
@@ -2036,6 +2305,9 @@ static int vsock_create(struct net *net, struct socket *sock,
 	case SOCK_STREAM:
 		sock->ops = &vsock_stream_ops;
 		break;
+	case SOCK_SEQPACKET:
+		sock->ops = &vsock_seqpacket_ops;
+		break;
 	default:
 		return -ESOCKTNOSUPPORT;
 	}
-- 
2.25.1

Powered by blists - more mailing lists