]> asedeno.scripts.mit.edu Git - linux.git/blob - arch/um/drivers/virtio_uml.c
Merge tag 'compat-ioctl-fix' of git://git.kernel.org:/pub/scm/linux/kernel/git/arnd...
[linux.git] / arch / um / drivers / virtio_uml.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Virtio vhost-user driver
4  *
5  * Copyright(c) 2019 Intel Corporation
6  *
7  * This driver allows virtio devices to be used over a vhost-user socket.
8  *
9  * Guest devices can be instantiated by kernel module or command line
10  * parameters. One device will be created for each parameter. Syntax:
11  *
12  *              virtio_uml.device=<socket>:<virtio_id>[:<platform_id>]
13  * where:
14  *              <socket>        := vhost-user socket path to connect
15  *              <virtio_id>     := virtio device id (as in virtio_ids.h)
16  *              <platform_id>   := (optional) platform device id
17  *
18  * example:
19  *              virtio_uml.device=/var/uml.socket:1
20  *
21  * Based on Virtio MMIO driver by Pawel Moll, copyright 2011-2014, ARM Ltd.
22  */
23 #include <linux/module.h>
24 #include <linux/platform_device.h>
25 #include <linux/slab.h>
26 #include <linux/virtio.h>
27 #include <linux/virtio_config.h>
28 #include <linux/virtio_ring.h>
29 #include <shared/as-layout.h>
30 #include <irq_kern.h>
31 #include <init.h>
32 #include <os.h>
33 #include "vhost_user.h"
34
35 /* Workaround due to a conflict between irq_user.h and irqreturn.h */
36 #ifdef IRQ_NONE
37 #undef IRQ_NONE
38 #endif
39
40 #define MAX_SUPPORTED_QUEUE_SIZE        256
41
42 #define to_virtio_uml_device(_vdev) \
43         container_of(_vdev, struct virtio_uml_device, vdev)
44
45 struct virtio_uml_platform_data {
46         u32 virtio_device_id;
47         const char *socket_path;
48         struct work_struct conn_broken_wk;
49         struct platform_device *pdev;
50 };
51
52 struct virtio_uml_device {
53         struct virtio_device vdev;
54         struct platform_device *pdev;
55
56         int sock, req_fd;
57         u64 features;
58         u64 protocol_features;
59         u8 status;
60         u8 registered:1;
61 };
62
63 struct virtio_uml_vq_info {
64         int kick_fd, call_fd;
65         char name[32];
66 };
67
68 extern unsigned long long physmem_size, highmem;
69
70 #define vu_err(vu_dev, ...)     dev_err(&(vu_dev)->pdev->dev, __VA_ARGS__)
71
72 /* Vhost-user protocol */
73
74 static int full_sendmsg_fds(int fd, const void *buf, unsigned int len,
75                             const int *fds, unsigned int fds_num)
76 {
77         int rc;
78
79         do {
80                 rc = os_sendmsg_fds(fd, buf, len, fds, fds_num);
81                 if (rc > 0) {
82                         buf += rc;
83                         len -= rc;
84                         fds = NULL;
85                         fds_num = 0;
86                 }
87         } while (len && (rc >= 0 || rc == -EINTR));
88
89         if (rc < 0)
90                 return rc;
91         return 0;
92 }
93
94 static int full_read(int fd, void *buf, int len, bool abortable)
95 {
96         int rc;
97
98         do {
99                 rc = os_read_file(fd, buf, len);
100                 if (rc > 0) {
101                         buf += rc;
102                         len -= rc;
103                 }
104         } while (len && (rc > 0 || rc == -EINTR || (!abortable && rc == -EAGAIN)));
105
106         if (rc < 0)
107                 return rc;
108         if (rc == 0)
109                 return -ECONNRESET;
110         return 0;
111 }
112
113 static int vhost_user_recv_header(int fd, struct vhost_user_msg *msg)
114 {
115         return full_read(fd, msg, sizeof(msg->header), true);
116 }
117
118 static int vhost_user_recv(struct virtio_uml_device *vu_dev,
119                            int fd, struct vhost_user_msg *msg,
120                            size_t max_payload_size)
121 {
122         size_t size;
123         int rc = vhost_user_recv_header(fd, msg);
124
125         if (rc == -ECONNRESET && vu_dev->registered) {
126                 struct virtio_uml_platform_data *pdata;
127
128                 pdata = vu_dev->pdev->dev.platform_data;
129
130                 virtio_break_device(&vu_dev->vdev);
131                 schedule_work(&pdata->conn_broken_wk);
132         }
133         if (rc)
134                 return rc;
135         size = msg->header.size;
136         if (size > max_payload_size)
137                 return -EPROTO;
138         return full_read(fd, &msg->payload, size, false);
139 }
140
141 static int vhost_user_recv_resp(struct virtio_uml_device *vu_dev,
142                                 struct vhost_user_msg *msg,
143                                 size_t max_payload_size)
144 {
145         int rc = vhost_user_recv(vu_dev, vu_dev->sock, msg, max_payload_size);
146
147         if (rc)
148                 return rc;
149
150         if (msg->header.flags != (VHOST_USER_FLAG_REPLY | VHOST_USER_VERSION))
151                 return -EPROTO;
152
153         return 0;
154 }
155
156 static int vhost_user_recv_u64(struct virtio_uml_device *vu_dev,
157                                u64 *value)
158 {
159         struct vhost_user_msg msg;
160         int rc = vhost_user_recv_resp(vu_dev, &msg,
161                                       sizeof(msg.payload.integer));
162
163         if (rc)
164                 return rc;
165         if (msg.header.size != sizeof(msg.payload.integer))
166                 return -EPROTO;
167         *value = msg.payload.integer;
168         return 0;
169 }
170
171 static int vhost_user_recv_req(struct virtio_uml_device *vu_dev,
172                                struct vhost_user_msg *msg,
173                                size_t max_payload_size)
174 {
175         int rc = vhost_user_recv(vu_dev, vu_dev->req_fd, msg, max_payload_size);
176
177         if (rc)
178                 return rc;
179
180         if ((msg->header.flags & ~VHOST_USER_FLAG_NEED_REPLY) !=
181                         VHOST_USER_VERSION)
182                 return -EPROTO;
183
184         return 0;
185 }
186
187 static int vhost_user_send(struct virtio_uml_device *vu_dev,
188                            bool need_response, struct vhost_user_msg *msg,
189                            int *fds, size_t num_fds)
190 {
191         size_t size = sizeof(msg->header) + msg->header.size;
192         bool request_ack;
193         int rc;
194
195         msg->header.flags |= VHOST_USER_VERSION;
196
197         /*
198          * The need_response flag indicates that we already need a response,
199          * e.g. to read the features. In these cases, don't request an ACK as
200          * it is meaningless. Also request an ACK only if supported.
201          */
202         request_ack = !need_response;
203         if (!(vu_dev->protocol_features &
204                         BIT_ULL(VHOST_USER_PROTOCOL_F_REPLY_ACK)))
205                 request_ack = false;
206
207         if (request_ack)
208                 msg->header.flags |= VHOST_USER_FLAG_NEED_REPLY;
209
210         rc = full_sendmsg_fds(vu_dev->sock, msg, size, fds, num_fds);
211         if (rc < 0)
212                 return rc;
213
214         if (request_ack) {
215                 uint64_t status;
216
217                 rc = vhost_user_recv_u64(vu_dev, &status);
218                 if (rc)
219                         return rc;
220
221                 if (status) {
222                         vu_err(vu_dev, "slave reports error: %llu\n", status);
223                         return -EIO;
224                 }
225         }
226
227         return 0;
228 }
229
230 static int vhost_user_send_no_payload(struct virtio_uml_device *vu_dev,
231                                       bool need_response, u32 request)
232 {
233         struct vhost_user_msg msg = {
234                 .header.request = request,
235         };
236
237         return vhost_user_send(vu_dev, need_response, &msg, NULL, 0);
238 }
239
240 static int vhost_user_send_no_payload_fd(struct virtio_uml_device *vu_dev,
241                                          u32 request, int fd)
242 {
243         struct vhost_user_msg msg = {
244                 .header.request = request,
245         };
246
247         return vhost_user_send(vu_dev, false, &msg, &fd, 1);
248 }
249
250 static int vhost_user_send_u64(struct virtio_uml_device *vu_dev,
251                                u32 request, u64 value)
252 {
253         struct vhost_user_msg msg = {
254                 .header.request = request,
255                 .header.size = sizeof(msg.payload.integer),
256                 .payload.integer = value,
257         };
258
259         return vhost_user_send(vu_dev, false, &msg, NULL, 0);
260 }
261
262 static int vhost_user_set_owner(struct virtio_uml_device *vu_dev)
263 {
264         return vhost_user_send_no_payload(vu_dev, false, VHOST_USER_SET_OWNER);
265 }
266
267 static int vhost_user_get_features(struct virtio_uml_device *vu_dev,
268                                    u64 *features)
269 {
270         int rc = vhost_user_send_no_payload(vu_dev, true,
271                                             VHOST_USER_GET_FEATURES);
272
273         if (rc)
274                 return rc;
275         return vhost_user_recv_u64(vu_dev, features);
276 }
277
278 static int vhost_user_set_features(struct virtio_uml_device *vu_dev,
279                                    u64 features)
280 {
281         return vhost_user_send_u64(vu_dev, VHOST_USER_SET_FEATURES, features);
282 }
283
284 static int vhost_user_get_protocol_features(struct virtio_uml_device *vu_dev,
285                                             u64 *protocol_features)
286 {
287         int rc = vhost_user_send_no_payload(vu_dev, true,
288                         VHOST_USER_GET_PROTOCOL_FEATURES);
289
290         if (rc)
291                 return rc;
292         return vhost_user_recv_u64(vu_dev, protocol_features);
293 }
294
295 static int vhost_user_set_protocol_features(struct virtio_uml_device *vu_dev,
296                                             u64 protocol_features)
297 {
298         return vhost_user_send_u64(vu_dev, VHOST_USER_SET_PROTOCOL_FEATURES,
299                                    protocol_features);
300 }
301
302 static void vhost_user_reply(struct virtio_uml_device *vu_dev,
303                              struct vhost_user_msg *msg, int response)
304 {
305         struct vhost_user_msg reply = {
306                 .payload.integer = response,
307         };
308         size_t size = sizeof(reply.header) + sizeof(reply.payload.integer);
309         int rc;
310
311         reply.header = msg->header;
312         reply.header.flags &= ~VHOST_USER_FLAG_NEED_REPLY;
313         reply.header.flags |= VHOST_USER_FLAG_REPLY;
314         reply.header.size = sizeof(reply.payload.integer);
315
316         rc = full_sendmsg_fds(vu_dev->req_fd, &reply, size, NULL, 0);
317
318         if (rc)
319                 vu_err(vu_dev,
320                        "sending reply to slave request failed: %d (size %zu)\n",
321                        rc, size);
322 }
323
324 static irqreturn_t vu_req_interrupt(int irq, void *data)
325 {
326         struct virtio_uml_device *vu_dev = data;
327         int response = 1;
328         struct {
329                 struct vhost_user_msg msg;
330                 u8 extra_payload[512];
331         } msg;
332         int rc;
333
334         rc = vhost_user_recv_req(vu_dev, &msg.msg,
335                                  sizeof(msg.msg.payload) +
336                                  sizeof(msg.extra_payload));
337
338         if (rc)
339                 return IRQ_NONE;
340
341         switch (msg.msg.header.request) {
342         case VHOST_USER_SLAVE_CONFIG_CHANGE_MSG:
343                 virtio_config_changed(&vu_dev->vdev);
344                 response = 0;
345                 break;
346         case VHOST_USER_SLAVE_IOTLB_MSG:
347                 /* not supported - VIRTIO_F_IOMMU_PLATFORM */
348         case VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG:
349                 /* not supported - VHOST_USER_PROTOCOL_F_HOST_NOTIFIER */
350         default:
351                 vu_err(vu_dev, "unexpected slave request %d\n",
352                        msg.msg.header.request);
353         }
354
355         if (msg.msg.header.flags & VHOST_USER_FLAG_NEED_REPLY)
356                 vhost_user_reply(vu_dev, &msg.msg, response);
357
358         return IRQ_HANDLED;
359 }
360
361 static int vhost_user_init_slave_req(struct virtio_uml_device *vu_dev)
362 {
363         int rc, req_fds[2];
364
365         /* Use a pipe for slave req fd, SIGIO is not supported for eventfd */
366         rc = os_pipe(req_fds, true, true);
367         if (rc < 0)
368                 return rc;
369         vu_dev->req_fd = req_fds[0];
370
371         rc = um_request_irq(VIRTIO_IRQ, vu_dev->req_fd, IRQ_READ,
372                             vu_req_interrupt, IRQF_SHARED,
373                             vu_dev->pdev->name, vu_dev);
374         if (rc)
375                 goto err_close;
376
377         rc = vhost_user_send_no_payload_fd(vu_dev, VHOST_USER_SET_SLAVE_REQ_FD,
378                                            req_fds[1]);
379         if (rc)
380                 goto err_free_irq;
381
382         goto out;
383
384 err_free_irq:
385         um_free_irq(VIRTIO_IRQ, vu_dev);
386 err_close:
387         os_close_file(req_fds[0]);
388 out:
389         /* Close unused write end of request fds */
390         os_close_file(req_fds[1]);
391         return rc;
392 }
393
394 static int vhost_user_init(struct virtio_uml_device *vu_dev)
395 {
396         int rc = vhost_user_set_owner(vu_dev);
397
398         if (rc)
399                 return rc;
400         rc = vhost_user_get_features(vu_dev, &vu_dev->features);
401         if (rc)
402                 return rc;
403
404         if (vu_dev->features & BIT_ULL(VHOST_USER_F_PROTOCOL_FEATURES)) {
405                 rc = vhost_user_get_protocol_features(vu_dev,
406                                 &vu_dev->protocol_features);
407                 if (rc)
408                         return rc;
409                 vu_dev->protocol_features &= VHOST_USER_SUPPORTED_PROTOCOL_F;
410                 rc = vhost_user_set_protocol_features(vu_dev,
411                                 vu_dev->protocol_features);
412                 if (rc)
413                         return rc;
414         }
415
416         if (vu_dev->protocol_features &
417                         BIT_ULL(VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
418                 rc = vhost_user_init_slave_req(vu_dev);
419                 if (rc)
420                         return rc;
421         }
422
423         return 0;
424 }
425
426 static void vhost_user_get_config(struct virtio_uml_device *vu_dev,
427                                   u32 offset, void *buf, u32 len)
428 {
429         u32 cfg_size = offset + len;
430         struct vhost_user_msg *msg;
431         size_t payload_size = sizeof(msg->payload.config) + cfg_size;
432         size_t msg_size = sizeof(msg->header) + payload_size;
433         int rc;
434
435         if (!(vu_dev->protocol_features &
436               BIT_ULL(VHOST_USER_PROTOCOL_F_CONFIG)))
437                 return;
438
439         msg = kzalloc(msg_size, GFP_KERNEL);
440         if (!msg)
441                 return;
442         msg->header.request = VHOST_USER_GET_CONFIG;
443         msg->header.size = payload_size;
444         msg->payload.config.offset = 0;
445         msg->payload.config.size = cfg_size;
446
447         rc = vhost_user_send(vu_dev, true, msg, NULL, 0);
448         if (rc) {
449                 vu_err(vu_dev, "sending VHOST_USER_GET_CONFIG failed: %d\n",
450                        rc);
451                 goto free;
452         }
453
454         rc = vhost_user_recv_resp(vu_dev, msg, msg_size);
455         if (rc) {
456                 vu_err(vu_dev,
457                        "receiving VHOST_USER_GET_CONFIG response failed: %d\n",
458                        rc);
459                 goto free;
460         }
461
462         if (msg->header.size != payload_size ||
463             msg->payload.config.size != cfg_size) {
464                 rc = -EPROTO;
465                 vu_err(vu_dev,
466                        "Invalid VHOST_USER_GET_CONFIG sizes (payload %d expected %zu, config %u expected %u)\n",
467                        msg->header.size, payload_size,
468                        msg->payload.config.size, cfg_size);
469                 goto free;
470         }
471         memcpy(buf, msg->payload.config.payload + offset, len);
472
473 free:
474         kfree(msg);
475 }
476
477 static void vhost_user_set_config(struct virtio_uml_device *vu_dev,
478                                   u32 offset, const void *buf, u32 len)
479 {
480         struct vhost_user_msg *msg;
481         size_t payload_size = sizeof(msg->payload.config) + len;
482         size_t msg_size = sizeof(msg->header) + payload_size;
483         int rc;
484
485         if (!(vu_dev->protocol_features &
486               BIT_ULL(VHOST_USER_PROTOCOL_F_CONFIG)))
487                 return;
488
489         msg = kzalloc(msg_size, GFP_KERNEL);
490         if (!msg)
491                 return;
492         msg->header.request = VHOST_USER_SET_CONFIG;
493         msg->header.size = payload_size;
494         msg->payload.config.offset = offset;
495         msg->payload.config.size = len;
496         memcpy(msg->payload.config.payload, buf, len);
497
498         rc = vhost_user_send(vu_dev, false, msg, NULL, 0);
499         if (rc)
500                 vu_err(vu_dev, "sending VHOST_USER_SET_CONFIG failed: %d\n",
501                        rc);
502
503         kfree(msg);
504 }
505
506 static int vhost_user_init_mem_region(u64 addr, u64 size, int *fd_out,
507                                       struct vhost_user_mem_region *region_out)
508 {
509         unsigned long long mem_offset;
510         int rc = phys_mapping(addr, &mem_offset);
511
512         if (WARN(rc < 0, "phys_mapping of 0x%llx returned %d\n", addr, rc))
513                 return -EFAULT;
514         *fd_out = rc;
515         region_out->guest_addr = addr;
516         region_out->user_addr = addr;
517         region_out->size = size;
518         region_out->mmap_offset = mem_offset;
519
520         /* Ensure mapping is valid for the entire region */
521         rc = phys_mapping(addr + size - 1, &mem_offset);
522         if (WARN(rc != *fd_out, "phys_mapping of 0x%llx failed: %d != %d\n",
523                  addr + size - 1, rc, *fd_out))
524                 return -EFAULT;
525         return 0;
526 }
527
528 static int vhost_user_set_mem_table(struct virtio_uml_device *vu_dev)
529 {
530         struct vhost_user_msg msg = {
531                 .header.request = VHOST_USER_SET_MEM_TABLE,
532                 .header.size = sizeof(msg.payload.mem_regions),
533                 .payload.mem_regions.num = 1,
534         };
535         unsigned long reserved = uml_reserved - uml_physmem;
536         int fds[2];
537         int rc;
538
539         /*
540          * This is a bit tricky, see also the comment with setup_physmem().
541          *
542          * Essentially, setup_physmem() uses a file to mmap() our physmem,
543          * but the code and data we *already* have is omitted. To us, this
544          * is no difference, since they both become part of our address
545          * space and memory consumption. To somebody looking in from the
546          * outside, however, it is different because the part of our memory
547          * consumption that's already part of the binary (code/data) is not
548          * mapped from the file, so it's not visible to another mmap from
549          * the file descriptor.
550          *
551          * Thus, don't advertise this space to the vhost-user slave. This
552          * means that the slave will likely abort or similar when we give
553          * it an address from the hidden range, since it's not marked as
554          * a valid address, but at least that way we detect the issue and
555          * don't just have the slave read an all-zeroes buffer from the
556          * shared memory file, or write something there that we can never
557          * see (depending on the direction of the virtqueue traffic.)
558          *
559          * Since we usually don't want to use .text for virtio buffers,
560          * this effectively means that you cannot use
561          *  1) global variables, which are in the .bss and not in the shm
562          *     file-backed memory
563          *  2) the stack in some processes, depending on where they have
564          *     their stack (or maybe only no interrupt stack?)
565          *
566          * The stack is already not typically valid for DMA, so this isn't
567          * much of a restriction, but global variables might be encountered.
568          *
569          * It might be possible to fix it by copying around the data that's
570          * between bss_start and where we map the file now, but it's not
571          * something that you typically encounter with virtio drivers, so
572          * it didn't seem worthwhile.
573          */
574         rc = vhost_user_init_mem_region(reserved, physmem_size - reserved,
575                                         &fds[0],
576                                         &msg.payload.mem_regions.regions[0]);
577
578         if (rc < 0)
579                 return rc;
580         if (highmem) {
581                 msg.payload.mem_regions.num++;
582                 rc = vhost_user_init_mem_region(__pa(end_iomem), highmem,
583                                 &fds[1], &msg.payload.mem_regions.regions[1]);
584                 if (rc < 0)
585                         return rc;
586         }
587
588         return vhost_user_send(vu_dev, false, &msg, fds,
589                                msg.payload.mem_regions.num);
590 }
591
592 static int vhost_user_set_vring_state(struct virtio_uml_device *vu_dev,
593                                       u32 request, u32 index, u32 num)
594 {
595         struct vhost_user_msg msg = {
596                 .header.request = request,
597                 .header.size = sizeof(msg.payload.vring_state),
598                 .payload.vring_state.index = index,
599                 .payload.vring_state.num = num,
600         };
601
602         return vhost_user_send(vu_dev, false, &msg, NULL, 0);
603 }
604
605 static int vhost_user_set_vring_num(struct virtio_uml_device *vu_dev,
606                                     u32 index, u32 num)
607 {
608         return vhost_user_set_vring_state(vu_dev, VHOST_USER_SET_VRING_NUM,
609                                           index, num);
610 }
611
612 static int vhost_user_set_vring_base(struct virtio_uml_device *vu_dev,
613                                      u32 index, u32 offset)
614 {
615         return vhost_user_set_vring_state(vu_dev, VHOST_USER_SET_VRING_BASE,
616                                           index, offset);
617 }
618
619 static int vhost_user_set_vring_addr(struct virtio_uml_device *vu_dev,
620                                      u32 index, u64 desc, u64 used, u64 avail,
621                                      u64 log)
622 {
623         struct vhost_user_msg msg = {
624                 .header.request = VHOST_USER_SET_VRING_ADDR,
625                 .header.size = sizeof(msg.payload.vring_addr),
626                 .payload.vring_addr.index = index,
627                 .payload.vring_addr.desc = desc,
628                 .payload.vring_addr.used = used,
629                 .payload.vring_addr.avail = avail,
630                 .payload.vring_addr.log = log,
631         };
632
633         return vhost_user_send(vu_dev, false, &msg, NULL, 0);
634 }
635
636 static int vhost_user_set_vring_fd(struct virtio_uml_device *vu_dev,
637                                    u32 request, int index, int fd)
638 {
639         struct vhost_user_msg msg = {
640                 .header.request = request,
641                 .header.size = sizeof(msg.payload.integer),
642                 .payload.integer = index,
643         };
644
645         if (index & ~VHOST_USER_VRING_INDEX_MASK)
646                 return -EINVAL;
647         if (fd < 0) {
648                 msg.payload.integer |= VHOST_USER_VRING_POLL_MASK;
649                 return vhost_user_send(vu_dev, false, &msg, NULL, 0);
650         }
651         return vhost_user_send(vu_dev, false, &msg, &fd, 1);
652 }
653
654 static int vhost_user_set_vring_call(struct virtio_uml_device *vu_dev,
655                                      int index, int fd)
656 {
657         return vhost_user_set_vring_fd(vu_dev, VHOST_USER_SET_VRING_CALL,
658                                        index, fd);
659 }
660
661 static int vhost_user_set_vring_kick(struct virtio_uml_device *vu_dev,
662                                      int index, int fd)
663 {
664         return vhost_user_set_vring_fd(vu_dev, VHOST_USER_SET_VRING_KICK,
665                                        index, fd);
666 }
667
668 static int vhost_user_set_vring_enable(struct virtio_uml_device *vu_dev,
669                                        u32 index, bool enable)
670 {
671         if (!(vu_dev->features & BIT_ULL(VHOST_USER_F_PROTOCOL_FEATURES)))
672                 return 0;
673
674         return vhost_user_set_vring_state(vu_dev, VHOST_USER_SET_VRING_ENABLE,
675                                           index, enable);
676 }
677
678
679 /* Virtio interface */
680
681 static bool vu_notify(struct virtqueue *vq)
682 {
683         struct virtio_uml_vq_info *info = vq->priv;
684         const uint64_t n = 1;
685         int rc;
686
687         do {
688                 rc = os_write_file(info->kick_fd, &n, sizeof(n));
689         } while (rc == -EINTR);
690         return !WARN(rc != sizeof(n), "write returned %d\n", rc);
691 }
692
693 static irqreturn_t vu_interrupt(int irq, void *opaque)
694 {
695         struct virtqueue *vq = opaque;
696         struct virtio_uml_vq_info *info = vq->priv;
697         uint64_t n;
698         int rc;
699         irqreturn_t ret = IRQ_NONE;
700
701         do {
702                 rc = os_read_file(info->call_fd, &n, sizeof(n));
703                 if (rc == sizeof(n))
704                         ret |= vring_interrupt(irq, vq);
705         } while (rc == sizeof(n) || rc == -EINTR);
706         WARN(rc != -EAGAIN, "read returned %d\n", rc);
707         return ret;
708 }
709
710
711 static void vu_get(struct virtio_device *vdev, unsigned offset,
712                    void *buf, unsigned len)
713 {
714         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
715
716         vhost_user_get_config(vu_dev, offset, buf, len);
717 }
718
719 static void vu_set(struct virtio_device *vdev, unsigned offset,
720                    const void *buf, unsigned len)
721 {
722         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
723
724         vhost_user_set_config(vu_dev, offset, buf, len);
725 }
726
727 static u8 vu_get_status(struct virtio_device *vdev)
728 {
729         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
730
731         return vu_dev->status;
732 }
733
734 static void vu_set_status(struct virtio_device *vdev, u8 status)
735 {
736         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
737
738         vu_dev->status = status;
739 }
740
741 static void vu_reset(struct virtio_device *vdev)
742 {
743         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
744
745         vu_dev->status = 0;
746 }
747
748 static void vu_del_vq(struct virtqueue *vq)
749 {
750         struct virtio_uml_vq_info *info = vq->priv;
751
752         um_free_irq(VIRTIO_IRQ, vq);
753
754         os_close_file(info->call_fd);
755         os_close_file(info->kick_fd);
756
757         vring_del_virtqueue(vq);
758         kfree(info);
759 }
760
761 static void vu_del_vqs(struct virtio_device *vdev)
762 {
763         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
764         struct virtqueue *vq, *n;
765         u64 features;
766
767         /* Note: reverse order as a workaround to a decoding bug in snabb */
768         list_for_each_entry_reverse(vq, &vdev->vqs, list)
769                 WARN_ON(vhost_user_set_vring_enable(vu_dev, vq->index, false));
770
771         /* Ensure previous messages have been processed */
772         WARN_ON(vhost_user_get_features(vu_dev, &features));
773
774         list_for_each_entry_safe(vq, n, &vdev->vqs, list)
775                 vu_del_vq(vq);
776 }
777
778 static int vu_setup_vq_call_fd(struct virtio_uml_device *vu_dev,
779                                struct virtqueue *vq)
780 {
781         struct virtio_uml_vq_info *info = vq->priv;
782         int call_fds[2];
783         int rc;
784
785         /* Use a pipe for call fd, since SIGIO is not supported for eventfd */
786         rc = os_pipe(call_fds, true, true);
787         if (rc < 0)
788                 return rc;
789
790         info->call_fd = call_fds[0];
791         rc = um_request_irq(VIRTIO_IRQ, info->call_fd, IRQ_READ,
792                             vu_interrupt, IRQF_SHARED, info->name, vq);
793         if (rc)
794                 goto close_both;
795
796         rc = vhost_user_set_vring_call(vu_dev, vq->index, call_fds[1]);
797         if (rc)
798                 goto release_irq;
799
800         goto out;
801
802 release_irq:
803         um_free_irq(VIRTIO_IRQ, vq);
804 close_both:
805         os_close_file(call_fds[0]);
806 out:
807         /* Close (unused) write end of call fds */
808         os_close_file(call_fds[1]);
809
810         return rc;
811 }
812
813 static struct virtqueue *vu_setup_vq(struct virtio_device *vdev,
814                                      unsigned index, vq_callback_t *callback,
815                                      const char *name, bool ctx)
816 {
817         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
818         struct platform_device *pdev = vu_dev->pdev;
819         struct virtio_uml_vq_info *info;
820         struct virtqueue *vq;
821         int num = MAX_SUPPORTED_QUEUE_SIZE;
822         int rc;
823
824         info = kzalloc(sizeof(*info), GFP_KERNEL);
825         if (!info) {
826                 rc = -ENOMEM;
827                 goto error_kzalloc;
828         }
829         snprintf(info->name, sizeof(info->name), "%s.%d-%s", pdev->name,
830                  pdev->id, name);
831
832         vq = vring_create_virtqueue(index, num, PAGE_SIZE, vdev, true, true,
833                                     ctx, vu_notify, callback, info->name);
834         if (!vq) {
835                 rc = -ENOMEM;
836                 goto error_create;
837         }
838         vq->priv = info;
839         num = virtqueue_get_vring_size(vq);
840
841         rc = os_eventfd(0, 0);
842         if (rc < 0)
843                 goto error_kick;
844         info->kick_fd = rc;
845
846         rc = vu_setup_vq_call_fd(vu_dev, vq);
847         if (rc)
848                 goto error_call;
849
850         rc = vhost_user_set_vring_num(vu_dev, index, num);
851         if (rc)
852                 goto error_setup;
853
854         rc = vhost_user_set_vring_base(vu_dev, index, 0);
855         if (rc)
856                 goto error_setup;
857
858         rc = vhost_user_set_vring_addr(vu_dev, index,
859                                        virtqueue_get_desc_addr(vq),
860                                        virtqueue_get_used_addr(vq),
861                                        virtqueue_get_avail_addr(vq),
862                                        (u64) -1);
863         if (rc)
864                 goto error_setup;
865
866         return vq;
867
868 error_setup:
869         um_free_irq(VIRTIO_IRQ, vq);
870         os_close_file(info->call_fd);
871 error_call:
872         os_close_file(info->kick_fd);
873 error_kick:
874         vring_del_virtqueue(vq);
875 error_create:
876         kfree(info);
877 error_kzalloc:
878         return ERR_PTR(rc);
879 }
880
881 static int vu_find_vqs(struct virtio_device *vdev, unsigned nvqs,
882                        struct virtqueue *vqs[], vq_callback_t *callbacks[],
883                        const char * const names[], const bool *ctx,
884                        struct irq_affinity *desc)
885 {
886         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
887         int i, queue_idx = 0, rc;
888         struct virtqueue *vq;
889
890         rc = vhost_user_set_mem_table(vu_dev);
891         if (rc)
892                 return rc;
893
894         for (i = 0; i < nvqs; ++i) {
895                 if (!names[i]) {
896                         vqs[i] = NULL;
897                         continue;
898                 }
899
900                 vqs[i] = vu_setup_vq(vdev, queue_idx++, callbacks[i], names[i],
901                                      ctx ? ctx[i] : false);
902                 if (IS_ERR(vqs[i])) {
903                         rc = PTR_ERR(vqs[i]);
904                         goto error_setup;
905                 }
906         }
907
908         list_for_each_entry(vq, &vdev->vqs, list) {
909                 struct virtio_uml_vq_info *info = vq->priv;
910
911                 rc = vhost_user_set_vring_kick(vu_dev, vq->index,
912                                                info->kick_fd);
913                 if (rc)
914                         goto error_setup;
915
916                 rc = vhost_user_set_vring_enable(vu_dev, vq->index, true);
917                 if (rc)
918                         goto error_setup;
919         }
920
921         return 0;
922
923 error_setup:
924         vu_del_vqs(vdev);
925         return rc;
926 }
927
928 static u64 vu_get_features(struct virtio_device *vdev)
929 {
930         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
931
932         return vu_dev->features;
933 }
934
935 static int vu_finalize_features(struct virtio_device *vdev)
936 {
937         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
938         u64 supported = vdev->features & VHOST_USER_SUPPORTED_F;
939
940         vring_transport_features(vdev);
941         vu_dev->features = vdev->features | supported;
942
943         return vhost_user_set_features(vu_dev, vu_dev->features);
944 }
945
946 static const char *vu_bus_name(struct virtio_device *vdev)
947 {
948         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
949
950         return vu_dev->pdev->name;
951 }
952
953 static const struct virtio_config_ops virtio_uml_config_ops = {
954         .get = vu_get,
955         .set = vu_set,
956         .get_status = vu_get_status,
957         .set_status = vu_set_status,
958         .reset = vu_reset,
959         .find_vqs = vu_find_vqs,
960         .del_vqs = vu_del_vqs,
961         .get_features = vu_get_features,
962         .finalize_features = vu_finalize_features,
963         .bus_name = vu_bus_name,
964 };
965
966 static void virtio_uml_release_dev(struct device *d)
967 {
968         struct virtio_device *vdev =
969                         container_of(d, struct virtio_device, dev);
970         struct virtio_uml_device *vu_dev = to_virtio_uml_device(vdev);
971
972         /* might not have been opened due to not negotiating the feature */
973         if (vu_dev->req_fd >= 0) {
974                 um_free_irq(VIRTIO_IRQ, vu_dev);
975                 os_close_file(vu_dev->req_fd);
976         }
977
978         os_close_file(vu_dev->sock);
979 }
980
981 /* Platform device */
982
983 static int virtio_uml_probe(struct platform_device *pdev)
984 {
985         struct virtio_uml_platform_data *pdata = pdev->dev.platform_data;
986         struct virtio_uml_device *vu_dev;
987         int rc;
988
989         if (!pdata)
990                 return -EINVAL;
991
992         vu_dev = devm_kzalloc(&pdev->dev, sizeof(*vu_dev), GFP_KERNEL);
993         if (!vu_dev)
994                 return -ENOMEM;
995
996         vu_dev->vdev.dev.parent = &pdev->dev;
997         vu_dev->vdev.dev.release = virtio_uml_release_dev;
998         vu_dev->vdev.config = &virtio_uml_config_ops;
999         vu_dev->vdev.id.device = pdata->virtio_device_id;
1000         vu_dev->vdev.id.vendor = VIRTIO_DEV_ANY_ID;
1001         vu_dev->pdev = pdev;
1002         vu_dev->req_fd = -1;
1003
1004         do {
1005                 rc = os_connect_socket(pdata->socket_path);
1006         } while (rc == -EINTR);
1007         if (rc < 0)
1008                 return rc;
1009         vu_dev->sock = rc;
1010
1011         rc = vhost_user_init(vu_dev);
1012         if (rc)
1013                 goto error_init;
1014
1015         platform_set_drvdata(pdev, vu_dev);
1016
1017         rc = register_virtio_device(&vu_dev->vdev);
1018         if (rc)
1019                 put_device(&vu_dev->vdev.dev);
1020         vu_dev->registered = 1;
1021         return rc;
1022
1023 error_init:
1024         os_close_file(vu_dev->sock);
1025         return rc;
1026 }
1027
1028 static int virtio_uml_remove(struct platform_device *pdev)
1029 {
1030         struct virtio_uml_device *vu_dev = platform_get_drvdata(pdev);
1031
1032         unregister_virtio_device(&vu_dev->vdev);
1033         return 0;
1034 }
1035
1036 /* Command line device list */
1037
1038 static void vu_cmdline_release_dev(struct device *d)
1039 {
1040 }
1041
1042 static struct device vu_cmdline_parent = {
1043         .init_name = "virtio-uml-cmdline",
1044         .release = vu_cmdline_release_dev,
1045 };
1046
1047 static bool vu_cmdline_parent_registered;
1048 static int vu_cmdline_id;
1049
1050 static int vu_unregister_cmdline_device(struct device *dev, void *data)
1051 {
1052         struct platform_device *pdev = to_platform_device(dev);
1053         struct virtio_uml_platform_data *pdata = pdev->dev.platform_data;
1054
1055         kfree(pdata->socket_path);
1056         platform_device_unregister(pdev);
1057         return 0;
1058 }
1059
1060 static void vu_conn_broken(struct work_struct *wk)
1061 {
1062         struct virtio_uml_platform_data *pdata;
1063
1064         pdata = container_of(wk, struct virtio_uml_platform_data, conn_broken_wk);
1065         vu_unregister_cmdline_device(&pdata->pdev->dev, NULL);
1066 }
1067
1068 static int vu_cmdline_set(const char *device, const struct kernel_param *kp)
1069 {
1070         const char *ids = strchr(device, ':');
1071         unsigned int virtio_device_id;
1072         int processed, consumed, err;
1073         char *socket_path;
1074         struct virtio_uml_platform_data pdata, *ppdata;
1075         struct platform_device *pdev;
1076
1077         if (!ids || ids == device)
1078                 return -EINVAL;
1079
1080         processed = sscanf(ids, ":%u%n:%d%n",
1081                            &virtio_device_id, &consumed,
1082                            &vu_cmdline_id, &consumed);
1083
1084         if (processed < 1 || ids[consumed])
1085                 return -EINVAL;
1086
1087         if (!vu_cmdline_parent_registered) {
1088                 err = device_register(&vu_cmdline_parent);
1089                 if (err) {
1090                         pr_err("Failed to register parent device!\n");
1091                         put_device(&vu_cmdline_parent);
1092                         return err;
1093                 }
1094                 vu_cmdline_parent_registered = true;
1095         }
1096
1097         socket_path = kmemdup_nul(device, ids - device, GFP_KERNEL);
1098         if (!socket_path)
1099                 return -ENOMEM;
1100
1101         pdata.virtio_device_id = (u32) virtio_device_id;
1102         pdata.socket_path = socket_path;
1103
1104         pr_info("Registering device virtio-uml.%d id=%d at %s\n",
1105                 vu_cmdline_id, virtio_device_id, socket_path);
1106
1107         pdev = platform_device_register_data(&vu_cmdline_parent, "virtio-uml",
1108                                              vu_cmdline_id++, &pdata,
1109                                              sizeof(pdata));
1110         err = PTR_ERR_OR_ZERO(pdev);
1111         if (err)
1112                 goto free;
1113
1114         ppdata = pdev->dev.platform_data;
1115         ppdata->pdev = pdev;
1116         INIT_WORK(&ppdata->conn_broken_wk, vu_conn_broken);
1117
1118         return 0;
1119
1120 free:
1121         kfree(socket_path);
1122         return err;
1123 }
1124
1125 static int vu_cmdline_get_device(struct device *dev, void *data)
1126 {
1127         struct platform_device *pdev = to_platform_device(dev);
1128         struct virtio_uml_platform_data *pdata = pdev->dev.platform_data;
1129         char *buffer = data;
1130         unsigned int len = strlen(buffer);
1131
1132         snprintf(buffer + len, PAGE_SIZE - len, "%s:%d:%d\n",
1133                  pdata->socket_path, pdata->virtio_device_id, pdev->id);
1134         return 0;
1135 }
1136
1137 static int vu_cmdline_get(char *buffer, const struct kernel_param *kp)
1138 {
1139         buffer[0] = '\0';
1140         if (vu_cmdline_parent_registered)
1141                 device_for_each_child(&vu_cmdline_parent, buffer,
1142                                       vu_cmdline_get_device);
1143         return strlen(buffer) + 1;
1144 }
1145
1146 static const struct kernel_param_ops vu_cmdline_param_ops = {
1147         .set = vu_cmdline_set,
1148         .get = vu_cmdline_get,
1149 };
1150
1151 device_param_cb(device, &vu_cmdline_param_ops, NULL, S_IRUSR);
1152 __uml_help(vu_cmdline_param_ops,
1153 "virtio_uml.device=<socket>:<virtio_id>[:<platform_id>]\n"
1154 "    Configure a virtio device over a vhost-user socket.\n"
1155 "    See virtio_ids.h for a list of possible virtio device id values.\n"
1156 "    Optionally use a specific platform_device id.\n\n"
1157 );
1158
1159
1160 static void vu_unregister_cmdline_devices(void)
1161 {
1162         if (vu_cmdline_parent_registered) {
1163                 device_for_each_child(&vu_cmdline_parent, NULL,
1164                                       vu_unregister_cmdline_device);
1165                 device_unregister(&vu_cmdline_parent);
1166                 vu_cmdline_parent_registered = false;
1167         }
1168 }
1169
1170 /* Platform driver */
1171
1172 static const struct of_device_id virtio_uml_match[] = {
1173         { .compatible = "virtio,uml", },
1174         { }
1175 };
1176 MODULE_DEVICE_TABLE(of, virtio_uml_match);
1177
1178 static struct platform_driver virtio_uml_driver = {
1179         .probe = virtio_uml_probe,
1180         .remove = virtio_uml_remove,
1181         .driver = {
1182                 .name = "virtio-uml",
1183                 .of_match_table = virtio_uml_match,
1184         },
1185 };
1186
1187 static int __init virtio_uml_init(void)
1188 {
1189         return platform_driver_register(&virtio_uml_driver);
1190 }
1191
1192 static void __exit virtio_uml_exit(void)
1193 {
1194         platform_driver_unregister(&virtio_uml_driver);
1195         vu_unregister_cmdline_devices();
1196 }
1197
1198 module_init(virtio_uml_init);
1199 module_exit(virtio_uml_exit);
1200 __uml_exitcall(virtio_uml_exit);
1201
1202 MODULE_DESCRIPTION("UML driver for vhost-user virtio devices");
1203 MODULE_LICENSE("GPL");