]> asedeno.scripts.mit.edu Git - linux.git/blob - net/netfilter/nft_ct.c
Merge tag 'nfsd-5.2-2' of git://linux-nfs.org/~bfields/linux
[linux.git] / net / netfilter / nft_ct.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (c) 2008-2009 Patrick McHardy <kaber@trash.net>
4  * Copyright (c) 2016 Pablo Neira Ayuso <pablo@netfilter.org>
5  *
6  * Development of this code funded by Astaro AG (http://www.astaro.com/)
7  */
8
9 #include <linux/kernel.h>
10 #include <linux/init.h>
11 #include <linux/module.h>
12 #include <linux/netlink.h>
13 #include <linux/netfilter.h>
14 #include <linux/netfilter/nf_tables.h>
15 #include <net/netfilter/nf_tables.h>
16 #include <net/netfilter/nf_conntrack.h>
17 #include <net/netfilter/nf_conntrack_acct.h>
18 #include <net/netfilter/nf_conntrack_tuple.h>
19 #include <net/netfilter/nf_conntrack_helper.h>
20 #include <net/netfilter/nf_conntrack_ecache.h>
21 #include <net/netfilter/nf_conntrack_labels.h>
22 #include <net/netfilter/nf_conntrack_timeout.h>
23 #include <net/netfilter/nf_conntrack_l4proto.h>
24
25 struct nft_ct {
26         enum nft_ct_keys        key:8;
27         enum ip_conntrack_dir   dir:8;
28         union {
29                 enum nft_registers      dreg:8;
30                 enum nft_registers      sreg:8;
31         };
32 };
33
34 struct nft_ct_helper_obj  {
35         struct nf_conntrack_helper *helper4;
36         struct nf_conntrack_helper *helper6;
37         u8 l4proto;
38 };
39
40 #ifdef CONFIG_NF_CONNTRACK_ZONES
41 static DEFINE_PER_CPU(struct nf_conn *, nft_ct_pcpu_template);
42 static unsigned int nft_ct_pcpu_template_refcnt __read_mostly;
43 #endif
44
45 static u64 nft_ct_get_eval_counter(const struct nf_conn_counter *c,
46                                    enum nft_ct_keys k,
47                                    enum ip_conntrack_dir d)
48 {
49         if (d < IP_CT_DIR_MAX)
50                 return k == NFT_CT_BYTES ? atomic64_read(&c[d].bytes) :
51                                            atomic64_read(&c[d].packets);
52
53         return nft_ct_get_eval_counter(c, k, IP_CT_DIR_ORIGINAL) +
54                nft_ct_get_eval_counter(c, k, IP_CT_DIR_REPLY);
55 }
56
57 static void nft_ct_get_eval(const struct nft_expr *expr,
58                             struct nft_regs *regs,
59                             const struct nft_pktinfo *pkt)
60 {
61         const struct nft_ct *priv = nft_expr_priv(expr);
62         u32 *dest = &regs->data[priv->dreg];
63         enum ip_conntrack_info ctinfo;
64         const struct nf_conn *ct;
65         const struct nf_conn_help *help;
66         const struct nf_conntrack_tuple *tuple;
67         const struct nf_conntrack_helper *helper;
68         unsigned int state;
69
70         ct = nf_ct_get(pkt->skb, &ctinfo);
71
72         switch (priv->key) {
73         case NFT_CT_STATE:
74                 if (ct)
75                         state = NF_CT_STATE_BIT(ctinfo);
76                 else if (ctinfo == IP_CT_UNTRACKED)
77                         state = NF_CT_STATE_UNTRACKED_BIT;
78                 else
79                         state = NF_CT_STATE_INVALID_BIT;
80                 *dest = state;
81                 return;
82         default:
83                 break;
84         }
85
86         if (ct == NULL)
87                 goto err;
88
89         switch (priv->key) {
90         case NFT_CT_DIRECTION:
91                 nft_reg_store8(dest, CTINFO2DIR(ctinfo));
92                 return;
93         case NFT_CT_STATUS:
94                 *dest = ct->status;
95                 return;
96 #ifdef CONFIG_NF_CONNTRACK_MARK
97         case NFT_CT_MARK:
98                 *dest = ct->mark;
99                 return;
100 #endif
101 #ifdef CONFIG_NF_CONNTRACK_SECMARK
102         case NFT_CT_SECMARK:
103                 *dest = ct->secmark;
104                 return;
105 #endif
106         case NFT_CT_EXPIRATION:
107                 *dest = jiffies_to_msecs(nf_ct_expires(ct));
108                 return;
109         case NFT_CT_HELPER:
110                 if (ct->master == NULL)
111                         goto err;
112                 help = nfct_help(ct->master);
113                 if (help == NULL)
114                         goto err;
115                 helper = rcu_dereference(help->helper);
116                 if (helper == NULL)
117                         goto err;
118                 strncpy((char *)dest, helper->name, NF_CT_HELPER_NAME_LEN);
119                 return;
120 #ifdef CONFIG_NF_CONNTRACK_LABELS
121         case NFT_CT_LABELS: {
122                 struct nf_conn_labels *labels = nf_ct_labels_find(ct);
123
124                 if (labels)
125                         memcpy(dest, labels->bits, NF_CT_LABELS_MAX_SIZE);
126                 else
127                         memset(dest, 0, NF_CT_LABELS_MAX_SIZE);
128                 return;
129         }
130 #endif
131         case NFT_CT_BYTES: /* fallthrough */
132         case NFT_CT_PKTS: {
133                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
134                 u64 count = 0;
135
136                 if (acct)
137                         count = nft_ct_get_eval_counter(acct->counter,
138                                                         priv->key, priv->dir);
139                 memcpy(dest, &count, sizeof(count));
140                 return;
141         }
142         case NFT_CT_AVGPKT: {
143                 const struct nf_conn_acct *acct = nf_conn_acct_find(ct);
144                 u64 avgcnt = 0, bcnt = 0, pcnt = 0;
145
146                 if (acct) {
147                         pcnt = nft_ct_get_eval_counter(acct->counter,
148                                                        NFT_CT_PKTS, priv->dir);
149                         bcnt = nft_ct_get_eval_counter(acct->counter,
150                                                        NFT_CT_BYTES, priv->dir);
151                         if (pcnt != 0)
152                                 avgcnt = div64_u64(bcnt, pcnt);
153                 }
154
155                 memcpy(dest, &avgcnt, sizeof(avgcnt));
156                 return;
157         }
158         case NFT_CT_L3PROTOCOL:
159                 nft_reg_store8(dest, nf_ct_l3num(ct));
160                 return;
161         case NFT_CT_PROTOCOL:
162                 nft_reg_store8(dest, nf_ct_protonum(ct));
163                 return;
164 #ifdef CONFIG_NF_CONNTRACK_ZONES
165         case NFT_CT_ZONE: {
166                 const struct nf_conntrack_zone *zone = nf_ct_zone(ct);
167                 u16 zoneid;
168
169                 if (priv->dir < IP_CT_DIR_MAX)
170                         zoneid = nf_ct_zone_id(zone, priv->dir);
171                 else
172                         zoneid = zone->id;
173
174                 nft_reg_store16(dest, zoneid);
175                 return;
176         }
177 #endif
178         case NFT_CT_ID:
179                 if (!nf_ct_is_confirmed(ct))
180                         goto err;
181                 *dest = nf_ct_get_id(ct);
182                 return;
183         default:
184                 break;
185         }
186
187         tuple = &ct->tuplehash[priv->dir].tuple;
188         switch (priv->key) {
189         case NFT_CT_SRC:
190                 memcpy(dest, tuple->src.u3.all,
191                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
192                 return;
193         case NFT_CT_DST:
194                 memcpy(dest, tuple->dst.u3.all,
195                        nf_ct_l3num(ct) == NFPROTO_IPV4 ? 4 : 16);
196                 return;
197         case NFT_CT_PROTO_SRC:
198                 nft_reg_store16(dest, (__force u16)tuple->src.u.all);
199                 return;
200         case NFT_CT_PROTO_DST:
201                 nft_reg_store16(dest, (__force u16)tuple->dst.u.all);
202                 return;
203         case NFT_CT_SRC_IP:
204                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
205                         goto err;
206                 *dest = tuple->src.u3.ip;
207                 return;
208         case NFT_CT_DST_IP:
209                 if (nf_ct_l3num(ct) != NFPROTO_IPV4)
210                         goto err;
211                 *dest = tuple->dst.u3.ip;
212                 return;
213         case NFT_CT_SRC_IP6:
214                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
215                         goto err;
216                 memcpy(dest, tuple->src.u3.ip6, sizeof(struct in6_addr));
217                 return;
218         case NFT_CT_DST_IP6:
219                 if (nf_ct_l3num(ct) != NFPROTO_IPV6)
220                         goto err;
221                 memcpy(dest, tuple->dst.u3.ip6, sizeof(struct in6_addr));
222                 return;
223         default:
224                 break;
225         }
226         return;
227 err:
228         regs->verdict.code = NFT_BREAK;
229 }
230
231 #ifdef CONFIG_NF_CONNTRACK_ZONES
232 static void nft_ct_set_zone_eval(const struct nft_expr *expr,
233                                  struct nft_regs *regs,
234                                  const struct nft_pktinfo *pkt)
235 {
236         struct nf_conntrack_zone zone = { .dir = NF_CT_DEFAULT_ZONE_DIR };
237         const struct nft_ct *priv = nft_expr_priv(expr);
238         struct sk_buff *skb = pkt->skb;
239         enum ip_conntrack_info ctinfo;
240         u16 value = nft_reg_load16(&regs->data[priv->sreg]);
241         struct nf_conn *ct;
242
243         ct = nf_ct_get(skb, &ctinfo);
244         if (ct) /* already tracked */
245                 return;
246
247         zone.id = value;
248
249         switch (priv->dir) {
250         case IP_CT_DIR_ORIGINAL:
251                 zone.dir = NF_CT_ZONE_DIR_ORIG;
252                 break;
253         case IP_CT_DIR_REPLY:
254                 zone.dir = NF_CT_ZONE_DIR_REPL;
255                 break;
256         default:
257                 break;
258         }
259
260         ct = this_cpu_read(nft_ct_pcpu_template);
261
262         if (likely(atomic_read(&ct->ct_general.use) == 1)) {
263                 nf_ct_zone_add(ct, &zone);
264         } else {
265                 /* previous skb got queued to userspace */
266                 ct = nf_ct_tmpl_alloc(nft_net(pkt), &zone, GFP_ATOMIC);
267                 if (!ct) {
268                         regs->verdict.code = NF_DROP;
269                         return;
270                 }
271         }
272
273         atomic_inc(&ct->ct_general.use);
274         nf_ct_set(skb, ct, IP_CT_NEW);
275 }
276 #endif
277
278 static void nft_ct_set_eval(const struct nft_expr *expr,
279                             struct nft_regs *regs,
280                             const struct nft_pktinfo *pkt)
281 {
282         const struct nft_ct *priv = nft_expr_priv(expr);
283         struct sk_buff *skb = pkt->skb;
284 #if defined(CONFIG_NF_CONNTRACK_MARK) || defined(CONFIG_NF_CONNTRACK_SECMARK)
285         u32 value = regs->data[priv->sreg];
286 #endif
287         enum ip_conntrack_info ctinfo;
288         struct nf_conn *ct;
289
290         ct = nf_ct_get(skb, &ctinfo);
291         if (ct == NULL || nf_ct_is_template(ct))
292                 return;
293
294         switch (priv->key) {
295 #ifdef CONFIG_NF_CONNTRACK_MARK
296         case NFT_CT_MARK:
297                 if (ct->mark != value) {
298                         ct->mark = value;
299                         nf_conntrack_event_cache(IPCT_MARK, ct);
300                 }
301                 break;
302 #endif
303 #ifdef CONFIG_NF_CONNTRACK_SECMARK
304         case NFT_CT_SECMARK:
305                 if (ct->secmark != value) {
306                         ct->secmark = value;
307                         nf_conntrack_event_cache(IPCT_SECMARK, ct);
308                 }
309                 break;
310 #endif
311 #ifdef CONFIG_NF_CONNTRACK_LABELS
312         case NFT_CT_LABELS:
313                 nf_connlabels_replace(ct,
314                                       &regs->data[priv->sreg],
315                                       &regs->data[priv->sreg],
316                                       NF_CT_LABELS_MAX_SIZE / sizeof(u32));
317                 break;
318 #endif
319 #ifdef CONFIG_NF_CONNTRACK_EVENTS
320         case NFT_CT_EVENTMASK: {
321                 struct nf_conntrack_ecache *e = nf_ct_ecache_find(ct);
322                 u32 ctmask = regs->data[priv->sreg];
323
324                 if (e) {
325                         if (e->ctmask != ctmask)
326                                 e->ctmask = ctmask;
327                         break;
328                 }
329
330                 if (ctmask && !nf_ct_is_confirmed(ct))
331                         nf_ct_ecache_ext_add(ct, ctmask, 0, GFP_ATOMIC);
332                 break;
333         }
334 #endif
335         default:
336                 break;
337         }
338 }
339
340 static const struct nla_policy nft_ct_policy[NFTA_CT_MAX + 1] = {
341         [NFTA_CT_DREG]          = { .type = NLA_U32 },
342         [NFTA_CT_KEY]           = { .type = NLA_U32 },
343         [NFTA_CT_DIRECTION]     = { .type = NLA_U8 },
344         [NFTA_CT_SREG]          = { .type = NLA_U32 },
345 };
346
347 #ifdef CONFIG_NF_CONNTRACK_ZONES
348 static void nft_ct_tmpl_put_pcpu(void)
349 {
350         struct nf_conn *ct;
351         int cpu;
352
353         for_each_possible_cpu(cpu) {
354                 ct = per_cpu(nft_ct_pcpu_template, cpu);
355                 if (!ct)
356                         break;
357                 nf_ct_put(ct);
358                 per_cpu(nft_ct_pcpu_template, cpu) = NULL;
359         }
360 }
361
362 static bool nft_ct_tmpl_alloc_pcpu(void)
363 {
364         struct nf_conntrack_zone zone = { .id = 0 };
365         struct nf_conn *tmp;
366         int cpu;
367
368         if (nft_ct_pcpu_template_refcnt)
369                 return true;
370
371         for_each_possible_cpu(cpu) {
372                 tmp = nf_ct_tmpl_alloc(&init_net, &zone, GFP_KERNEL);
373                 if (!tmp) {
374                         nft_ct_tmpl_put_pcpu();
375                         return false;
376                 }
377
378                 atomic_set(&tmp->ct_general.use, 1);
379                 per_cpu(nft_ct_pcpu_template, cpu) = tmp;
380         }
381
382         return true;
383 }
384 #endif
385
386 static int nft_ct_get_init(const struct nft_ctx *ctx,
387                            const struct nft_expr *expr,
388                            const struct nlattr * const tb[])
389 {
390         struct nft_ct *priv = nft_expr_priv(expr);
391         unsigned int len;
392         int err;
393
394         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
395         priv->dir = IP_CT_DIR_MAX;
396         switch (priv->key) {
397         case NFT_CT_DIRECTION:
398                 if (tb[NFTA_CT_DIRECTION] != NULL)
399                         return -EINVAL;
400                 len = sizeof(u8);
401                 break;
402         case NFT_CT_STATE:
403         case NFT_CT_STATUS:
404 #ifdef CONFIG_NF_CONNTRACK_MARK
405         case NFT_CT_MARK:
406 #endif
407 #ifdef CONFIG_NF_CONNTRACK_SECMARK
408         case NFT_CT_SECMARK:
409 #endif
410         case NFT_CT_EXPIRATION:
411                 if (tb[NFTA_CT_DIRECTION] != NULL)
412                         return -EINVAL;
413                 len = sizeof(u32);
414                 break;
415 #ifdef CONFIG_NF_CONNTRACK_LABELS
416         case NFT_CT_LABELS:
417                 if (tb[NFTA_CT_DIRECTION] != NULL)
418                         return -EINVAL;
419                 len = NF_CT_LABELS_MAX_SIZE;
420                 break;
421 #endif
422         case NFT_CT_HELPER:
423                 if (tb[NFTA_CT_DIRECTION] != NULL)
424                         return -EINVAL;
425                 len = NF_CT_HELPER_NAME_LEN;
426                 break;
427
428         case NFT_CT_L3PROTOCOL:
429         case NFT_CT_PROTOCOL:
430                 /* For compatibility, do not report error if NFTA_CT_DIRECTION
431                  * attribute is specified.
432                  */
433                 len = sizeof(u8);
434                 break;
435         case NFT_CT_SRC:
436         case NFT_CT_DST:
437                 if (tb[NFTA_CT_DIRECTION] == NULL)
438                         return -EINVAL;
439
440                 switch (ctx->family) {
441                 case NFPROTO_IPV4:
442                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
443                                            src.u3.ip);
444                         break;
445                 case NFPROTO_IPV6:
446                 case NFPROTO_INET:
447                         len = FIELD_SIZEOF(struct nf_conntrack_tuple,
448                                            src.u3.ip6);
449                         break;
450                 default:
451                         return -EAFNOSUPPORT;
452                 }
453                 break;
454         case NFT_CT_SRC_IP:
455         case NFT_CT_DST_IP:
456                 if (tb[NFTA_CT_DIRECTION] == NULL)
457                         return -EINVAL;
458
459                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip);
460                 break;
461         case NFT_CT_SRC_IP6:
462         case NFT_CT_DST_IP6:
463                 if (tb[NFTA_CT_DIRECTION] == NULL)
464                         return -EINVAL;
465
466                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u3.ip6);
467                 break;
468         case NFT_CT_PROTO_SRC:
469         case NFT_CT_PROTO_DST:
470                 if (tb[NFTA_CT_DIRECTION] == NULL)
471                         return -EINVAL;
472                 len = FIELD_SIZEOF(struct nf_conntrack_tuple, src.u.all);
473                 break;
474         case NFT_CT_BYTES:
475         case NFT_CT_PKTS:
476         case NFT_CT_AVGPKT:
477                 len = sizeof(u64);
478                 break;
479 #ifdef CONFIG_NF_CONNTRACK_ZONES
480         case NFT_CT_ZONE:
481                 len = sizeof(u16);
482                 break;
483 #endif
484         case NFT_CT_ID:
485                 len = sizeof(u32);
486                 break;
487         default:
488                 return -EOPNOTSUPP;
489         }
490
491         if (tb[NFTA_CT_DIRECTION] != NULL) {
492                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
493                 switch (priv->dir) {
494                 case IP_CT_DIR_ORIGINAL:
495                 case IP_CT_DIR_REPLY:
496                         break;
497                 default:
498                         return -EINVAL;
499                 }
500         }
501
502         priv->dreg = nft_parse_register(tb[NFTA_CT_DREG]);
503         err = nft_validate_register_store(ctx, priv->dreg, NULL,
504                                           NFT_DATA_VALUE, len);
505         if (err < 0)
506                 return err;
507
508         err = nf_ct_netns_get(ctx->net, ctx->family);
509         if (err < 0)
510                 return err;
511
512         if (priv->key == NFT_CT_BYTES ||
513             priv->key == NFT_CT_PKTS  ||
514             priv->key == NFT_CT_AVGPKT)
515                 nf_ct_set_acct(ctx->net, true);
516
517         return 0;
518 }
519
520 static void __nft_ct_set_destroy(const struct nft_ctx *ctx, struct nft_ct *priv)
521 {
522         switch (priv->key) {
523 #ifdef CONFIG_NF_CONNTRACK_LABELS
524         case NFT_CT_LABELS:
525                 nf_connlabels_put(ctx->net);
526                 break;
527 #endif
528 #ifdef CONFIG_NF_CONNTRACK_ZONES
529         case NFT_CT_ZONE:
530                 if (--nft_ct_pcpu_template_refcnt == 0)
531                         nft_ct_tmpl_put_pcpu();
532 #endif
533         default:
534                 break;
535         }
536 }
537
538 static int nft_ct_set_init(const struct nft_ctx *ctx,
539                            const struct nft_expr *expr,
540                            const struct nlattr * const tb[])
541 {
542         struct nft_ct *priv = nft_expr_priv(expr);
543         unsigned int len;
544         int err;
545
546         priv->dir = IP_CT_DIR_MAX;
547         priv->key = ntohl(nla_get_be32(tb[NFTA_CT_KEY]));
548         switch (priv->key) {
549 #ifdef CONFIG_NF_CONNTRACK_MARK
550         case NFT_CT_MARK:
551                 if (tb[NFTA_CT_DIRECTION])
552                         return -EINVAL;
553                 len = FIELD_SIZEOF(struct nf_conn, mark);
554                 break;
555 #endif
556 #ifdef CONFIG_NF_CONNTRACK_LABELS
557         case NFT_CT_LABELS:
558                 if (tb[NFTA_CT_DIRECTION])
559                         return -EINVAL;
560                 len = NF_CT_LABELS_MAX_SIZE;
561                 err = nf_connlabels_get(ctx->net, (len * BITS_PER_BYTE) - 1);
562                 if (err)
563                         return err;
564                 break;
565 #endif
566 #ifdef CONFIG_NF_CONNTRACK_ZONES
567         case NFT_CT_ZONE:
568                 if (!nft_ct_tmpl_alloc_pcpu())
569                         return -ENOMEM;
570                 nft_ct_pcpu_template_refcnt++;
571                 len = sizeof(u16);
572                 break;
573 #endif
574 #ifdef CONFIG_NF_CONNTRACK_EVENTS
575         case NFT_CT_EVENTMASK:
576                 if (tb[NFTA_CT_DIRECTION])
577                         return -EINVAL;
578                 len = sizeof(u32);
579                 break;
580 #endif
581 #ifdef CONFIG_NF_CONNTRACK_SECMARK
582         case NFT_CT_SECMARK:
583                 if (tb[NFTA_CT_DIRECTION])
584                         return -EINVAL;
585                 len = sizeof(u32);
586                 break;
587 #endif
588         default:
589                 return -EOPNOTSUPP;
590         }
591
592         if (tb[NFTA_CT_DIRECTION]) {
593                 priv->dir = nla_get_u8(tb[NFTA_CT_DIRECTION]);
594                 switch (priv->dir) {
595                 case IP_CT_DIR_ORIGINAL:
596                 case IP_CT_DIR_REPLY:
597                         break;
598                 default:
599                         err = -EINVAL;
600                         goto err1;
601                 }
602         }
603
604         priv->sreg = nft_parse_register(tb[NFTA_CT_SREG]);
605         err = nft_validate_register_load(priv->sreg, len);
606         if (err < 0)
607                 goto err1;
608
609         err = nf_ct_netns_get(ctx->net, ctx->family);
610         if (err < 0)
611                 goto err1;
612
613         return 0;
614
615 err1:
616         __nft_ct_set_destroy(ctx, priv);
617         return err;
618 }
619
620 static void nft_ct_get_destroy(const struct nft_ctx *ctx,
621                                const struct nft_expr *expr)
622 {
623         nf_ct_netns_put(ctx->net, ctx->family);
624 }
625
626 static void nft_ct_set_destroy(const struct nft_ctx *ctx,
627                                const struct nft_expr *expr)
628 {
629         struct nft_ct *priv = nft_expr_priv(expr);
630
631         __nft_ct_set_destroy(ctx, priv);
632         nf_ct_netns_put(ctx->net, ctx->family);
633 }
634
635 static int nft_ct_get_dump(struct sk_buff *skb, const struct nft_expr *expr)
636 {
637         const struct nft_ct *priv = nft_expr_priv(expr);
638
639         if (nft_dump_register(skb, NFTA_CT_DREG, priv->dreg))
640                 goto nla_put_failure;
641         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
642                 goto nla_put_failure;
643
644         switch (priv->key) {
645         case NFT_CT_SRC:
646         case NFT_CT_DST:
647         case NFT_CT_SRC_IP:
648         case NFT_CT_DST_IP:
649         case NFT_CT_SRC_IP6:
650         case NFT_CT_DST_IP6:
651         case NFT_CT_PROTO_SRC:
652         case NFT_CT_PROTO_DST:
653                 if (nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
654                         goto nla_put_failure;
655                 break;
656         case NFT_CT_BYTES:
657         case NFT_CT_PKTS:
658         case NFT_CT_AVGPKT:
659         case NFT_CT_ZONE:
660                 if (priv->dir < IP_CT_DIR_MAX &&
661                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
662                         goto nla_put_failure;
663                 break;
664         default:
665                 break;
666         }
667
668         return 0;
669
670 nla_put_failure:
671         return -1;
672 }
673
674 static int nft_ct_set_dump(struct sk_buff *skb, const struct nft_expr *expr)
675 {
676         const struct nft_ct *priv = nft_expr_priv(expr);
677
678         if (nft_dump_register(skb, NFTA_CT_SREG, priv->sreg))
679                 goto nla_put_failure;
680         if (nla_put_be32(skb, NFTA_CT_KEY, htonl(priv->key)))
681                 goto nla_put_failure;
682
683         switch (priv->key) {
684         case NFT_CT_ZONE:
685                 if (priv->dir < IP_CT_DIR_MAX &&
686                     nla_put_u8(skb, NFTA_CT_DIRECTION, priv->dir))
687                         goto nla_put_failure;
688                 break;
689         default:
690                 break;
691         }
692
693         return 0;
694
695 nla_put_failure:
696         return -1;
697 }
698
699 static struct nft_expr_type nft_ct_type;
700 static const struct nft_expr_ops nft_ct_get_ops = {
701         .type           = &nft_ct_type,
702         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
703         .eval           = nft_ct_get_eval,
704         .init           = nft_ct_get_init,
705         .destroy        = nft_ct_get_destroy,
706         .dump           = nft_ct_get_dump,
707 };
708
709 static const struct nft_expr_ops nft_ct_set_ops = {
710         .type           = &nft_ct_type,
711         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
712         .eval           = nft_ct_set_eval,
713         .init           = nft_ct_set_init,
714         .destroy        = nft_ct_set_destroy,
715         .dump           = nft_ct_set_dump,
716 };
717
718 #ifdef CONFIG_NF_CONNTRACK_ZONES
719 static const struct nft_expr_ops nft_ct_set_zone_ops = {
720         .type           = &nft_ct_type,
721         .size           = NFT_EXPR_SIZE(sizeof(struct nft_ct)),
722         .eval           = nft_ct_set_zone_eval,
723         .init           = nft_ct_set_init,
724         .destroy        = nft_ct_set_destroy,
725         .dump           = nft_ct_set_dump,
726 };
727 #endif
728
729 static const struct nft_expr_ops *
730 nft_ct_select_ops(const struct nft_ctx *ctx,
731                     const struct nlattr * const tb[])
732 {
733         if (tb[NFTA_CT_KEY] == NULL)
734                 return ERR_PTR(-EINVAL);
735
736         if (tb[NFTA_CT_DREG] && tb[NFTA_CT_SREG])
737                 return ERR_PTR(-EINVAL);
738
739         if (tb[NFTA_CT_DREG])
740                 return &nft_ct_get_ops;
741
742         if (tb[NFTA_CT_SREG]) {
743 #ifdef CONFIG_NF_CONNTRACK_ZONES
744                 if (nla_get_be32(tb[NFTA_CT_KEY]) == htonl(NFT_CT_ZONE))
745                         return &nft_ct_set_zone_ops;
746 #endif
747                 return &nft_ct_set_ops;
748         }
749
750         return ERR_PTR(-EINVAL);
751 }
752
753 static struct nft_expr_type nft_ct_type __read_mostly = {
754         .name           = "ct",
755         .select_ops     = nft_ct_select_ops,
756         .policy         = nft_ct_policy,
757         .maxattr        = NFTA_CT_MAX,
758         .owner          = THIS_MODULE,
759 };
760
761 static void nft_notrack_eval(const struct nft_expr *expr,
762                              struct nft_regs *regs,
763                              const struct nft_pktinfo *pkt)
764 {
765         struct sk_buff *skb = pkt->skb;
766         enum ip_conntrack_info ctinfo;
767         struct nf_conn *ct;
768
769         ct = nf_ct_get(pkt->skb, &ctinfo);
770         /* Previously seen (loopback or untracked)?  Ignore. */
771         if (ct || ctinfo == IP_CT_UNTRACKED)
772                 return;
773
774         nf_ct_set(skb, ct, IP_CT_UNTRACKED);
775 }
776
777 static struct nft_expr_type nft_notrack_type;
778 static const struct nft_expr_ops nft_notrack_ops = {
779         .type           = &nft_notrack_type,
780         .size           = NFT_EXPR_SIZE(0),
781         .eval           = nft_notrack_eval,
782 };
783
784 static struct nft_expr_type nft_notrack_type __read_mostly = {
785         .name           = "notrack",
786         .ops            = &nft_notrack_ops,
787         .owner          = THIS_MODULE,
788 };
789
790 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
791 static int
792 nft_ct_timeout_parse_policy(void *timeouts,
793                             const struct nf_conntrack_l4proto *l4proto,
794                             struct net *net, const struct nlattr *attr)
795 {
796         struct nlattr **tb;
797         int ret = 0;
798
799         tb = kcalloc(l4proto->ctnl_timeout.nlattr_max + 1, sizeof(*tb),
800                      GFP_KERNEL);
801
802         if (!tb)
803                 return -ENOMEM;
804
805         ret = nla_parse_nested_deprecated(tb,
806                                           l4proto->ctnl_timeout.nlattr_max,
807                                           attr,
808                                           l4proto->ctnl_timeout.nla_policy,
809                                           NULL);
810         if (ret < 0)
811                 goto err;
812
813         ret = l4proto->ctnl_timeout.nlattr_to_obj(tb, net, timeouts);
814
815 err:
816         kfree(tb);
817         return ret;
818 }
819
820 struct nft_ct_timeout_obj {
821         struct nf_ct_timeout    *timeout;
822         u8                      l4proto;
823 };
824
825 static void nft_ct_timeout_obj_eval(struct nft_object *obj,
826                                     struct nft_regs *regs,
827                                     const struct nft_pktinfo *pkt)
828 {
829         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
830         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
831         struct nf_conn_timeout *timeout;
832         const unsigned int *values;
833
834         if (priv->l4proto != pkt->tprot)
835                 return;
836
837         if (!ct || nf_ct_is_template(ct) || nf_ct_is_confirmed(ct))
838                 return;
839
840         timeout = nf_ct_timeout_find(ct);
841         if (!timeout) {
842                 timeout = nf_ct_timeout_ext_add(ct, priv->timeout, GFP_ATOMIC);
843                 if (!timeout) {
844                         regs->verdict.code = NF_DROP;
845                         return;
846                 }
847         }
848
849         rcu_assign_pointer(timeout->timeout, priv->timeout);
850
851         /* adjust the timeout as per 'new' state. ct is unconfirmed,
852          * so the current timestamp must not be added.
853          */
854         values = nf_ct_timeout_data(timeout);
855         if (values)
856                 nf_ct_refresh(ct, pkt->skb, values[0]);
857 }
858
859 static int nft_ct_timeout_obj_init(const struct nft_ctx *ctx,
860                                    const struct nlattr * const tb[],
861                                    struct nft_object *obj)
862 {
863         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
864         const struct nf_conntrack_l4proto *l4proto;
865         struct nf_ct_timeout *timeout;
866         int l3num = ctx->family;
867         __u8 l4num;
868         int ret;
869
870         if (!tb[NFTA_CT_TIMEOUT_L4PROTO] ||
871             !tb[NFTA_CT_TIMEOUT_DATA])
872                 return -EINVAL;
873
874         if (tb[NFTA_CT_TIMEOUT_L3PROTO])
875                 l3num = ntohs(nla_get_be16(tb[NFTA_CT_TIMEOUT_L3PROTO]));
876
877         l4num = nla_get_u8(tb[NFTA_CT_TIMEOUT_L4PROTO]);
878         priv->l4proto = l4num;
879
880         l4proto = nf_ct_l4proto_find(l4num);
881
882         if (l4proto->l4proto != l4num) {
883                 ret = -EOPNOTSUPP;
884                 goto err_proto_put;
885         }
886
887         timeout = kzalloc(sizeof(struct nf_ct_timeout) +
888                           l4proto->ctnl_timeout.obj_size, GFP_KERNEL);
889         if (timeout == NULL) {
890                 ret = -ENOMEM;
891                 goto err_proto_put;
892         }
893
894         ret = nft_ct_timeout_parse_policy(&timeout->data, l4proto, ctx->net,
895                                           tb[NFTA_CT_TIMEOUT_DATA]);
896         if (ret < 0)
897                 goto err_free_timeout;
898
899         timeout->l3num = l3num;
900         timeout->l4proto = l4proto;
901
902         ret = nf_ct_netns_get(ctx->net, ctx->family);
903         if (ret < 0)
904                 goto err_free_timeout;
905
906         priv->timeout = timeout;
907         return 0;
908
909 err_free_timeout:
910         kfree(timeout);
911 err_proto_put:
912         return ret;
913 }
914
915 static void nft_ct_timeout_obj_destroy(const struct nft_ctx *ctx,
916                                        struct nft_object *obj)
917 {
918         struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
919         struct nf_ct_timeout *timeout = priv->timeout;
920
921         nf_ct_untimeout(ctx->net, timeout);
922         nf_ct_netns_put(ctx->net, ctx->family);
923         kfree(priv->timeout);
924 }
925
926 static int nft_ct_timeout_obj_dump(struct sk_buff *skb,
927                                    struct nft_object *obj, bool reset)
928 {
929         const struct nft_ct_timeout_obj *priv = nft_obj_data(obj);
930         const struct nf_ct_timeout *timeout = priv->timeout;
931         struct nlattr *nest_params;
932         int ret;
933
934         if (nla_put_u8(skb, NFTA_CT_TIMEOUT_L4PROTO, timeout->l4proto->l4proto) ||
935             nla_put_be16(skb, NFTA_CT_TIMEOUT_L3PROTO, htons(timeout->l3num)))
936                 return -1;
937
938         nest_params = nla_nest_start(skb, NFTA_CT_TIMEOUT_DATA);
939         if (!nest_params)
940                 return -1;
941
942         ret = timeout->l4proto->ctnl_timeout.obj_to_nlattr(skb, &timeout->data);
943         if (ret < 0)
944                 return -1;
945         nla_nest_end(skb, nest_params);
946         return 0;
947 }
948
949 static const struct nla_policy nft_ct_timeout_policy[NFTA_CT_TIMEOUT_MAX + 1] = {
950         [NFTA_CT_TIMEOUT_L3PROTO] = {.type = NLA_U16 },
951         [NFTA_CT_TIMEOUT_L4PROTO] = {.type = NLA_U8 },
952         [NFTA_CT_TIMEOUT_DATA]    = {.type = NLA_NESTED },
953 };
954
955 static struct nft_object_type nft_ct_timeout_obj_type;
956
957 static const struct nft_object_ops nft_ct_timeout_obj_ops = {
958         .type           = &nft_ct_timeout_obj_type,
959         .size           = sizeof(struct nft_ct_timeout_obj),
960         .eval           = nft_ct_timeout_obj_eval,
961         .init           = nft_ct_timeout_obj_init,
962         .destroy        = nft_ct_timeout_obj_destroy,
963         .dump           = nft_ct_timeout_obj_dump,
964 };
965
966 static struct nft_object_type nft_ct_timeout_obj_type __read_mostly = {
967         .type           = NFT_OBJECT_CT_TIMEOUT,
968         .ops            = &nft_ct_timeout_obj_ops,
969         .maxattr        = NFTA_CT_TIMEOUT_MAX,
970         .policy         = nft_ct_timeout_policy,
971         .owner          = THIS_MODULE,
972 };
973 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
974
975 static int nft_ct_helper_obj_init(const struct nft_ctx *ctx,
976                                   const struct nlattr * const tb[],
977                                   struct nft_object *obj)
978 {
979         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
980         struct nf_conntrack_helper *help4, *help6;
981         char name[NF_CT_HELPER_NAME_LEN];
982         int family = ctx->family;
983         int err;
984
985         if (!tb[NFTA_CT_HELPER_NAME] || !tb[NFTA_CT_HELPER_L4PROTO])
986                 return -EINVAL;
987
988         priv->l4proto = nla_get_u8(tb[NFTA_CT_HELPER_L4PROTO]);
989         if (!priv->l4proto)
990                 return -ENOENT;
991
992         nla_strlcpy(name, tb[NFTA_CT_HELPER_NAME], sizeof(name));
993
994         if (tb[NFTA_CT_HELPER_L3PROTO])
995                 family = ntohs(nla_get_be16(tb[NFTA_CT_HELPER_L3PROTO]));
996
997         help4 = NULL;
998         help6 = NULL;
999
1000         switch (family) {
1001         case NFPROTO_IPV4:
1002                 if (ctx->family == NFPROTO_IPV6)
1003                         return -EINVAL;
1004
1005                 help4 = nf_conntrack_helper_try_module_get(name, family,
1006                                                            priv->l4proto);
1007                 break;
1008         case NFPROTO_IPV6:
1009                 if (ctx->family == NFPROTO_IPV4)
1010                         return -EINVAL;
1011
1012                 help6 = nf_conntrack_helper_try_module_get(name, family,
1013                                                            priv->l4proto);
1014                 break;
1015         case NFPROTO_NETDEV: /* fallthrough */
1016         case NFPROTO_BRIDGE: /* same */
1017         case NFPROTO_INET:
1018                 help4 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV4,
1019                                                            priv->l4proto);
1020                 help6 = nf_conntrack_helper_try_module_get(name, NFPROTO_IPV6,
1021                                                            priv->l4proto);
1022                 break;
1023         default:
1024                 return -EAFNOSUPPORT;
1025         }
1026
1027         /* && is intentional; only error if INET found neither ipv4 or ipv6 */
1028         if (!help4 && !help6)
1029                 return -ENOENT;
1030
1031         priv->helper4 = help4;
1032         priv->helper6 = help6;
1033
1034         err = nf_ct_netns_get(ctx->net, ctx->family);
1035         if (err < 0)
1036                 goto err_put_helper;
1037
1038         return 0;
1039
1040 err_put_helper:
1041         if (priv->helper4)
1042                 nf_conntrack_helper_put(priv->helper4);
1043         if (priv->helper6)
1044                 nf_conntrack_helper_put(priv->helper6);
1045         return err;
1046 }
1047
1048 static void nft_ct_helper_obj_destroy(const struct nft_ctx *ctx,
1049                                       struct nft_object *obj)
1050 {
1051         struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1052
1053         if (priv->helper4)
1054                 nf_conntrack_helper_put(priv->helper4);
1055         if (priv->helper6)
1056                 nf_conntrack_helper_put(priv->helper6);
1057
1058         nf_ct_netns_put(ctx->net, ctx->family);
1059 }
1060
1061 static void nft_ct_helper_obj_eval(struct nft_object *obj,
1062                                    struct nft_regs *regs,
1063                                    const struct nft_pktinfo *pkt)
1064 {
1065         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1066         struct nf_conn *ct = (struct nf_conn *)skb_nfct(pkt->skb);
1067         struct nf_conntrack_helper *to_assign = NULL;
1068         struct nf_conn_help *help;
1069
1070         if (!ct ||
1071             nf_ct_is_confirmed(ct) ||
1072             nf_ct_is_template(ct) ||
1073             priv->l4proto != nf_ct_protonum(ct))
1074                 return;
1075
1076         switch (nf_ct_l3num(ct)) {
1077         case NFPROTO_IPV4:
1078                 to_assign = priv->helper4;
1079                 break;
1080         case NFPROTO_IPV6:
1081                 to_assign = priv->helper6;
1082                 break;
1083         default:
1084                 WARN_ON_ONCE(1);
1085                 return;
1086         }
1087
1088         if (!to_assign)
1089                 return;
1090
1091         if (test_bit(IPS_HELPER_BIT, &ct->status))
1092                 return;
1093
1094         help = nf_ct_helper_ext_add(ct, GFP_ATOMIC);
1095         if (help) {
1096                 rcu_assign_pointer(help->helper, to_assign);
1097                 set_bit(IPS_HELPER_BIT, &ct->status);
1098         }
1099 }
1100
1101 static int nft_ct_helper_obj_dump(struct sk_buff *skb,
1102                                   struct nft_object *obj, bool reset)
1103 {
1104         const struct nft_ct_helper_obj *priv = nft_obj_data(obj);
1105         const struct nf_conntrack_helper *helper;
1106         u16 family;
1107
1108         if (priv->helper4 && priv->helper6) {
1109                 family = NFPROTO_INET;
1110                 helper = priv->helper4;
1111         } else if (priv->helper6) {
1112                 family = NFPROTO_IPV6;
1113                 helper = priv->helper6;
1114         } else {
1115                 family = NFPROTO_IPV4;
1116                 helper = priv->helper4;
1117         }
1118
1119         if (nla_put_string(skb, NFTA_CT_HELPER_NAME, helper->name))
1120                 return -1;
1121
1122         if (nla_put_u8(skb, NFTA_CT_HELPER_L4PROTO, priv->l4proto))
1123                 return -1;
1124
1125         if (nla_put_be16(skb, NFTA_CT_HELPER_L3PROTO, htons(family)))
1126                 return -1;
1127
1128         return 0;
1129 }
1130
1131 static const struct nla_policy nft_ct_helper_policy[NFTA_CT_HELPER_MAX + 1] = {
1132         [NFTA_CT_HELPER_NAME] = { .type = NLA_STRING,
1133                                   .len = NF_CT_HELPER_NAME_LEN - 1 },
1134         [NFTA_CT_HELPER_L3PROTO] = { .type = NLA_U16 },
1135         [NFTA_CT_HELPER_L4PROTO] = { .type = NLA_U8 },
1136 };
1137
1138 static struct nft_object_type nft_ct_helper_obj_type;
1139 static const struct nft_object_ops nft_ct_helper_obj_ops = {
1140         .type           = &nft_ct_helper_obj_type,
1141         .size           = sizeof(struct nft_ct_helper_obj),
1142         .eval           = nft_ct_helper_obj_eval,
1143         .init           = nft_ct_helper_obj_init,
1144         .destroy        = nft_ct_helper_obj_destroy,
1145         .dump           = nft_ct_helper_obj_dump,
1146 };
1147
1148 static struct nft_object_type nft_ct_helper_obj_type __read_mostly = {
1149         .type           = NFT_OBJECT_CT_HELPER,
1150         .ops            = &nft_ct_helper_obj_ops,
1151         .maxattr        = NFTA_CT_HELPER_MAX,
1152         .policy         = nft_ct_helper_policy,
1153         .owner          = THIS_MODULE,
1154 };
1155
1156 static int __init nft_ct_module_init(void)
1157 {
1158         int err;
1159
1160         BUILD_BUG_ON(NF_CT_LABELS_MAX_SIZE > NFT_REG_SIZE);
1161
1162         err = nft_register_expr(&nft_ct_type);
1163         if (err < 0)
1164                 return err;
1165
1166         err = nft_register_expr(&nft_notrack_type);
1167         if (err < 0)
1168                 goto err1;
1169
1170         err = nft_register_obj(&nft_ct_helper_obj_type);
1171         if (err < 0)
1172                 goto err2;
1173 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1174         err = nft_register_obj(&nft_ct_timeout_obj_type);
1175         if (err < 0)
1176                 goto err3;
1177 #endif
1178         return 0;
1179
1180 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1181 err3:
1182         nft_unregister_obj(&nft_ct_helper_obj_type);
1183 #endif
1184 err2:
1185         nft_unregister_expr(&nft_notrack_type);
1186 err1:
1187         nft_unregister_expr(&nft_ct_type);
1188         return err;
1189 }
1190
1191 static void __exit nft_ct_module_exit(void)
1192 {
1193 #ifdef CONFIG_NF_CONNTRACK_TIMEOUT
1194         nft_unregister_obj(&nft_ct_timeout_obj_type);
1195 #endif
1196         nft_unregister_obj(&nft_ct_helper_obj_type);
1197         nft_unregister_expr(&nft_notrack_type);
1198         nft_unregister_expr(&nft_ct_type);
1199 }
1200
1201 module_init(nft_ct_module_init);
1202 module_exit(nft_ct_module_exit);
1203
1204 MODULE_LICENSE("GPL");
1205 MODULE_AUTHOR("Patrick McHardy <kaber@trash.net>");
1206 MODULE_ALIAS_NFT_EXPR("ct");
1207 MODULE_ALIAS_NFT_EXPR("notrack");
1208 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_HELPER);
1209 MODULE_ALIAS_NFT_OBJ(NFT_OBJECT_CT_TIMEOUT);