]> asedeno.scripts.mit.edu Git - linux.git/blob - net/core/skmsg.c
Merge tag 'riscv-for-linus-4.20-rc4' of git://git.kernel.org/pub/scm/linux/kernel...
[linux.git] / net / core / skmsg.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4 #include <linux/skmsg.h>
5 #include <linux/skbuff.h>
6 #include <linux/scatterlist.h>
7
8 #include <net/sock.h>
9 #include <net/tcp.h>
10
11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce)
12 {
13         if (msg->sg.end > msg->sg.start &&
14             elem_first_coalesce < msg->sg.end)
15                 return true;
16
17         if (msg->sg.end < msg->sg.start &&
18             (elem_first_coalesce > msg->sg.start ||
19              elem_first_coalesce < msg->sg.end))
20                 return true;
21
22         return false;
23 }
24
25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len,
26                  int elem_first_coalesce)
27 {
28         struct page_frag *pfrag = sk_page_frag(sk);
29         int ret = 0;
30
31         len -= msg->sg.size;
32         while (len > 0) {
33                 struct scatterlist *sge;
34                 u32 orig_offset;
35                 int use, i;
36
37                 if (!sk_page_frag_refill(sk, pfrag))
38                         return -ENOMEM;
39
40                 orig_offset = pfrag->offset;
41                 use = min_t(int, len, pfrag->size - orig_offset);
42                 if (!sk_wmem_schedule(sk, use))
43                         return -ENOMEM;
44
45                 i = msg->sg.end;
46                 sk_msg_iter_var_prev(i);
47                 sge = &msg->sg.data[i];
48
49                 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) &&
50                     sg_page(sge) == pfrag->page &&
51                     sge->offset + sge->length == orig_offset) {
52                         sge->length += use;
53                 } else {
54                         if (sk_msg_full(msg)) {
55                                 ret = -ENOSPC;
56                                 break;
57                         }
58
59                         sge = &msg->sg.data[msg->sg.end];
60                         sg_unmark_end(sge);
61                         sg_set_page(sge, pfrag->page, use, orig_offset);
62                         get_page(pfrag->page);
63                         sk_msg_iter_next(msg, end);
64                 }
65
66                 sk_mem_charge(sk, use);
67                 msg->sg.size += use;
68                 pfrag->offset += use;
69                 len -= use;
70         }
71
72         return ret;
73 }
74 EXPORT_SYMBOL_GPL(sk_msg_alloc);
75
76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src,
77                  u32 off, u32 len)
78 {
79         int i = src->sg.start;
80         struct scatterlist *sge = sk_msg_elem(src, i);
81         u32 sge_len, sge_off;
82
83         if (sk_msg_full(dst))
84                 return -ENOSPC;
85
86         while (off) {
87                 if (sge->length > off)
88                         break;
89                 off -= sge->length;
90                 sk_msg_iter_var_next(i);
91                 if (i == src->sg.end && off)
92                         return -ENOSPC;
93                 sge = sk_msg_elem(src, i);
94         }
95
96         while (len) {
97                 sge_len = sge->length - off;
98                 sge_off = sge->offset + off;
99                 if (sge_len > len)
100                         sge_len = len;
101                 off = 0;
102                 len -= sge_len;
103                 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off);
104                 sk_mem_charge(sk, sge_len);
105                 sk_msg_iter_var_next(i);
106                 if (i == src->sg.end && len)
107                         return -ENOSPC;
108                 sge = sk_msg_elem(src, i);
109         }
110
111         return 0;
112 }
113 EXPORT_SYMBOL_GPL(sk_msg_clone);
114
115 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes)
116 {
117         int i = msg->sg.start;
118
119         do {
120                 struct scatterlist *sge = sk_msg_elem(msg, i);
121
122                 if (bytes < sge->length) {
123                         sge->length -= bytes;
124                         sge->offset += bytes;
125                         sk_mem_uncharge(sk, bytes);
126                         break;
127                 }
128
129                 sk_mem_uncharge(sk, sge->length);
130                 bytes -= sge->length;
131                 sge->length = 0;
132                 sge->offset = 0;
133                 sk_msg_iter_var_next(i);
134         } while (bytes && i != msg->sg.end);
135         msg->sg.start = i;
136 }
137 EXPORT_SYMBOL_GPL(sk_msg_return_zero);
138
139 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes)
140 {
141         int i = msg->sg.start;
142
143         do {
144                 struct scatterlist *sge = &msg->sg.data[i];
145                 int uncharge = (bytes < sge->length) ? bytes : sge->length;
146
147                 sk_mem_uncharge(sk, uncharge);
148                 bytes -= uncharge;
149                 sk_msg_iter_var_next(i);
150         } while (i != msg->sg.end);
151 }
152 EXPORT_SYMBOL_GPL(sk_msg_return);
153
154 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i,
155                             bool charge)
156 {
157         struct scatterlist *sge = sk_msg_elem(msg, i);
158         u32 len = sge->length;
159
160         if (charge)
161                 sk_mem_uncharge(sk, len);
162         if (!msg->skb)
163                 put_page(sg_page(sge));
164         memset(sge, 0, sizeof(*sge));
165         return len;
166 }
167
168 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i,
169                          bool charge)
170 {
171         struct scatterlist *sge = sk_msg_elem(msg, i);
172         int freed = 0;
173
174         while (msg->sg.size) {
175                 msg->sg.size -= sge->length;
176                 freed += sk_msg_free_elem(sk, msg, i, charge);
177                 sk_msg_iter_var_next(i);
178                 sk_msg_check_to_free(msg, i, msg->sg.size);
179                 sge = sk_msg_elem(msg, i);
180         }
181         if (msg->skb)
182                 consume_skb(msg->skb);
183         sk_msg_init(msg);
184         return freed;
185 }
186
187 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg)
188 {
189         return __sk_msg_free(sk, msg, msg->sg.start, false);
190 }
191 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge);
192
193 int sk_msg_free(struct sock *sk, struct sk_msg *msg)
194 {
195         return __sk_msg_free(sk, msg, msg->sg.start, true);
196 }
197 EXPORT_SYMBOL_GPL(sk_msg_free);
198
199 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg,
200                                   u32 bytes, bool charge)
201 {
202         struct scatterlist *sge;
203         u32 i = msg->sg.start;
204
205         while (bytes) {
206                 sge = sk_msg_elem(msg, i);
207                 if (!sge->length)
208                         break;
209                 if (bytes < sge->length) {
210                         if (charge)
211                                 sk_mem_uncharge(sk, bytes);
212                         sge->length -= bytes;
213                         sge->offset += bytes;
214                         msg->sg.size -= bytes;
215                         break;
216                 }
217
218                 msg->sg.size -= sge->length;
219                 bytes -= sge->length;
220                 sk_msg_free_elem(sk, msg, i, charge);
221                 sk_msg_iter_var_next(i);
222                 sk_msg_check_to_free(msg, i, bytes);
223         }
224         msg->sg.start = i;
225 }
226
227 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes)
228 {
229         __sk_msg_free_partial(sk, msg, bytes, true);
230 }
231 EXPORT_SYMBOL_GPL(sk_msg_free_partial);
232
233 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg,
234                                   u32 bytes)
235 {
236         __sk_msg_free_partial(sk, msg, bytes, false);
237 }
238
239 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len)
240 {
241         int trim = msg->sg.size - len;
242         u32 i = msg->sg.end;
243
244         if (trim <= 0) {
245                 WARN_ON(trim < 0);
246                 return;
247         }
248
249         sk_msg_iter_var_prev(i);
250         msg->sg.size = len;
251         while (msg->sg.data[i].length &&
252                trim >= msg->sg.data[i].length) {
253                 trim -= msg->sg.data[i].length;
254                 sk_msg_free_elem(sk, msg, i, true);
255                 sk_msg_iter_var_prev(i);
256                 if (!trim)
257                         goto out;
258         }
259
260         msg->sg.data[i].length -= trim;
261         sk_mem_uncharge(sk, trim);
262 out:
263         /* If we trim data before curr pointer update copybreak and current
264          * so that any future copy operations start at new copy location.
265          * However trimed data that has not yet been used in a copy op
266          * does not require an update.
267          */
268         if (msg->sg.curr >= i) {
269                 msg->sg.curr = i;
270                 msg->sg.copybreak = msg->sg.data[i].length;
271         }
272         sk_msg_iter_var_next(i);
273         msg->sg.end = i;
274 }
275 EXPORT_SYMBOL_GPL(sk_msg_trim);
276
277 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
278                               struct sk_msg *msg, u32 bytes)
279 {
280         int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg);
281         const int to_max_pages = MAX_MSG_FRAGS;
282         struct page *pages[MAX_MSG_FRAGS];
283         ssize_t orig, copied, use, offset;
284
285         orig = msg->sg.size;
286         while (bytes > 0) {
287                 i = 0;
288                 maxpages = to_max_pages - num_elems;
289                 if (maxpages == 0) {
290                         ret = -EFAULT;
291                         goto out;
292                 }
293
294                 copied = iov_iter_get_pages(from, pages, bytes, maxpages,
295                                             &offset);
296                 if (copied <= 0) {
297                         ret = -EFAULT;
298                         goto out;
299                 }
300
301                 iov_iter_advance(from, copied);
302                 bytes -= copied;
303                 msg->sg.size += copied;
304
305                 while (copied) {
306                         use = min_t(int, copied, PAGE_SIZE - offset);
307                         sg_set_page(&msg->sg.data[msg->sg.end],
308                                     pages[i], use, offset);
309                         sg_unmark_end(&msg->sg.data[msg->sg.end]);
310                         sk_mem_charge(sk, use);
311
312                         offset = 0;
313                         copied -= use;
314                         sk_msg_iter_next(msg, end);
315                         num_elems++;
316                         i++;
317                 }
318                 /* When zerocopy is mixed with sk_msg_*copy* operations we
319                  * may have a copybreak set in this case clear and prefer
320                  * zerocopy remainder when possible.
321                  */
322                 msg->sg.copybreak = 0;
323                 msg->sg.curr = msg->sg.end;
324         }
325 out:
326         /* Revert iov_iter updates, msg will need to use 'trim' later if it
327          * also needs to be cleared.
328          */
329         if (ret)
330                 iov_iter_revert(from, msg->sg.size - orig);
331         return ret;
332 }
333 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter);
334
335 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from,
336                              struct sk_msg *msg, u32 bytes)
337 {
338         int ret = -ENOSPC, i = msg->sg.curr;
339         struct scatterlist *sge;
340         u32 copy, buf_size;
341         void *to;
342
343         do {
344                 sge = sk_msg_elem(msg, i);
345                 /* This is possible if a trim operation shrunk the buffer */
346                 if (msg->sg.copybreak >= sge->length) {
347                         msg->sg.copybreak = 0;
348                         sk_msg_iter_var_next(i);
349                         if (i == msg->sg.end)
350                                 break;
351                         sge = sk_msg_elem(msg, i);
352                 }
353
354                 buf_size = sge->length - msg->sg.copybreak;
355                 copy = (buf_size > bytes) ? bytes : buf_size;
356                 to = sg_virt(sge) + msg->sg.copybreak;
357                 msg->sg.copybreak += copy;
358                 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
359                         ret = copy_from_iter_nocache(to, copy, from);
360                 else
361                         ret = copy_from_iter(to, copy, from);
362                 if (ret != copy) {
363                         ret = -EFAULT;
364                         goto out;
365                 }
366                 bytes -= copy;
367                 if (!bytes)
368                         break;
369                 msg->sg.copybreak = 0;
370                 sk_msg_iter_var_next(i);
371         } while (i != msg->sg.end);
372 out:
373         msg->sg.curr = i;
374         return ret;
375 }
376 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter);
377
378 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
379 {
380         struct sock *sk = psock->sk;
381         int copied = 0, num_sge;
382         struct sk_msg *msg;
383
384         msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC);
385         if (unlikely(!msg))
386                 return -EAGAIN;
387         if (!sk_rmem_schedule(sk, skb, skb->len)) {
388                 kfree(msg);
389                 return -EAGAIN;
390         }
391
392         sk_msg_init(msg);
393         num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len);
394         if (unlikely(num_sge < 0)) {
395                 kfree(msg);
396                 return num_sge;
397         }
398
399         sk_mem_charge(sk, skb->len);
400         copied = skb->len;
401         msg->sg.start = 0;
402         msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
403         msg->skb = skb;
404
405         sk_psock_queue_msg(psock, msg);
406         sk->sk_data_ready(sk);
407         return copied;
408 }
409
410 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb,
411                                u32 off, u32 len, bool ingress)
412 {
413         if (ingress)
414                 return sk_psock_skb_ingress(psock, skb);
415         else
416                 return skb_send_sock_locked(psock->sk, skb, off, len);
417 }
418
419 static void sk_psock_backlog(struct work_struct *work)
420 {
421         struct sk_psock *psock = container_of(work, struct sk_psock, work);
422         struct sk_psock_work_state *state = &psock->work_state;
423         struct sk_buff *skb;
424         bool ingress;
425         u32 len, off;
426         int ret;
427
428         /* Lock sock to avoid losing sk_socket during loop. */
429         lock_sock(psock->sk);
430         if (state->skb) {
431                 skb = state->skb;
432                 len = state->len;
433                 off = state->off;
434                 state->skb = NULL;
435                 goto start;
436         }
437
438         while ((skb = skb_dequeue(&psock->ingress_skb))) {
439                 len = skb->len;
440                 off = 0;
441 start:
442                 ingress = tcp_skb_bpf_ingress(skb);
443                 do {
444                         ret = -EIO;
445                         if (likely(psock->sk->sk_socket))
446                                 ret = sk_psock_handle_skb(psock, skb, off,
447                                                           len, ingress);
448                         if (ret <= 0) {
449                                 if (ret == -EAGAIN) {
450                                         state->skb = skb;
451                                         state->len = len;
452                                         state->off = off;
453                                         goto end;
454                                 }
455                                 /* Hard errors break pipe and stop xmit. */
456                                 sk_psock_report_error(psock, ret ? -ret : EPIPE);
457                                 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
458                                 kfree_skb(skb);
459                                 goto end;
460                         }
461                         off += ret;
462                         len -= ret;
463                 } while (len);
464
465                 if (!ingress)
466                         kfree_skb(skb);
467         }
468 end:
469         release_sock(psock->sk);
470 }
471
472 struct sk_psock *sk_psock_init(struct sock *sk, int node)
473 {
474         struct sk_psock *psock = kzalloc_node(sizeof(*psock),
475                                               GFP_ATOMIC | __GFP_NOWARN,
476                                               node);
477         if (!psock)
478                 return NULL;
479
480         psock->sk = sk;
481         psock->eval =  __SK_NONE;
482
483         INIT_LIST_HEAD(&psock->link);
484         spin_lock_init(&psock->link_lock);
485
486         INIT_WORK(&psock->work, sk_psock_backlog);
487         INIT_LIST_HEAD(&psock->ingress_msg);
488         skb_queue_head_init(&psock->ingress_skb);
489
490         sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED);
491         refcount_set(&psock->refcnt, 1);
492
493         rcu_assign_sk_user_data(sk, psock);
494         sock_hold(sk);
495
496         return psock;
497 }
498 EXPORT_SYMBOL_GPL(sk_psock_init);
499
500 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock)
501 {
502         struct sk_psock_link *link;
503
504         spin_lock_bh(&psock->link_lock);
505         link = list_first_entry_or_null(&psock->link, struct sk_psock_link,
506                                         list);
507         if (link)
508                 list_del(&link->list);
509         spin_unlock_bh(&psock->link_lock);
510         return link;
511 }
512
513 void __sk_psock_purge_ingress_msg(struct sk_psock *psock)
514 {
515         struct sk_msg *msg, *tmp;
516
517         list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) {
518                 list_del(&msg->list);
519                 sk_msg_free(psock->sk, msg);
520                 kfree(msg);
521         }
522 }
523
524 static void sk_psock_zap_ingress(struct sk_psock *psock)
525 {
526         __skb_queue_purge(&psock->ingress_skb);
527         __sk_psock_purge_ingress_msg(psock);
528 }
529
530 static void sk_psock_link_destroy(struct sk_psock *psock)
531 {
532         struct sk_psock_link *link, *tmp;
533
534         list_for_each_entry_safe(link, tmp, &psock->link, list) {
535                 list_del(&link->list);
536                 sk_psock_free_link(link);
537         }
538 }
539
540 static void sk_psock_destroy_deferred(struct work_struct *gc)
541 {
542         struct sk_psock *psock = container_of(gc, struct sk_psock, gc);
543
544         /* No sk_callback_lock since already detached. */
545         if (psock->parser.enabled)
546                 strp_done(&psock->parser.strp);
547
548         cancel_work_sync(&psock->work);
549
550         psock_progs_drop(&psock->progs);
551
552         sk_psock_link_destroy(psock);
553         sk_psock_cork_free(psock);
554         sk_psock_zap_ingress(psock);
555
556         if (psock->sk_redir)
557                 sock_put(psock->sk_redir);
558         sock_put(psock->sk);
559         kfree(psock);
560 }
561
562 void sk_psock_destroy(struct rcu_head *rcu)
563 {
564         struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu);
565
566         INIT_WORK(&psock->gc, sk_psock_destroy_deferred);
567         schedule_work(&psock->gc);
568 }
569 EXPORT_SYMBOL_GPL(sk_psock_destroy);
570
571 void sk_psock_drop(struct sock *sk, struct sk_psock *psock)
572 {
573         rcu_assign_sk_user_data(sk, NULL);
574         sk_psock_cork_free(psock);
575         sk_psock_restore_proto(sk, psock);
576
577         write_lock_bh(&sk->sk_callback_lock);
578         if (psock->progs.skb_parser)
579                 sk_psock_stop_strp(sk, psock);
580         write_unlock_bh(&sk->sk_callback_lock);
581         sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED);
582
583         call_rcu_sched(&psock->rcu, sk_psock_destroy);
584 }
585 EXPORT_SYMBOL_GPL(sk_psock_drop);
586
587 static int sk_psock_map_verd(int verdict, bool redir)
588 {
589         switch (verdict) {
590         case SK_PASS:
591                 return redir ? __SK_REDIRECT : __SK_PASS;
592         case SK_DROP:
593         default:
594                 break;
595         }
596
597         return __SK_DROP;
598 }
599
600 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock,
601                          struct sk_msg *msg)
602 {
603         struct bpf_prog *prog;
604         int ret;
605
606         preempt_disable();
607         rcu_read_lock();
608         prog = READ_ONCE(psock->progs.msg_parser);
609         if (unlikely(!prog)) {
610                 ret = __SK_PASS;
611                 goto out;
612         }
613
614         sk_msg_compute_data_pointers(msg);
615         msg->sk = sk;
616         ret = BPF_PROG_RUN(prog, msg);
617         ret = sk_psock_map_verd(ret, msg->sk_redir);
618         psock->apply_bytes = msg->apply_bytes;
619         if (ret == __SK_REDIRECT) {
620                 if (psock->sk_redir)
621                         sock_put(psock->sk_redir);
622                 psock->sk_redir = msg->sk_redir;
623                 if (!psock->sk_redir) {
624                         ret = __SK_DROP;
625                         goto out;
626                 }
627                 sock_hold(psock->sk_redir);
628         }
629 out:
630         rcu_read_unlock();
631         preempt_enable();
632         return ret;
633 }
634 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict);
635
636 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog,
637                             struct sk_buff *skb)
638 {
639         int ret;
640
641         skb->sk = psock->sk;
642         bpf_compute_data_end_sk_skb(skb);
643         preempt_disable();
644         ret = BPF_PROG_RUN(prog, skb);
645         preempt_enable();
646         /* strparser clones the skb before handing it to a upper layer,
647          * meaning skb_orphan has been called. We NULL sk on the way out
648          * to ensure we don't trigger a BUG_ON() in skb/sk operations
649          * later and because we are not charging the memory of this skb
650          * to any socket yet.
651          */
652         skb->sk = NULL;
653         return ret;
654 }
655
656 static struct sk_psock *sk_psock_from_strp(struct strparser *strp)
657 {
658         struct sk_psock_parser *parser;
659
660         parser = container_of(strp, struct sk_psock_parser, strp);
661         return container_of(parser, struct sk_psock, parser);
662 }
663
664 static void sk_psock_verdict_apply(struct sk_psock *psock,
665                                    struct sk_buff *skb, int verdict)
666 {
667         struct sk_psock *psock_other;
668         struct sock *sk_other;
669         bool ingress;
670
671         switch (verdict) {
672         case __SK_REDIRECT:
673                 sk_other = tcp_skb_bpf_redirect_fetch(skb);
674                 if (unlikely(!sk_other))
675                         goto out_free;
676                 psock_other = sk_psock(sk_other);
677                 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) ||
678                     !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED))
679                         goto out_free;
680                 ingress = tcp_skb_bpf_ingress(skb);
681                 if ((!ingress && sock_writeable(sk_other)) ||
682                     (ingress &&
683                      atomic_read(&sk_other->sk_rmem_alloc) <=
684                      sk_other->sk_rcvbuf)) {
685                         if (!ingress)
686                                 skb_set_owner_w(skb, sk_other);
687                         skb_queue_tail(&psock_other->ingress_skb, skb);
688                         schedule_work(&psock_other->work);
689                         break;
690                 }
691                 /* fall-through */
692         case __SK_DROP:
693                 /* fall-through */
694         default:
695 out_free:
696                 kfree_skb(skb);
697         }
698 }
699
700 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb)
701 {
702         struct sk_psock *psock = sk_psock_from_strp(strp);
703         struct bpf_prog *prog;
704         int ret = __SK_DROP;
705
706         rcu_read_lock();
707         prog = READ_ONCE(psock->progs.skb_verdict);
708         if (likely(prog)) {
709                 skb_orphan(skb);
710                 tcp_skb_bpf_redirect_clear(skb);
711                 ret = sk_psock_bpf_run(psock, prog, skb);
712                 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb));
713         }
714         rcu_read_unlock();
715         sk_psock_verdict_apply(psock, skb, ret);
716 }
717
718 static int sk_psock_strp_read_done(struct strparser *strp, int err)
719 {
720         return err;
721 }
722
723 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb)
724 {
725         struct sk_psock *psock = sk_psock_from_strp(strp);
726         struct bpf_prog *prog;
727         int ret = skb->len;
728
729         rcu_read_lock();
730         prog = READ_ONCE(psock->progs.skb_parser);
731         if (likely(prog))
732                 ret = sk_psock_bpf_run(psock, prog, skb);
733         rcu_read_unlock();
734         return ret;
735 }
736
737 /* Called with socket lock held. */
738 static void sk_psock_data_ready(struct sock *sk)
739 {
740         struct sk_psock *psock;
741
742         rcu_read_lock();
743         psock = sk_psock(sk);
744         if (likely(psock)) {
745                 write_lock_bh(&sk->sk_callback_lock);
746                 strp_data_ready(&psock->parser.strp);
747                 write_unlock_bh(&sk->sk_callback_lock);
748         }
749         rcu_read_unlock();
750 }
751
752 static void sk_psock_write_space(struct sock *sk)
753 {
754         struct sk_psock *psock;
755         void (*write_space)(struct sock *sk);
756
757         rcu_read_lock();
758         psock = sk_psock(sk);
759         if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)))
760                 schedule_work(&psock->work);
761         write_space = psock->saved_write_space;
762         rcu_read_unlock();
763         write_space(sk);
764 }
765
766 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock)
767 {
768         static const struct strp_callbacks cb = {
769                 .rcv_msg        = sk_psock_strp_read,
770                 .read_sock_done = sk_psock_strp_read_done,
771                 .parse_msg      = sk_psock_strp_parse,
772         };
773
774         psock->parser.enabled = false;
775         return strp_init(&psock->parser.strp, sk, &cb);
776 }
777
778 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock)
779 {
780         struct sk_psock_parser *parser = &psock->parser;
781
782         if (parser->enabled)
783                 return;
784
785         parser->saved_data_ready = sk->sk_data_ready;
786         sk->sk_data_ready = sk_psock_data_ready;
787         sk->sk_write_space = sk_psock_write_space;
788         parser->enabled = true;
789 }
790
791 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock)
792 {
793         struct sk_psock_parser *parser = &psock->parser;
794
795         if (!parser->enabled)
796                 return;
797
798         sk->sk_data_ready = parser->saved_data_ready;
799         parser->saved_data_ready = NULL;
800         strp_stop(&parser->strp);
801         parser->enabled = false;
802 }