]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/vhost/vsock.c
Merge branches 'pm-core', 'pm-qos', 'pm-domains' and 'pm-opp'
[linux.git] / drivers / vhost / vsock.c
index a504e2e003da58181e6fbf4e8276c8f4615e18f2..ce5e63d2c66aac7d019c422ec294cab025e94e5e 100644 (file)
@@ -50,11 +50,10 @@ static u32 vhost_transport_get_local_cid(void)
        return VHOST_VSOCK_DEFAULT_HOST_CID;
 }
 
-static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
+static struct vhost_vsock *__vhost_vsock_get(u32 guest_cid)
 {
        struct vhost_vsock *vsock;
 
-       spin_lock_bh(&vhost_vsock_lock);
        list_for_each_entry(vsock, &vhost_vsock_list, list) {
                u32 other_cid = vsock->guest_cid;
 
@@ -63,15 +62,24 @@ static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
                        continue;
 
                if (other_cid == guest_cid) {
-                       spin_unlock_bh(&vhost_vsock_lock);
                        return vsock;
                }
        }
-       spin_unlock_bh(&vhost_vsock_lock);
 
        return NULL;
 }
 
+static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
+{
+       struct vhost_vsock *vsock;
+
+       spin_lock_bh(&vhost_vsock_lock);
+       vsock = __vhost_vsock_get(guest_cid);
+       spin_unlock_bh(&vhost_vsock_lock);
+
+       return vsock;
+}
+
 static void
 vhost_transport_do_send_pkt(struct vhost_vsock *vsock,
                            struct vhost_virtqueue *vq)
@@ -195,7 +203,6 @@ static int
 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
 {
        struct vhost_vsock *vsock;
-       struct vhost_virtqueue *vq;
        int len = pkt->len;
 
        /* Find the vhost_vsock according to guest context id  */
@@ -205,8 +212,6 @@ vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt)
                return -ENODEV;
        }
 
-       vq = &vsock->vqs[VSOCK_VQ_RX];
-
        if (pkt->reply)
                atomic_inc(&vsock->queued_replies);
 
@@ -368,6 +373,7 @@ static void vhost_vsock_handle_rx_kick(struct vhost_work *work)
 
 static int vhost_vsock_start(struct vhost_vsock *vsock)
 {
+       struct vhost_virtqueue *vq;
        size_t i;
        int ret;
 
@@ -378,19 +384,20 @@ static int vhost_vsock_start(struct vhost_vsock *vsock)
                goto err;
 
        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
-               struct vhost_virtqueue *vq = &vsock->vqs[i];
+               vq = &vsock->vqs[i];
 
                mutex_lock(&vq->mutex);
 
                if (!vhost_vq_access_ok(vq)) {
                        ret = -EFAULT;
-                       mutex_unlock(&vq->mutex);
                        goto err_vq;
                }
 
                if (!vq->private_data) {
                        vq->private_data = vsock;
-                       vhost_vq_init_access(vq);
+                       ret = vhost_vq_init_access(vq);
+                       if (ret)
+                               goto err_vq;
                }
 
                mutex_unlock(&vq->mutex);
@@ -400,8 +407,11 @@ static int vhost_vsock_start(struct vhost_vsock *vsock)
        return 0;
 
 err_vq:
+       vq->private_data = NULL;
+       mutex_unlock(&vq->mutex);
+
        for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) {
-               struct vhost_virtqueue *vq = &vsock->vqs[i];
+               vq = &vsock->vqs[i];
 
                mutex_lock(&vq->mutex);
                vq->private_data = NULL;
@@ -562,11 +572,12 @@ static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid)
                return -EINVAL;
 
        /* Refuse if CID is already in use */
-       other = vhost_vsock_get(guest_cid);
-       if (other && other != vsock)
-               return -EADDRINUSE;
-
        spin_lock_bh(&vhost_vsock_lock);
+       other = __vhost_vsock_get(guest_cid);
+       if (other && other != vsock) {
+               spin_unlock_bh(&vhost_vsock_lock);
+               return -EADDRINUSE;
+       }
        vsock->guest_cid = guest_cid;
        spin_unlock_bh(&vhost_vsock_lock);