]> asedeno.scripts.mit.edu Git - linux.git/blob - net/vmw_vsock/virtio_transport_common.c
Merge branch 'pm-devfreq'
[linux.git] / net / vmw_vsock / virtio_transport_common.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * common code for virtio vsock
4  *
5  * Copyright (C) 2013-2015 Red Hat, Inc.
6  * Author: Asias He <asias@redhat.com>
7  *         Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 #include <linux/spinlock.h>
10 #include <linux/module.h>
11 #include <linux/sched/signal.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18 #include <uapi/linux/vsockmon.h>
19
20 #include <net/sock.h>
21 #include <net/af_vsock.h>
22
23 #define CREATE_TRACE_POINTS
24 #include <trace/events/vsock_virtio_transport_common.h>
25
26 /* How long to wait for graceful shutdown of a connection */
27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
28
29 /* Threshold for detecting small packets to copy */
30 #define GOOD_COPY_LEN  128
31
32 static const struct virtio_transport *
33 virtio_transport_get_ops(struct vsock_sock *vsk)
34 {
35         const struct vsock_transport *t = vsock_core_get_transport(vsk);
36
37         if (WARN_ON(!t))
38                 return NULL;
39
40         return container_of(t, struct virtio_transport, transport);
41 }
42
43 static struct virtio_vsock_pkt *
44 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
45                            size_t len,
46                            u32 src_cid,
47                            u32 src_port,
48                            u32 dst_cid,
49                            u32 dst_port)
50 {
51         struct virtio_vsock_pkt *pkt;
52         int err;
53
54         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
55         if (!pkt)
56                 return NULL;
57
58         pkt->hdr.type           = cpu_to_le16(info->type);
59         pkt->hdr.op             = cpu_to_le16(info->op);
60         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
61         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
62         pkt->hdr.src_port       = cpu_to_le32(src_port);
63         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
64         pkt->hdr.flags          = cpu_to_le32(info->flags);
65         pkt->len                = len;
66         pkt->hdr.len            = cpu_to_le32(len);
67         pkt->reply              = info->reply;
68         pkt->vsk                = info->vsk;
69
70         if (info->msg && len > 0) {
71                 pkt->buf = kmalloc(len, GFP_KERNEL);
72                 if (!pkt->buf)
73                         goto out_pkt;
74
75                 pkt->buf_len = len;
76
77                 err = memcpy_from_msg(pkt->buf, info->msg, len);
78                 if (err)
79                         goto out;
80         }
81
82         trace_virtio_transport_alloc_pkt(src_cid, src_port,
83                                          dst_cid, dst_port,
84                                          len,
85                                          info->type,
86                                          info->op,
87                                          info->flags);
88
89         return pkt;
90
91 out:
92         kfree(pkt->buf);
93 out_pkt:
94         kfree(pkt);
95         return NULL;
96 }
97
98 /* Packet capture */
99 static struct sk_buff *virtio_transport_build_skb(void *opaque)
100 {
101         struct virtio_vsock_pkt *pkt = opaque;
102         struct af_vsockmon_hdr *hdr;
103         struct sk_buff *skb;
104         size_t payload_len;
105         void *payload_buf;
106
107         /* A packet could be split to fit the RX buffer, so we can retrieve
108          * the payload length from the header and the buffer pointer taking
109          * care of the offset in the original packet.
110          */
111         payload_len = le32_to_cpu(pkt->hdr.len);
112         payload_buf = pkt->buf + pkt->off;
113
114         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + payload_len,
115                         GFP_ATOMIC);
116         if (!skb)
117                 return NULL;
118
119         hdr = skb_put(skb, sizeof(*hdr));
120
121         /* pkt->hdr is little-endian so no need to byteswap here */
122         hdr->src_cid = pkt->hdr.src_cid;
123         hdr->src_port = pkt->hdr.src_port;
124         hdr->dst_cid = pkt->hdr.dst_cid;
125         hdr->dst_port = pkt->hdr.dst_port;
126
127         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
128         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
129         memset(hdr->reserved, 0, sizeof(hdr->reserved));
130
131         switch (le16_to_cpu(pkt->hdr.op)) {
132         case VIRTIO_VSOCK_OP_REQUEST:
133         case VIRTIO_VSOCK_OP_RESPONSE:
134                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
135                 break;
136         case VIRTIO_VSOCK_OP_RST:
137         case VIRTIO_VSOCK_OP_SHUTDOWN:
138                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
139                 break;
140         case VIRTIO_VSOCK_OP_RW:
141                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
142                 break;
143         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
144         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
145                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
146                 break;
147         default:
148                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
149                 break;
150         }
151
152         skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
153
154         if (payload_len) {
155                 skb_put_data(skb, payload_buf, payload_len);
156         }
157
158         return skb;
159 }
160
161 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
162 {
163         vsock_deliver_tap(virtio_transport_build_skb, pkt);
164 }
165 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
166
167 /* This function can only be used on connecting/connected sockets,
168  * since a socket assigned to a transport is required.
169  *
170  * Do not use on listener sockets!
171  */
172 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
173                                           struct virtio_vsock_pkt_info *info)
174 {
175         u32 src_cid, src_port, dst_cid, dst_port;
176         const struct virtio_transport *t_ops;
177         struct virtio_vsock_sock *vvs;
178         struct virtio_vsock_pkt *pkt;
179         u32 pkt_len = info->pkt_len;
180
181         t_ops = virtio_transport_get_ops(vsk);
182         if (unlikely(!t_ops))
183                 return -EFAULT;
184
185         src_cid = t_ops->transport.get_local_cid();
186         src_port = vsk->local_addr.svm_port;
187         if (!info->remote_cid) {
188                 dst_cid = vsk->remote_addr.svm_cid;
189                 dst_port = vsk->remote_addr.svm_port;
190         } else {
191                 dst_cid = info->remote_cid;
192                 dst_port = info->remote_port;
193         }
194
195         vvs = vsk->trans;
196
197         /* we can send less than pkt_len bytes */
198         if (pkt_len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE)
199                 pkt_len = VIRTIO_VSOCK_MAX_PKT_BUF_SIZE;
200
201         /* virtio_transport_get_credit might return less than pkt_len credit */
202         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
203
204         /* Do not send zero length OP_RW pkt */
205         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
206                 return pkt_len;
207
208         pkt = virtio_transport_alloc_pkt(info, pkt_len,
209                                          src_cid, src_port,
210                                          dst_cid, dst_port);
211         if (!pkt) {
212                 virtio_transport_put_credit(vvs, pkt_len);
213                 return -ENOMEM;
214         }
215
216         virtio_transport_inc_tx_pkt(vvs, pkt);
217
218         return t_ops->send_pkt(pkt);
219 }
220
221 static bool virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
222                                         struct virtio_vsock_pkt *pkt)
223 {
224         if (vvs->rx_bytes + pkt->len > vvs->buf_alloc)
225                 return false;
226
227         vvs->rx_bytes += pkt->len;
228         return true;
229 }
230
231 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
232                                         struct virtio_vsock_pkt *pkt)
233 {
234         vvs->rx_bytes -= pkt->len;
235         vvs->fwd_cnt += pkt->len;
236 }
237
238 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
239 {
240         spin_lock_bh(&vvs->rx_lock);
241         vvs->last_fwd_cnt = vvs->fwd_cnt;
242         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
243         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
244         spin_unlock_bh(&vvs->rx_lock);
245 }
246 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
247
248 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
249 {
250         u32 ret;
251
252         spin_lock_bh(&vvs->tx_lock);
253         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
254         if (ret > credit)
255                 ret = credit;
256         vvs->tx_cnt += ret;
257         spin_unlock_bh(&vvs->tx_lock);
258
259         return ret;
260 }
261 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
262
263 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
264 {
265         spin_lock_bh(&vvs->tx_lock);
266         vvs->tx_cnt -= credit;
267         spin_unlock_bh(&vvs->tx_lock);
268 }
269 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
270
271 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
272                                                int type,
273                                                struct virtio_vsock_hdr *hdr)
274 {
275         struct virtio_vsock_pkt_info info = {
276                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
277                 .type = type,
278                 .vsk = vsk,
279         };
280
281         return virtio_transport_send_pkt_info(vsk, &info);
282 }
283
284 static ssize_t
285 virtio_transport_stream_do_peek(struct vsock_sock *vsk,
286                                 struct msghdr *msg,
287                                 size_t len)
288 {
289         struct virtio_vsock_sock *vvs = vsk->trans;
290         struct virtio_vsock_pkt *pkt;
291         size_t bytes, total = 0, off;
292         int err = -EFAULT;
293
294         spin_lock_bh(&vvs->rx_lock);
295
296         list_for_each_entry(pkt, &vvs->rx_queue, list) {
297                 off = pkt->off;
298
299                 if (total == len)
300                         break;
301
302                 while (total < len && off < pkt->len) {
303                         bytes = len - total;
304                         if (bytes > pkt->len - off)
305                                 bytes = pkt->len - off;
306
307                         /* sk_lock is held by caller so no one else can dequeue.
308                          * Unlock rx_lock since memcpy_to_msg() may sleep.
309                          */
310                         spin_unlock_bh(&vvs->rx_lock);
311
312                         err = memcpy_to_msg(msg, pkt->buf + off, bytes);
313                         if (err)
314                                 goto out;
315
316                         spin_lock_bh(&vvs->rx_lock);
317
318                         total += bytes;
319                         off += bytes;
320                 }
321         }
322
323         spin_unlock_bh(&vvs->rx_lock);
324
325         return total;
326
327 out:
328         if (total)
329                 err = total;
330         return err;
331 }
332
333 static ssize_t
334 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
335                                    struct msghdr *msg,
336                                    size_t len)
337 {
338         struct virtio_vsock_sock *vvs = vsk->trans;
339         struct virtio_vsock_pkt *pkt;
340         size_t bytes, total = 0;
341         u32 free_space;
342         int err = -EFAULT;
343
344         spin_lock_bh(&vvs->rx_lock);
345         while (total < len && !list_empty(&vvs->rx_queue)) {
346                 pkt = list_first_entry(&vvs->rx_queue,
347                                        struct virtio_vsock_pkt, list);
348
349                 bytes = len - total;
350                 if (bytes > pkt->len - pkt->off)
351                         bytes = pkt->len - pkt->off;
352
353                 /* sk_lock is held by caller so no one else can dequeue.
354                  * Unlock rx_lock since memcpy_to_msg() may sleep.
355                  */
356                 spin_unlock_bh(&vvs->rx_lock);
357
358                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
359                 if (err)
360                         goto out;
361
362                 spin_lock_bh(&vvs->rx_lock);
363
364                 total += bytes;
365                 pkt->off += bytes;
366                 if (pkt->off == pkt->len) {
367                         virtio_transport_dec_rx_pkt(vvs, pkt);
368                         list_del(&pkt->list);
369                         virtio_transport_free_pkt(pkt);
370                 }
371         }
372
373         free_space = vvs->buf_alloc - (vvs->fwd_cnt - vvs->last_fwd_cnt);
374
375         spin_unlock_bh(&vvs->rx_lock);
376
377         /* To reduce the number of credit update messages,
378          * don't update credits as long as lots of space is available.
379          * Note: the limit chosen here is arbitrary. Setting the limit
380          * too high causes extra messages. Too low causes transmitter
381          * stalls. As stalls are in theory more expensive than extra
382          * messages, we set the limit to a high value. TODO: experiment
383          * with different values.
384          */
385         if (free_space < VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) {
386                 virtio_transport_send_credit_update(vsk,
387                                                     VIRTIO_VSOCK_TYPE_STREAM,
388                                                     NULL);
389         }
390
391         return total;
392
393 out:
394         if (total)
395                 err = total;
396         return err;
397 }
398
399 ssize_t
400 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
401                                 struct msghdr *msg,
402                                 size_t len, int flags)
403 {
404         if (flags & MSG_PEEK)
405                 return virtio_transport_stream_do_peek(vsk, msg, len);
406         else
407                 return virtio_transport_stream_do_dequeue(vsk, msg, len);
408 }
409 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
410
411 int
412 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
413                                struct msghdr *msg,
414                                size_t len, int flags)
415 {
416         return -EOPNOTSUPP;
417 }
418 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
419
420 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
421 {
422         struct virtio_vsock_sock *vvs = vsk->trans;
423         s64 bytes;
424
425         spin_lock_bh(&vvs->rx_lock);
426         bytes = vvs->rx_bytes;
427         spin_unlock_bh(&vvs->rx_lock);
428
429         return bytes;
430 }
431 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
432
433 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
434 {
435         struct virtio_vsock_sock *vvs = vsk->trans;
436         s64 bytes;
437
438         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
439         if (bytes < 0)
440                 bytes = 0;
441
442         return bytes;
443 }
444
445 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
446 {
447         struct virtio_vsock_sock *vvs = vsk->trans;
448         s64 bytes;
449
450         spin_lock_bh(&vvs->tx_lock);
451         bytes = virtio_transport_has_space(vsk);
452         spin_unlock_bh(&vvs->tx_lock);
453
454         return bytes;
455 }
456 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
457
458 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
459                                     struct vsock_sock *psk)
460 {
461         struct virtio_vsock_sock *vvs;
462
463         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
464         if (!vvs)
465                 return -ENOMEM;
466
467         vsk->trans = vvs;
468         vvs->vsk = vsk;
469         if (psk && psk->trans) {
470                 struct virtio_vsock_sock *ptrans = psk->trans;
471
472                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
473         }
474
475         if (vsk->buffer_size > VIRTIO_VSOCK_MAX_BUF_SIZE)
476                 vsk->buffer_size = VIRTIO_VSOCK_MAX_BUF_SIZE;
477
478         vvs->buf_alloc = vsk->buffer_size;
479
480         spin_lock_init(&vvs->rx_lock);
481         spin_lock_init(&vvs->tx_lock);
482         INIT_LIST_HEAD(&vvs->rx_queue);
483
484         return 0;
485 }
486 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
487
488 /* sk_lock held by the caller */
489 void virtio_transport_notify_buffer_size(struct vsock_sock *vsk, u64 *val)
490 {
491         struct virtio_vsock_sock *vvs = vsk->trans;
492
493         if (*val > VIRTIO_VSOCK_MAX_BUF_SIZE)
494                 *val = VIRTIO_VSOCK_MAX_BUF_SIZE;
495
496         vvs->buf_alloc = *val;
497
498         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
499                                             NULL);
500 }
501 EXPORT_SYMBOL_GPL(virtio_transport_notify_buffer_size);
502
503 int
504 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
505                                 size_t target,
506                                 bool *data_ready_now)
507 {
508         if (vsock_stream_has_data(vsk))
509                 *data_ready_now = true;
510         else
511                 *data_ready_now = false;
512
513         return 0;
514 }
515 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
516
517 int
518 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
519                                  size_t target,
520                                  bool *space_avail_now)
521 {
522         s64 free_space;
523
524         free_space = vsock_stream_has_space(vsk);
525         if (free_space > 0)
526                 *space_avail_now = true;
527         else if (free_space == 0)
528                 *space_avail_now = false;
529
530         return 0;
531 }
532 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
533
534 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
535         size_t target, struct vsock_transport_recv_notify_data *data)
536 {
537         return 0;
538 }
539 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
540
541 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
542         size_t target, struct vsock_transport_recv_notify_data *data)
543 {
544         return 0;
545 }
546 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
547
548 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
549         size_t target, struct vsock_transport_recv_notify_data *data)
550 {
551         return 0;
552 }
553 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
554
555 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
556         size_t target, ssize_t copied, bool data_read,
557         struct vsock_transport_recv_notify_data *data)
558 {
559         return 0;
560 }
561 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
562
563 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
564         struct vsock_transport_send_notify_data *data)
565 {
566         return 0;
567 }
568 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
569
570 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
571         struct vsock_transport_send_notify_data *data)
572 {
573         return 0;
574 }
575 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
576
577 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
578         struct vsock_transport_send_notify_data *data)
579 {
580         return 0;
581 }
582 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
583
584 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
585         ssize_t written, struct vsock_transport_send_notify_data *data)
586 {
587         return 0;
588 }
589 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
590
591 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
592 {
593         return vsk->buffer_size;
594 }
595 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
596
597 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
598 {
599         return true;
600 }
601 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
602
603 bool virtio_transport_stream_allow(u32 cid, u32 port)
604 {
605         return true;
606 }
607 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
608
609 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
610                                 struct sockaddr_vm *addr)
611 {
612         return -EOPNOTSUPP;
613 }
614 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
615
616 bool virtio_transport_dgram_allow(u32 cid, u32 port)
617 {
618         return false;
619 }
620 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
621
622 int virtio_transport_connect(struct vsock_sock *vsk)
623 {
624         struct virtio_vsock_pkt_info info = {
625                 .op = VIRTIO_VSOCK_OP_REQUEST,
626                 .type = VIRTIO_VSOCK_TYPE_STREAM,
627                 .vsk = vsk,
628         };
629
630         return virtio_transport_send_pkt_info(vsk, &info);
631 }
632 EXPORT_SYMBOL_GPL(virtio_transport_connect);
633
634 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
635 {
636         struct virtio_vsock_pkt_info info = {
637                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
638                 .type = VIRTIO_VSOCK_TYPE_STREAM,
639                 .flags = (mode & RCV_SHUTDOWN ?
640                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
641                          (mode & SEND_SHUTDOWN ?
642                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
643                 .vsk = vsk,
644         };
645
646         return virtio_transport_send_pkt_info(vsk, &info);
647 }
648 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
649
650 int
651 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
652                                struct sockaddr_vm *remote_addr,
653                                struct msghdr *msg,
654                                size_t dgram_len)
655 {
656         return -EOPNOTSUPP;
657 }
658 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
659
660 ssize_t
661 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
662                                 struct msghdr *msg,
663                                 size_t len)
664 {
665         struct virtio_vsock_pkt_info info = {
666                 .op = VIRTIO_VSOCK_OP_RW,
667                 .type = VIRTIO_VSOCK_TYPE_STREAM,
668                 .msg = msg,
669                 .pkt_len = len,
670                 .vsk = vsk,
671         };
672
673         return virtio_transport_send_pkt_info(vsk, &info);
674 }
675 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
676
677 void virtio_transport_destruct(struct vsock_sock *vsk)
678 {
679         struct virtio_vsock_sock *vvs = vsk->trans;
680
681         kfree(vvs);
682 }
683 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
684
685 static int virtio_transport_reset(struct vsock_sock *vsk,
686                                   struct virtio_vsock_pkt *pkt)
687 {
688         struct virtio_vsock_pkt_info info = {
689                 .op = VIRTIO_VSOCK_OP_RST,
690                 .type = VIRTIO_VSOCK_TYPE_STREAM,
691                 .reply = !!pkt,
692                 .vsk = vsk,
693         };
694
695         /* Send RST only if the original pkt is not a RST pkt */
696         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
697                 return 0;
698
699         return virtio_transport_send_pkt_info(vsk, &info);
700 }
701
702 /* Normally packets are associated with a socket.  There may be no socket if an
703  * attempt was made to connect to a socket that does not exist.
704  */
705 static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
706                                           struct virtio_vsock_pkt *pkt)
707 {
708         struct virtio_vsock_pkt *reply;
709         struct virtio_vsock_pkt_info info = {
710                 .op = VIRTIO_VSOCK_OP_RST,
711                 .type = le16_to_cpu(pkt->hdr.type),
712                 .reply = true,
713         };
714
715         /* Send RST only if the original pkt is not a RST pkt */
716         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
717                 return 0;
718
719         reply = virtio_transport_alloc_pkt(&info, 0,
720                                            le64_to_cpu(pkt->hdr.dst_cid),
721                                            le32_to_cpu(pkt->hdr.dst_port),
722                                            le64_to_cpu(pkt->hdr.src_cid),
723                                            le32_to_cpu(pkt->hdr.src_port));
724         if (!reply)
725                 return -ENOMEM;
726
727         if (!t) {
728                 virtio_transport_free_pkt(reply);
729                 return -ENOTCONN;
730         }
731
732         return t->send_pkt(reply);
733 }
734
735 static void virtio_transport_wait_close(struct sock *sk, long timeout)
736 {
737         if (timeout) {
738                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
739
740                 add_wait_queue(sk_sleep(sk), &wait);
741
742                 do {
743                         if (sk_wait_event(sk, &timeout,
744                                           sock_flag(sk, SOCK_DONE), &wait))
745                                 break;
746                 } while (!signal_pending(current) && timeout);
747
748                 remove_wait_queue(sk_sleep(sk), &wait);
749         }
750 }
751
752 static void virtio_transport_do_close(struct vsock_sock *vsk,
753                                       bool cancel_timeout)
754 {
755         struct sock *sk = sk_vsock(vsk);
756
757         sock_set_flag(sk, SOCK_DONE);
758         vsk->peer_shutdown = SHUTDOWN_MASK;
759         if (vsock_stream_has_data(vsk) <= 0)
760                 sk->sk_state = TCP_CLOSING;
761         sk->sk_state_change(sk);
762
763         if (vsk->close_work_scheduled &&
764             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
765                 vsk->close_work_scheduled = false;
766
767                 vsock_remove_sock(vsk);
768
769                 /* Release refcnt obtained when we scheduled the timeout */
770                 sock_put(sk);
771         }
772 }
773
774 static void virtio_transport_close_timeout(struct work_struct *work)
775 {
776         struct vsock_sock *vsk =
777                 container_of(work, struct vsock_sock, close_work.work);
778         struct sock *sk = sk_vsock(vsk);
779
780         sock_hold(sk);
781         lock_sock(sk);
782
783         if (!sock_flag(sk, SOCK_DONE)) {
784                 (void)virtio_transport_reset(vsk, NULL);
785
786                 virtio_transport_do_close(vsk, false);
787         }
788
789         vsk->close_work_scheduled = false;
790
791         release_sock(sk);
792         sock_put(sk);
793 }
794
795 /* User context, vsk->sk is locked */
796 static bool virtio_transport_close(struct vsock_sock *vsk)
797 {
798         struct sock *sk = &vsk->sk;
799
800         if (!(sk->sk_state == TCP_ESTABLISHED ||
801               sk->sk_state == TCP_CLOSING))
802                 return true;
803
804         /* Already received SHUTDOWN from peer, reply with RST */
805         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
806                 (void)virtio_transport_reset(vsk, NULL);
807                 return true;
808         }
809
810         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
811                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
812
813         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
814                 virtio_transport_wait_close(sk, sk->sk_lingertime);
815
816         if (sock_flag(sk, SOCK_DONE)) {
817                 return true;
818         }
819
820         sock_hold(sk);
821         INIT_DELAYED_WORK(&vsk->close_work,
822                           virtio_transport_close_timeout);
823         vsk->close_work_scheduled = true;
824         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
825         return false;
826 }
827
828 void virtio_transport_release(struct vsock_sock *vsk)
829 {
830         struct virtio_vsock_sock *vvs = vsk->trans;
831         struct virtio_vsock_pkt *pkt, *tmp;
832         struct sock *sk = &vsk->sk;
833         bool remove_sock = true;
834
835         lock_sock_nested(sk, SINGLE_DEPTH_NESTING);
836         if (sk->sk_type == SOCK_STREAM)
837                 remove_sock = virtio_transport_close(vsk);
838
839         list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) {
840                 list_del(&pkt->list);
841                 virtio_transport_free_pkt(pkt);
842         }
843         release_sock(sk);
844
845         if (remove_sock)
846                 vsock_remove_sock(vsk);
847 }
848 EXPORT_SYMBOL_GPL(virtio_transport_release);
849
850 static int
851 virtio_transport_recv_connecting(struct sock *sk,
852                                  struct virtio_vsock_pkt *pkt)
853 {
854         struct vsock_sock *vsk = vsock_sk(sk);
855         int err;
856         int skerr;
857
858         switch (le16_to_cpu(pkt->hdr.op)) {
859         case VIRTIO_VSOCK_OP_RESPONSE:
860                 sk->sk_state = TCP_ESTABLISHED;
861                 sk->sk_socket->state = SS_CONNECTED;
862                 vsock_insert_connected(vsk);
863                 sk->sk_state_change(sk);
864                 break;
865         case VIRTIO_VSOCK_OP_INVALID:
866                 break;
867         case VIRTIO_VSOCK_OP_RST:
868                 skerr = ECONNRESET;
869                 err = 0;
870                 goto destroy;
871         default:
872                 skerr = EPROTO;
873                 err = -EINVAL;
874                 goto destroy;
875         }
876         return 0;
877
878 destroy:
879         virtio_transport_reset(vsk, pkt);
880         sk->sk_state = TCP_CLOSE;
881         sk->sk_err = skerr;
882         sk->sk_error_report(sk);
883         return err;
884 }
885
886 static void
887 virtio_transport_recv_enqueue(struct vsock_sock *vsk,
888                               struct virtio_vsock_pkt *pkt)
889 {
890         struct virtio_vsock_sock *vvs = vsk->trans;
891         bool can_enqueue, free_pkt = false;
892
893         pkt->len = le32_to_cpu(pkt->hdr.len);
894         pkt->off = 0;
895
896         spin_lock_bh(&vvs->rx_lock);
897
898         can_enqueue = virtio_transport_inc_rx_pkt(vvs, pkt);
899         if (!can_enqueue) {
900                 free_pkt = true;
901                 goto out;
902         }
903
904         /* Try to copy small packets into the buffer of last packet queued,
905          * to avoid wasting memory queueing the entire buffer with a small
906          * payload.
907          */
908         if (pkt->len <= GOOD_COPY_LEN && !list_empty(&vvs->rx_queue)) {
909                 struct virtio_vsock_pkt *last_pkt;
910
911                 last_pkt = list_last_entry(&vvs->rx_queue,
912                                            struct virtio_vsock_pkt, list);
913
914                 /* If there is space in the last packet queued, we copy the
915                  * new packet in its buffer.
916                  */
917                 if (pkt->len <= last_pkt->buf_len - last_pkt->len) {
918                         memcpy(last_pkt->buf + last_pkt->len, pkt->buf,
919                                pkt->len);
920                         last_pkt->len += pkt->len;
921                         free_pkt = true;
922                         goto out;
923                 }
924         }
925
926         list_add_tail(&pkt->list, &vvs->rx_queue);
927
928 out:
929         spin_unlock_bh(&vvs->rx_lock);
930         if (free_pkt)
931                 virtio_transport_free_pkt(pkt);
932 }
933
934 static int
935 virtio_transport_recv_connected(struct sock *sk,
936                                 struct virtio_vsock_pkt *pkt)
937 {
938         struct vsock_sock *vsk = vsock_sk(sk);
939         int err = 0;
940
941         switch (le16_to_cpu(pkt->hdr.op)) {
942         case VIRTIO_VSOCK_OP_RW:
943                 virtio_transport_recv_enqueue(vsk, pkt);
944                 sk->sk_data_ready(sk);
945                 return err;
946         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
947                 sk->sk_write_space(sk);
948                 break;
949         case VIRTIO_VSOCK_OP_SHUTDOWN:
950                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
951                         vsk->peer_shutdown |= RCV_SHUTDOWN;
952                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
953                         vsk->peer_shutdown |= SEND_SHUTDOWN;
954                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
955                     vsock_stream_has_data(vsk) <= 0 &&
956                     !sock_flag(sk, SOCK_DONE)) {
957                         (void)virtio_transport_reset(vsk, NULL);
958
959                         virtio_transport_do_close(vsk, true);
960                 }
961                 if (le32_to_cpu(pkt->hdr.flags))
962                         sk->sk_state_change(sk);
963                 break;
964         case VIRTIO_VSOCK_OP_RST:
965                 virtio_transport_do_close(vsk, true);
966                 break;
967         default:
968                 err = -EINVAL;
969                 break;
970         }
971
972         virtio_transport_free_pkt(pkt);
973         return err;
974 }
975
976 static void
977 virtio_transport_recv_disconnecting(struct sock *sk,
978                                     struct virtio_vsock_pkt *pkt)
979 {
980         struct vsock_sock *vsk = vsock_sk(sk);
981
982         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
983                 virtio_transport_do_close(vsk, true);
984 }
985
986 static int
987 virtio_transport_send_response(struct vsock_sock *vsk,
988                                struct virtio_vsock_pkt *pkt)
989 {
990         struct virtio_vsock_pkt_info info = {
991                 .op = VIRTIO_VSOCK_OP_RESPONSE,
992                 .type = VIRTIO_VSOCK_TYPE_STREAM,
993                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
994                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
995                 .reply = true,
996                 .vsk = vsk,
997         };
998
999         return virtio_transport_send_pkt_info(vsk, &info);
1000 }
1001
1002 static bool virtio_transport_space_update(struct sock *sk,
1003                                           struct virtio_vsock_pkt *pkt)
1004 {
1005         struct vsock_sock *vsk = vsock_sk(sk);
1006         struct virtio_vsock_sock *vvs = vsk->trans;
1007         bool space_available;
1008
1009         /* Listener sockets are not associated with any transport, so we are
1010          * not able to take the state to see if there is space available in the
1011          * remote peer, but since they are only used to receive requests, we
1012          * can assume that there is always space available in the other peer.
1013          */
1014         if (!vvs)
1015                 return true;
1016
1017         /* buf_alloc and fwd_cnt is always included in the hdr */
1018         spin_lock_bh(&vvs->tx_lock);
1019         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
1020         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
1021         space_available = virtio_transport_has_space(vsk);
1022         spin_unlock_bh(&vvs->tx_lock);
1023         return space_available;
1024 }
1025
1026 /* Handle server socket */
1027 static int
1028 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt,
1029                              struct virtio_transport *t)
1030 {
1031         struct vsock_sock *vsk = vsock_sk(sk);
1032         struct vsock_sock *vchild;
1033         struct sock *child;
1034         int ret;
1035
1036         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
1037                 virtio_transport_reset_no_sock(t, pkt);
1038                 return -EINVAL;
1039         }
1040
1041         if (sk_acceptq_is_full(sk)) {
1042                 virtio_transport_reset_no_sock(t, pkt);
1043                 return -ENOMEM;
1044         }
1045
1046         child = vsock_create_connected(sk);
1047         if (!child) {
1048                 virtio_transport_reset_no_sock(t, pkt);
1049                 return -ENOMEM;
1050         }
1051
1052         sk_acceptq_added(sk);
1053
1054         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
1055
1056         child->sk_state = TCP_ESTABLISHED;
1057
1058         vchild = vsock_sk(child);
1059         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
1060                         le32_to_cpu(pkt->hdr.dst_port));
1061         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
1062                         le32_to_cpu(pkt->hdr.src_port));
1063
1064         ret = vsock_assign_transport(vchild, vsk);
1065         /* Transport assigned (looking at remote_addr) must be the same
1066          * where we received the request.
1067          */
1068         if (ret || vchild->transport != &t->transport) {
1069                 release_sock(child);
1070                 virtio_transport_reset_no_sock(t, pkt);
1071                 sock_put(child);
1072                 return ret;
1073         }
1074
1075         if (virtio_transport_space_update(child, pkt))
1076                 child->sk_write_space(child);
1077
1078         vsock_insert_connected(vchild);
1079         vsock_enqueue_accept(sk, child);
1080         virtio_transport_send_response(vchild, pkt);
1081
1082         release_sock(child);
1083
1084         sk->sk_data_ready(sk);
1085         return 0;
1086 }
1087
1088 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
1089  * lock.
1090  */
1091 void virtio_transport_recv_pkt(struct virtio_transport *t,
1092                                struct virtio_vsock_pkt *pkt)
1093 {
1094         struct sockaddr_vm src, dst;
1095         struct vsock_sock *vsk;
1096         struct sock *sk;
1097         bool space_available;
1098
1099         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
1100                         le32_to_cpu(pkt->hdr.src_port));
1101         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
1102                         le32_to_cpu(pkt->hdr.dst_port));
1103
1104         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
1105                                         dst.svm_cid, dst.svm_port,
1106                                         le32_to_cpu(pkt->hdr.len),
1107                                         le16_to_cpu(pkt->hdr.type),
1108                                         le16_to_cpu(pkt->hdr.op),
1109                                         le32_to_cpu(pkt->hdr.flags),
1110                                         le32_to_cpu(pkt->hdr.buf_alloc),
1111                                         le32_to_cpu(pkt->hdr.fwd_cnt));
1112
1113         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
1114                 (void)virtio_transport_reset_no_sock(t, pkt);
1115                 goto free_pkt;
1116         }
1117
1118         /* The socket must be in connected or bound table
1119          * otherwise send reset back
1120          */
1121         sk = vsock_find_connected_socket(&src, &dst);
1122         if (!sk) {
1123                 sk = vsock_find_bound_socket(&dst);
1124                 if (!sk) {
1125                         (void)virtio_transport_reset_no_sock(t, pkt);
1126                         goto free_pkt;
1127                 }
1128         }
1129
1130         vsk = vsock_sk(sk);
1131
1132         space_available = virtio_transport_space_update(sk, pkt);
1133
1134         lock_sock(sk);
1135
1136         /* Update CID in case it has changed after a transport reset event */
1137         vsk->local_addr.svm_cid = dst.svm_cid;
1138
1139         if (space_available)
1140                 sk->sk_write_space(sk);
1141
1142         switch (sk->sk_state) {
1143         case TCP_LISTEN:
1144                 virtio_transport_recv_listen(sk, pkt, t);
1145                 virtio_transport_free_pkt(pkt);
1146                 break;
1147         case TCP_SYN_SENT:
1148                 virtio_transport_recv_connecting(sk, pkt);
1149                 virtio_transport_free_pkt(pkt);
1150                 break;
1151         case TCP_ESTABLISHED:
1152                 virtio_transport_recv_connected(sk, pkt);
1153                 break;
1154         case TCP_CLOSING:
1155                 virtio_transport_recv_disconnecting(sk, pkt);
1156                 virtio_transport_free_pkt(pkt);
1157                 break;
1158         default:
1159                 virtio_transport_free_pkt(pkt);
1160                 break;
1161         }
1162
1163         release_sock(sk);
1164
1165         /* Release refcnt obtained when we fetched this socket out of the
1166          * bound or connected list.
1167          */
1168         sock_put(sk);
1169         return;
1170
1171 free_pkt:
1172         virtio_transport_free_pkt(pkt);
1173 }
1174 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1175
1176 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1177 {
1178         kfree(pkt->buf);
1179         kfree(pkt);
1180 }
1181 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1182
1183 MODULE_LICENSE("GPL v2");
1184 MODULE_AUTHOR("Asias He");
1185 MODULE_DESCRIPTION("common code for virtio vsock");