]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - drivers/vhost/net.c
Merge tag 'dmaengine-4.19-rc1' of git://git.infradead.org/users/vkoul/slave-dma
[linux.git] / drivers / vhost / net.c
index 29756d88799b630f2c73ca097b56b092a14a7d5a..4e656f89cb225c83b42f579d1bc1f988224d49bc 100644 (file)
@@ -77,6 +77,10 @@ enum {
                         (1ULL << VIRTIO_F_IOMMU_PLATFORM)
 };
 
+enum {
+       VHOST_NET_BACKEND_FEATURES = (1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2)
+};
+
 enum {
        VHOST_NET_VQ_RX = 0,
        VHOST_NET_VQ_TX = 1,
@@ -94,7 +98,7 @@ struct vhost_net_ubuf_ref {
        struct vhost_virtqueue *vq;
 };
 
-#define VHOST_RX_BATCH 64
+#define VHOST_NET_BATCH 64
 struct vhost_net_buf {
        void **queue;
        int tail;
@@ -168,7 +172,7 @@ static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
 
        rxq->head = 0;
        rxq->tail = ptr_ring_consume_batched(nvq->rx_ring, rxq->queue,
-                                             VHOST_RX_BATCH);
+                                             VHOST_NET_BATCH);
        return rxq->tail;
 }
 
@@ -396,13 +400,10 @@ static inline unsigned long busy_clock(void)
        return local_clock() >> 10;
 }
 
-static bool vhost_can_busy_poll(struct vhost_dev *dev,
-                               unsigned long endtime)
+static bool vhost_can_busy_poll(unsigned long endtime)
 {
-       return likely(!need_resched()) &&
-              likely(!time_after(busy_clock(), endtime)) &&
-              likely(!signal_pending(current)) &&
-              !vhost_has_work(dev);
+       return likely(!need_resched() && !time_after(busy_clock(), endtime) &&
+                     !signal_pending(current));
 }
 
 static void vhost_net_disable_vq(struct vhost_net *n,
@@ -431,21 +432,42 @@ static int vhost_net_enable_vq(struct vhost_net *n,
        return vhost_poll_start(poll, sock->file);
 }
 
+static void vhost_net_signal_used(struct vhost_net_virtqueue *nvq)
+{
+       struct vhost_virtqueue *vq = &nvq->vq;
+       struct vhost_dev *dev = vq->dev;
+
+       if (!nvq->done_idx)
+               return;
+
+       vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
+       nvq->done_idx = 0;
+}
+
 static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
-                                   struct vhost_virtqueue *vq,
-                                   struct iovec iov[], unsigned int iov_size,
-                                   unsigned int *out_num, unsigned int *in_num)
+                                   struct vhost_net_virtqueue *nvq,
+                                   unsigned int *out_num, unsigned int *in_num,
+                                   bool *busyloop_intr)
 {
+       struct vhost_virtqueue *vq = &nvq->vq;
        unsigned long uninitialized_var(endtime);
        int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
                                  out_num, in_num, NULL, NULL);
 
        if (r == vq->num && vq->busyloop_timeout) {
+               if (!vhost_sock_zcopy(vq->private_data))
+                       vhost_net_signal_used(nvq);
                preempt_disable();
                endtime = busy_clock() + vq->busyloop_timeout;
-               while (vhost_can_busy_poll(vq->dev, endtime) &&
-                      vhost_vq_avail_empty(vq->dev, vq))
+               while (vhost_can_busy_poll(endtime)) {
+                       if (vhost_has_work(vq->dev)) {
+                               *busyloop_intr = true;
+                               break;
+                       }
+                       if (!vhost_vq_avail_empty(vq->dev, vq))
+                               break;
                        cpu_relax();
+               }
                preempt_enable();
                r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
                                      out_num, in_num, NULL, NULL);
@@ -463,9 +485,62 @@ static bool vhost_exceeds_maxpend(struct vhost_net *net)
               min_t(unsigned int, VHOST_MAX_PEND, vq->num >> 2);
 }
 
-/* Expects to be always run from workqueue - which acts as
- * read-size critical section for our kind of RCU. */
-static void handle_tx(struct vhost_net *net)
+static size_t init_iov_iter(struct vhost_virtqueue *vq, struct iov_iter *iter,
+                           size_t hdr_size, int out)
+{
+       /* Skip header. TODO: support TSO. */
+       size_t len = iov_length(vq->iov, out);
+
+       iov_iter_init(iter, WRITE, vq->iov, out, len);
+       iov_iter_advance(iter, hdr_size);
+
+       return iov_iter_count(iter);
+}
+
+static bool vhost_exceeds_weight(int pkts, int total_len)
+{
+       return total_len >= VHOST_NET_WEIGHT ||
+              pkts >= VHOST_NET_PKT_WEIGHT;
+}
+
+static int get_tx_bufs(struct vhost_net *net,
+                      struct vhost_net_virtqueue *nvq,
+                      struct msghdr *msg,
+                      unsigned int *out, unsigned int *in,
+                      size_t *len, bool *busyloop_intr)
+{
+       struct vhost_virtqueue *vq = &nvq->vq;
+       int ret;
+
+       ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr);
+
+       if (ret < 0 || ret == vq->num)
+               return ret;
+
+       if (*in) {
+               vq_err(vq, "Unexpected descriptor format for TX: out %d, int %d\n",
+                       *out, *in);
+               return -EFAULT;
+       }
+
+       /* Sanity check */
+       *len = init_iov_iter(vq, &msg->msg_iter, nvq->vhost_hlen, *out);
+       if (*len == 0) {
+               vq_err(vq, "Unexpected header len for TX: %zd expected %zd\n",
+                       *len, nvq->vhost_hlen);
+               return -EFAULT;
+       }
+
+       return ret;
+}
+
+static bool tx_can_batch(struct vhost_virtqueue *vq, size_t total_len)
+{
+       return total_len < VHOST_NET_WEIGHT &&
+              !vhost_vq_avail_empty(vq->dev, vq);
+}
+
+static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
 {
        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
        struct vhost_virtqueue *vq = &nvq->vq;
@@ -480,67 +555,103 @@ static void handle_tx(struct vhost_net *net)
        };
        size_t len, total_len = 0;
        int err;
-       size_t hdr_size;
-       struct socket *sock;
-       struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
-       bool zcopy, zcopy_used;
        int sent_pkts = 0;
 
-       mutex_lock(&vq->mutex);
-       sock = vq->private_data;
-       if (!sock)
-               goto out;
+       for (;;) {
+               bool busyloop_intr = false;
 
-       if (!vq_iotlb_prefetch(vq))
-               goto out;
+               head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
+                                  &busyloop_intr);
+               /* On error, stop handling until the next kick. */
+               if (unlikely(head < 0))
+                       break;
+               /* Nothing new?  Wait for eventfd to tell us they refilled. */
+               if (head == vq->num) {
+                       if (unlikely(busyloop_intr)) {
+                               vhost_poll_queue(&vq->poll);
+                       } else if (unlikely(vhost_enable_notify(&net->dev,
+                                                               vq))) {
+                               vhost_disable_notify(&net->dev, vq);
+                               continue;
+                       }
+                       break;
+               }
 
-       vhost_disable_notify(&net->dev, vq);
-       vhost_net_disable_vq(net, vq);
+               vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
+               vq->heads[nvq->done_idx].len = 0;
 
-       hdr_size = nvq->vhost_hlen;
-       zcopy = nvq->ubufs;
+               total_len += len;
+               if (tx_can_batch(vq, total_len))
+                       msg.msg_flags |= MSG_MORE;
+               else
+                       msg.msg_flags &= ~MSG_MORE;
+
+               /* TODO: Check specific error and bomb out unless ENOBUFS? */
+               err = sock->ops->sendmsg(sock, &msg, len);
+               if (unlikely(err < 0)) {
+                       vhost_discard_vq_desc(vq, 1);
+                       vhost_net_enable_vq(net, vq);
+                       break;
+               }
+               if (err != len)
+                       pr_debug("Truncated TX packet: len %d != %zd\n",
+                                err, len);
+               if (++nvq->done_idx >= VHOST_NET_BATCH)
+                       vhost_net_signal_used(nvq);
+               if (vhost_exceeds_weight(++sent_pkts, total_len)) {
+                       vhost_poll_queue(&vq->poll);
+                       break;
+               }
+       }
+
+       vhost_net_signal_used(nvq);
+}
+
+static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
+{
+       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+       struct vhost_virtqueue *vq = &nvq->vq;
+       unsigned out, in;
+       int head;
+       struct msghdr msg = {
+               .msg_name = NULL,
+               .msg_namelen = 0,
+               .msg_control = NULL,
+               .msg_controllen = 0,
+               .msg_flags = MSG_DONTWAIT,
+       };
+       size_t len, total_len = 0;
+       int err;
+       struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
+       bool zcopy_used;
+       int sent_pkts = 0;
 
        for (;;) {
-               /* Release DMAs done buffers first */
-               if (zcopy)
-                       vhost_zerocopy_signal_used(net, vq);
+               bool busyloop_intr;
 
+               /* Release DMAs done buffers first */
+               vhost_zerocopy_signal_used(net, vq);
 
-               head = vhost_net_tx_get_vq_desc(net, vq, vq->iov,
-                                               ARRAY_SIZE(vq->iov),
-                                               &out, &in);
+               busyloop_intr = false;
+               head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
+                                  &busyloop_intr);
                /* On error, stop handling until the next kick. */
                if (unlikely(head < 0))
                        break;
                /* Nothing new?  Wait for eventfd to tell us they refilled. */
                if (head == vq->num) {
-                       if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+                       if (unlikely(busyloop_intr)) {
+                               vhost_poll_queue(&vq->poll);
+                       } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                vhost_disable_notify(&net->dev, vq);
                                continue;
                        }
                        break;
                }
-               if (in) {
-                       vq_err(vq, "Unexpected descriptor format for TX: "
-                              "out %d, int %d\n", out, in);
-                       break;
-               }
-               /* Skip header. TODO: support TSO. */
-               len = iov_length(vq->iov, out);
-               iov_iter_init(&msg.msg_iter, WRITE, vq->iov, out, len);
-               iov_iter_advance(&msg.msg_iter, hdr_size);
-               /* Sanity check */
-               if (!msg_data_left(&msg)) {
-                       vq_err(vq, "Unexpected header len for TX: "
-                              "%zd expected %zd\n",
-                              len, hdr_size);
-                       break;
-               }
-               len = msg_data_left(&msg);
 
-               zcopy_used = zcopy && len >= VHOST_GOODCOPY_LEN
-                                  && !vhost_exceeds_maxpend(net)
-                                  && vhost_net_tx_select_zcopy(net);
+               zcopy_used = len >= VHOST_GOODCOPY_LEN
+                            && !vhost_exceeds_maxpend(net)
+                            && vhost_net_tx_select_zcopy(net);
 
                /* use msg_control to pass vhost zerocopy ubuf info to skb */
                if (zcopy_used) {
@@ -562,10 +673,8 @@ static void handle_tx(struct vhost_net *net)
                        msg.msg_control = NULL;
                        ubufs = NULL;
                }
-
                total_len += len;
-               if (total_len < VHOST_NET_WEIGHT &&
-                   !vhost_vq_avail_empty(&net->dev, vq) &&
+               if (tx_can_batch(vq, total_len) &&
                    likely(!vhost_exceeds_maxpend(net))) {
                        msg.msg_flags |= MSG_MORE;
                } else {
@@ -592,12 +701,37 @@ static void handle_tx(struct vhost_net *net)
                else
                        vhost_zerocopy_signal_used(net, vq);
                vhost_net_tx_packet(net);
-               if (unlikely(total_len >= VHOST_NET_WEIGHT) ||
-                   unlikely(++sent_pkts >= VHOST_NET_PKT_WEIGHT)) {
+               if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) {
                        vhost_poll_queue(&vq->poll);
                        break;
                }
        }
+}
+
+/* Expects to be always run from workqueue - which acts as
+ * read-size critical section for our kind of RCU. */
+static void handle_tx(struct vhost_net *net)
+{
+       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
+       struct vhost_virtqueue *vq = &nvq->vq;
+       struct socket *sock;
+
+       mutex_lock(&vq->mutex);
+       sock = vq->private_data;
+       if (!sock)
+               goto out;
+
+       if (!vq_iotlb_prefetch(vq))
+               goto out;
+
+       vhost_disable_notify(&net->dev, vq);
+       vhost_net_disable_vq(net, vq);
+
+       if (vhost_sock_zcopy(sock))
+               handle_tx_zerocopy(net, sock);
+       else
+               handle_tx_copy(net, sock);
+
 out:
        mutex_unlock(&vq->mutex);
 }
@@ -633,53 +767,50 @@ static int sk_has_rx_data(struct sock *sk)
        return skb_queue_empty(&sk->sk_receive_queue);
 }
 
-static void vhost_rx_signal_used(struct vhost_net_virtqueue *nvq)
+static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
+                                     bool *busyloop_intr)
 {
-       struct vhost_virtqueue *vq = &nvq->vq;
-       struct vhost_dev *dev = vq->dev;
-
-       if (!nvq->done_idx)
-               return;
-
-       vhost_add_used_and_signal_n(dev, vq, vq->heads, nvq->done_idx);
-       nvq->done_idx = 0;
-}
-
-static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
-{
-       struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
-       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
-       struct vhost_virtqueue *vq = &nvq->vq;
+       struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
+       struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
+       struct vhost_virtqueue *rvq = &rnvq->vq;
+       struct vhost_virtqueue *tvq = &tnvq->vq;
        unsigned long uninitialized_var(endtime);
-       int len = peek_head_len(rvq, sk);
+       int len = peek_head_len(rnvq, sk);
 
-       if (!len && vq->busyloop_timeout) {
+       if (!len && tvq->busyloop_timeout) {
                /* Flush batched heads first */
-               vhost_rx_signal_used(rvq);
+               vhost_net_signal_used(rnvq);
                /* Both tx vq and rx socket were polled here */
-               mutex_lock_nested(&vq->mutex, 1);
-               vhost_disable_notify(&net->dev, vq);
+               mutex_lock_nested(&tvq->mutex, 1);
+               vhost_disable_notify(&net->dev, tvq);
 
                preempt_disable();
-               endtime = busy_clock() + vq->busyloop_timeout;
+               endtime = busy_clock() + tvq->busyloop_timeout;
 
-               while (vhost_can_busy_poll(&net->dev, endtime) &&
-                      !sk_has_rx_data(sk) &&
-                      vhost_vq_avail_empty(&net->dev, vq))
+               while (vhost_can_busy_poll(endtime)) {
+                       if (vhost_has_work(&net->dev)) {
+                               *busyloop_intr = true;
+                               break;
+                       }
+                       if ((sk_has_rx_data(sk) &&
+                            !vhost_vq_avail_empty(&net->dev, rvq)) ||
+                           !vhost_vq_avail_empty(&net->dev, tvq))
+                               break;
                        cpu_relax();
+               }
 
                preempt_enable();
 
-               if (!vhost_vq_avail_empty(&net->dev, vq))
-                       vhost_poll_queue(&vq->poll);
-               else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
-                       vhost_disable_notify(&net->dev, vq);
-                       vhost_poll_queue(&vq->poll);
+               if (!vhost_vq_avail_empty(&net->dev, tvq)) {
+                       vhost_poll_queue(&tvq->poll);
+               } else if (unlikely(vhost_enable_notify(&net->dev, tvq))) {
+                       vhost_disable_notify(&net->dev, tvq);
+                       vhost_poll_queue(&tvq->poll);
                }
 
-               mutex_unlock(&vq->mutex);
+               mutex_unlock(&tvq->mutex);
 
-               len = peek_head_len(rvq, sk);
+               len = peek_head_len(rnvq, sk);
        }
 
        return len;
@@ -786,6 +917,7 @@ static void handle_rx(struct vhost_net *net)
        s16 headcount;
        size_t vhost_hlen, sock_hlen;
        size_t vhost_len, sock_len;
+       bool busyloop_intr = false;
        struct socket *sock;
        struct iov_iter fixup;
        __virtio16 num_buffers;
@@ -809,7 +941,8 @@ static void handle_rx(struct vhost_net *net)
                vq->log : NULL;
        mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
 
-       while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk))) {
+       while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+                                                     &busyloop_intr))) {
                sock_len += sock_hlen;
                vhost_len = sock_len + vhost_hlen;
                headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
@@ -820,7 +953,9 @@ static void handle_rx(struct vhost_net *net)
                        goto out;
                /* OK, now we need to know about added descriptors. */
                if (!headcount) {
-                       if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+                       if (unlikely(busyloop_intr)) {
+                               vhost_poll_queue(&vq->poll);
+                       } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
                                /* They have slipped one in as we were
                                 * doing that: check again. */
                                vhost_disable_notify(&net->dev, vq);
@@ -830,6 +965,7 @@ static void handle_rx(struct vhost_net *net)
                         * they refilled. */
                        goto out;
                }
+               busyloop_intr = false;
                if (nvq->rx_ring)
                        msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
                /* On overrun, truncate and discard */
@@ -885,20 +1021,22 @@ static void handle_rx(struct vhost_net *net)
                        goto out;
                }
                nvq->done_idx += headcount;
-               if (nvq->done_idx > VHOST_RX_BATCH)
-                       vhost_rx_signal_used(nvq);
+               if (nvq->done_idx > VHOST_NET_BATCH)
+                       vhost_net_signal_used(nvq);
                if (unlikely(vq_log))
                        vhost_log_write(vq, vq_log, log, vhost_len);
                total_len += vhost_len;
-               if (unlikely(total_len >= VHOST_NET_WEIGHT) ||
-                   unlikely(++recv_pkts >= VHOST_NET_PKT_WEIGHT)) {
+               if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
                        vhost_poll_queue(&vq->poll);
                        goto out;
                }
        }
-       vhost_net_enable_vq(net, vq);
+       if (unlikely(busyloop_intr))
+               vhost_poll_queue(&vq->poll);
+       else
+               vhost_net_enable_vq(net, vq);
 out:
-       vhost_rx_signal_used(nvq);
+       vhost_net_signal_used(nvq);
        mutex_unlock(&vq->mutex);
 }
 
@@ -951,7 +1089,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
                return -ENOMEM;
        }
 
-       queue = kmalloc_array(VHOST_RX_BATCH, sizeof(void *),
+       queue = kmalloc_array(VHOST_NET_BATCH, sizeof(void *),
                              GFP_KERNEL);
        if (!queue) {
                kfree(vqs);
@@ -1265,6 +1403,21 @@ static long vhost_net_reset_owner(struct vhost_net *n)
        return err;
 }
 
+static int vhost_net_set_backend_features(struct vhost_net *n, u64 features)
+{
+       int i;
+
+       mutex_lock(&n->dev.mutex);
+       for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
+               mutex_lock(&n->vqs[i].vq.mutex);
+               n->vqs[i].vq.acked_backend_features = features;
+               mutex_unlock(&n->vqs[i].vq.mutex);
+       }
+       mutex_unlock(&n->dev.mutex);
+
+       return 0;
+}
+
 static int vhost_net_set_features(struct vhost_net *n, u64 features)
 {
        size_t vhost_hlen, sock_hlen, hdr_len;
@@ -1355,6 +1508,17 @@ static long vhost_net_ioctl(struct file *f, unsigned int ioctl,
                if (features & ~VHOST_NET_FEATURES)
                        return -EOPNOTSUPP;
                return vhost_net_set_features(n, features);
+       case VHOST_GET_BACKEND_FEATURES:
+               features = VHOST_NET_BACKEND_FEATURES;
+               if (copy_to_user(featurep, &features, sizeof(features)))
+                       return -EFAULT;
+               return 0;
+       case VHOST_SET_BACKEND_FEATURES:
+               if (copy_from_user(&features, featurep, sizeof(features)))
+                       return -EFAULT;
+               if (features & ~VHOST_NET_BACKEND_FEATURES)
+                       return -EOPNOTSUPP;
+               return vhost_net_set_backend_features(n, features);
        case VHOST_RESET_OWNER:
                return vhost_net_reset_owner(n);
        case VHOST_SET_OWNER: