[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-ID: <20260208143441.2177372-4-lulu@redhat.com>
Date: Sun, 8 Feb 2026 22:32:24 +0800
From: Cindy Lu <lulu@...hat.com>
To: lulu@...hat.com,
mst@...hat.com,
jasowang@...hat.com,
kvm@...r.kernel.org,
virtualization@...ts.linux.dev,
netdev@...r.kernel.org,
linux-kernel@...r.kernel.org
Subject: [RFC 3/3] vhost/net: add RX netfilter offload path
Route RX packets through the netfilter socket when configured.
Key points:
- Add VHOST_NET_FILTER_MAX_LEN upper bound for filter payload size
- Introduce vhost_net_filter_request() to send REQUEST to userspace
- Add handle_rx_filter() fast path for RX when filter is active
- Hook filter path in handle_rx() when filter_sock is set
Signed-off-by: Cindy Lu <lulu@...hat.com>
---
drivers/vhost/net.c | 229 ++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 229 insertions(+)
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f02deff0e53c..aa9a5ed43eae 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -161,6 +161,13 @@ struct vhost_net {
static unsigned vhost_net_zcopy_mask __read_mostly;
+/*
+ * Upper bound for a single packet payload on the filter path.
+ * Keep this large enough for the largest expected frame plus vnet headers,
+ * but still bounded to avoid unbounded allocations.
+ */
+#define VHOST_NET_FILTER_MAX_LEN (4096 + 65536)
+
static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
{
if (rxq->tail != rxq->head)
@@ -1227,6 +1234,222 @@ static long vhost_net_set_filter(struct vhost_net *n, int fd)
return r;
}
+/*
+ * Send a filter REQUEST message to userspace for a single packet.
+ *
+ * The caller provides a writable buffer; userspace may inspect the content and
+ * optionally modify it in place. We only accept the packet if the returned
+ * length matches the original length, otherwise the packet is dropped.
+ */
+static int vhost_net_filter_request(struct vhost_net *n, u16 direction,
+ void *buf, u32 *len)
+{
+ struct vhost_net_filter_msg msg = {
+ .type = VHOST_NET_FILTER_MSG_REQUEST,
+ .direction = direction,
+ .len = *len,
+ };
+ struct msghdr msghdr = { 0 };
+ struct kvec iov[2] = {
+ { .iov_base = &msg, .iov_len = sizeof(msg) },
+ { .iov_base = buf, .iov_len = *len },
+ };
+ struct socket *sock;
+ struct file *sock_file = NULL;
+ int ret;
+
+ /*
+ * Take a temporary file reference to guard against concurrent
+ * filter socket replacement while we send the message.
+ */
+ spin_lock(&n->filter_lock);
+ sock = n->filter_sock;
+ if (sock)
+ sock_file = get_file(sock->file);
+ spin_unlock(&n->filter_lock);
+
+ if (!sock) {
+ ret = -ENOTCONN;
+ goto out_put;
+ }
+
+ ret = kernel_sendmsg(sock, &msghdr, iov,
+ *len ? 2 : 1, sizeof(msg) + *len);
+
+out_put:
+ if (sock_file)
+ fput(sock_file);
+
+ if (ret < 0)
+ return ret;
+ return 0;
+}
+
+/*
+ * RX fast path when filter offload is active.
+ *
+ * This mirrors handle_rx() but routes each RX packet through userspace
+ * netfilter. Packets are copied into a temporary buffer, sent to the filter
+ * socket as a REQUEST, and only delivered to the guest if userspace keeps the
+ * length unchanged. Any truncation or mismatch drops the packet.
+ */
+static void handle_rx_filter(struct vhost_net *net, struct socket *sock)
+{
+ struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
+ struct vhost_virtqueue *vq = &nvq->vq;
+ bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
+ unsigned int count = 0;
+ unsigned int in, log;
+ struct vhost_log *vq_log;
+ struct virtio_net_hdr hdr = {
+ .flags = 0,
+ .gso_type = VIRTIO_NET_HDR_GSO_NONE
+ };
+ struct msghdr msg = {
+ .msg_name = NULL,
+ .msg_namelen = 0,
+ .msg_control = NULL,
+ .msg_controllen = 0,
+ .msg_flags = MSG_DONTWAIT,
+ };
+ size_t total_len = 0;
+ int mergeable;
+ bool set_num_buffers;
+ size_t vhost_hlen, sock_hlen;
+ size_t vhost_len, sock_len;
+ bool busyloop_intr = false;
+ struct iov_iter fixup;
+ __virtio16 num_buffers;
+ int recv_pkts = 0;
+ unsigned int ndesc;
+ void *pkt;
+
+ pkt = kvmalloc(VHOST_NET_FILTER_MAX_LEN, GFP_KERNEL | __GFP_NOWARN);
+ if (!pkt) {
+ vhost_net_enable_vq(net, vq);
+ return;
+ }
+
+ vhost_hlen = nvq->vhost_hlen;
+ sock_hlen = nvq->sock_hlen;
+
+ vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? vq->log : NULL;
+ mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
+ set_num_buffers = mergeable || vhost_has_feature(vq, VIRTIO_F_VERSION_1);
+
+ do {
+ u32 pkt_len;
+ int err;
+ s16 headcount;
+ struct kvec iov;
+
+ sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+ &busyloop_intr, &count);
+ if (!sock_len)
+ break;
+ sock_len += sock_hlen;
+ if (sock_len > VHOST_NET_FILTER_MAX_LEN) {
+ /* Consume and drop oversized packet. */
+ iov.iov_base = pkt;
+ iov.iov_len = 1;
+ kernel_recvmsg(sock, &msg, &iov, 1, 1,
+ MSG_DONTWAIT | MSG_TRUNC);
+ continue;
+ }
+
+ vhost_len = sock_len + vhost_hlen;
+ headcount = get_rx_bufs(nvq, vq->heads + count,
+ vq->nheads + count, vhost_len, &in,
+ vq_log, &log,
+ likely(mergeable) ? UIO_MAXIOV : 1,
+ &ndesc);
+ if (unlikely(headcount < 0))
+ goto out;
+
+ if (!headcount) {
+ if (unlikely(busyloop_intr)) {
+ vhost_poll_queue(&vq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+ vhost_disable_notify(&net->dev, vq);
+ continue;
+ }
+ goto out;
+ }
+
+ busyloop_intr = false;
+
+ if (nvq->rx_ring)
+ msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
+
+ iov.iov_base = pkt;
+ iov.iov_len = sock_len;
+ err = kernel_recvmsg(sock, &msg, &iov, 1, sock_len,
+ MSG_DONTWAIT | MSG_TRUNC);
+ if (unlikely(err != sock_len)) {
+ vhost_discard_vq_desc(vq, headcount, ndesc);
+ continue;
+ }
+
+ pkt_len = sock_len;
+ err = vhost_net_filter_request(net, VHOST_NET_FILTER_DIRECTION_TX,
+ pkt, &pkt_len);
+ if (err < 0)
+ pkt_len = sock_len;
+ if (pkt_len != sock_len) {
+ vhost_discard_vq_desc(vq, headcount, ndesc);
+ continue;
+ }
+
+ iov_iter_init(&msg.msg_iter, ITER_DEST, vq->iov, in, vhost_len);
+ fixup = msg.msg_iter;
+ if (unlikely(vhost_hlen))
+ iov_iter_advance(&msg.msg_iter, vhost_hlen);
+
+ if (copy_to_iter(pkt, sock_len, &msg.msg_iter) != sock_len) {
+ vhost_discard_vq_desc(vq, headcount, ndesc);
+ goto out;
+ }
+
+ if (unlikely(vhost_hlen)) {
+ if (copy_to_iter(&hdr, sizeof(hdr),
+ &fixup) != sizeof(hdr)) {
+ vhost_discard_vq_desc(vq, headcount, ndesc);
+ goto out;
+ }
+ } else {
+ iov_iter_advance(&fixup, sizeof(hdr));
+ }
+
+ num_buffers = cpu_to_vhost16(vq, headcount);
+ if (likely(set_num_buffers) &&
+ copy_to_iter(&num_buffers, sizeof(num_buffers), &fixup) !=
+ sizeof(num_buffers)) {
+ vhost_discard_vq_desc(vq, headcount, ndesc);
+ goto out;
+ }
+
+ nvq->done_idx += headcount;
+ count += in_order ? 1 : headcount;
+ if (nvq->done_idx > VHOST_NET_BATCH) {
+ vhost_net_signal_used(nvq, count);
+ count = 0;
+ }
+
+ if (unlikely(vq_log))
+ vhost_log_write(vq, vq_log, log, vhost_len, vq->iov, in);
+
+ total_len += vhost_len;
+ } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
+
+ if (unlikely(busyloop_intr))
+ vhost_poll_queue(&vq->poll);
+ else if (!sock_len)
+ vhost_net_enable_vq(net, vq);
+
+out:
+ vhost_net_signal_used(nvq, count);
+ kvfree(pkt);
+}
/* Expects to be always run from workqueue - which acts as
* read-size critical section for our kind of RCU. */
static void handle_rx(struct vhost_net *net)
@@ -1281,6 +1504,11 @@ static void handle_rx(struct vhost_net *net)
set_num_buffers = mergeable ||
vhost_has_feature(vq, VIRTIO_F_VERSION_1);
+ if (READ_ONCE(net->filter_sock)) {
+ handle_rx_filter(net, sock);
+ goto out_unlock;
+ }
+
do {
sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
&busyloop_intr, &count);
@@ -1383,6 +1611,7 @@ static void handle_rx(struct vhost_net *net)
vhost_net_enable_vq(net, vq);
out:
vhost_net_signal_used(nvq, count);
+out_unlock:
mutex_unlock(&vq->mutex);
}
--
2.52.0
Powered by blists - more mailing lists