]> asedeno.scripts.mit.edu Git - linux.git/blob - net/core/sock_map.c
bpf: move memory size checks to bpf_map_charge_init()
[linux.git] / net / core / sock_map.c
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3
4 #include <linux/bpf.h>
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
8 #include <linux/net.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
13
14 struct bpf_stab {
15         struct bpf_map map;
16         struct sock **sks;
17         struct sk_psock_progs progs;
18         raw_spinlock_t lock;
19 };
20
21 #define SOCK_CREATE_FLAG_MASK                           \
22         (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
23
24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
25 {
26         struct bpf_stab *stab;
27         u64 cost;
28         int err;
29
30         if (!capable(CAP_NET_ADMIN))
31                 return ERR_PTR(-EPERM);
32         if (attr->max_entries == 0 ||
33             attr->key_size    != 4 ||
34             attr->value_size  != 4 ||
35             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
36                 return ERR_PTR(-EINVAL);
37
38         stab = kzalloc(sizeof(*stab), GFP_USER);
39         if (!stab)
40                 return ERR_PTR(-ENOMEM);
41
42         bpf_map_init_from_attr(&stab->map, attr);
43         raw_spin_lock_init(&stab->lock);
44
45         /* Make sure page count doesn't overflow. */
46         cost = (u64) stab->map.max_entries * sizeof(struct sock *);
47         err = bpf_map_charge_init(&stab->map.memory, cost);
48         if (err)
49                 goto free_stab;
50
51         stab->sks = bpf_map_area_alloc(stab->map.max_entries *
52                                        sizeof(struct sock *),
53                                        stab->map.numa_node);
54         if (stab->sks)
55                 return &stab->map;
56         err = -ENOMEM;
57         bpf_map_charge_finish(&stab->map.memory);
58 free_stab:
59         kfree(stab);
60         return ERR_PTR(err);
61 }
62
63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
64 {
65         u32 ufd = attr->target_fd;
66         struct bpf_map *map;
67         struct fd f;
68         int ret;
69
70         f = fdget(ufd);
71         map = __bpf_map_get(f);
72         if (IS_ERR(map))
73                 return PTR_ERR(map);
74         ret = sock_map_prog_update(map, prog, attr->attach_type);
75         fdput(f);
76         return ret;
77 }
78
79 static void sock_map_sk_acquire(struct sock *sk)
80         __acquires(&sk->sk_lock.slock)
81 {
82         lock_sock(sk);
83         preempt_disable();
84         rcu_read_lock();
85 }
86
87 static void sock_map_sk_release(struct sock *sk)
88         __releases(&sk->sk_lock.slock)
89 {
90         rcu_read_unlock();
91         preempt_enable();
92         release_sock(sk);
93 }
94
95 static void sock_map_add_link(struct sk_psock *psock,
96                               struct sk_psock_link *link,
97                               struct bpf_map *map, void *link_raw)
98 {
99         link->link_raw = link_raw;
100         link->map = map;
101         spin_lock_bh(&psock->link_lock);
102         list_add_tail(&link->list, &psock->link);
103         spin_unlock_bh(&psock->link_lock);
104 }
105
106 static void sock_map_del_link(struct sock *sk,
107                               struct sk_psock *psock, void *link_raw)
108 {
109         struct sk_psock_link *link, *tmp;
110         bool strp_stop = false;
111
112         spin_lock_bh(&psock->link_lock);
113         list_for_each_entry_safe(link, tmp, &psock->link, list) {
114                 if (link->link_raw == link_raw) {
115                         struct bpf_map *map = link->map;
116                         struct bpf_stab *stab = container_of(map, struct bpf_stab,
117                                                              map);
118                         if (psock->parser.enabled && stab->progs.skb_parser)
119                                 strp_stop = true;
120                         list_del(&link->list);
121                         sk_psock_free_link(link);
122                 }
123         }
124         spin_unlock_bh(&psock->link_lock);
125         if (strp_stop) {
126                 write_lock_bh(&sk->sk_callback_lock);
127                 sk_psock_stop_strp(sk, psock);
128                 write_unlock_bh(&sk->sk_callback_lock);
129         }
130 }
131
132 static void sock_map_unref(struct sock *sk, void *link_raw)
133 {
134         struct sk_psock *psock = sk_psock(sk);
135
136         if (likely(psock)) {
137                 sock_map_del_link(sk, psock, link_raw);
138                 sk_psock_put(sk, psock);
139         }
140 }
141
142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
143                          struct sock *sk)
144 {
145         struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
146         bool skb_progs, sk_psock_is_new = false;
147         struct sk_psock *psock;
148         int ret;
149
150         skb_verdict = READ_ONCE(progs->skb_verdict);
151         skb_parser = READ_ONCE(progs->skb_parser);
152         skb_progs = skb_parser && skb_verdict;
153         if (skb_progs) {
154                 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
155                 if (IS_ERR(skb_verdict))
156                         return PTR_ERR(skb_verdict);
157                 skb_parser = bpf_prog_inc_not_zero(skb_parser);
158                 if (IS_ERR(skb_parser)) {
159                         bpf_prog_put(skb_verdict);
160                         return PTR_ERR(skb_parser);
161                 }
162         }
163
164         msg_parser = READ_ONCE(progs->msg_parser);
165         if (msg_parser) {
166                 msg_parser = bpf_prog_inc_not_zero(msg_parser);
167                 if (IS_ERR(msg_parser)) {
168                         ret = PTR_ERR(msg_parser);
169                         goto out;
170                 }
171         }
172
173         psock = sk_psock_get_checked(sk);
174         if (IS_ERR(psock)) {
175                 ret = PTR_ERR(psock);
176                 goto out_progs;
177         }
178
179         if (psock) {
180                 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
181                     (skb_progs  && READ_ONCE(psock->progs.skb_parser))) {
182                         sk_psock_put(sk, psock);
183                         ret = -EBUSY;
184                         goto out_progs;
185                 }
186         } else {
187                 psock = sk_psock_init(sk, map->numa_node);
188                 if (!psock) {
189                         ret = -ENOMEM;
190                         goto out_progs;
191                 }
192                 sk_psock_is_new = true;
193         }
194
195         if (msg_parser)
196                 psock_set_prog(&psock->progs.msg_parser, msg_parser);
197         if (sk_psock_is_new) {
198                 ret = tcp_bpf_init(sk);
199                 if (ret < 0)
200                         goto out_drop;
201         } else {
202                 tcp_bpf_reinit(sk);
203         }
204
205         write_lock_bh(&sk->sk_callback_lock);
206         if (skb_progs && !psock->parser.enabled) {
207                 ret = sk_psock_init_strp(sk, psock);
208                 if (ret) {
209                         write_unlock_bh(&sk->sk_callback_lock);
210                         goto out_drop;
211                 }
212                 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
213                 psock_set_prog(&psock->progs.skb_parser, skb_parser);
214                 sk_psock_start_strp(sk, psock);
215         }
216         write_unlock_bh(&sk->sk_callback_lock);
217         return 0;
218 out_drop:
219         sk_psock_put(sk, psock);
220 out_progs:
221         if (msg_parser)
222                 bpf_prog_put(msg_parser);
223 out:
224         if (skb_progs) {
225                 bpf_prog_put(skb_verdict);
226                 bpf_prog_put(skb_parser);
227         }
228         return ret;
229 }
230
231 static void sock_map_free(struct bpf_map *map)
232 {
233         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
234         int i;
235
236         synchronize_rcu();
237         rcu_read_lock();
238         raw_spin_lock_bh(&stab->lock);
239         for (i = 0; i < stab->map.max_entries; i++) {
240                 struct sock **psk = &stab->sks[i];
241                 struct sock *sk;
242
243                 sk = xchg(psk, NULL);
244                 if (sk)
245                         sock_map_unref(sk, psk);
246         }
247         raw_spin_unlock_bh(&stab->lock);
248         rcu_read_unlock();
249
250         bpf_map_area_free(stab->sks);
251         kfree(stab);
252 }
253
254 static void sock_map_release_progs(struct bpf_map *map)
255 {
256         psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
257 }
258
259 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
260 {
261         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
262
263         WARN_ON_ONCE(!rcu_read_lock_held());
264
265         if (unlikely(key >= map->max_entries))
266                 return NULL;
267         return READ_ONCE(stab->sks[key]);
268 }
269
270 static void *sock_map_lookup(struct bpf_map *map, void *key)
271 {
272         return ERR_PTR(-EOPNOTSUPP);
273 }
274
275 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
276                              struct sock **psk)
277 {
278         struct sock *sk;
279
280         raw_spin_lock_bh(&stab->lock);
281         sk = *psk;
282         if (!sk_test || sk_test == sk)
283                 *psk = NULL;
284         raw_spin_unlock_bh(&stab->lock);
285         if (unlikely(!sk))
286                 return -EINVAL;
287         sock_map_unref(sk, psk);
288         return 0;
289 }
290
291 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
292                                       void *link_raw)
293 {
294         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
295
296         __sock_map_delete(stab, sk, link_raw);
297 }
298
299 static int sock_map_delete_elem(struct bpf_map *map, void *key)
300 {
301         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
302         u32 i = *(u32 *)key;
303         struct sock **psk;
304
305         if (unlikely(i >= map->max_entries))
306                 return -EINVAL;
307
308         psk = &stab->sks[i];
309         return __sock_map_delete(stab, NULL, psk);
310 }
311
312 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
313 {
314         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
315         u32 i = key ? *(u32 *)key : U32_MAX;
316         u32 *key_next = next;
317
318         if (i == stab->map.max_entries - 1)
319                 return -ENOENT;
320         if (i >= stab->map.max_entries)
321                 *key_next = 0;
322         else
323                 *key_next = i + 1;
324         return 0;
325 }
326
327 static int sock_map_update_common(struct bpf_map *map, u32 idx,
328                                   struct sock *sk, u64 flags)
329 {
330         struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
331         struct sk_psock_link *link;
332         struct sk_psock *psock;
333         struct sock *osk;
334         int ret;
335
336         WARN_ON_ONCE(!rcu_read_lock_held());
337         if (unlikely(flags > BPF_EXIST))
338                 return -EINVAL;
339         if (unlikely(idx >= map->max_entries))
340                 return -E2BIG;
341
342         link = sk_psock_init_link();
343         if (!link)
344                 return -ENOMEM;
345
346         ret = sock_map_link(map, &stab->progs, sk);
347         if (ret < 0)
348                 goto out_free;
349
350         psock = sk_psock(sk);
351         WARN_ON_ONCE(!psock);
352
353         raw_spin_lock_bh(&stab->lock);
354         osk = stab->sks[idx];
355         if (osk && flags == BPF_NOEXIST) {
356                 ret = -EEXIST;
357                 goto out_unlock;
358         } else if (!osk && flags == BPF_EXIST) {
359                 ret = -ENOENT;
360                 goto out_unlock;
361         }
362
363         sock_map_add_link(psock, link, map, &stab->sks[idx]);
364         stab->sks[idx] = sk;
365         if (osk)
366                 sock_map_unref(osk, &stab->sks[idx]);
367         raw_spin_unlock_bh(&stab->lock);
368         return 0;
369 out_unlock:
370         raw_spin_unlock_bh(&stab->lock);
371         if (psock)
372                 sk_psock_put(sk, psock);
373 out_free:
374         sk_psock_free_link(link);
375         return ret;
376 }
377
378 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
379 {
380         return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
381                ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB;
382 }
383
384 static bool sock_map_sk_is_suitable(const struct sock *sk)
385 {
386         return sk->sk_type == SOCK_STREAM &&
387                sk->sk_protocol == IPPROTO_TCP;
388 }
389
390 static int sock_map_update_elem(struct bpf_map *map, void *key,
391                                 void *value, u64 flags)
392 {
393         u32 ufd = *(u32 *)value;
394         u32 idx = *(u32 *)key;
395         struct socket *sock;
396         struct sock *sk;
397         int ret;
398
399         sock = sockfd_lookup(ufd, &ret);
400         if (!sock)
401                 return ret;
402         sk = sock->sk;
403         if (!sk) {
404                 ret = -EINVAL;
405                 goto out;
406         }
407         if (!sock_map_sk_is_suitable(sk) ||
408             sk->sk_state != TCP_ESTABLISHED) {
409                 ret = -EOPNOTSUPP;
410                 goto out;
411         }
412
413         sock_map_sk_acquire(sk);
414         ret = sock_map_update_common(map, idx, sk, flags);
415         sock_map_sk_release(sk);
416 out:
417         fput(sock->file);
418         return ret;
419 }
420
421 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
422            struct bpf_map *, map, void *, key, u64, flags)
423 {
424         WARN_ON_ONCE(!rcu_read_lock_held());
425
426         if (likely(sock_map_sk_is_suitable(sops->sk) &&
427                    sock_map_op_okay(sops)))
428                 return sock_map_update_common(map, *(u32 *)key, sops->sk,
429                                               flags);
430         return -EOPNOTSUPP;
431 }
432
433 const struct bpf_func_proto bpf_sock_map_update_proto = {
434         .func           = bpf_sock_map_update,
435         .gpl_only       = false,
436         .pkt_access     = true,
437         .ret_type       = RET_INTEGER,
438         .arg1_type      = ARG_PTR_TO_CTX,
439         .arg2_type      = ARG_CONST_MAP_PTR,
440         .arg3_type      = ARG_PTR_TO_MAP_KEY,
441         .arg4_type      = ARG_ANYTHING,
442 };
443
444 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
445            struct bpf_map *, map, u32, key, u64, flags)
446 {
447         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
448
449         if (unlikely(flags & ~(BPF_F_INGRESS)))
450                 return SK_DROP;
451         tcb->bpf.flags = flags;
452         tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key);
453         if (!tcb->bpf.sk_redir)
454                 return SK_DROP;
455         return SK_PASS;
456 }
457
458 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
459         .func           = bpf_sk_redirect_map,
460         .gpl_only       = false,
461         .ret_type       = RET_INTEGER,
462         .arg1_type      = ARG_PTR_TO_CTX,
463         .arg2_type      = ARG_CONST_MAP_PTR,
464         .arg3_type      = ARG_ANYTHING,
465         .arg4_type      = ARG_ANYTHING,
466 };
467
468 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
469            struct bpf_map *, map, u32, key, u64, flags)
470 {
471         if (unlikely(flags & ~(BPF_F_INGRESS)))
472                 return SK_DROP;
473         msg->flags = flags;
474         msg->sk_redir = __sock_map_lookup_elem(map, key);
475         if (!msg->sk_redir)
476                 return SK_DROP;
477         return SK_PASS;
478 }
479
480 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
481         .func           = bpf_msg_redirect_map,
482         .gpl_only       = false,
483         .ret_type       = RET_INTEGER,
484         .arg1_type      = ARG_PTR_TO_CTX,
485         .arg2_type      = ARG_CONST_MAP_PTR,
486         .arg3_type      = ARG_ANYTHING,
487         .arg4_type      = ARG_ANYTHING,
488 };
489
490 const struct bpf_map_ops sock_map_ops = {
491         .map_alloc              = sock_map_alloc,
492         .map_free               = sock_map_free,
493         .map_get_next_key       = sock_map_get_next_key,
494         .map_update_elem        = sock_map_update_elem,
495         .map_delete_elem        = sock_map_delete_elem,
496         .map_lookup_elem        = sock_map_lookup,
497         .map_release_uref       = sock_map_release_progs,
498         .map_check_btf          = map_check_no_btf,
499 };
500
501 struct bpf_htab_elem {
502         struct rcu_head rcu;
503         u32 hash;
504         struct sock *sk;
505         struct hlist_node node;
506         u8 key[0];
507 };
508
509 struct bpf_htab_bucket {
510         struct hlist_head head;
511         raw_spinlock_t lock;
512 };
513
514 struct bpf_htab {
515         struct bpf_map map;
516         struct bpf_htab_bucket *buckets;
517         u32 buckets_num;
518         u32 elem_size;
519         struct sk_psock_progs progs;
520         atomic_t count;
521 };
522
523 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
524 {
525         return jhash(key, len, 0);
526 }
527
528 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab,
529                                                        u32 hash)
530 {
531         return &htab->buckets[hash & (htab->buckets_num - 1)];
532 }
533
534 static struct bpf_htab_elem *
535 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
536                           u32 key_size)
537 {
538         struct bpf_htab_elem *elem;
539
540         hlist_for_each_entry_rcu(elem, head, node) {
541                 if (elem->hash == hash &&
542                     !memcmp(&elem->key, key, key_size))
543                         return elem;
544         }
545
546         return NULL;
547 }
548
549 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
550 {
551         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
552         u32 key_size = map->key_size, hash;
553         struct bpf_htab_bucket *bucket;
554         struct bpf_htab_elem *elem;
555
556         WARN_ON_ONCE(!rcu_read_lock_held());
557
558         hash = sock_hash_bucket_hash(key, key_size);
559         bucket = sock_hash_select_bucket(htab, hash);
560         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
561
562         return elem ? elem->sk : NULL;
563 }
564
565 static void sock_hash_free_elem(struct bpf_htab *htab,
566                                 struct bpf_htab_elem *elem)
567 {
568         atomic_dec(&htab->count);
569         kfree_rcu(elem, rcu);
570 }
571
572 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
573                                        void *link_raw)
574 {
575         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
576         struct bpf_htab_elem *elem_probe, *elem = link_raw;
577         struct bpf_htab_bucket *bucket;
578
579         WARN_ON_ONCE(!rcu_read_lock_held());
580         bucket = sock_hash_select_bucket(htab, elem->hash);
581
582         /* elem may be deleted in parallel from the map, but access here
583          * is okay since it's going away only after RCU grace period.
584          * However, we need to check whether it's still present.
585          */
586         raw_spin_lock_bh(&bucket->lock);
587         elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
588                                                elem->key, map->key_size);
589         if (elem_probe && elem_probe == elem) {
590                 hlist_del_rcu(&elem->node);
591                 sock_map_unref(elem->sk, elem);
592                 sock_hash_free_elem(htab, elem);
593         }
594         raw_spin_unlock_bh(&bucket->lock);
595 }
596
597 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
598 {
599         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
600         u32 hash, key_size = map->key_size;
601         struct bpf_htab_bucket *bucket;
602         struct bpf_htab_elem *elem;
603         int ret = -ENOENT;
604
605         hash = sock_hash_bucket_hash(key, key_size);
606         bucket = sock_hash_select_bucket(htab, hash);
607
608         raw_spin_lock_bh(&bucket->lock);
609         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
610         if (elem) {
611                 hlist_del_rcu(&elem->node);
612                 sock_map_unref(elem->sk, elem);
613                 sock_hash_free_elem(htab, elem);
614                 ret = 0;
615         }
616         raw_spin_unlock_bh(&bucket->lock);
617         return ret;
618 }
619
620 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab,
621                                                   void *key, u32 key_size,
622                                                   u32 hash, struct sock *sk,
623                                                   struct bpf_htab_elem *old)
624 {
625         struct bpf_htab_elem *new;
626
627         if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
628                 if (!old) {
629                         atomic_dec(&htab->count);
630                         return ERR_PTR(-E2BIG);
631                 }
632         }
633
634         new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
635                            htab->map.numa_node);
636         if (!new) {
637                 atomic_dec(&htab->count);
638                 return ERR_PTR(-ENOMEM);
639         }
640         memcpy(new->key, key, key_size);
641         new->sk = sk;
642         new->hash = hash;
643         return new;
644 }
645
646 static int sock_hash_update_common(struct bpf_map *map, void *key,
647                                    struct sock *sk, u64 flags)
648 {
649         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
650         u32 key_size = map->key_size, hash;
651         struct bpf_htab_elem *elem, *elem_new;
652         struct bpf_htab_bucket *bucket;
653         struct sk_psock_link *link;
654         struct sk_psock *psock;
655         int ret;
656
657         WARN_ON_ONCE(!rcu_read_lock_held());
658         if (unlikely(flags > BPF_EXIST))
659                 return -EINVAL;
660
661         link = sk_psock_init_link();
662         if (!link)
663                 return -ENOMEM;
664
665         ret = sock_map_link(map, &htab->progs, sk);
666         if (ret < 0)
667                 goto out_free;
668
669         psock = sk_psock(sk);
670         WARN_ON_ONCE(!psock);
671
672         hash = sock_hash_bucket_hash(key, key_size);
673         bucket = sock_hash_select_bucket(htab, hash);
674
675         raw_spin_lock_bh(&bucket->lock);
676         elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
677         if (elem && flags == BPF_NOEXIST) {
678                 ret = -EEXIST;
679                 goto out_unlock;
680         } else if (!elem && flags == BPF_EXIST) {
681                 ret = -ENOENT;
682                 goto out_unlock;
683         }
684
685         elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
686         if (IS_ERR(elem_new)) {
687                 ret = PTR_ERR(elem_new);
688                 goto out_unlock;
689         }
690
691         sock_map_add_link(psock, link, map, elem_new);
692         /* Add new element to the head of the list, so that
693          * concurrent search will find it before old elem.
694          */
695         hlist_add_head_rcu(&elem_new->node, &bucket->head);
696         if (elem) {
697                 hlist_del_rcu(&elem->node);
698                 sock_map_unref(elem->sk, elem);
699                 sock_hash_free_elem(htab, elem);
700         }
701         raw_spin_unlock_bh(&bucket->lock);
702         return 0;
703 out_unlock:
704         raw_spin_unlock_bh(&bucket->lock);
705         sk_psock_put(sk, psock);
706 out_free:
707         sk_psock_free_link(link);
708         return ret;
709 }
710
711 static int sock_hash_update_elem(struct bpf_map *map, void *key,
712                                  void *value, u64 flags)
713 {
714         u32 ufd = *(u32 *)value;
715         struct socket *sock;
716         struct sock *sk;
717         int ret;
718
719         sock = sockfd_lookup(ufd, &ret);
720         if (!sock)
721                 return ret;
722         sk = sock->sk;
723         if (!sk) {
724                 ret = -EINVAL;
725                 goto out;
726         }
727         if (!sock_map_sk_is_suitable(sk) ||
728             sk->sk_state != TCP_ESTABLISHED) {
729                 ret = -EOPNOTSUPP;
730                 goto out;
731         }
732
733         sock_map_sk_acquire(sk);
734         ret = sock_hash_update_common(map, key, sk, flags);
735         sock_map_sk_release(sk);
736 out:
737         fput(sock->file);
738         return ret;
739 }
740
741 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
742                                   void *key_next)
743 {
744         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
745         struct bpf_htab_elem *elem, *elem_next;
746         u32 hash, key_size = map->key_size;
747         struct hlist_head *head;
748         int i = 0;
749
750         if (!key)
751                 goto find_first_elem;
752         hash = sock_hash_bucket_hash(key, key_size);
753         head = &sock_hash_select_bucket(htab, hash)->head;
754         elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
755         if (!elem)
756                 goto find_first_elem;
757
758         elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
759                                      struct bpf_htab_elem, node);
760         if (elem_next) {
761                 memcpy(key_next, elem_next->key, key_size);
762                 return 0;
763         }
764
765         i = hash & (htab->buckets_num - 1);
766         i++;
767 find_first_elem:
768         for (; i < htab->buckets_num; i++) {
769                 head = &sock_hash_select_bucket(htab, i)->head;
770                 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
771                                              struct bpf_htab_elem, node);
772                 if (elem_next) {
773                         memcpy(key_next, elem_next->key, key_size);
774                         return 0;
775                 }
776         }
777
778         return -ENOENT;
779 }
780
781 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
782 {
783         struct bpf_htab *htab;
784         int i, err;
785         u64 cost;
786
787         if (!capable(CAP_NET_ADMIN))
788                 return ERR_PTR(-EPERM);
789         if (attr->max_entries == 0 ||
790             attr->key_size    == 0 ||
791             attr->value_size  != 4 ||
792             attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
793                 return ERR_PTR(-EINVAL);
794         if (attr->key_size > MAX_BPF_STACK)
795                 return ERR_PTR(-E2BIG);
796
797         htab = kzalloc(sizeof(*htab), GFP_USER);
798         if (!htab)
799                 return ERR_PTR(-ENOMEM);
800
801         bpf_map_init_from_attr(&htab->map, attr);
802
803         htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
804         htab->elem_size = sizeof(struct bpf_htab_elem) +
805                           round_up(htab->map.key_size, 8);
806         if (htab->buckets_num == 0 ||
807             htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) {
808                 err = -EINVAL;
809                 goto free_htab;
810         }
811
812         cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) +
813                (u64) htab->elem_size * htab->map.max_entries;
814         if (cost >= U32_MAX - PAGE_SIZE) {
815                 err = -EINVAL;
816                 goto free_htab;
817         }
818
819         htab->buckets = bpf_map_area_alloc(htab->buckets_num *
820                                            sizeof(struct bpf_htab_bucket),
821                                            htab->map.numa_node);
822         if (!htab->buckets) {
823                 err = -ENOMEM;
824                 goto free_htab;
825         }
826
827         for (i = 0; i < htab->buckets_num; i++) {
828                 INIT_HLIST_HEAD(&htab->buckets[i].head);
829                 raw_spin_lock_init(&htab->buckets[i].lock);
830         }
831
832         return &htab->map;
833 free_htab:
834         kfree(htab);
835         return ERR_PTR(err);
836 }
837
838 static void sock_hash_free(struct bpf_map *map)
839 {
840         struct bpf_htab *htab = container_of(map, struct bpf_htab, map);
841         struct bpf_htab_bucket *bucket;
842         struct bpf_htab_elem *elem;
843         struct hlist_node *node;
844         int i;
845
846         synchronize_rcu();
847         rcu_read_lock();
848         for (i = 0; i < htab->buckets_num; i++) {
849                 bucket = sock_hash_select_bucket(htab, i);
850                 raw_spin_lock_bh(&bucket->lock);
851                 hlist_for_each_entry_safe(elem, node, &bucket->head, node) {
852                         hlist_del_rcu(&elem->node);
853                         sock_map_unref(elem->sk, elem);
854                 }
855                 raw_spin_unlock_bh(&bucket->lock);
856         }
857         rcu_read_unlock();
858
859         bpf_map_area_free(htab->buckets);
860         kfree(htab);
861 }
862
863 static void sock_hash_release_progs(struct bpf_map *map)
864 {
865         psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs);
866 }
867
868 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
869            struct bpf_map *, map, void *, key, u64, flags)
870 {
871         WARN_ON_ONCE(!rcu_read_lock_held());
872
873         if (likely(sock_map_sk_is_suitable(sops->sk) &&
874                    sock_map_op_okay(sops)))
875                 return sock_hash_update_common(map, key, sops->sk, flags);
876         return -EOPNOTSUPP;
877 }
878
879 const struct bpf_func_proto bpf_sock_hash_update_proto = {
880         .func           = bpf_sock_hash_update,
881         .gpl_only       = false,
882         .pkt_access     = true,
883         .ret_type       = RET_INTEGER,
884         .arg1_type      = ARG_PTR_TO_CTX,
885         .arg2_type      = ARG_CONST_MAP_PTR,
886         .arg3_type      = ARG_PTR_TO_MAP_KEY,
887         .arg4_type      = ARG_ANYTHING,
888 };
889
890 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
891            struct bpf_map *, map, void *, key, u64, flags)
892 {
893         struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
894
895         if (unlikely(flags & ~(BPF_F_INGRESS)))
896                 return SK_DROP;
897         tcb->bpf.flags = flags;
898         tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key);
899         if (!tcb->bpf.sk_redir)
900                 return SK_DROP;
901         return SK_PASS;
902 }
903
904 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
905         .func           = bpf_sk_redirect_hash,
906         .gpl_only       = false,
907         .ret_type       = RET_INTEGER,
908         .arg1_type      = ARG_PTR_TO_CTX,
909         .arg2_type      = ARG_CONST_MAP_PTR,
910         .arg3_type      = ARG_PTR_TO_MAP_KEY,
911         .arg4_type      = ARG_ANYTHING,
912 };
913
914 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
915            struct bpf_map *, map, void *, key, u64, flags)
916 {
917         if (unlikely(flags & ~(BPF_F_INGRESS)))
918                 return SK_DROP;
919         msg->flags = flags;
920         msg->sk_redir = __sock_hash_lookup_elem(map, key);
921         if (!msg->sk_redir)
922                 return SK_DROP;
923         return SK_PASS;
924 }
925
926 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
927         .func           = bpf_msg_redirect_hash,
928         .gpl_only       = false,
929         .ret_type       = RET_INTEGER,
930         .arg1_type      = ARG_PTR_TO_CTX,
931         .arg2_type      = ARG_CONST_MAP_PTR,
932         .arg3_type      = ARG_PTR_TO_MAP_KEY,
933         .arg4_type      = ARG_ANYTHING,
934 };
935
936 const struct bpf_map_ops sock_hash_ops = {
937         .map_alloc              = sock_hash_alloc,
938         .map_free               = sock_hash_free,
939         .map_get_next_key       = sock_hash_get_next_key,
940         .map_update_elem        = sock_hash_update_elem,
941         .map_delete_elem        = sock_hash_delete_elem,
942         .map_lookup_elem        = sock_map_lookup,
943         .map_release_uref       = sock_hash_release_progs,
944         .map_check_btf          = map_check_no_btf,
945 };
946
947 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
948 {
949         switch (map->map_type) {
950         case BPF_MAP_TYPE_SOCKMAP:
951                 return &container_of(map, struct bpf_stab, map)->progs;
952         case BPF_MAP_TYPE_SOCKHASH:
953                 return &container_of(map, struct bpf_htab, map)->progs;
954         default:
955                 break;
956         }
957
958         return NULL;
959 }
960
961 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
962                          u32 which)
963 {
964         struct sk_psock_progs *progs = sock_map_progs(map);
965
966         if (!progs)
967                 return -EOPNOTSUPP;
968
969         switch (which) {
970         case BPF_SK_MSG_VERDICT:
971                 psock_set_prog(&progs->msg_parser, prog);
972                 break;
973         case BPF_SK_SKB_STREAM_PARSER:
974                 psock_set_prog(&progs->skb_parser, prog);
975                 break;
976         case BPF_SK_SKB_STREAM_VERDICT:
977                 psock_set_prog(&progs->skb_verdict, prog);
978                 break;
979         default:
980                 return -EOPNOTSUPP;
981         }
982
983         return 0;
984 }
985
986 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link)
987 {
988         switch (link->map->map_type) {
989         case BPF_MAP_TYPE_SOCKMAP:
990                 return sock_map_delete_from_link(link->map, sk,
991                                                  link->link_raw);
992         case BPF_MAP_TYPE_SOCKHASH:
993                 return sock_hash_delete_from_link(link->map, sk,
994                                                   link->link_raw);
995         default:
996                 break;
997         }
998 }