]> asedeno.scripts.mit.edu Git - linux.git/blob - tools/testing/selftests/net/tls.c
6d78bd05081301ff9110f1a38e459d5ec723712e
[linux.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 #ifndef ENOTSUPP
29 #define ENOTSUPP 524
30 #endif
31
32 FIXTURE(tls_basic)
33 {
34         int fd, cfd;
35         bool notls;
36 };
37
38 FIXTURE_SETUP(tls_basic)
39 {
40         struct sockaddr_in addr;
41         socklen_t len;
42         int sfd, ret;
43
44         self->notls = false;
45         len = sizeof(addr);
46
47         addr.sin_family = AF_INET;
48         addr.sin_addr.s_addr = htonl(INADDR_ANY);
49         addr.sin_port = 0;
50
51         self->fd = socket(AF_INET, SOCK_STREAM, 0);
52         sfd = socket(AF_INET, SOCK_STREAM, 0);
53
54         ret = bind(sfd, &addr, sizeof(addr));
55         ASSERT_EQ(ret, 0);
56         ret = listen(sfd, 10);
57         ASSERT_EQ(ret, 0);
58
59         ret = getsockname(sfd, &addr, &len);
60         ASSERT_EQ(ret, 0);
61
62         ret = connect(self->fd, &addr, sizeof(addr));
63         ASSERT_EQ(ret, 0);
64
65         self->cfd = accept(sfd, &addr, &len);
66         ASSERT_GE(self->cfd, 0);
67
68         close(sfd);
69
70         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
71         if (ret != 0) {
72                 ASSERT_EQ(errno, ENOTSUPP);
73                 self->notls = true;
74                 printf("Failure setting TCP_ULP, testing without tls\n");
75                 return;
76         }
77
78         ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
79         ASSERT_EQ(ret, 0);
80 }
81
82 FIXTURE_TEARDOWN(tls_basic)
83 {
84         close(self->fd);
85         close(self->cfd);
86 }
87
88 /* Send some data through with ULP but no keys */
89 TEST_F(tls_basic, base_base)
90 {
91         char const *test_str = "test_read";
92         int send_len = 10;
93         char buf[10];
94
95         ASSERT_EQ(strlen(test_str) + 1, send_len);
96
97         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
98         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
99         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
100 };
101
102 FIXTURE(tls)
103 {
104         int fd, cfd;
105         bool notls;
106 };
107
108 FIXTURE_SETUP(tls)
109 {
110         struct tls12_crypto_info_aes_gcm_128 tls12;
111         struct sockaddr_in addr;
112         socklen_t len;
113         int sfd, ret;
114
115         self->notls = false;
116         len = sizeof(addr);
117
118         memset(&tls12, 0, sizeof(tls12));
119         tls12.info.version = TLS_1_3_VERSION;
120         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
121
122         addr.sin_family = AF_INET;
123         addr.sin_addr.s_addr = htonl(INADDR_ANY);
124         addr.sin_port = 0;
125
126         self->fd = socket(AF_INET, SOCK_STREAM, 0);
127         sfd = socket(AF_INET, SOCK_STREAM, 0);
128
129         ret = bind(sfd, &addr, sizeof(addr));
130         ASSERT_EQ(ret, 0);
131         ret = listen(sfd, 10);
132         ASSERT_EQ(ret, 0);
133
134         ret = getsockname(sfd, &addr, &len);
135         ASSERT_EQ(ret, 0);
136
137         ret = connect(self->fd, &addr, sizeof(addr));
138         ASSERT_EQ(ret, 0);
139
140         ret = setsockopt(self->fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
141         if (ret != 0) {
142                 self->notls = true;
143                 printf("Failure setting TCP_ULP, testing without tls\n");
144         }
145
146         if (!self->notls) {
147                 ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12,
148                                  sizeof(tls12));
149                 ASSERT_EQ(ret, 0);
150         }
151
152         self->cfd = accept(sfd, &addr, &len);
153         ASSERT_GE(self->cfd, 0);
154
155         if (!self->notls) {
156                 ret = setsockopt(self->cfd, IPPROTO_TCP, TCP_ULP, "tls",
157                                  sizeof("tls"));
158                 ASSERT_EQ(ret, 0);
159
160                 ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12,
161                                  sizeof(tls12));
162                 ASSERT_EQ(ret, 0);
163         }
164
165         close(sfd);
166 }
167
168 FIXTURE_TEARDOWN(tls)
169 {
170         close(self->fd);
171         close(self->cfd);
172 }
173
174 TEST_F(tls, sendfile)
175 {
176         int filefd = open("/proc/self/exe", O_RDONLY);
177         struct stat st;
178
179         EXPECT_GE(filefd, 0);
180         fstat(filefd, &st);
181         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
182 }
183
184 TEST_F(tls, send_then_sendfile)
185 {
186         int filefd = open("/proc/self/exe", O_RDONLY);
187         char const *test_str = "test_send";
188         int to_send = strlen(test_str) + 1;
189         char recv_buf[10];
190         struct stat st;
191         char *buf;
192
193         EXPECT_GE(filefd, 0);
194         fstat(filefd, &st);
195         buf = (char *)malloc(st.st_size);
196
197         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
198         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
199         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
200
201         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
202         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
203 }
204
205 TEST_F(tls, recv_max)
206 {
207         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
208         char recv_mem[TLS_PAYLOAD_MAX_LEN];
209         char buf[TLS_PAYLOAD_MAX_LEN];
210
211         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
212         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
213         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
214 }
215
216 TEST_F(tls, recv_small)
217 {
218         char const *test_str = "test_read";
219         int send_len = 10;
220         char buf[10];
221
222         send_len = strlen(test_str) + 1;
223         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
224         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
225         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
226 }
227
228 TEST_F(tls, msg_more)
229 {
230         char const *test_str = "test_read";
231         int send_len = 10;
232         char buf[10 * 2];
233
234         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
235         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
236         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
237         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
238                   send_len * 2);
239         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
240 }
241
242 TEST_F(tls, sendmsg_single)
243 {
244         struct msghdr msg;
245
246         char const *test_str = "test_sendmsg";
247         size_t send_len = 13;
248         struct iovec vec;
249         char buf[13];
250
251         vec.iov_base = (char *)test_str;
252         vec.iov_len = send_len;
253         memset(&msg, 0, sizeof(struct msghdr));
254         msg.msg_iov = &vec;
255         msg.msg_iovlen = 1;
256         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
257         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
258         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
259 }
260
261 TEST_F(tls, sendmsg_large)
262 {
263         void *mem = malloc(16384);
264         size_t send_len = 16384;
265         size_t sends = 128;
266         struct msghdr msg;
267         size_t recvs = 0;
268         size_t sent = 0;
269
270         memset(&msg, 0, sizeof(struct msghdr));
271         while (sent++ < sends) {
272                 struct iovec vec = { (void *)mem, send_len };
273
274                 msg.msg_iov = &vec;
275                 msg.msg_iovlen = 1;
276                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
277         }
278
279         while (recvs++ < sends)
280                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
281
282         free(mem);
283 }
284
285 TEST_F(tls, sendmsg_multiple)
286 {
287         char const *test_str = "test_sendmsg_multiple";
288         struct iovec vec[5];
289         char *test_strs[5];
290         struct msghdr msg;
291         int total_len = 0;
292         int len_cmp = 0;
293         int iov_len = 5;
294         char *buf;
295         int i;
296
297         memset(&msg, 0, sizeof(struct msghdr));
298         for (i = 0; i < iov_len; i++) {
299                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
300                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
301                 vec[i].iov_base = (void *)test_strs[i];
302                 vec[i].iov_len = strlen(test_strs[i]) + 1;
303                 total_len += vec[i].iov_len;
304         }
305         msg.msg_iov = vec;
306         msg.msg_iovlen = iov_len;
307
308         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
309         buf = malloc(total_len);
310         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
311         for (i = 0; i < iov_len; i++) {
312                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
313                                  strlen(test_strs[i])),
314                           0);
315                 len_cmp += strlen(buf + len_cmp) + 1;
316         }
317         for (i = 0; i < iov_len; i++)
318                 free(test_strs[i]);
319         free(buf);
320 }
321
322 TEST_F(tls, sendmsg_multiple_stress)
323 {
324         char const *test_str = "abcdefghijklmno";
325         struct iovec vec[1024];
326         char *test_strs[1024];
327         int iov_len = 1024;
328         int total_len = 0;
329         char buf[1 << 14];
330         struct msghdr msg;
331         int len_cmp = 0;
332         int i;
333
334         memset(&msg, 0, sizeof(struct msghdr));
335         for (i = 0; i < iov_len; i++) {
336                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
337                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
338                 vec[i].iov_base = (void *)test_strs[i];
339                 vec[i].iov_len = strlen(test_strs[i]) + 1;
340                 total_len += vec[i].iov_len;
341         }
342         msg.msg_iov = vec;
343         msg.msg_iovlen = iov_len;
344
345         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
346         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
347
348         for (i = 0; i < iov_len; i++)
349                 len_cmp += strlen(buf + len_cmp) + 1;
350
351         for (i = 0; i < iov_len; i++)
352                 free(test_strs[i]);
353 }
354
355 TEST_F(tls, splice_from_pipe)
356 {
357         int send_len = TLS_PAYLOAD_MAX_LEN;
358         char mem_send[TLS_PAYLOAD_MAX_LEN];
359         char mem_recv[TLS_PAYLOAD_MAX_LEN];
360         int p[2];
361
362         ASSERT_GE(pipe(p), 0);
363         EXPECT_GE(write(p[1], mem_send, send_len), 0);
364         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
365         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
366         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
367 }
368
369 TEST_F(tls, splice_from_pipe2)
370 {
371         int send_len = 16000;
372         char mem_send[16000];
373         char mem_recv[16000];
374         int p2[2];
375         int p[2];
376
377         ASSERT_GE(pipe(p), 0);
378         ASSERT_GE(pipe(p2), 0);
379         EXPECT_GE(write(p[1], mem_send, 8000), 0);
380         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
381         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
382         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
383         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
384         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
385 }
386
387 TEST_F(tls, send_and_splice)
388 {
389         int send_len = TLS_PAYLOAD_MAX_LEN;
390         char mem_send[TLS_PAYLOAD_MAX_LEN];
391         char mem_recv[TLS_PAYLOAD_MAX_LEN];
392         char const *test_str = "test_read";
393         int send_len2 = 10;
394         char buf[10];
395         int p[2];
396
397         ASSERT_GE(pipe(p), 0);
398         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
399         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
400         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
401
402         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
403         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
404
405         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
406         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
407 }
408
409 TEST_F(tls, splice_to_pipe)
410 {
411         int send_len = TLS_PAYLOAD_MAX_LEN;
412         char mem_send[TLS_PAYLOAD_MAX_LEN];
413         char mem_recv[TLS_PAYLOAD_MAX_LEN];
414         int p[2];
415
416         ASSERT_GE(pipe(p), 0);
417         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
418         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
419         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
420         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
421 }
422
423 TEST_F(tls, recvmsg_single)
424 {
425         char const *test_str = "test_recvmsg_single";
426         int send_len = strlen(test_str) + 1;
427         char buf[20];
428         struct msghdr hdr;
429         struct iovec vec;
430
431         memset(&hdr, 0, sizeof(hdr));
432         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
433         vec.iov_base = (char *)buf;
434         vec.iov_len = send_len;
435         hdr.msg_iovlen = 1;
436         hdr.msg_iov = &vec;
437         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
438         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
439 }
440
441 TEST_F(tls, recvmsg_single_max)
442 {
443         int send_len = TLS_PAYLOAD_MAX_LEN;
444         char send_mem[TLS_PAYLOAD_MAX_LEN];
445         char recv_mem[TLS_PAYLOAD_MAX_LEN];
446         struct iovec vec;
447         struct msghdr hdr;
448
449         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
450         vec.iov_base = (char *)recv_mem;
451         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
452
453         hdr.msg_iovlen = 1;
454         hdr.msg_iov = &vec;
455         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
456         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
457 }
458
459 TEST_F(tls, recvmsg_multiple)
460 {
461         unsigned int msg_iovlen = 1024;
462         unsigned int len_compared = 0;
463         struct iovec vec[1024];
464         char *iov_base[1024];
465         unsigned int iov_len = 16;
466         int send_len = 1 << 14;
467         char buf[1 << 14];
468         struct msghdr hdr;
469         int i;
470
471         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
472         for (i = 0; i < msg_iovlen; i++) {
473                 iov_base[i] = (char *)malloc(iov_len);
474                 vec[i].iov_base = iov_base[i];
475                 vec[i].iov_len = iov_len;
476         }
477
478         hdr.msg_iovlen = msg_iovlen;
479         hdr.msg_iov = vec;
480         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
481         for (i = 0; i < msg_iovlen; i++)
482                 len_compared += iov_len;
483
484         for (i = 0; i < msg_iovlen; i++)
485                 free(iov_base[i]);
486 }
487
488 TEST_F(tls, single_send_multiple_recv)
489 {
490         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
491         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
492         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
493         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
494
495         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
496         memset(recv_mem, 0, total_len);
497
498         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
499         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
500         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
501 }
502
503 TEST_F(tls, multiple_send_single_recv)
504 {
505         unsigned int total_len = 2 * 10;
506         unsigned int send_len = 10;
507         char recv_mem[2 * 10];
508         char send_mem[10];
509
510         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
511         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
512         memset(recv_mem, 0, total_len);
513         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
514
515         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
516         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
517 }
518
519 TEST_F(tls, single_send_multiple_recv_non_align)
520 {
521         const unsigned int total_len = 15;
522         const unsigned int recv_len = 10;
523         char recv_mem[recv_len * 2];
524         char send_mem[total_len];
525
526         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
527         memset(recv_mem, 0, total_len);
528
529         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
530         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
531         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
532 }
533
534 TEST_F(tls, recv_partial)
535 {
536         char const *test_str = "test_read_partial";
537         char const *test_str_first = "test_read";
538         char const *test_str_second = "_partial";
539         int send_len = strlen(test_str) + 1;
540         char recv_mem[18];
541
542         memset(recv_mem, 0, sizeof(recv_mem));
543         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
544         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
545                        MSG_WAITALL), -1);
546         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
547         memset(recv_mem, 0, sizeof(recv_mem));
548         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
549                        MSG_WAITALL), -1);
550         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
551                   0);
552 }
553
554 TEST_F(tls, recv_nonblock)
555 {
556         char buf[4096];
557         bool err;
558
559         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
560         err = (errno == EAGAIN || errno == EWOULDBLOCK);
561         EXPECT_EQ(err, true);
562 }
563
564 TEST_F(tls, recv_peek)
565 {
566         char const *test_str = "test_read_peek";
567         int send_len = strlen(test_str) + 1;
568         char buf[15];
569
570         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
571         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
572         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
573         memset(buf, 0, sizeof(buf));
574         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
575         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
576 }
577
578 TEST_F(tls, recv_peek_multiple)
579 {
580         char const *test_str = "test_read_peek";
581         int send_len = strlen(test_str) + 1;
582         unsigned int num_peeks = 100;
583         char buf[15];
584         int i;
585
586         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
587         for (i = 0; i < num_peeks; i++) {
588                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
589                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
590                 memset(buf, 0, sizeof(buf));
591         }
592         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
593         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
594 }
595
596 TEST_F(tls, recv_peek_multiple_records)
597 {
598         char const *test_str = "test_read_peek_mult_recs";
599         char const *test_str_first = "test_read_peek";
600         char const *test_str_second = "_mult_recs";
601         int len;
602         char buf[64];
603
604         len = strlen(test_str_first);
605         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
606
607         len = strlen(test_str_second) + 1;
608         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
609
610         len = strlen(test_str_first);
611         memset(buf, 0, len);
612         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
613
614         /* MSG_PEEK can only peek into the current record. */
615         len = strlen(test_str_first);
616         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
617
618         len = strlen(test_str) + 1;
619         memset(buf, 0, len);
620         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
621
622         /* Non-MSG_PEEK will advance strparser (and therefore record)
623          * however.
624          */
625         len = strlen(test_str) + 1;
626         EXPECT_EQ(memcmp(test_str, buf, len), 0);
627
628         /* MSG_MORE will hold current record open, so later MSG_PEEK
629          * will see everything.
630          */
631         len = strlen(test_str_first);
632         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
633
634         len = strlen(test_str_second) + 1;
635         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
636
637         len = strlen(test_str) + 1;
638         memset(buf, 0, len);
639         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
640
641         len = strlen(test_str) + 1;
642         EXPECT_EQ(memcmp(test_str, buf, len), 0);
643 }
644
645 TEST_F(tls, recv_peek_large_buf_mult_recs)
646 {
647         char const *test_str = "test_read_peek_mult_recs";
648         char const *test_str_first = "test_read_peek";
649         char const *test_str_second = "_mult_recs";
650         int len;
651         char buf[64];
652
653         len = strlen(test_str_first);
654         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
655
656         len = strlen(test_str_second) + 1;
657         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
658
659         len = strlen(test_str) + 1;
660         memset(buf, 0, len);
661         EXPECT_NE((len = recv(self->cfd, buf, len,
662                               MSG_PEEK | MSG_WAITALL)), -1);
663         len = strlen(test_str) + 1;
664         EXPECT_EQ(memcmp(test_str, buf, len), 0);
665 }
666
667 TEST_F(tls, recv_lowat)
668 {
669         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
670         char recv_mem[20];
671         int lowat = 8;
672
673         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
674         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
675
676         memset(recv_mem, 0, 20);
677         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
678                              &lowat, sizeof(lowat)), 0);
679         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
680         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
681         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
682
683         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
684         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
685 }
686
687 TEST_F(tls, bidir)
688 {
689         struct tls12_crypto_info_aes_gcm_128 tls12;
690         char const *test_str = "test_read";
691         int send_len = 10;
692         char buf[10];
693         int ret;
694
695         memset(&tls12, 0, sizeof(tls12));
696         tls12.info.version = TLS_1_3_VERSION;
697         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
698
699         ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12, sizeof(tls12));
700         ASSERT_EQ(ret, 0);
701
702         ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12, sizeof(tls12));
703         ASSERT_EQ(ret, 0);
704
705         ASSERT_EQ(strlen(test_str) + 1, send_len);
706
707         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
708         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
709         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
710
711         memset(buf, 0, sizeof(buf));
712
713         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
714         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
715         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
716 };
717
718 TEST_F(tls, pollin)
719 {
720         char const *test_str = "test_poll";
721         struct pollfd fd = { 0, 0, 0 };
722         char buf[10];
723         int send_len = 10;
724
725         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
726         fd.fd = self->cfd;
727         fd.events = POLLIN;
728
729         EXPECT_EQ(poll(&fd, 1, 20), 1);
730         EXPECT_EQ(fd.revents & POLLIN, 1);
731         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
732         /* Test timing out */
733         EXPECT_EQ(poll(&fd, 1, 20), 0);
734 }
735
736 TEST_F(tls, poll_wait)
737 {
738         char const *test_str = "test_poll_wait";
739         int send_len = strlen(test_str) + 1;
740         struct pollfd fd = { 0, 0, 0 };
741         char recv_mem[15];
742
743         fd.fd = self->cfd;
744         fd.events = POLLIN;
745         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
746         /* Set timeout to inf. secs */
747         EXPECT_EQ(poll(&fd, 1, -1), 1);
748         EXPECT_EQ(fd.revents & POLLIN, 1);
749         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
750 }
751
752 TEST_F(tls, poll_wait_split)
753 {
754         struct pollfd fd = { 0, 0, 0 };
755         char send_mem[20] = {};
756         char recv_mem[15];
757
758         fd.fd = self->cfd;
759         fd.events = POLLIN;
760         /* Send 20 bytes */
761         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
762                   sizeof(send_mem));
763         /* Poll with inf. timeout */
764         EXPECT_EQ(poll(&fd, 1, -1), 1);
765         EXPECT_EQ(fd.revents & POLLIN, 1);
766         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
767                   sizeof(recv_mem));
768
769         /* Now the remaining 5 bytes of record data are in TLS ULP */
770         fd.fd = self->cfd;
771         fd.events = POLLIN;
772         EXPECT_EQ(poll(&fd, 1, -1), 1);
773         EXPECT_EQ(fd.revents & POLLIN, 1);
774         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
775                   sizeof(send_mem) - sizeof(recv_mem));
776 }
777
778 TEST_F(tls, blocking)
779 {
780         size_t data = 100000;
781         int res = fork();
782
783         EXPECT_NE(res, -1);
784
785         if (res) {
786                 /* parent */
787                 size_t left = data;
788                 char buf[16384];
789                 int status;
790                 int pid2;
791
792                 while (left) {
793                         int res = send(self->fd, buf,
794                                        left > 16384 ? 16384 : left, 0);
795
796                         EXPECT_GE(res, 0);
797                         left -= res;
798                 }
799
800                 pid2 = wait(&status);
801                 EXPECT_EQ(status, 0);
802                 EXPECT_EQ(res, pid2);
803         } else {
804                 /* child */
805                 size_t left = data;
806                 char buf[16384];
807
808                 while (left) {
809                         int res = recv(self->cfd, buf,
810                                        left > 16384 ? 16384 : left, 0);
811
812                         EXPECT_GE(res, 0);
813                         left -= res;
814                 }
815         }
816 }
817
818 TEST_F(tls, nonblocking)
819 {
820         size_t data = 100000;
821         int sendbuf = 100;
822         int flags;
823         int res;
824
825         flags = fcntl(self->fd, F_GETFL, 0);
826         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
827         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
828
829         /* Ensure nonblocking behavior by imposing a small send
830          * buffer.
831          */
832         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
833                              &sendbuf, sizeof(sendbuf)), 0);
834
835         res = fork();
836         EXPECT_NE(res, -1);
837
838         if (res) {
839                 /* parent */
840                 bool eagain = false;
841                 size_t left = data;
842                 char buf[16384];
843                 int status;
844                 int pid2;
845
846                 while (left) {
847                         int res = send(self->fd, buf,
848                                        left > 16384 ? 16384 : left, 0);
849
850                         if (res == -1 && errno == EAGAIN) {
851                                 eagain = true;
852                                 usleep(10000);
853                                 continue;
854                         }
855                         EXPECT_GE(res, 0);
856                         left -= res;
857                 }
858
859                 EXPECT_TRUE(eagain);
860                 pid2 = wait(&status);
861
862                 EXPECT_EQ(status, 0);
863                 EXPECT_EQ(res, pid2);
864         } else {
865                 /* child */
866                 bool eagain = false;
867                 size_t left = data;
868                 char buf[16384];
869
870                 while (left) {
871                         int res = recv(self->cfd, buf,
872                                        left > 16384 ? 16384 : left, 0);
873
874                         if (res == -1 && errno == EAGAIN) {
875                                 eagain = true;
876                                 usleep(10000);
877                                 continue;
878                         }
879                         EXPECT_GE(res, 0);
880                         left -= res;
881                 }
882                 EXPECT_TRUE(eagain);
883         }
884 }
885
886 TEST_F(tls, control_msg)
887 {
888         if (self->notls)
889                 return;
890
891         char cbuf[CMSG_SPACE(sizeof(char))];
892         char const *test_str = "test_read";
893         int cmsg_len = sizeof(char);
894         char record_type = 100;
895         struct cmsghdr *cmsg;
896         struct msghdr msg;
897         int send_len = 10;
898         struct iovec vec;
899         char buf[10];
900
901         vec.iov_base = (char *)test_str;
902         vec.iov_len = 10;
903         memset(&msg, 0, sizeof(struct msghdr));
904         msg.msg_iov = &vec;
905         msg.msg_iovlen = 1;
906         msg.msg_control = cbuf;
907         msg.msg_controllen = sizeof(cbuf);
908         cmsg = CMSG_FIRSTHDR(&msg);
909         cmsg->cmsg_level = SOL_TLS;
910         /* test sending non-record types. */
911         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
912         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
913         *CMSG_DATA(cmsg) = record_type;
914         msg.msg_controllen = cmsg->cmsg_len;
915
916         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
917         /* Should fail because we didn't provide a control message */
918         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
919
920         vec.iov_base = buf;
921         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL | MSG_PEEK), send_len);
922
923         cmsg = CMSG_FIRSTHDR(&msg);
924         EXPECT_NE(cmsg, NULL);
925         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
926         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
927         record_type = *((unsigned char *)CMSG_DATA(cmsg));
928         EXPECT_EQ(record_type, 100);
929         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
930
931         /* Recv the message again without MSG_PEEK */
932         record_type = 0;
933         memset(buf, 0, sizeof(buf));
934
935         EXPECT_EQ(recvmsg(self->cfd, &msg, MSG_WAITALL), send_len);
936         cmsg = CMSG_FIRSTHDR(&msg);
937         EXPECT_NE(cmsg, NULL);
938         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
939         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
940         record_type = *((unsigned char *)CMSG_DATA(cmsg));
941         EXPECT_EQ(record_type, 100);
942         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
943 }
944
945 TEST(non_established) {
946         struct tls12_crypto_info_aes_gcm_256 tls12;
947         struct sockaddr_in addr;
948         int sfd, ret, fd;
949         socklen_t len;
950
951         len = sizeof(addr);
952
953         memset(&tls12, 0, sizeof(tls12));
954         tls12.info.version = TLS_1_2_VERSION;
955         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
956
957         addr.sin_family = AF_INET;
958         addr.sin_addr.s_addr = htonl(INADDR_ANY);
959         addr.sin_port = 0;
960
961         fd = socket(AF_INET, SOCK_STREAM, 0);
962         sfd = socket(AF_INET, SOCK_STREAM, 0);
963
964         ret = bind(sfd, &addr, sizeof(addr));
965         ASSERT_EQ(ret, 0);
966         ret = listen(sfd, 10);
967         ASSERT_EQ(ret, 0);
968
969         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
970         EXPECT_EQ(ret, -1);
971         /* TLS ULP not supported */
972         if (errno == ENOENT)
973                 return;
974         EXPECT_EQ(errno, ENOTSUPP);
975
976         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
977         EXPECT_EQ(ret, -1);
978         EXPECT_EQ(errno, ENOTSUPP);
979
980         ret = getsockname(sfd, &addr, &len);
981         ASSERT_EQ(ret, 0);
982
983         ret = connect(fd, &addr, sizeof(addr));
984         ASSERT_EQ(ret, 0);
985
986         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
987         ASSERT_EQ(ret, 0);
988
989         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
990         EXPECT_EQ(ret, -1);
991         EXPECT_EQ(errno, EEXIST);
992
993         close(fd);
994         close(sfd);
995 }
996
997 TEST(keysizes) {
998         struct tls12_crypto_info_aes_gcm_256 tls12;
999         struct sockaddr_in addr;
1000         int sfd, ret, fd, cfd;
1001         socklen_t len;
1002         bool notls;
1003
1004         notls = false;
1005         len = sizeof(addr);
1006
1007         memset(&tls12, 0, sizeof(tls12));
1008         tls12.info.version = TLS_1_2_VERSION;
1009         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1010
1011         addr.sin_family = AF_INET;
1012         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1013         addr.sin_port = 0;
1014
1015         fd = socket(AF_INET, SOCK_STREAM, 0);
1016         sfd = socket(AF_INET, SOCK_STREAM, 0);
1017
1018         ret = bind(sfd, &addr, sizeof(addr));
1019         ASSERT_EQ(ret, 0);
1020         ret = listen(sfd, 10);
1021         ASSERT_EQ(ret, 0);
1022
1023         ret = getsockname(sfd, &addr, &len);
1024         ASSERT_EQ(ret, 0);
1025
1026         ret = connect(fd, &addr, sizeof(addr));
1027         ASSERT_EQ(ret, 0);
1028
1029         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1030         if (ret != 0) {
1031                 notls = true;
1032                 printf("Failure setting TCP_ULP, testing without tls\n");
1033         }
1034
1035         if (!notls) {
1036                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1037                                  sizeof(tls12));
1038                 EXPECT_EQ(ret, 0);
1039         }
1040
1041         cfd = accept(sfd, &addr, &len);
1042         ASSERT_GE(cfd, 0);
1043
1044         if (!notls) {
1045                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
1046                                  sizeof("tls"));
1047                 EXPECT_EQ(ret, 0);
1048
1049                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1050                                  sizeof(tls12));
1051                 EXPECT_EQ(ret, 0);
1052         }
1053
1054         close(sfd);
1055         close(fd);
1056         close(cfd);
1057 }
1058
1059 TEST(tls12) {
1060         int fd, cfd;
1061         bool notls;
1062
1063         struct tls12_crypto_info_aes_gcm_128 tls12;
1064         struct sockaddr_in addr;
1065         socklen_t len;
1066         int sfd, ret;
1067
1068         notls = false;
1069         len = sizeof(addr);
1070
1071         memset(&tls12, 0, sizeof(tls12));
1072         tls12.info.version = TLS_1_2_VERSION;
1073         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_128;
1074
1075         addr.sin_family = AF_INET;
1076         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1077         addr.sin_port = 0;
1078
1079         fd = socket(AF_INET, SOCK_STREAM, 0);
1080         sfd = socket(AF_INET, SOCK_STREAM, 0);
1081
1082         ret = bind(sfd, &addr, sizeof(addr));
1083         ASSERT_EQ(ret, 0);
1084         ret = listen(sfd, 10);
1085         ASSERT_EQ(ret, 0);
1086
1087         ret = getsockname(sfd, &addr, &len);
1088         ASSERT_EQ(ret, 0);
1089
1090         ret = connect(fd, &addr, sizeof(addr));
1091         ASSERT_EQ(ret, 0);
1092
1093         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1094         if (ret != 0) {
1095                 notls = true;
1096                 printf("Failure setting TCP_ULP, testing without tls\n");
1097         }
1098
1099         if (!notls) {
1100                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1101                                  sizeof(tls12));
1102                 ASSERT_EQ(ret, 0);
1103         }
1104
1105         cfd = accept(sfd, &addr, &len);
1106         ASSERT_GE(cfd, 0);
1107
1108         if (!notls) {
1109                 ret = setsockopt(cfd, IPPROTO_TCP, TCP_ULP, "tls",
1110                                  sizeof("tls"));
1111                 ASSERT_EQ(ret, 0);
1112
1113                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1114                                  sizeof(tls12));
1115                 ASSERT_EQ(ret, 0);
1116         }
1117
1118         close(sfd);
1119
1120         char const *test_str = "test_read";
1121         int send_len = 10;
1122         char buf[10];
1123
1124         send_len = strlen(test_str) + 1;
1125         EXPECT_EQ(send(fd, test_str, send_len, 0), send_len);
1126         EXPECT_NE(recv(cfd, buf, send_len, 0), -1);
1127         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1128
1129         close(fd);
1130         close(cfd);
1131 }
1132
1133 TEST_HARNESS_MAIN