]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/vhost/vhost.c
Merge tag 'for-linus' of git://git.kernel.org/pub/scm/virt/kvm/kvm
[linux.git] / drivers / vhost / vhost.c
index ed3114556fdaf96eb130832c54a80cc4c41973b7..96c1d8400822a3d852553e5baddaaed9907d7f71 100644 (file)
@@ -315,6 +315,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        vq->log_addr = -1ull;
        vq->private_data = NULL;
        vq->acked_features = 0;
+       vq->acked_backend_features = 0;
        vq->log_base = NULL;
        vq->error_ctx = NULL;
        vq->kick = NULL;
@@ -1027,28 +1028,40 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
                             struct iov_iter *from)
 {
-       struct vhost_msg_node node;
-       unsigned size = sizeof(struct vhost_msg);
-       size_t ret;
-       int err;
+       struct vhost_iotlb_msg msg;
+       size_t offset;
+       int type, ret;
 
-       if (iov_iter_count(from) < size)
-               return 0;
-       ret = copy_from_iter(&node.msg, size, from);
-       if (ret != size)
+       ret = copy_from_iter(&type, sizeof(type), from);
+       if (ret != sizeof(type))
                goto done;
 
-       switch (node.msg.type) {
+       switch (type) {
        case VHOST_IOTLB_MSG:
-               err = vhost_process_iotlb_msg(dev, &node.msg.iotlb);
-               if (err)
-                       ret = err;
+               /* There maybe a hole after type for V1 message type,
+                * so skip it here.
+                */
+               offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
+               break;
+       case VHOST_IOTLB_MSG_V2:
+               offset = sizeof(__u32);
                break;
        default:
                ret = -EINVAL;
-               break;
+               goto done;
+       }
+
+       iov_iter_advance(from, offset);
+       ret = copy_from_iter(&msg, sizeof(msg), from);
+       if (ret != sizeof(msg))
+               goto done;
+       if (vhost_process_iotlb_msg(dev, &msg)) {
+               ret = -EFAULT;
+               goto done;
        }
 
+       ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
+             sizeof(struct vhost_msg_v2);
 done:
        return ret;
 }
@@ -1107,13 +1120,28 @@ ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
                finish_wait(&dev->wait, &wait);
 
        if (node) {
-               ret = copy_to_iter(&node->msg, size, to);
+               struct vhost_iotlb_msg *msg;
+               void *start = &node->msg;
+
+               switch (node->msg.type) {
+               case VHOST_IOTLB_MSG:
+                       size = sizeof(node->msg);
+                       msg = &node->msg.iotlb;
+                       break;
+               case VHOST_IOTLB_MSG_V2:
+                       size = sizeof(node->msg_v2);
+                       msg = &node->msg_v2.iotlb;
+                       break;
+               default:
+                       BUG();
+                       break;
+               }
 
-               if (ret != size || node->msg.type != VHOST_IOTLB_MISS) {
+               ret = copy_to_iter(start, size, to);
+               if (ret != size || msg->type != VHOST_IOTLB_MISS) {
                        kfree(node);
                        return ret;
                }
-
                vhost_enqueue_msg(dev, &dev->pending_list, node);
        }
 
@@ -1126,12 +1154,19 @@ static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
        struct vhost_dev *dev = vq->dev;
        struct vhost_msg_node *node;
        struct vhost_iotlb_msg *msg;
+       bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
 
-       node = vhost_new_msg(vq, VHOST_IOTLB_MISS);
+       node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
        if (!node)
                return -ENOMEM;
 
-       msg = &node->msg.iotlb;
+       if (v2) {
+               node->msg_v2.type = VHOST_IOTLB_MSG_V2;
+               msg = &node->msg_v2.iotlb;
+       } else {
+               msg = &node->msg.iotlb;
+       }
+
        msg->type = VHOST_IOTLB_MISS;
        msg->iova = iova;
        msg->perm = access;