]> asedeno.scripts.mit.edu Git - linux.git/blob - tools/testing/selftests/bpf/test_sock_addr.c
d488f20926e8962f1b8d59ecf5b85c6ee4767c7a
[linux.git] / tools / testing / selftests / bpf / test_sock_addr.c
1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <unistd.h>
7
8 #include <arpa/inet.h>
9 #include <sys/types.h>
10 #include <sys/socket.h>
11
12 #include <linux/filter.h>
13
14 #include <bpf/bpf.h>
15 #include <bpf/libbpf.h>
16
17 #include "cgroup_helpers.h"
18
19 #define CG_PATH "/foo"
20 #define CONNECT4_PROG_PATH      "./connect4_prog.o"
21 #define CONNECT6_PROG_PATH      "./connect6_prog.o"
22
23 #define SERV4_IP                "192.168.1.254"
24 #define SERV4_REWRITE_IP        "127.0.0.1"
25 #define SERV4_PORT              4040
26 #define SERV4_REWRITE_PORT      4444
27
28 #define SERV6_IP                "face:b00c:1234:5678::abcd"
29 #define SERV6_REWRITE_IP        "::1"
30 #define SERV6_PORT              6060
31 #define SERV6_REWRITE_PORT      6666
32
33 #define INET_NTOP_BUF   40
34
35 typedef int (*load_fn)(enum bpf_attach_type, const char *comment);
36 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
37
38 struct program {
39         enum bpf_attach_type type;
40         load_fn loadfn;
41         int fd;
42         const char *name;
43         enum bpf_attach_type invalid_type;
44 };
45
46 char bpf_log_buf[BPF_LOG_BUF_SIZE];
47
48 static int mk_sockaddr(int domain, const char *ip, unsigned short port,
49                        struct sockaddr *addr, socklen_t addr_len)
50 {
51         struct sockaddr_in6 *addr6;
52         struct sockaddr_in *addr4;
53
54         if (domain != AF_INET && domain != AF_INET6) {
55                 log_err("Unsupported address family");
56                 return -1;
57         }
58
59         memset(addr, 0, addr_len);
60
61         if (domain == AF_INET) {
62                 if (addr_len < sizeof(struct sockaddr_in))
63                         return -1;
64                 addr4 = (struct sockaddr_in *)addr;
65                 addr4->sin_family = domain;
66                 addr4->sin_port = htons(port);
67                 if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) {
68                         log_err("Invalid IPv4: %s", ip);
69                         return -1;
70                 }
71         } else if (domain == AF_INET6) {
72                 if (addr_len < sizeof(struct sockaddr_in6))
73                         return -1;
74                 addr6 = (struct sockaddr_in6 *)addr;
75                 addr6->sin6_family = domain;
76                 addr6->sin6_port = htons(port);
77                 if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) {
78                         log_err("Invalid IPv6: %s", ip);
79                         return -1;
80                 }
81         }
82
83         return 0;
84 }
85
86 static int load_insns(enum bpf_attach_type attach_type,
87                       const struct bpf_insn *insns, size_t insns_cnt,
88                       const char *comment)
89 {
90         struct bpf_load_program_attr load_attr;
91         int ret;
92
93         memset(&load_attr, 0, sizeof(struct bpf_load_program_attr));
94         load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
95         load_attr.expected_attach_type = attach_type;
96         load_attr.insns = insns;
97         load_attr.insns_cnt = insns_cnt;
98         load_attr.license = "GPL";
99
100         ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
101         if (ret < 0 && comment) {
102                 log_err(">>> Loading %s program error.\n"
103                         ">>> Output from verifier:\n%s\n-------\n",
104                         comment, bpf_log_buf);
105         }
106
107         return ret;
108 }
109
110 /* [1] These testing programs try to read different context fields, including
111  * narrow loads of different sizes from user_ip4 and user_ip6, and write to
112  * those allowed to be overridden.
113  *
114  * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to
115  * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used
116  * in such cases since it accepts only _signed_ 32bit integer as IMM
117  * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters
118  * to count jumps properly.
119  */
120
121 static int bind4_prog_load(enum bpf_attach_type attach_type,
122                            const char *comment)
123 {
124         union {
125                 uint8_t u4_addr8[4];
126                 uint16_t u4_addr16[2];
127                 uint32_t u4_addr32;
128         } ip4;
129         struct sockaddr_in addr4_rw;
130
131         if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) {
132                 log_err("Invalid IPv4: %s", SERV4_IP);
133                 return -1;
134         }
135
136         if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT,
137                         (struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1)
138                 return -1;
139
140         /* See [1]. */
141         struct bpf_insn insns[] = {
142                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
143
144                 /* if (sk.family == AF_INET && */
145                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
146                             offsetof(struct bpf_sock_addr, family)),
147                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 16),
148
149                 /*     (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */
150                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
151                             offsetof(struct bpf_sock_addr, type)),
152                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1),
153                 BPF_JMP_A(1),
154                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 12),
155
156                 /*     1st_byte_of_user_ip4 == expected && */
157                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
158                             offsetof(struct bpf_sock_addr, user_ip4)),
159                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 10),
160
161                 /*     1st_half_of_user_ip4 == expected && */
162                 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
163                             offsetof(struct bpf_sock_addr, user_ip4)),
164                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 8),
165
166                 /*     whole_user_ip4 == expected) { */
167                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
168                             offsetof(struct bpf_sock_addr, user_ip4)),
169                 BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */
170                 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4),
171
172                 /*      user_ip4 = addr4_rw.sin_addr */
173                 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr),
174                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
175                             offsetof(struct bpf_sock_addr, user_ip4)),
176
177                 /*      user_port = addr4_rw.sin_port */
178                 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port),
179                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
180                             offsetof(struct bpf_sock_addr, user_port)),
181                 /* } */
182
183                 /* return 1 */
184                 BPF_MOV64_IMM(BPF_REG_0, 1),
185                 BPF_EXIT_INSN(),
186         };
187
188         return load_insns(attach_type, insns,
189                           sizeof(insns) / sizeof(struct bpf_insn), comment);
190 }
191
192 static int bind6_prog_load(enum bpf_attach_type attach_type,
193                            const char *comment)
194 {
195         struct sockaddr_in6 addr6_rw;
196         struct in6_addr ip6;
197
198         if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) {
199                 log_err("Invalid IPv6: %s", SERV6_IP);
200                 return -1;
201         }
202
203         if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT,
204                         (struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1)
205                 return -1;
206
207         /* See [1]. */
208         struct bpf_insn insns[] = {
209                 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
210
211                 /* if (sk.family == AF_INET6 && */
212                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
213                             offsetof(struct bpf_sock_addr, family)),
214                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18),
215
216                 /*            5th_byte_of_user_ip6 == expected && */
217                 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6,
218                             offsetof(struct bpf_sock_addr, user_ip6[1])),
219                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16),
220
221                 /*            3rd_half_of_user_ip6 == expected && */
222                 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6,
223                             offsetof(struct bpf_sock_addr, user_ip6[1])),
224                 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14),
225
226                 /*            last_word_of_user_ip6 == expected) { */
227                 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
228                             offsetof(struct bpf_sock_addr, user_ip6[3])),
229                 BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]),  /* See [2]. */
230                 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10),
231
232
233 #define STORE_IPV6_WORD(N)                                                     \
234                 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]),     \
235                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,                       \
236                             offsetof(struct bpf_sock_addr, user_ip6[N]))
237
238                 /*      user_ip6 = addr6_rw.sin6_addr */
239                 STORE_IPV6_WORD(0),
240                 STORE_IPV6_WORD(1),
241                 STORE_IPV6_WORD(2),
242                 STORE_IPV6_WORD(3),
243
244                 /*      user_port = addr6_rw.sin6_port */
245                 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port),
246                 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7,
247                             offsetof(struct bpf_sock_addr, user_port)),
248
249                 /* } */
250
251                 /* return 1 */
252                 BPF_MOV64_IMM(BPF_REG_0, 1),
253                 BPF_EXIT_INSN(),
254         };
255
256         return load_insns(attach_type, insns,
257                           sizeof(insns) / sizeof(struct bpf_insn), comment);
258 }
259
260 static int connect_prog_load_path(const char *path,
261                                   enum bpf_attach_type attach_type,
262                                   const char *comment)
263 {
264         struct bpf_prog_load_attr attr;
265         struct bpf_object *obj;
266         int prog_fd;
267
268         memset(&attr, 0, sizeof(struct bpf_prog_load_attr));
269         attr.file = path;
270         attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR;
271         attr.expected_attach_type = attach_type;
272
273         if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) {
274                 if (comment)
275                         log_err(">>> Loading %s program at %s error.\n",
276                                 comment, path);
277                 return -1;
278         }
279
280         return prog_fd;
281 }
282
283 static int connect4_prog_load(enum bpf_attach_type attach_type,
284                               const char *comment)
285 {
286         return connect_prog_load_path(CONNECT4_PROG_PATH, attach_type, comment);
287 }
288
289 static int connect6_prog_load(enum bpf_attach_type attach_type,
290                               const char *comment)
291 {
292         return connect_prog_load_path(CONNECT6_PROG_PATH, attach_type, comment);
293 }
294
295 static void print_ip_port(int sockfd, info_fn fn, const char *fmt)
296 {
297         char addr_buf[INET_NTOP_BUF];
298         struct sockaddr_storage addr;
299         struct sockaddr_in6 *addr6;
300         struct sockaddr_in *addr4;
301         socklen_t addr_len;
302         unsigned short port;
303         void *nip;
304
305         addr_len = sizeof(struct sockaddr_storage);
306         memset(&addr, 0, addr_len);
307
308         if (fn(sockfd, (struct sockaddr *)&addr, (socklen_t *)&addr_len) == 0) {
309                 if (addr.ss_family == AF_INET) {
310                         addr4 = (struct sockaddr_in *)&addr;
311                         nip = (void *)&addr4->sin_addr;
312                         port = ntohs(addr4->sin_port);
313                 } else if (addr.ss_family == AF_INET6) {
314                         addr6 = (struct sockaddr_in6 *)&addr;
315                         nip = (void *)&addr6->sin6_addr;
316                         port = ntohs(addr6->sin6_port);
317                 } else {
318                         return;
319                 }
320                 const char *addr_str =
321                         inet_ntop(addr.ss_family, nip, addr_buf, INET_NTOP_BUF);
322                 printf(fmt, addr_str ? addr_str : "??", port);
323         }
324 }
325
326 static void print_local_ip_port(int sockfd, const char *fmt)
327 {
328         print_ip_port(sockfd, getsockname, fmt);
329 }
330
331 static void print_remote_ip_port(int sockfd, const char *fmt)
332 {
333         print_ip_port(sockfd, getpeername, fmt);
334 }
335
336 static int start_server(int type, const struct sockaddr_storage *addr,
337                         socklen_t addr_len)
338 {
339
340         int fd;
341
342         fd = socket(addr->ss_family, type, 0);
343         if (fd == -1) {
344                 log_err("Failed to create server socket");
345                 goto out;
346         }
347
348         if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) {
349                 log_err("Failed to bind server socket");
350                 goto close_out;
351         }
352
353         if (type == SOCK_STREAM) {
354                 if (listen(fd, 128) == -1) {
355                         log_err("Failed to listen on server socket");
356                         goto close_out;
357                 }
358         }
359
360         print_local_ip_port(fd, "\t   Actual: bind(%s, %d)\n");
361
362         goto out;
363 close_out:
364         close(fd);
365         fd = -1;
366 out:
367         return fd;
368 }
369
370 static int connect_to_server(int type, const struct sockaddr_storage *addr,
371                              socklen_t addr_len)
372 {
373         int domain;
374         int fd;
375
376         domain = addr->ss_family;
377
378         if (domain != AF_INET && domain != AF_INET6) {
379                 log_err("Unsupported address family");
380                 return -1;
381         }
382
383         fd = socket(domain, type, 0);
384         if (fd == -1) {
385                 log_err("Failed to creating client socket");
386                 return -1;
387         }
388
389         if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) {
390                 log_err("Fail to connect to server");
391                 goto err;
392         }
393
394         print_remote_ip_port(fd, "\t   Actual: connect(%s, %d)");
395         print_local_ip_port(fd, " from (%s, %d)\n");
396
397         return 0;
398 err:
399         close(fd);
400         return -1;
401 }
402
403 static void print_test_case_num(int domain, int type)
404 {
405         static int test_num;
406
407         printf("Test case #%d (%s/%s):\n", ++test_num,
408                (domain == AF_INET ? "IPv4" :
409                 domain == AF_INET6 ? "IPv6" :
410                 "unknown_domain"),
411                (type == SOCK_STREAM ? "TCP" :
412                 type == SOCK_DGRAM ? "UDP" :
413                 "unknown_type"));
414 }
415
416 static int run_test_case(int domain, int type, const char *ip,
417                          unsigned short port)
418 {
419         struct sockaddr_storage addr;
420         socklen_t addr_len = sizeof(addr);
421         int servfd = -1;
422         int err = 0;
423
424         print_test_case_num(domain, type);
425
426         if (mk_sockaddr(domain, ip, port, (struct sockaddr *)&addr,
427                         addr_len) == -1)
428                 return -1;
429
430         printf("\tRequested: bind(%s, %d) ..\n", ip, port);
431         servfd = start_server(type, &addr, addr_len);
432         if (servfd == -1)
433                 goto err;
434
435         printf("\tRequested: connect(%s, %d) from (*, *) ..\n", ip, port);
436         if (connect_to_server(type, &addr, addr_len))
437                 goto err;
438
439         goto out;
440 err:
441         err = -1;
442 out:
443         close(servfd);
444         return err;
445 }
446
447 static void close_progs_fds(struct program *progs, size_t prog_cnt)
448 {
449         size_t i;
450
451         for (i = 0; i < prog_cnt; ++i) {
452                 close(progs[i].fd);
453                 progs[i].fd = -1;
454         }
455 }
456
457 static int load_and_attach_progs(int cgfd, struct program *progs,
458                                  size_t prog_cnt)
459 {
460         size_t i;
461
462         for (i = 0; i < prog_cnt; ++i) {
463                 printf("Load %s with invalid type (can pollute stderr) ",
464                        progs[i].name);
465                 fflush(stdout);
466                 progs[i].fd = progs[i].loadfn(progs[i].invalid_type, NULL);
467                 if (progs[i].fd != -1) {
468                         log_err("Load with invalid type accepted for %s",
469                                 progs[i].name);
470                         goto err;
471                 }
472                 printf("... REJECTED\n");
473
474                 printf("Load %s with valid type", progs[i].name);
475                 progs[i].fd = progs[i].loadfn(progs[i].type, progs[i].name);
476                 if (progs[i].fd == -1) {
477                         log_err("Failed to load program %s", progs[i].name);
478                         goto err;
479                 }
480                 printf(" ... OK\n");
481
482                 printf("Attach %s with invalid type", progs[i].name);
483                 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].invalid_type,
484                                     BPF_F_ALLOW_OVERRIDE) != -1) {
485                         log_err("Attach with invalid type accepted for %s",
486                                 progs[i].name);
487                         goto err;
488                 }
489                 printf(" ... REJECTED\n");
490
491                 printf("Attach %s with valid type", progs[i].name);
492                 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].type,
493                                     BPF_F_ALLOW_OVERRIDE) == -1) {
494                         log_err("Failed to attach program %s", progs[i].name);
495                         goto err;
496                 }
497                 printf(" ... OK\n");
498         }
499
500         return 0;
501 err:
502         close_progs_fds(progs, prog_cnt);
503         return -1;
504 }
505
506 static int run_domain_test(int domain, int cgfd, struct program *progs,
507                            size_t prog_cnt, const char *ip, unsigned short port)
508 {
509         int err = 0;
510
511         if (load_and_attach_progs(cgfd, progs, prog_cnt) == -1)
512                 goto err;
513
514         if (run_test_case(domain, SOCK_STREAM, ip, port) == -1)
515                 goto err;
516
517         if (run_test_case(domain, SOCK_DGRAM, ip, port) == -1)
518                 goto err;
519
520         goto out;
521 err:
522         err = -1;
523 out:
524         close_progs_fds(progs, prog_cnt);
525         return err;
526 }
527
528 static int run_test(void)
529 {
530         size_t inet6_prog_cnt;
531         size_t inet_prog_cnt;
532         int cgfd = -1;
533         int err = 0;
534
535         struct program inet6_progs[] = {
536                 {BPF_CGROUP_INET6_BIND, bind6_prog_load, -1, "bind6",
537                  BPF_CGROUP_INET4_BIND},
538                 {BPF_CGROUP_INET6_CONNECT, connect6_prog_load, -1, "connect6",
539                  BPF_CGROUP_INET4_CONNECT},
540         };
541         inet6_prog_cnt = sizeof(inet6_progs) / sizeof(struct program);
542
543         struct program inet_progs[] = {
544                 {BPF_CGROUP_INET4_BIND, bind4_prog_load, -1, "bind4",
545                  BPF_CGROUP_INET6_BIND},
546                 {BPF_CGROUP_INET4_CONNECT, connect4_prog_load, -1, "connect4",
547                  BPF_CGROUP_INET6_CONNECT},
548         };
549         inet_prog_cnt = sizeof(inet_progs) / sizeof(struct program);
550
551         if (setup_cgroup_environment())
552                 goto err;
553
554         cgfd = create_and_get_cgroup(CG_PATH);
555         if (!cgfd)
556                 goto err;
557
558         if (join_cgroup(CG_PATH))
559                 goto err;
560
561         if (run_domain_test(AF_INET, cgfd, inet_progs, inet_prog_cnt, SERV4_IP,
562                             SERV4_PORT) == -1)
563                 goto err;
564
565         if (run_domain_test(AF_INET6, cgfd, inet6_progs, inet6_prog_cnt,
566                             SERV6_IP, SERV6_PORT) == -1)
567                 goto err;
568
569         goto out;
570 err:
571         err = -1;
572 out:
573         close(cgfd);
574         cleanup_cgroup_environment();
575         printf(err ? "### FAIL\n" : "### SUCCESS\n");
576         return err;
577 }
578
579 int main(int argc, char **argv)
580 {
581         if (argc < 2) {
582                 fprintf(stderr,
583                         "%s has to be run via %s.sh. Skip direct run.\n",
584                         argv[0], argv[0]);
585                 exit(0);
586         }
587         return run_test();
588 }