lists.openwall.net   lists  /  announce  owl-users  owl-dev  john-users  john-dev  passwdqc-users  yescrypt  popa3d-users  /  oss-security  kernel-hardening  musl  sabotage  tlsify  passwords  /  crypt-dev  xvendor  /  Bugtraq  Full-Disclosure  linux-kernel  linux-netdev  linux-ext4  linux-hardening  linux-cve-announce  PHC 
Open Source and information security mailing list archives
 
Hash Suite: Windows password security audit tool. GUI, reports in PDF.
[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Date:   Sun, 3 Jan 2021 23:03:45 +0300
From:   Arseny Krasnov <arseny.krasnov@...persky.com>
To:     Stefan Hajnoczi <stefanha@...hat.com>,
        Stefano Garzarella <sgarzare@...hat.com>,
        "Michael S. Tsirkin" <mst@...hat.com>,
        Jason Wang <jasowang@...hat.com>,
        "David S. Miller" <davem@...emloft.net>,
        Jakub Kicinski <kuba@...nel.org>,
        Jorgen Hansen <jhansen@...are.com>,
        Colin Ian King <colin.king@...onical.com>,
        Arseniy Krasnov <oxffffaa@...il.com>,
        Andra Paraschiv <andraprs@...zon.com>,
        Jeff Vander Stoep <jeffv@...gle.com>
CC:     <kvm@...r.kernel.org>, <virtualization@...ts.linux-foundation.org>,
        <netdev@...r.kernel.org>, <linux-kernel@...r.kernel.org>,
        <stsp2@...dex.ru>, <arseny.krasnov@...persky.com>
Subject: [PATCH 3/5] af_vsock: send/receive loops for SOCK_SEQPACKET.

From: Arseniy Krasnov <oxffffaa@...il.com>

  For send, this patch adds:
  1) Send of record begin marker with record length.
  2) Return error if send of whole record is failed.

  For receive, this patch adds another loop, it looks like
  stream loop, but:
  1) It doesn't call notify callbacks.
  2) It doesn't care about 'SO_SNDLOWAT' and 'SO_RCVLOWAT'
     values.
  3) It waits until whole record is received or error is
     found during receiving.
  3) It processes and sets 'MSG_TRUNC' flag.
---
 net/vmw_vsock/af_vsock.c | 319 +++++++++++++++++++++++++++++++--------
 1 file changed, 256 insertions(+), 63 deletions(-)

diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c
index b12d3a322242..7ff00449a9a2 100644
--- a/net/vmw_vsock/af_vsock.c
+++ b/net/vmw_vsock/af_vsock.c
@@ -1683,8 +1683,8 @@ static int vsock_stream_getsockopt(struct socket *sock,
 	return 0;
 }
 
-static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
-				size_t len)
+static int vsock_connectible_sendmsg(struct socket *sock, struct msghdr *msg,
+				     size_t len)
 {
 	struct sock *sk;
 	struct vsock_sock *vsk;
@@ -1737,6 +1737,12 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 	if (err < 0)
 		goto out;
 
+	if (sk->sk_type == SOCK_SEQPACKET) {
+		err = transport->seqpacket_seq_send_len(vsk, len);
+		if (err < 0)
+			goto out;
+	}
+
 	while (total_written < len) {
 		ssize_t written;
 
@@ -1796,10 +1802,8 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 		 * smaller than the queue size.  It is the caller's
 		 * responsibility to check how many bytes we were able to send.
 		 */
-
-		written = transport->stream_enqueue(
-				vsk, msg,
-				len - total_written);
+		written = transport->stream_enqueue(vsk, msg,
+						    len - total_written);
 		if (written < 0) {
 			err = -ENOMEM;
 			goto out_err;
@@ -1815,36 +1819,96 @@ static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
 	}
 
 out_err:
-	if (total_written > 0)
-		err = total_written;
+	if (total_written > 0) {
+		/* Return number of written bytes only if:
+		 * 1) SOCK_STREAM socket.
+		 * 2) SOCK_SEQPACKET socket when whole buffer is sent.
+		 */
+		if (sk->sk_type == SOCK_STREAM || total_written == len)
+			err = total_written;
+	}
 out:
 	release_sock(sk);
 	return err;
 }
 
+static int vsock_stream_sendmsg(struct socket *sock, struct msghdr *msg,
+				size_t len)
+{
+	return vsock_connectible_sendmsg(sock, msg, len);
+}
 
-static int
-vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
-		     int flags)
+static int vsock_seqpacket_sendmsg(struct socket *sock, struct msghdr *msg,
+				   size_t len)
 {
-	struct sock *sk;
+	return vsock_connectible_sendmsg(sock, msg, len);
+}
+
+static int vsock_wait_data(struct sock *sk, struct wait_queue_entry *wait,
+			   long timeout,
+			   struct vsock_transport_recv_notify_data *recv_data,
+			   size_t target)
+{
+	int err = 0;
 	struct vsock_sock *vsk;
 	const struct vsock_transport *transport;
-	int err;
-	size_t target;
-	ssize_t copied;
-	long timeout;
-	struct vsock_transport_recv_notify_data recv_data;
-
-	DEFINE_WAIT(wait);
 
-	sk = sock->sk;
 	vsk = vsock_sk(sk);
 	transport = vsk->transport;
-	err = 0;
 
+	if (sk->sk_err != 0 ||
+	    (sk->sk_shutdown & RCV_SHUTDOWN) ||
+	    (vsk->peer_shutdown & SEND_SHUTDOWN)) {
+		finish_wait(sk_sleep(sk), wait);
+		return -1;
+	}
+	/* Don't wait for non-blocking sockets. */
+	if (timeout == 0) {
+		err = -EAGAIN;
+		finish_wait(sk_sleep(sk), wait);
+		return err;
+	}
+
+	if (sk->sk_type == SOCK_STREAM) {
+		err = transport->notify_recv_pre_block(vsk, target, recv_data);
+		if (err < 0) {
+			finish_wait(sk_sleep(sk), wait);
+			return err;
+		}
+	}
+
+	release_sock(sk);
+	timeout = schedule_timeout(timeout);
 	lock_sock(sk);
 
+	if (signal_pending(current)) {
+		err = sock_intr_errno(timeout);
+		finish_wait(sk_sleep(sk), wait);
+	} else if (timeout == 0) {
+		err = -EAGAIN;
+		finish_wait(sk_sleep(sk), wait);
+	}
+
+	return err;
+}
+
+static int vsock_wait_data_seqpacket(struct sock *sk, struct wait_queue_entry *wait,
+				     long timeout)
+{
+	return vsock_wait_data(sk, wait, timeout, NULL, 0);
+}
+
+static int vsock_pre_recv_check(struct socket *sock,
+				int flags, size_t len, int *err)
+{
+	struct sock *sk;
+	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
+
+	sk = sock->sk;
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
+
 	if (!transport || sk->sk_state != TCP_ESTABLISHED) {
 		/* Recvmsg is supposed to return 0 if a peer performs an
 		 * orderly shutdown. Differentiate between that case and when a
@@ -1852,16 +1916,16 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 		 * SOCK_DONE flag.
 		 */
 		if (sock_flag(sk, SOCK_DONE))
-			err = 0;
+			*err = 0;
 		else
-			err = -ENOTCONN;
+			*err = -ENOTCONN;
 
-		goto out;
+		return false;
 	}
 
 	if (flags & MSG_OOB) {
-		err = -EOPNOTSUPP;
-		goto out;
+		*err = -EOPNOTSUPP;
+		return false;
 	}
 
 	/* We don't check peer_shutdown flag here since peer may actually shut
@@ -1869,17 +1933,143 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	 * receive.
 	 */
 	if (sk->sk_shutdown & RCV_SHUTDOWN) {
-		err = 0;
-		goto out;
+		*err = 0;
+		return false;
 	}
 
 	/* It is valid on Linux to pass in a zero-length receive buffer.  This
 	 * is not an error.  We may as well bail out now.
 	 */
 	if (!len) {
+		*err = 0;
+		return false;
+	}
+
+	return true;
+}
+
+static int __vsock_seqpacket_recvmsg(struct sock *sk, struct msghdr *msg,
+				     size_t len, int flags)
+{
+	int err = 0;
+	size_t record_len;
+	struct vsock_sock *vsk;
+	const struct vsock_transport *transport;
+	long timeout;
+	ssize_t dequeued_total = 0;
+	unsigned long orig_nr_segs;
+	const struct iovec *orig_iov;
+	DEFINE_WAIT(wait);
+
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
+
+	timeout = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+	msg->msg_flags &= ~MSG_EOR;
+	orig_nr_segs = msg->msg_iter.nr_segs;
+	orig_iov = msg->msg_iter.iov;
+
+	while (1) {
+		s64 ready;
+
+		prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
+		ready = vsock_stream_has_data(vsk);
+
+		if (ready == 0) {
+			if (vsock_wait_data_seqpacket(sk, &wait, timeout)) {
+				/* In case of any loop break(timeout, signal
+				 * interrupt or shutdown), we report user that
+				 * nothing was copied.
+				 */
+				dequeued_total = 0;
+				break;
+			}
+		} else {
+			ssize_t dequeued;
+
+			finish_wait(sk_sleep(sk), &wait);
+
+			if (ready < 0) {
+				err = -ENOMEM;
+				goto out;
+			}
+
+			if (dequeued_total == 0) {
+				record_len =
+					transport->seqpacket_seq_get_len(vsk);
+
+				if (record_len == 0)
+					continue;
+			}
+
+			/* 'msg_iter.count' is number of unused bytes in iov.
+			 * On every copy to iov iterator it is decremented at
+			 * size of data.
+			 */
+			dequeued = transport->stream_dequeue(vsk, msg,
+						msg->msg_iter.count, flags);
+
+			if (dequeued < 0) {
+				dequeued_total = 0;
+
+				if (dequeued == -EAGAIN) {
+					iov_iter_init(&msg->msg_iter, READ,
+						      orig_iov, orig_nr_segs,
+						      len);
+					msg->msg_flags &= ~MSG_EOR;
+					continue;
+				}
+
+				err = -ENOMEM;
+				break;
+			}
+
+			dequeued_total += dequeued;
+
+			if (dequeued_total >= record_len)
+				break;
+		}
+	}
+
+	if (sk->sk_err)
+		err = -sk->sk_err;
+	else if (sk->sk_shutdown & RCV_SHUTDOWN)
 		err = 0;
-		goto out;
+
+	if (dequeued_total > 0) {
+		/* User sets MSG_TRUNC, so return real length of
+		 * packet.
+		 */
+		if (flags & MSG_TRUNC)
+			err = record_len;
+		else
+			err = len - msg->msg_iter.count;
+
+		/* Always set MSG_TRUNC if real length of packet is
+		 * bigger that user buffer.
+		 */
+		if (record_len > len)
+			msg->msg_flags |= MSG_TRUNC;
 	}
+out:
+	return err;
+}
+
+static int __vsock_stream_recvmsg(struct sock *sk, struct msghdr *msg,
+				  size_t len, int flags)
+{
+	int err;
+	const struct vsock_transport *transport;
+	struct vsock_sock *vsk;
+	size_t target;
+	struct vsock_transport_recv_notify_data recv_data;
+	long timeout;
+	ssize_t copied;
+
+	DEFINE_WAIT(wait);
+
+	vsk = vsock_sk(sk);
+	transport = vsk->transport;
 
 	/* We must not copy less than target bytes into the user's buffer
 	 * before returning successfully, so we wait for the consume queue to
@@ -1907,38 +2097,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 		ready = vsock_stream_has_data(vsk);
 
 		if (ready == 0) {
-			if (sk->sk_err != 0 ||
-			    (sk->sk_shutdown & RCV_SHUTDOWN) ||
-			    (vsk->peer_shutdown & SEND_SHUTDOWN)) {
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
-			/* Don't wait for non-blocking sockets. */
-			if (timeout == 0) {
-				err = -EAGAIN;
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
-
-			err = transport->notify_recv_pre_block(
-					vsk, target, &recv_data);
-			if (err < 0) {
-				finish_wait(sk_sleep(sk), &wait);
+			if (vsock_wait_data(sk, &wait, timeout, &recv_data, target))
 				break;
-			}
-			release_sock(sk);
-			timeout = schedule_timeout(timeout);
-			lock_sock(sk);
-
-			if (signal_pending(current)) {
-				err = sock_intr_errno(timeout);
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			} else if (timeout == 0) {
-				err = -EAGAIN;
-				finish_wait(sk_sleep(sk), &wait);
-				break;
-			}
 		} else {
 			ssize_t read;
 
@@ -1959,9 +2119,8 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 			if (err < 0)
 				break;
 
-			read = transport->stream_dequeue(
-					vsk, msg,
-					len - copied, flags);
+			read = transport->stream_dequeue(vsk, msg, len - copied, flags);
+
 			if (read < 0) {
 				err = -ENOMEM;
 				break;
@@ -1990,11 +2149,45 @@ vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 	if (copied > 0)
 		err = copied;
 
+out:
+	return err;
+}
+
+static int vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg,
+				     size_t len, int flags)
+{
+	struct sock *sk;
+	int err;
+
+	sk = sock->sk;
+
+	lock_sock(sk);
+
+	if (!vsock_pre_recv_check(sock, flags,  len, &err))
+		goto out;
+
+	if (sk->sk_type == SOCK_STREAM)
+		err = __vsock_stream_recvmsg(sk, msg, len, flags);
+	else
+		err = __vsock_seqpacket_recvmsg(sk, msg, len, flags);
+
 out:
 	release_sock(sk);
 	return err;
 }
 
+static int vsock_seqpacket_recvmsg(struct socket *sock, struct msghdr *msg,
+				   size_t len, int flags)
+{
+	return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
+static int vsock_stream_recvmsg(struct socket *sock, struct msghdr *msg,
+				size_t len, int flags)
+{
+	return vsock_connectible_recvmsg(sock, msg, len, flags);
+}
+
 static const struct proto_ops vsock_stream_ops = {
 	.family = PF_VSOCK,
 	.owner = THIS_MODULE,
-- 
2.25.1

Powered by blists - more mailing lists

Powered by Openwall GNU/*/Linux Powered by OpenVZ