]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/vhost/vhost.c
vhost: log dirty page correctly
[linux.git] / drivers / vhost / vhost.c
index 55e5aa662ad59d4b72c44db743198876654f2d2d..babbb32b9bf0cde1f517ca779fdd47f1642d2e52 100644 (file)
@@ -655,7 +655,7 @@ static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
            a + (unsigned long)log_base > ULONG_MAX)
                return false;
 
-       return access_ok(VERIFY_WRITE, log_base + a,
+       return access_ok(log_base + a,
                         (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
 }
 
@@ -681,7 +681,7 @@ static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
                        return false;
 
 
-               if (!access_ok(VERIFY_WRITE, (void __user *)a,
+               if (!access_ok((void __user *)a,
                                    node->size))
                        return false;
                else if (log_all && !log_access_ok(log_base,
@@ -973,10 +973,10 @@ static bool umem_access_ok(u64 uaddr, u64 size, int access)
                return false;
 
        if ((access & VHOST_ACCESS_RO) &&
-           !access_ok(VERIFY_READ, (void __user *)a, size))
+           !access_ok((void __user *)a, size))
                return false;
        if ((access & VHOST_ACCESS_WO) &&
-           !access_ok(VERIFY_WRITE, (void __user *)a, size))
+           !access_ok((void __user *)a, size))
                return false;
        return true;
 }
@@ -1185,10 +1185,10 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
 {
        size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
 
-       return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
-              access_ok(VERIFY_READ, avail,
+       return access_ok(desc, num * sizeof *desc) &&
+              access_ok(avail,
                         sizeof *avail + num * sizeof *avail->ring + s) &&
-              access_ok(VERIFY_WRITE, used,
+              access_ok(used,
                        sizeof *used + num * sizeof *used->ring + s);
 }
 
@@ -1733,13 +1733,87 @@ static int log_write(void __user *log_base,
        return r;
 }
 
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+       struct vhost_umem *umem = vq->umem;
+       struct vhost_umem_node *u;
+       u64 start, end, l, min;
+       int r;
+       bool hit = false;
+
+       while (len) {
+               min = len;
+               /* More than one GPAs can be mapped into a single HVA. So
+                * iterate all possible umems here to be safe.
+                */
+               list_for_each_entry(u, &umem->umem_list, link) {
+                       if (u->userspace_addr > hva - 1 + len ||
+                           u->userspace_addr - 1 + u->size < hva)
+                               continue;
+                       start = max(u->userspace_addr, hva);
+                       end = min(u->userspace_addr - 1 + u->size,
+                                 hva - 1 + len);
+                       l = end - start + 1;
+                       r = log_write(vq->log_base,
+                                     u->start + start - u->userspace_addr,
+                                     l);
+                       if (r < 0)
+                               return r;
+                       hit = true;
+                       min = min(l, min);
+               }
+
+               if (!hit)
+                       return -EFAULT;
+
+               len -= min;
+               hva += min;
+       }
+
+       return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+       struct iovec iov[64];
+       int i, ret;
+
+       if (!vq->iotlb)
+               return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+       ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+                            len, iov, 64, VHOST_ACCESS_WO);
+       if (ret)
+               return ret;
+
+       for (i = 0; i < ret; i++) {
+               ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+                                   iov[i].iov_len);
+               if (ret)
+                       return ret;
+       }
+
+       return 0;
+}
+
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
-                   unsigned int log_num, u64 len)
+                   unsigned int log_num, u64 len, struct iovec *iov, int count)
 {
        int i, r;
 
        /* Make sure data written is seen before log. */
        smp_wmb();
+
+       if (vq->iotlb) {
+               for (i = 0; i < count; i++) {
+                       r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+                                         iov[i].iov_len);
+                       if (r < 0)
+                               return r;
+               }
+               return 0;
+       }
+
        for (i = 0; i < log_num; ++i) {
                u64 l = min(log[i].len, len);
                r = log_write(vq->log_base, log[i].addr, l);
@@ -1769,9 +1843,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
                smp_wmb();
                /* Log used flag write. */
                used = &vq->used->flags;
-               log_write(vq->log_base, vq->log_addr +
-                         (used - (void __user *)vq->used),
-                         sizeof vq->used->flags);
+               log_used(vq, (used - (void __user *)vq->used),
+                        sizeof vq->used->flags);
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -1789,9 +1862,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
                smp_wmb();
                /* Log avail event write */
                used = vhost_avail_event(vq);
-               log_write(vq->log_base, vq->log_addr +
-                         (used - (void __user *)vq->used),
-                         sizeof *vhost_avail_event(vq));
+               log_used(vq, (used - (void __user *)vq->used),
+                        sizeof *vhost_avail_event(vq));
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -1814,7 +1886,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
                goto err;
        vq->signalled_used_valid = false;
        if (!vq->iotlb &&
-           !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+           !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
                r = -EFAULT;
                goto err;
        }
@@ -2191,10 +2263,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
                /* Make sure data is seen before log. */
                smp_wmb();
                /* Log used ring entry write. */
-               log_write(vq->log_base,
-                         vq->log_addr +
-                          ((void __user *)used - (void __user *)vq->used),
-                         count * sizeof *used);
+               log_used(vq, ((void __user *)used - (void __user *)vq->used),
+                        count * sizeof *used);
        }
        old = vq->last_used_idx;
        new = (vq->last_used_idx += count);
@@ -2236,9 +2306,8 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
                /* Make sure used idx is seen before log. */
                smp_wmb();
                /* Log used index update. */
-               log_write(vq->log_base,
-                         vq->log_addr + offsetof(struct vring_used, idx),
-                         sizeof vq->used->idx);
+               log_used(vq, offsetof(struct vring_used, idx),
+                        sizeof vq->used->idx);
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }