/* * Copyright (c) 2008 Red Hat, Inc. * * Author(s): Gleb Natapov */ #include #include #include #include #include #include #include #include #include "vmchannel_connector.h" static struct vmchannel_dev vmc_dev; static int add_recq_buf(struct vmchannel_dev *vmc, struct vmchannel_hdr *hdr) { struct scatterlist sg[2]; sg_init_table(sg, 2); sg_init_one(&sg[0], hdr, sizeof(struct vmchannel_desc)); sg_init_one(&sg[1], hdr->msg.data, MAX_PACKET_LEN); if (!vmc->rq->vq_ops->add_buf(vmc->rq, sg, 0, 2, hdr)) return 1; kfree(hdr); return 0; } static int try_fill_recvq(struct vmchannel_dev *vmc) { int num = 0; for (;;) { struct vmchannel_hdr *hdr; hdr = kmalloc(sizeof(*hdr) + MAX_PACKET_LEN, GFP_KERNEL); if (unlikely(!hdr)) break; if (!add_recq_buf(vmc, hdr)) break; num++; } if (num) vmc->rq->vq_ops->kick(vmc->rq); return num; } static void vmchannel_recv(unsigned long data) { struct vmchannel_dev *vmc = (struct vmchannel_dev *)data; struct vmchannel_hdr *hdr; unsigned int len; int posted = 0; while ((hdr = vmc->rq->vq_ops->get_buf(vmc->rq, &len))) { hdr->msg.len = le32_to_cpu(hdr->desc.len); len -= sizeof(struct vmchannel_desc); if (hdr->msg.len == len) { hdr->msg.id.idx = VMCHANNEL_CONNECTOR_IDX; hdr->msg.id.val = le32_to_cpu(hdr->desc.id); hdr->msg.seq = vmc->seq++; hdr->msg.ack = random32(); cn_netlink_send(&hdr->msg, VMCHANNEL_CONNECTOR_IDX, GFP_ATOMIC); } else dev_printk(KERN_ERR, &vmc->vdev->dev, "wrong length in received descriptor" " (%d instead of %d)\n", hdr->msg.len, len); posted += add_recq_buf(vmc, hdr); } if (posted) vmc->rq->vq_ops->kick(vmc->rq); } static void recvq_notify(struct virtqueue *recvq) { struct vmchannel_dev *vmc = recvq->vdev->priv; tasklet_schedule(&vmc->tasklet); } static void cleanup_sendq(struct vmchannel_dev *vmc) { char *buf; unsigned int len; spin_lock(&vmc->sq_lock); while ((buf = vmc->sq->vq_ops->get_buf(vmc->sq, &len))) kfree(buf); spin_unlock(&vmc->sq_lock); } static void sendq_notify(struct virtqueue *sendq) { struct vmchannel_dev *vmc = sendq->vdev->priv; cleanup_sendq(vmc); } static void vmchannel_cn_callback(void *data) { struct vmchannel_desc *desc; struct cn_msg *msg = data; struct scatterlist sg; char *buf; int err; unsigned long flags; desc = kmalloc(msg->len + sizeof(*desc), GFP_KERNEL); if (!desc) return; desc->id = cpu_to_le32(msg->id.val); desc->len = cpu_to_le32(msg->len); buf = (char *)(desc + 1); memcpy(buf, msg->data, msg->len); sg_init_one(&sg, desc, msg->len + sizeof(*desc)); spin_lock_irqsave(&vmc_dev.sq_lock, flags); err = vmc_dev.sq->vq_ops->add_buf(vmc_dev.sq, &sg, 1, 0, desc); if (err) kfree(desc); else vmc_dev.sq->vq_ops->kick(vmc_dev.sq); spin_unlock_irqrestore(&vmc_dev.sq_lock, flags); } static int vmchannel_probe(struct virtio_device *vdev) { struct vmchannel_dev *vmc = &vmc_dev; struct cb_id cn_id; int r, i; __le32 count; unsigned offset; cn_id.idx = VMCHANNEL_CONNECTOR_IDX; vdev->priv = vmc; vmc->vdev = vdev; vdev->config->get(vdev, 0, &count, sizeof(count)); vmc->channel_count = le32_to_cpu(count); if (vmc->channel_count == 0) { dev_printk(KERN_ERR, &vdev->dev, "No channels present\n"); return -ENODEV; } pr_debug("vmchannel: %d channel detected\n", vmc->channel_count); vmc->channels = kzalloc(vmc->channel_count * sizeof(struct vmchannel_info), GFP_KERNEL); if (!vmc->channels) return -ENOMEM; offset = sizeof(count); for (i = 0; i < vmc->channel_count; i++) { __u32 len; __le32 tmp; vdev->config->get(vdev, offset, &tmp, 4); vmc->channels[i].id = le32_to_cpu(tmp); offset += 4; vdev->config->get(vdev, offset, &tmp, 4); len = le32_to_cpu(tmp); if (len > VMCHANNEL_NAME_MAX) { dev_printk(KERN_ERR, &vdev->dev, "Wrong device configuration. " "Channel name is too long"); r = -ENODEV; goto out; } vmc->channels[i].name = kmalloc(len, GFP_KERNEL); if (!vmc->channels[i].name) { r = -ENOMEM; goto out; } offset += 4; vdev->config->get(vdev, offset, vmc->channels[i].name, len); offset += len; pr_debug("vmhannel: found channel '%s' id %d\n", vmc->channels[i].name, vmc->channels[i].id); } vmc->rq = vdev->config->find_vq(vdev, 0, recvq_notify); if (IS_ERR(vmc->rq)) { r = PTR_ERR(vmc->rq); goto out; } vmc->sq = vdev->config->find_vq(vdev, 1, sendq_notify); if (IS_ERR(vmc->sq)) { r = PTR_ERR(vmc->sq); goto out; } spin_lock_init(&vmc->sq_lock); for (i = 0; i < vmc->channel_count; i++) { cn_id.val = vmc->channels[i].id; r = cn_add_callback(&cn_id, "vmchannel", vmchannel_cn_callback); if (r) goto cn_unreg; } tasklet_init(&vmc->tasklet, vmchannel_recv, (unsigned long)vmc); if (!try_fill_recvq(vmc)) { r = -ENOMEM; goto kill_task; } return 0; kill_task: tasklet_kill(&vmc->tasklet); cn_unreg: for (i = 0; i < vmc->channel_count; i++) { cn_id.val = vmc->channels[i].id; cn_del_callback(&cn_id); } out: if (vmc->sq) vdev->config->del_vq(vmc->sq); if (vmc->rq) vdev->config->del_vq(vmc->rq); for (i = 0; i < vmc->channel_count; i++) { if (!vmc->channels[i].name) break; kfree(vmc->channels[i].name); } kfree(vmc->channels); return r; } static void vmchannel_remove(struct virtio_device *vdev) { struct vmchannel_dev *vmc = vdev->priv; struct cb_id cn_id; int i; /* Stop all the virtqueues. */ vdev->config->reset(vdev); tasklet_kill(&vmc->tasklet); cn_id.idx = VMCHANNEL_CONNECTOR_IDX; for (i = 0; i < vmc->channel_count; i++) { cn_id.val = vmc->channels[i].id; cn_del_callback(&cn_id); } vdev->config->del_vq(vmc->rq); vdev->config->del_vq(vmc->sq); for (i = 0; i < vmc_dev.channel_count; i++) kfree(vmc_dev.channels[i].name); kfree(vmc_dev.channels); } static struct virtio_device_id id_table[] = { { VIRTIO_ID_VMCHANNEL, VIRTIO_DEV_ANY_ID }, { 0 }, }; static struct virtio_driver virtio_vmchannel = { .driver.name = "virtio-vmchannel", .driver.owner = THIS_MODULE, .id_table = id_table, .probe = vmchannel_probe, .remove = __devexit_p(vmchannel_remove), }; static int __init init(void) { return register_virtio_driver(&virtio_vmchannel); } static void __exit fini(void) { unregister_virtio_driver(&virtio_vmchannel); } module_init(init); module_exit(fini); MODULE_AUTHOR("Gleb Natapov"); MODULE_DEVICE_TABLE(virtio, id_table); MODULE_DESCRIPTION("Virtio vmchannel driver"); MODULE_LICENSE("GPL");