]> asedeno.scripts.mit.edu Git - linux.git/blob - net/openvswitch/conntrack.c
Merge tag 'nfs-for-5.1-1' of git://git.linux-nfs.org/projects/trondmy/linux-nfs
[linux.git] / net / openvswitch / conntrack.c
1 /*
2  * Copyright (c) 2015 Nicira, Inc.
3  *
4  * This program is free software; you can redistribute it and/or
5  * modify it under the terms of version 2 of the GNU General Public
6  * License as published by the Free Software Foundation.
7  *
8  * This program is distributed in the hope that it will be useful, but
9  * WITHOUT ANY WARRANTY; without even the implied warranty of
10  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
11  * General Public License for more details.
12  */
13
14 #include <linux/module.h>
15 #include <linux/openvswitch.h>
16 #include <linux/tcp.h>
17 #include <linux/udp.h>
18 #include <linux/sctp.h>
19 #include <linux/static_key.h>
20 #include <net/ip.h>
21 #include <net/genetlink.h>
22 #include <net/netfilter/nf_conntrack_core.h>
23 #include <net/netfilter/nf_conntrack_count.h>
24 #include <net/netfilter/nf_conntrack_helper.h>
25 #include <net/netfilter/nf_conntrack_labels.h>
26 #include <net/netfilter/nf_conntrack_seqadj.h>
27 #include <net/netfilter/nf_conntrack_zones.h>
28 #include <net/netfilter/ipv6/nf_defrag_ipv6.h>
29 #include <net/ipv6_frag.h>
30
31 #ifdef CONFIG_NF_NAT_NEEDED
32 #include <net/netfilter/nf_nat.h>
33 #endif
34
35 #include "datapath.h"
36 #include "conntrack.h"
37 #include "flow.h"
38 #include "flow_netlink.h"
39
40 struct ovs_ct_len_tbl {
41         int maxlen;
42         int minlen;
43 };
44
45 /* Metadata mark for masked write to conntrack mark */
46 struct md_mark {
47         u32 value;
48         u32 mask;
49 };
50
51 /* Metadata label for masked write to conntrack label. */
52 struct md_labels {
53         struct ovs_key_ct_labels value;
54         struct ovs_key_ct_labels mask;
55 };
56
57 enum ovs_ct_nat {
58         OVS_CT_NAT = 1 << 0,     /* NAT for committed connections only. */
59         OVS_CT_SRC_NAT = 1 << 1, /* Source NAT for NEW connections. */
60         OVS_CT_DST_NAT = 1 << 2, /* Destination NAT for NEW connections. */
61 };
62
63 /* Conntrack action context for execution. */
64 struct ovs_conntrack_info {
65         struct nf_conntrack_helper *helper;
66         struct nf_conntrack_zone zone;
67         struct nf_conn *ct;
68         u8 commit : 1;
69         u8 nat : 3;                 /* enum ovs_ct_nat */
70         u8 force : 1;
71         u8 have_eventmask : 1;
72         u16 family;
73         u32 eventmask;              /* Mask of 1 << IPCT_*. */
74         struct md_mark mark;
75         struct md_labels labels;
76 #ifdef CONFIG_NF_NAT_NEEDED
77         struct nf_nat_range2 range;  /* Only present for SRC NAT and DST NAT. */
78 #endif
79 };
80
81 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
82 #define OVS_CT_LIMIT_UNLIMITED  0
83 #define OVS_CT_LIMIT_DEFAULT OVS_CT_LIMIT_UNLIMITED
84 #define CT_LIMIT_HASH_BUCKETS 512
85 static DEFINE_STATIC_KEY_FALSE(ovs_ct_limit_enabled);
86
87 struct ovs_ct_limit {
88         /* Elements in ovs_ct_limit_info->limits hash table */
89         struct hlist_node hlist_node;
90         struct rcu_head rcu;
91         u16 zone;
92         u32 limit;
93 };
94
95 struct ovs_ct_limit_info {
96         u32 default_limit;
97         struct hlist_head *limits;
98         struct nf_conncount_data *data;
99 };
100
101 static const struct nla_policy ct_limit_policy[OVS_CT_LIMIT_ATTR_MAX + 1] = {
102         [OVS_CT_LIMIT_ATTR_ZONE_LIMIT] = { .type = NLA_NESTED, },
103 };
104 #endif
105
106 static bool labels_nonzero(const struct ovs_key_ct_labels *labels);
107
108 static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info);
109
110 static u16 key_to_nfproto(const struct sw_flow_key *key)
111 {
112         switch (ntohs(key->eth.type)) {
113         case ETH_P_IP:
114                 return NFPROTO_IPV4;
115         case ETH_P_IPV6:
116                 return NFPROTO_IPV6;
117         default:
118                 return NFPROTO_UNSPEC;
119         }
120 }
121
122 /* Map SKB connection state into the values used by flow definition. */
123 static u8 ovs_ct_get_state(enum ip_conntrack_info ctinfo)
124 {
125         u8 ct_state = OVS_CS_F_TRACKED;
126
127         switch (ctinfo) {
128         case IP_CT_ESTABLISHED_REPLY:
129         case IP_CT_RELATED_REPLY:
130                 ct_state |= OVS_CS_F_REPLY_DIR;
131                 break;
132         default:
133                 break;
134         }
135
136         switch (ctinfo) {
137         case IP_CT_ESTABLISHED:
138         case IP_CT_ESTABLISHED_REPLY:
139                 ct_state |= OVS_CS_F_ESTABLISHED;
140                 break;
141         case IP_CT_RELATED:
142         case IP_CT_RELATED_REPLY:
143                 ct_state |= OVS_CS_F_RELATED;
144                 break;
145         case IP_CT_NEW:
146                 ct_state |= OVS_CS_F_NEW;
147                 break;
148         default:
149                 break;
150         }
151
152         return ct_state;
153 }
154
155 static u32 ovs_ct_get_mark(const struct nf_conn *ct)
156 {
157 #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
158         return ct ? ct->mark : 0;
159 #else
160         return 0;
161 #endif
162 }
163
164 /* Guard against conntrack labels max size shrinking below 128 bits. */
165 #if NF_CT_LABELS_MAX_SIZE < 16
166 #error NF_CT_LABELS_MAX_SIZE must be at least 16 bytes
167 #endif
168
169 static void ovs_ct_get_labels(const struct nf_conn *ct,
170                               struct ovs_key_ct_labels *labels)
171 {
172         struct nf_conn_labels *cl = ct ? nf_ct_labels_find(ct) : NULL;
173
174         if (cl)
175                 memcpy(labels, cl->bits, OVS_CT_LABELS_LEN);
176         else
177                 memset(labels, 0, OVS_CT_LABELS_LEN);
178 }
179
180 static void __ovs_ct_update_key_orig_tp(struct sw_flow_key *key,
181                                         const struct nf_conntrack_tuple *orig,
182                                         u8 icmp_proto)
183 {
184         key->ct_orig_proto = orig->dst.protonum;
185         if (orig->dst.protonum == icmp_proto) {
186                 key->ct.orig_tp.src = htons(orig->dst.u.icmp.type);
187                 key->ct.orig_tp.dst = htons(orig->dst.u.icmp.code);
188         } else {
189                 key->ct.orig_tp.src = orig->src.u.all;
190                 key->ct.orig_tp.dst = orig->dst.u.all;
191         }
192 }
193
194 static void __ovs_ct_update_key(struct sw_flow_key *key, u8 state,
195                                 const struct nf_conntrack_zone *zone,
196                                 const struct nf_conn *ct)
197 {
198         key->ct_state = state;
199         key->ct_zone = zone->id;
200         key->ct.mark = ovs_ct_get_mark(ct);
201         ovs_ct_get_labels(ct, &key->ct.labels);
202
203         if (ct) {
204                 const struct nf_conntrack_tuple *orig;
205
206                 /* Use the master if we have one. */
207                 if (ct->master)
208                         ct = ct->master;
209                 orig = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
210
211                 /* IP version must match with the master connection. */
212                 if (key->eth.type == htons(ETH_P_IP) &&
213                     nf_ct_l3num(ct) == NFPROTO_IPV4) {
214                         key->ipv4.ct_orig.src = orig->src.u3.ip;
215                         key->ipv4.ct_orig.dst = orig->dst.u3.ip;
216                         __ovs_ct_update_key_orig_tp(key, orig, IPPROTO_ICMP);
217                         return;
218                 } else if (key->eth.type == htons(ETH_P_IPV6) &&
219                            !sw_flow_key_is_nd(key) &&
220                            nf_ct_l3num(ct) == NFPROTO_IPV6) {
221                         key->ipv6.ct_orig.src = orig->src.u3.in6;
222                         key->ipv6.ct_orig.dst = orig->dst.u3.in6;
223                         __ovs_ct_update_key_orig_tp(key, orig, NEXTHDR_ICMP);
224                         return;
225                 }
226         }
227         /* Clear 'ct_orig_proto' to mark the non-existence of conntrack
228          * original direction key fields.
229          */
230         key->ct_orig_proto = 0;
231 }
232
233 /* Update 'key' based on skb->_nfct.  If 'post_ct' is true, then OVS has
234  * previously sent the packet to conntrack via the ct action.  If
235  * 'keep_nat_flags' is true, the existing NAT flags retained, else they are
236  * initialized from the connection status.
237  */
238 static void ovs_ct_update_key(const struct sk_buff *skb,
239                               const struct ovs_conntrack_info *info,
240                               struct sw_flow_key *key, bool post_ct,
241                               bool keep_nat_flags)
242 {
243         const struct nf_conntrack_zone *zone = &nf_ct_zone_dflt;
244         enum ip_conntrack_info ctinfo;
245         struct nf_conn *ct;
246         u8 state = 0;
247
248         ct = nf_ct_get(skb, &ctinfo);
249         if (ct) {
250                 state = ovs_ct_get_state(ctinfo);
251                 /* All unconfirmed entries are NEW connections. */
252                 if (!nf_ct_is_confirmed(ct))
253                         state |= OVS_CS_F_NEW;
254                 /* OVS persists the related flag for the duration of the
255                  * connection.
256                  */
257                 if (ct->master)
258                         state |= OVS_CS_F_RELATED;
259                 if (keep_nat_flags) {
260                         state |= key->ct_state & OVS_CS_F_NAT_MASK;
261                 } else {
262                         if (ct->status & IPS_SRC_NAT)
263                                 state |= OVS_CS_F_SRC_NAT;
264                         if (ct->status & IPS_DST_NAT)
265                                 state |= OVS_CS_F_DST_NAT;
266                 }
267                 zone = nf_ct_zone(ct);
268         } else if (post_ct) {
269                 state = OVS_CS_F_TRACKED | OVS_CS_F_INVALID;
270                 if (info)
271                         zone = &info->zone;
272         }
273         __ovs_ct_update_key(key, state, zone, ct);
274 }
275
276 /* This is called to initialize CT key fields possibly coming in from the local
277  * stack.
278  */
279 void ovs_ct_fill_key(const struct sk_buff *skb, struct sw_flow_key *key)
280 {
281         ovs_ct_update_key(skb, NULL, key, false, false);
282 }
283
284 #define IN6_ADDR_INITIALIZER(ADDR) \
285         { (ADDR).s6_addr32[0], (ADDR).s6_addr32[1], \
286           (ADDR).s6_addr32[2], (ADDR).s6_addr32[3] }
287
288 int ovs_ct_put_key(const struct sw_flow_key *swkey,
289                    const struct sw_flow_key *output, struct sk_buff *skb)
290 {
291         if (nla_put_u32(skb, OVS_KEY_ATTR_CT_STATE, output->ct_state))
292                 return -EMSGSIZE;
293
294         if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) &&
295             nla_put_u16(skb, OVS_KEY_ATTR_CT_ZONE, output->ct_zone))
296                 return -EMSGSIZE;
297
298         if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) &&
299             nla_put_u32(skb, OVS_KEY_ATTR_CT_MARK, output->ct.mark))
300                 return -EMSGSIZE;
301
302         if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) &&
303             nla_put(skb, OVS_KEY_ATTR_CT_LABELS, sizeof(output->ct.labels),
304                     &output->ct.labels))
305                 return -EMSGSIZE;
306
307         if (swkey->ct_orig_proto) {
308                 if (swkey->eth.type == htons(ETH_P_IP)) {
309                         struct ovs_key_ct_tuple_ipv4 orig = {
310                                 output->ipv4.ct_orig.src,
311                                 output->ipv4.ct_orig.dst,
312                                 output->ct.orig_tp.src,
313                                 output->ct.orig_tp.dst,
314                                 output->ct_orig_proto,
315                         };
316                         if (nla_put(skb, OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV4,
317                                     sizeof(orig), &orig))
318                                 return -EMSGSIZE;
319                 } else if (swkey->eth.type == htons(ETH_P_IPV6)) {
320                         struct ovs_key_ct_tuple_ipv6 orig = {
321                                 IN6_ADDR_INITIALIZER(output->ipv6.ct_orig.src),
322                                 IN6_ADDR_INITIALIZER(output->ipv6.ct_orig.dst),
323                                 output->ct.orig_tp.src,
324                                 output->ct.orig_tp.dst,
325                                 output->ct_orig_proto,
326                         };
327                         if (nla_put(skb, OVS_KEY_ATTR_CT_ORIG_TUPLE_IPV6,
328                                     sizeof(orig), &orig))
329                                 return -EMSGSIZE;
330                 }
331         }
332
333         return 0;
334 }
335
336 static int ovs_ct_set_mark(struct nf_conn *ct, struct sw_flow_key *key,
337                            u32 ct_mark, u32 mask)
338 {
339 #if IS_ENABLED(CONFIG_NF_CONNTRACK_MARK)
340         u32 new_mark;
341
342         new_mark = ct_mark | (ct->mark & ~(mask));
343         if (ct->mark != new_mark) {
344                 ct->mark = new_mark;
345                 if (nf_ct_is_confirmed(ct))
346                         nf_conntrack_event_cache(IPCT_MARK, ct);
347                 key->ct.mark = new_mark;
348         }
349
350         return 0;
351 #else
352         return -ENOTSUPP;
353 #endif
354 }
355
356 static struct nf_conn_labels *ovs_ct_get_conn_labels(struct nf_conn *ct)
357 {
358         struct nf_conn_labels *cl;
359
360         cl = nf_ct_labels_find(ct);
361         if (!cl) {
362                 nf_ct_labels_ext_add(ct);
363                 cl = nf_ct_labels_find(ct);
364         }
365
366         return cl;
367 }
368
369 /* Initialize labels for a new, yet to be committed conntrack entry.  Note that
370  * since the new connection is not yet confirmed, and thus no-one else has
371  * access to it's labels, we simply write them over.
372  */
373 static int ovs_ct_init_labels(struct nf_conn *ct, struct sw_flow_key *key,
374                               const struct ovs_key_ct_labels *labels,
375                               const struct ovs_key_ct_labels *mask)
376 {
377         struct nf_conn_labels *cl, *master_cl;
378         bool have_mask = labels_nonzero(mask);
379
380         /* Inherit master's labels to the related connection? */
381         master_cl = ct->master ? nf_ct_labels_find(ct->master) : NULL;
382
383         if (!master_cl && !have_mask)
384                 return 0;   /* Nothing to do. */
385
386         cl = ovs_ct_get_conn_labels(ct);
387         if (!cl)
388                 return -ENOSPC;
389
390         /* Inherit the master's labels, if any. */
391         if (master_cl)
392                 *cl = *master_cl;
393
394         if (have_mask) {
395                 u32 *dst = (u32 *)cl->bits;
396                 int i;
397
398                 for (i = 0; i < OVS_CT_LABELS_LEN_32; i++)
399                         dst[i] = (dst[i] & ~mask->ct_labels_32[i]) |
400                                 (labels->ct_labels_32[i]
401                                  & mask->ct_labels_32[i]);
402         }
403
404         /* Labels are included in the IPCTNL_MSG_CT_NEW event only if the
405          * IPCT_LABEL bit is set in the event cache.
406          */
407         nf_conntrack_event_cache(IPCT_LABEL, ct);
408
409         memcpy(&key->ct.labels, cl->bits, OVS_CT_LABELS_LEN);
410
411         return 0;
412 }
413
414 static int ovs_ct_set_labels(struct nf_conn *ct, struct sw_flow_key *key,
415                              const struct ovs_key_ct_labels *labels,
416                              const struct ovs_key_ct_labels *mask)
417 {
418         struct nf_conn_labels *cl;
419         int err;
420
421         cl = ovs_ct_get_conn_labels(ct);
422         if (!cl)
423                 return -ENOSPC;
424
425         err = nf_connlabels_replace(ct, labels->ct_labels_32,
426                                     mask->ct_labels_32,
427                                     OVS_CT_LABELS_LEN_32);
428         if (err)
429                 return err;
430
431         memcpy(&key->ct.labels, cl->bits, OVS_CT_LABELS_LEN);
432
433         return 0;
434 }
435
436 /* 'skb' should already be pulled to nh_ofs. */
437 static int ovs_ct_helper(struct sk_buff *skb, u16 proto)
438 {
439         const struct nf_conntrack_helper *helper;
440         const struct nf_conn_help *help;
441         enum ip_conntrack_info ctinfo;
442         unsigned int protoff;
443         struct nf_conn *ct;
444         int err;
445
446         ct = nf_ct_get(skb, &ctinfo);
447         if (!ct || ctinfo == IP_CT_RELATED_REPLY)
448                 return NF_ACCEPT;
449
450         help = nfct_help(ct);
451         if (!help)
452                 return NF_ACCEPT;
453
454         helper = rcu_dereference(help->helper);
455         if (!helper)
456                 return NF_ACCEPT;
457
458         switch (proto) {
459         case NFPROTO_IPV4:
460                 protoff = ip_hdrlen(skb);
461                 break;
462         case NFPROTO_IPV6: {
463                 u8 nexthdr = ipv6_hdr(skb)->nexthdr;
464                 __be16 frag_off;
465                 int ofs;
466
467                 ofs = ipv6_skip_exthdr(skb, sizeof(struct ipv6hdr), &nexthdr,
468                                        &frag_off);
469                 if (ofs < 0 || (frag_off & htons(~0x7)) != 0) {
470                         pr_debug("proto header not found\n");
471                         return NF_ACCEPT;
472                 }
473                 protoff = ofs;
474                 break;
475         }
476         default:
477                 WARN_ONCE(1, "helper invoked on non-IP family!");
478                 return NF_DROP;
479         }
480
481         err = helper->help(skb, protoff, ct, ctinfo);
482         if (err != NF_ACCEPT)
483                 return err;
484
485         /* Adjust seqs after helper.  This is needed due to some helpers (e.g.,
486          * FTP with NAT) adusting the TCP payload size when mangling IP
487          * addresses and/or port numbers in the text-based control connection.
488          */
489         if (test_bit(IPS_SEQ_ADJUST_BIT, &ct->status) &&
490             !nf_ct_seq_adjust(skb, ct, ctinfo, protoff))
491                 return NF_DROP;
492         return NF_ACCEPT;
493 }
494
495 /* Returns 0 on success, -EINPROGRESS if 'skb' is stolen, or other nonzero
496  * value if 'skb' is freed.
497  */
498 static int handle_fragments(struct net *net, struct sw_flow_key *key,
499                             u16 zone, struct sk_buff *skb)
500 {
501         struct ovs_skb_cb ovs_cb = *OVS_CB(skb);
502         int err;
503
504         if (key->eth.type == htons(ETH_P_IP)) {
505                 enum ip_defrag_users user = IP_DEFRAG_CONNTRACK_IN + zone;
506
507                 memset(IPCB(skb), 0, sizeof(struct inet_skb_parm));
508                 err = ip_defrag(net, skb, user);
509                 if (err)
510                         return err;
511
512                 ovs_cb.mru = IPCB(skb)->frag_max_size;
513 #if IS_ENABLED(CONFIG_NF_DEFRAG_IPV6)
514         } else if (key->eth.type == htons(ETH_P_IPV6)) {
515                 enum ip6_defrag_users user = IP6_DEFRAG_CONNTRACK_IN + zone;
516
517                 memset(IP6CB(skb), 0, sizeof(struct inet6_skb_parm));
518                 err = nf_ct_frag6_gather(net, skb, user);
519                 if (err) {
520                         if (err != -EINPROGRESS)
521                                 kfree_skb(skb);
522                         return err;
523                 }
524
525                 key->ip.proto = ipv6_hdr(skb)->nexthdr;
526                 ovs_cb.mru = IP6CB(skb)->frag_max_size;
527 #endif
528         } else {
529                 kfree_skb(skb);
530                 return -EPFNOSUPPORT;
531         }
532
533         key->ip.frag = OVS_FRAG_TYPE_NONE;
534         skb_clear_hash(skb);
535         skb->ignore_df = 1;
536         *OVS_CB(skb) = ovs_cb;
537
538         return 0;
539 }
540
541 static struct nf_conntrack_expect *
542 ovs_ct_expect_find(struct net *net, const struct nf_conntrack_zone *zone,
543                    u16 proto, const struct sk_buff *skb)
544 {
545         struct nf_conntrack_tuple tuple;
546         struct nf_conntrack_expect *exp;
547
548         if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), proto, net, &tuple))
549                 return NULL;
550
551         exp = __nf_ct_expect_find(net, zone, &tuple);
552         if (exp) {
553                 struct nf_conntrack_tuple_hash *h;
554
555                 /* Delete existing conntrack entry, if it clashes with the
556                  * expectation.  This can happen since conntrack ALGs do not
557                  * check for clashes between (new) expectations and existing
558                  * conntrack entries.  nf_conntrack_in() will check the
559                  * expectations only if a conntrack entry can not be found,
560                  * which can lead to OVS finding the expectation (here) in the
561                  * init direction, but which will not be removed by the
562                  * nf_conntrack_in() call, if a matching conntrack entry is
563                  * found instead.  In this case all init direction packets
564                  * would be reported as new related packets, while reply
565                  * direction packets would be reported as un-related
566                  * established packets.
567                  */
568                 h = nf_conntrack_find_get(net, zone, &tuple);
569                 if (h) {
570                         struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);
571
572                         nf_ct_delete(ct, 0, 0);
573                         nf_conntrack_put(&ct->ct_general);
574                 }
575         }
576
577         return exp;
578 }
579
580 /* This replicates logic from nf_conntrack_core.c that is not exported. */
581 static enum ip_conntrack_info
582 ovs_ct_get_info(const struct nf_conntrack_tuple_hash *h)
583 {
584         const struct nf_conn *ct = nf_ct_tuplehash_to_ctrack(h);
585
586         if (NF_CT_DIRECTION(h) == IP_CT_DIR_REPLY)
587                 return IP_CT_ESTABLISHED_REPLY;
588         /* Once we've had two way comms, always ESTABLISHED. */
589         if (test_bit(IPS_SEEN_REPLY_BIT, &ct->status))
590                 return IP_CT_ESTABLISHED;
591         if (test_bit(IPS_EXPECTED_BIT, &ct->status))
592                 return IP_CT_RELATED;
593         return IP_CT_NEW;
594 }
595
596 /* Find an existing connection which this packet belongs to without
597  * re-attributing statistics or modifying the connection state.  This allows an
598  * skb->_nfct lost due to an upcall to be recovered during actions execution.
599  *
600  * Must be called with rcu_read_lock.
601  *
602  * On success, populates skb->_nfct and returns the connection.  Returns NULL
603  * if there is no existing entry.
604  */
605 static struct nf_conn *
606 ovs_ct_find_existing(struct net *net, const struct nf_conntrack_zone *zone,
607                      u8 l3num, struct sk_buff *skb, bool natted)
608 {
609         struct nf_conntrack_tuple tuple;
610         struct nf_conntrack_tuple_hash *h;
611         struct nf_conn *ct;
612
613         if (!nf_ct_get_tuplepr(skb, skb_network_offset(skb), l3num,
614                                net, &tuple)) {
615                 pr_debug("ovs_ct_find_existing: Can't get tuple\n");
616                 return NULL;
617         }
618
619         /* Must invert the tuple if skb has been transformed by NAT. */
620         if (natted) {
621                 struct nf_conntrack_tuple inverse;
622
623                 if (!nf_ct_invert_tuple(&inverse, &tuple)) {
624                         pr_debug("ovs_ct_find_existing: Inversion failed!\n");
625                         return NULL;
626                 }
627                 tuple = inverse;
628         }
629
630         /* look for tuple match */
631         h = nf_conntrack_find_get(net, zone, &tuple);
632         if (!h)
633                 return NULL;   /* Not found. */
634
635         ct = nf_ct_tuplehash_to_ctrack(h);
636
637         /* Inverted packet tuple matches the reverse direction conntrack tuple,
638          * select the other tuplehash to get the right 'ctinfo' bits for this
639          * packet.
640          */
641         if (natted)
642                 h = &ct->tuplehash[!h->tuple.dst.dir];
643
644         nf_ct_set(skb, ct, ovs_ct_get_info(h));
645         return ct;
646 }
647
648 static
649 struct nf_conn *ovs_ct_executed(struct net *net,
650                                 const struct sw_flow_key *key,
651                                 const struct ovs_conntrack_info *info,
652                                 struct sk_buff *skb,
653                                 bool *ct_executed)
654 {
655         struct nf_conn *ct = NULL;
656
657         /* If no ct, check if we have evidence that an existing conntrack entry
658          * might be found for this skb.  This happens when we lose a skb->_nfct
659          * due to an upcall, or if the direction is being forced.  If the
660          * connection was not confirmed, it is not cached and needs to be run
661          * through conntrack again.
662          */
663         *ct_executed = (key->ct_state & OVS_CS_F_TRACKED) &&
664                        !(key->ct_state & OVS_CS_F_INVALID) &&
665                        (key->ct_zone == info->zone.id);
666
667         if (*ct_executed || (!key->ct_state && info->force)) {
668                 ct = ovs_ct_find_existing(net, &info->zone, info->family, skb,
669                                           !!(key->ct_state &
670                                           OVS_CS_F_NAT_MASK));
671         }
672
673         return ct;
674 }
675
676 /* Determine whether skb->_nfct is equal to the result of conntrack lookup. */
677 static bool skb_nfct_cached(struct net *net,
678                             const struct sw_flow_key *key,
679                             const struct ovs_conntrack_info *info,
680                             struct sk_buff *skb)
681 {
682         enum ip_conntrack_info ctinfo;
683         struct nf_conn *ct;
684         bool ct_executed = true;
685
686         ct = nf_ct_get(skb, &ctinfo);
687         if (!ct)
688                 ct = ovs_ct_executed(net, key, info, skb, &ct_executed);
689
690         if (ct)
691                 nf_ct_get(skb, &ctinfo);
692         else
693                 return false;
694
695         if (!net_eq(net, read_pnet(&ct->ct_net)))
696                 return false;
697         if (!nf_ct_zone_equal_any(info->ct, nf_ct_zone(ct)))
698                 return false;
699         if (info->helper) {
700                 struct nf_conn_help *help;
701
702                 help = nf_ct_ext_find(ct, NF_CT_EXT_HELPER);
703                 if (help && rcu_access_pointer(help->helper) != info->helper)
704                         return false;
705         }
706         /* Force conntrack entry direction to the current packet? */
707         if (info->force && CTINFO2DIR(ctinfo) != IP_CT_DIR_ORIGINAL) {
708                 /* Delete the conntrack entry if confirmed, else just release
709                  * the reference.
710                  */
711                 if (nf_ct_is_confirmed(ct))
712                         nf_ct_delete(ct, 0, 0);
713
714                 nf_conntrack_put(&ct->ct_general);
715                 nf_ct_set(skb, NULL, 0);
716                 return false;
717         }
718
719         return ct_executed;
720 }
721
722 #ifdef CONFIG_NF_NAT_NEEDED
723 /* Modelled after nf_nat_ipv[46]_fn().
724  * range is only used for new, uninitialized NAT state.
725  * Returns either NF_ACCEPT or NF_DROP.
726  */
727 static int ovs_ct_nat_execute(struct sk_buff *skb, struct nf_conn *ct,
728                               enum ip_conntrack_info ctinfo,
729                               const struct nf_nat_range2 *range,
730                               enum nf_nat_manip_type maniptype)
731 {
732         int hooknum, nh_off, err = NF_ACCEPT;
733
734         nh_off = skb_network_offset(skb);
735         skb_pull_rcsum(skb, nh_off);
736
737         /* See HOOK2MANIP(). */
738         if (maniptype == NF_NAT_MANIP_SRC)
739                 hooknum = NF_INET_LOCAL_IN; /* Source NAT */
740         else
741                 hooknum = NF_INET_LOCAL_OUT; /* Destination NAT */
742
743         switch (ctinfo) {
744         case IP_CT_RELATED:
745         case IP_CT_RELATED_REPLY:
746                 if (IS_ENABLED(CONFIG_NF_NAT) &&
747                     skb->protocol == htons(ETH_P_IP) &&
748                     ip_hdr(skb)->protocol == IPPROTO_ICMP) {
749                         if (!nf_nat_icmp_reply_translation(skb, ct, ctinfo,
750                                                            hooknum))
751                                 err = NF_DROP;
752                         goto push;
753                 } else if (IS_ENABLED(CONFIG_IPV6) &&
754                            skb->protocol == htons(ETH_P_IPV6)) {
755                         __be16 frag_off;
756                         u8 nexthdr = ipv6_hdr(skb)->nexthdr;
757                         int hdrlen = ipv6_skip_exthdr(skb,
758                                                       sizeof(struct ipv6hdr),
759                                                       &nexthdr, &frag_off);
760
761                         if (hdrlen >= 0 && nexthdr == IPPROTO_ICMPV6) {
762                                 if (!nf_nat_icmpv6_reply_translation(skb, ct,
763                                                                      ctinfo,
764                                                                      hooknum,
765                                                                      hdrlen))
766                                         err = NF_DROP;
767                                 goto push;
768                         }
769                 }
770                 /* Non-ICMP, fall thru to initialize if needed. */
771                 /* fall through */
772         case IP_CT_NEW:
773                 /* Seen it before?  This can happen for loopback, retrans,
774                  * or local packets.
775                  */
776                 if (!nf_nat_initialized(ct, maniptype)) {
777                         /* Initialize according to the NAT action. */
778                         err = (range && range->flags & NF_NAT_RANGE_MAP_IPS)
779                                 /* Action is set up to establish a new
780                                  * mapping.
781                                  */
782                                 ? nf_nat_setup_info(ct, range, maniptype)
783                                 : nf_nat_alloc_null_binding(ct, hooknum);
784                         if (err != NF_ACCEPT)
785                                 goto push;
786                 }
787                 break;
788
789         case IP_CT_ESTABLISHED:
790         case IP_CT_ESTABLISHED_REPLY:
791                 break;
792
793         default:
794                 err = NF_DROP;
795                 goto push;
796         }
797
798         err = nf_nat_packet(ct, ctinfo, hooknum, skb);
799 push:
800         skb_push(skb, nh_off);
801         skb_postpush_rcsum(skb, skb->data, nh_off);
802
803         return err;
804 }
805
806 static void ovs_nat_update_key(struct sw_flow_key *key,
807                                const struct sk_buff *skb,
808                                enum nf_nat_manip_type maniptype)
809 {
810         if (maniptype == NF_NAT_MANIP_SRC) {
811                 __be16 src;
812
813                 key->ct_state |= OVS_CS_F_SRC_NAT;
814                 if (key->eth.type == htons(ETH_P_IP))
815                         key->ipv4.addr.src = ip_hdr(skb)->saddr;
816                 else if (key->eth.type == htons(ETH_P_IPV6))
817                         memcpy(&key->ipv6.addr.src, &ipv6_hdr(skb)->saddr,
818                                sizeof(key->ipv6.addr.src));
819                 else
820                         return;
821
822                 if (key->ip.proto == IPPROTO_UDP)
823                         src = udp_hdr(skb)->source;
824                 else if (key->ip.proto == IPPROTO_TCP)
825                         src = tcp_hdr(skb)->source;
826                 else if (key->ip.proto == IPPROTO_SCTP)
827                         src = sctp_hdr(skb)->source;
828                 else
829                         return;
830
831                 key->tp.src = src;
832         } else {
833                 __be16 dst;
834
835                 key->ct_state |= OVS_CS_F_DST_NAT;
836                 if (key->eth.type == htons(ETH_P_IP))
837                         key->ipv4.addr.dst = ip_hdr(skb)->daddr;
838                 else if (key->eth.type == htons(ETH_P_IPV6))
839                         memcpy(&key->ipv6.addr.dst, &ipv6_hdr(skb)->daddr,
840                                sizeof(key->ipv6.addr.dst));
841                 else
842                         return;
843
844                 if (key->ip.proto == IPPROTO_UDP)
845                         dst = udp_hdr(skb)->dest;
846                 else if (key->ip.proto == IPPROTO_TCP)
847                         dst = tcp_hdr(skb)->dest;
848                 else if (key->ip.proto == IPPROTO_SCTP)
849                         dst = sctp_hdr(skb)->dest;
850                 else
851                         return;
852
853                 key->tp.dst = dst;
854         }
855 }
856
857 /* Returns NF_DROP if the packet should be dropped, NF_ACCEPT otherwise. */
858 static int ovs_ct_nat(struct net *net, struct sw_flow_key *key,
859                       const struct ovs_conntrack_info *info,
860                       struct sk_buff *skb, struct nf_conn *ct,
861                       enum ip_conntrack_info ctinfo)
862 {
863         enum nf_nat_manip_type maniptype;
864         int err;
865
866         /* Add NAT extension if not confirmed yet. */
867         if (!nf_ct_is_confirmed(ct) && !nf_ct_nat_ext_add(ct))
868                 return NF_ACCEPT;   /* Can't NAT. */
869
870         /* Determine NAT type.
871          * Check if the NAT type can be deduced from the tracked connection.
872          * Make sure new expected connections (IP_CT_RELATED) are NATted only
873          * when committing.
874          */
875         if (info->nat & OVS_CT_NAT && ctinfo != IP_CT_NEW &&
876             ct->status & IPS_NAT_MASK &&
877             (ctinfo != IP_CT_RELATED || info->commit)) {
878                 /* NAT an established or related connection like before. */
879                 if (CTINFO2DIR(ctinfo) == IP_CT_DIR_REPLY)
880                         /* This is the REPLY direction for a connection
881                          * for which NAT was applied in the forward
882                          * direction.  Do the reverse NAT.
883                          */
884                         maniptype = ct->status & IPS_SRC_NAT
885                                 ? NF_NAT_MANIP_DST : NF_NAT_MANIP_SRC;
886                 else
887                         maniptype = ct->status & IPS_SRC_NAT
888                                 ? NF_NAT_MANIP_SRC : NF_NAT_MANIP_DST;
889         } else if (info->nat & OVS_CT_SRC_NAT) {
890                 maniptype = NF_NAT_MANIP_SRC;
891         } else if (info->nat & OVS_CT_DST_NAT) {
892                 maniptype = NF_NAT_MANIP_DST;
893         } else {
894                 return NF_ACCEPT; /* Connection is not NATed. */
895         }
896         err = ovs_ct_nat_execute(skb, ct, ctinfo, &info->range, maniptype);
897
898         /* Mark NAT done if successful and update the flow key. */
899         if (err == NF_ACCEPT)
900                 ovs_nat_update_key(key, skb, maniptype);
901
902         return err;
903 }
904 #else /* !CONFIG_NF_NAT_NEEDED */
905 static int ovs_ct_nat(struct net *net, struct sw_flow_key *key,
906                       const struct ovs_conntrack_info *info,
907                       struct sk_buff *skb, struct nf_conn *ct,
908                       enum ip_conntrack_info ctinfo)
909 {
910         return NF_ACCEPT;
911 }
912 #endif
913
914 /* Pass 'skb' through conntrack in 'net', using zone configured in 'info', if
915  * not done already.  Update key with new CT state after passing the packet
916  * through conntrack.
917  * Note that if the packet is deemed invalid by conntrack, skb->_nfct will be
918  * set to NULL and 0 will be returned.
919  */
920 static int __ovs_ct_lookup(struct net *net, struct sw_flow_key *key,
921                            const struct ovs_conntrack_info *info,
922                            struct sk_buff *skb)
923 {
924         /* If we are recirculating packets to match on conntrack fields and
925          * committing with a separate conntrack action,  then we don't need to
926          * actually run the packet through conntrack twice unless it's for a
927          * different zone.
928          */
929         bool cached = skb_nfct_cached(net, key, info, skb);
930         enum ip_conntrack_info ctinfo;
931         struct nf_conn *ct;
932
933         if (!cached) {
934                 struct nf_hook_state state = {
935                         .hook = NF_INET_PRE_ROUTING,
936                         .pf = info->family,
937                         .net = net,
938                 };
939                 struct nf_conn *tmpl = info->ct;
940                 int err;
941
942                 /* Associate skb with specified zone. */
943                 if (tmpl) {
944                         if (skb_nfct(skb))
945                                 nf_conntrack_put(skb_nfct(skb));
946                         nf_conntrack_get(&tmpl->ct_general);
947                         nf_ct_set(skb, tmpl, IP_CT_NEW);
948                 }
949
950                 err = nf_conntrack_in(skb, &state);
951                 if (err != NF_ACCEPT)
952                         return -ENOENT;
953
954                 /* Clear CT state NAT flags to mark that we have not yet done
955                  * NAT after the nf_conntrack_in() call.  We can actually clear
956                  * the whole state, as it will be re-initialized below.
957                  */
958                 key->ct_state = 0;
959
960                 /* Update the key, but keep the NAT flags. */
961                 ovs_ct_update_key(skb, info, key, true, true);
962         }
963
964         ct = nf_ct_get(skb, &ctinfo);
965         if (ct) {
966                 /* Packets starting a new connection must be NATted before the
967                  * helper, so that the helper knows about the NAT.  We enforce
968                  * this by delaying both NAT and helper calls for unconfirmed
969                  * connections until the committing CT action.  For later
970                  * packets NAT and Helper may be called in either order.
971                  *
972                  * NAT will be done only if the CT action has NAT, and only
973                  * once per packet (per zone), as guarded by the NAT bits in
974                  * the key->ct_state.
975                  */
976                 if (info->nat && !(key->ct_state & OVS_CS_F_NAT_MASK) &&
977                     (nf_ct_is_confirmed(ct) || info->commit) &&
978                     ovs_ct_nat(net, key, info, skb, ct, ctinfo) != NF_ACCEPT) {
979                         return -EINVAL;
980                 }
981
982                 /* Userspace may decide to perform a ct lookup without a helper
983                  * specified followed by a (recirculate and) commit with one.
984                  * Therefore, for unconfirmed connections which we will commit,
985                  * we need to attach the helper here.
986                  */
987                 if (!nf_ct_is_confirmed(ct) && info->commit &&
988                     info->helper && !nfct_help(ct)) {
989                         int err = __nf_ct_try_assign_helper(ct, info->ct,
990                                                             GFP_ATOMIC);
991                         if (err)
992                                 return err;
993                 }
994
995                 /* Call the helper only if:
996                  * - nf_conntrack_in() was executed above ("!cached") for a
997                  *   confirmed connection, or
998                  * - When committing an unconfirmed connection.
999                  */
1000                 if ((nf_ct_is_confirmed(ct) ? !cached : info->commit) &&
1001                     ovs_ct_helper(skb, info->family) != NF_ACCEPT) {
1002                         return -EINVAL;
1003                 }
1004         }
1005
1006         return 0;
1007 }
1008
1009 /* Lookup connection and read fields into key. */
1010 static int ovs_ct_lookup(struct net *net, struct sw_flow_key *key,
1011                          const struct ovs_conntrack_info *info,
1012                          struct sk_buff *skb)
1013 {
1014         struct nf_conntrack_expect *exp;
1015
1016         /* If we pass an expected packet through nf_conntrack_in() the
1017          * expectation is typically removed, but the packet could still be
1018          * lost in upcall processing.  To prevent this from happening we
1019          * perform an explicit expectation lookup.  Expected connections are
1020          * always new, and will be passed through conntrack only when they are
1021          * committed, as it is OK to remove the expectation at that time.
1022          */
1023         exp = ovs_ct_expect_find(net, &info->zone, info->family, skb);
1024         if (exp) {
1025                 u8 state;
1026
1027                 /* NOTE: New connections are NATted and Helped only when
1028                  * committed, so we are not calling into NAT here.
1029                  */
1030                 state = OVS_CS_F_TRACKED | OVS_CS_F_NEW | OVS_CS_F_RELATED;
1031                 __ovs_ct_update_key(key, state, &info->zone, exp->master);
1032         } else {
1033                 struct nf_conn *ct;
1034                 int err;
1035
1036                 err = __ovs_ct_lookup(net, key, info, skb);
1037                 if (err)
1038                         return err;
1039
1040                 ct = (struct nf_conn *)skb_nfct(skb);
1041                 if (ct)
1042                         nf_ct_deliver_cached_events(ct);
1043         }
1044
1045         return 0;
1046 }
1047
1048 static bool labels_nonzero(const struct ovs_key_ct_labels *labels)
1049 {
1050         size_t i;
1051
1052         for (i = 0; i < OVS_CT_LABELS_LEN_32; i++)
1053                 if (labels->ct_labels_32[i])
1054                         return true;
1055
1056         return false;
1057 }
1058
1059 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
1060 static struct hlist_head *ct_limit_hash_bucket(
1061         const struct ovs_ct_limit_info *info, u16 zone)
1062 {
1063         return &info->limits[zone & (CT_LIMIT_HASH_BUCKETS - 1)];
1064 }
1065
1066 /* Call with ovs_mutex */
1067 static void ct_limit_set(const struct ovs_ct_limit_info *info,
1068                          struct ovs_ct_limit *new_ct_limit)
1069 {
1070         struct ovs_ct_limit *ct_limit;
1071         struct hlist_head *head;
1072
1073         head = ct_limit_hash_bucket(info, new_ct_limit->zone);
1074         hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
1075                 if (ct_limit->zone == new_ct_limit->zone) {
1076                         hlist_replace_rcu(&ct_limit->hlist_node,
1077                                           &new_ct_limit->hlist_node);
1078                         kfree_rcu(ct_limit, rcu);
1079                         return;
1080                 }
1081         }
1082
1083         hlist_add_head_rcu(&new_ct_limit->hlist_node, head);
1084 }
1085
1086 /* Call with ovs_mutex */
1087 static void ct_limit_del(const struct ovs_ct_limit_info *info, u16 zone)
1088 {
1089         struct ovs_ct_limit *ct_limit;
1090         struct hlist_head *head;
1091         struct hlist_node *n;
1092
1093         head = ct_limit_hash_bucket(info, zone);
1094         hlist_for_each_entry_safe(ct_limit, n, head, hlist_node) {
1095                 if (ct_limit->zone == zone) {
1096                         hlist_del_rcu(&ct_limit->hlist_node);
1097                         kfree_rcu(ct_limit, rcu);
1098                         return;
1099                 }
1100         }
1101 }
1102
1103 /* Call with RCU read lock */
1104 static u32 ct_limit_get(const struct ovs_ct_limit_info *info, u16 zone)
1105 {
1106         struct ovs_ct_limit *ct_limit;
1107         struct hlist_head *head;
1108
1109         head = ct_limit_hash_bucket(info, zone);
1110         hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
1111                 if (ct_limit->zone == zone)
1112                         return ct_limit->limit;
1113         }
1114
1115         return info->default_limit;
1116 }
1117
1118 static int ovs_ct_check_limit(struct net *net,
1119                               const struct ovs_conntrack_info *info,
1120                               const struct nf_conntrack_tuple *tuple)
1121 {
1122         struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
1123         const struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
1124         u32 per_zone_limit, connections;
1125         u32 conncount_key;
1126
1127         conncount_key = info->zone.id;
1128
1129         per_zone_limit = ct_limit_get(ct_limit_info, info->zone.id);
1130         if (per_zone_limit == OVS_CT_LIMIT_UNLIMITED)
1131                 return 0;
1132
1133         connections = nf_conncount_count(net, ct_limit_info->data,
1134                                          &conncount_key, tuple, &info->zone);
1135         if (connections > per_zone_limit)
1136                 return -ENOMEM;
1137
1138         return 0;
1139 }
1140 #endif
1141
1142 /* Lookup connection and confirm if unconfirmed. */
1143 static int ovs_ct_commit(struct net *net, struct sw_flow_key *key,
1144                          const struct ovs_conntrack_info *info,
1145                          struct sk_buff *skb)
1146 {
1147         enum ip_conntrack_info ctinfo;
1148         struct nf_conn *ct;
1149         int err;
1150
1151         err = __ovs_ct_lookup(net, key, info, skb);
1152         if (err)
1153                 return err;
1154
1155         /* The connection could be invalid, in which case this is a no-op.*/
1156         ct = nf_ct_get(skb, &ctinfo);
1157         if (!ct)
1158                 return 0;
1159
1160 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
1161         if (static_branch_unlikely(&ovs_ct_limit_enabled)) {
1162                 if (!nf_ct_is_confirmed(ct)) {
1163                         err = ovs_ct_check_limit(net, info,
1164                                 &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
1165                         if (err) {
1166                                 net_warn_ratelimited("openvswitch: zone: %u "
1167                                         "exceeds conntrack limit\n",
1168                                         info->zone.id);
1169                                 return err;
1170                         }
1171                 }
1172         }
1173 #endif
1174
1175         /* Set the conntrack event mask if given.  NEW and DELETE events have
1176          * their own groups, but the NFNLGRP_CONNTRACK_UPDATE group listener
1177          * typically would receive many kinds of updates.  Setting the event
1178          * mask allows those events to be filtered.  The set event mask will
1179          * remain in effect for the lifetime of the connection unless changed
1180          * by a further CT action with both the commit flag and the eventmask
1181          * option. */
1182         if (info->have_eventmask) {
1183                 struct nf_conntrack_ecache *cache = nf_ct_ecache_find(ct);
1184
1185                 if (cache)
1186                         cache->ctmask = info->eventmask;
1187         }
1188
1189         /* Apply changes before confirming the connection so that the initial
1190          * conntrack NEW netlink event carries the values given in the CT
1191          * action.
1192          */
1193         if (info->mark.mask) {
1194                 err = ovs_ct_set_mark(ct, key, info->mark.value,
1195                                       info->mark.mask);
1196                 if (err)
1197                         return err;
1198         }
1199         if (!nf_ct_is_confirmed(ct)) {
1200                 err = ovs_ct_init_labels(ct, key, &info->labels.value,
1201                                          &info->labels.mask);
1202                 if (err)
1203                         return err;
1204         } else if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) &&
1205                    labels_nonzero(&info->labels.mask)) {
1206                 err = ovs_ct_set_labels(ct, key, &info->labels.value,
1207                                         &info->labels.mask);
1208                 if (err)
1209                         return err;
1210         }
1211         /* This will take care of sending queued events even if the connection
1212          * is already confirmed.
1213          */
1214         if (nf_conntrack_confirm(skb) != NF_ACCEPT)
1215                 return -EINVAL;
1216
1217         return 0;
1218 }
1219
1220 /* Trim the skb to the length specified by the IP/IPv6 header,
1221  * removing any trailing lower-layer padding. This prepares the skb
1222  * for higher-layer processing that assumes skb->len excludes padding
1223  * (such as nf_ip_checksum). The caller needs to pull the skb to the
1224  * network header, and ensure ip_hdr/ipv6_hdr points to valid data.
1225  */
1226 static int ovs_skb_network_trim(struct sk_buff *skb)
1227 {
1228         unsigned int len;
1229         int err;
1230
1231         switch (skb->protocol) {
1232         case htons(ETH_P_IP):
1233                 len = ntohs(ip_hdr(skb)->tot_len);
1234                 break;
1235         case htons(ETH_P_IPV6):
1236                 len = sizeof(struct ipv6hdr)
1237                         + ntohs(ipv6_hdr(skb)->payload_len);
1238                 break;
1239         default:
1240                 len = skb->len;
1241         }
1242
1243         err = pskb_trim_rcsum(skb, len);
1244         if (err)
1245                 kfree_skb(skb);
1246
1247         return err;
1248 }
1249
1250 /* Returns 0 on success, -EINPROGRESS if 'skb' is stolen, or other nonzero
1251  * value if 'skb' is freed.
1252  */
1253 int ovs_ct_execute(struct net *net, struct sk_buff *skb,
1254                    struct sw_flow_key *key,
1255                    const struct ovs_conntrack_info *info)
1256 {
1257         int nh_ofs;
1258         int err;
1259
1260         /* The conntrack module expects to be working at L3. */
1261         nh_ofs = skb_network_offset(skb);
1262         skb_pull_rcsum(skb, nh_ofs);
1263
1264         err = ovs_skb_network_trim(skb);
1265         if (err)
1266                 return err;
1267
1268         if (key->ip.frag != OVS_FRAG_TYPE_NONE) {
1269                 err = handle_fragments(net, key, info->zone.id, skb);
1270                 if (err)
1271                         return err;
1272         }
1273
1274         if (info->commit)
1275                 err = ovs_ct_commit(net, key, info, skb);
1276         else
1277                 err = ovs_ct_lookup(net, key, info, skb);
1278
1279         skb_push(skb, nh_ofs);
1280         skb_postpush_rcsum(skb, skb->data, nh_ofs);
1281         if (err)
1282                 kfree_skb(skb);
1283         return err;
1284 }
1285
1286 int ovs_ct_clear(struct sk_buff *skb, struct sw_flow_key *key)
1287 {
1288         if (skb_nfct(skb)) {
1289                 nf_conntrack_put(skb_nfct(skb));
1290                 nf_ct_set(skb, NULL, IP_CT_UNTRACKED);
1291                 ovs_ct_fill_key(skb, key);
1292         }
1293
1294         return 0;
1295 }
1296
1297 static int ovs_ct_add_helper(struct ovs_conntrack_info *info, const char *name,
1298                              const struct sw_flow_key *key, bool log)
1299 {
1300         struct nf_conntrack_helper *helper;
1301         struct nf_conn_help *help;
1302
1303         helper = nf_conntrack_helper_try_module_get(name, info->family,
1304                                                     key->ip.proto);
1305         if (!helper) {
1306                 OVS_NLERR(log, "Unknown helper \"%s\"", name);
1307                 return -EINVAL;
1308         }
1309
1310         help = nf_ct_helper_ext_add(info->ct, GFP_KERNEL);
1311         if (!help) {
1312                 nf_conntrack_helper_put(helper);
1313                 return -ENOMEM;
1314         }
1315
1316         rcu_assign_pointer(help->helper, helper);
1317         info->helper = helper;
1318
1319         if (info->nat)
1320                 request_module("ip_nat_%s", name);
1321
1322         return 0;
1323 }
1324
1325 #ifdef CONFIG_NF_NAT_NEEDED
1326 static int parse_nat(const struct nlattr *attr,
1327                      struct ovs_conntrack_info *info, bool log)
1328 {
1329         struct nlattr *a;
1330         int rem;
1331         bool have_ip_max = false;
1332         bool have_proto_max = false;
1333         bool ip_vers = (info->family == NFPROTO_IPV6);
1334
1335         nla_for_each_nested(a, attr, rem) {
1336                 static const int ovs_nat_attr_lens[OVS_NAT_ATTR_MAX + 1][2] = {
1337                         [OVS_NAT_ATTR_SRC] = {0, 0},
1338                         [OVS_NAT_ATTR_DST] = {0, 0},
1339                         [OVS_NAT_ATTR_IP_MIN] = {sizeof(struct in_addr),
1340                                                  sizeof(struct in6_addr)},
1341                         [OVS_NAT_ATTR_IP_MAX] = {sizeof(struct in_addr),
1342                                                  sizeof(struct in6_addr)},
1343                         [OVS_NAT_ATTR_PROTO_MIN] = {sizeof(u16), sizeof(u16)},
1344                         [OVS_NAT_ATTR_PROTO_MAX] = {sizeof(u16), sizeof(u16)},
1345                         [OVS_NAT_ATTR_PERSISTENT] = {0, 0},
1346                         [OVS_NAT_ATTR_PROTO_HASH] = {0, 0},
1347                         [OVS_NAT_ATTR_PROTO_RANDOM] = {0, 0},
1348                 };
1349                 int type = nla_type(a);
1350
1351                 if (type > OVS_NAT_ATTR_MAX) {
1352                         OVS_NLERR(log, "Unknown NAT attribute (type=%d, max=%d)",
1353                                   type, OVS_NAT_ATTR_MAX);
1354                         return -EINVAL;
1355                 }
1356
1357                 if (nla_len(a) != ovs_nat_attr_lens[type][ip_vers]) {
1358                         OVS_NLERR(log, "NAT attribute type %d has unexpected length (%d != %d)",
1359                                   type, nla_len(a),
1360                                   ovs_nat_attr_lens[type][ip_vers]);
1361                         return -EINVAL;
1362                 }
1363
1364                 switch (type) {
1365                 case OVS_NAT_ATTR_SRC:
1366                 case OVS_NAT_ATTR_DST:
1367                         if (info->nat) {
1368                                 OVS_NLERR(log, "Only one type of NAT may be specified");
1369                                 return -ERANGE;
1370                         }
1371                         info->nat |= OVS_CT_NAT;
1372                         info->nat |= ((type == OVS_NAT_ATTR_SRC)
1373                                         ? OVS_CT_SRC_NAT : OVS_CT_DST_NAT);
1374                         break;
1375
1376                 case OVS_NAT_ATTR_IP_MIN:
1377                         nla_memcpy(&info->range.min_addr, a,
1378                                    sizeof(info->range.min_addr));
1379                         info->range.flags |= NF_NAT_RANGE_MAP_IPS;
1380                         break;
1381
1382                 case OVS_NAT_ATTR_IP_MAX:
1383                         have_ip_max = true;
1384                         nla_memcpy(&info->range.max_addr, a,
1385                                    sizeof(info->range.max_addr));
1386                         info->range.flags |= NF_NAT_RANGE_MAP_IPS;
1387                         break;
1388
1389                 case OVS_NAT_ATTR_PROTO_MIN:
1390                         info->range.min_proto.all = htons(nla_get_u16(a));
1391                         info->range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED;
1392                         break;
1393
1394                 case OVS_NAT_ATTR_PROTO_MAX:
1395                         have_proto_max = true;
1396                         info->range.max_proto.all = htons(nla_get_u16(a));
1397                         info->range.flags |= NF_NAT_RANGE_PROTO_SPECIFIED;
1398                         break;
1399
1400                 case OVS_NAT_ATTR_PERSISTENT:
1401                         info->range.flags |= NF_NAT_RANGE_PERSISTENT;
1402                         break;
1403
1404                 case OVS_NAT_ATTR_PROTO_HASH:
1405                         info->range.flags |= NF_NAT_RANGE_PROTO_RANDOM;
1406                         break;
1407
1408                 case OVS_NAT_ATTR_PROTO_RANDOM:
1409                         info->range.flags |= NF_NAT_RANGE_PROTO_RANDOM_FULLY;
1410                         break;
1411
1412                 default:
1413                         OVS_NLERR(log, "Unknown nat attribute (%d)", type);
1414                         return -EINVAL;
1415                 }
1416         }
1417
1418         if (rem > 0) {
1419                 OVS_NLERR(log, "NAT attribute has %d unknown bytes", rem);
1420                 return -EINVAL;
1421         }
1422         if (!info->nat) {
1423                 /* Do not allow flags if no type is given. */
1424                 if (info->range.flags) {
1425                         OVS_NLERR(log,
1426                                   "NAT flags may be given only when NAT range (SRC or DST) is also specified."
1427                                   );
1428                         return -EINVAL;
1429                 }
1430                 info->nat = OVS_CT_NAT;   /* NAT existing connections. */
1431         } else if (!info->commit) {
1432                 OVS_NLERR(log,
1433                           "NAT attributes may be specified only when CT COMMIT flag is also specified."
1434                           );
1435                 return -EINVAL;
1436         }
1437         /* Allow missing IP_MAX. */
1438         if (info->range.flags & NF_NAT_RANGE_MAP_IPS && !have_ip_max) {
1439                 memcpy(&info->range.max_addr, &info->range.min_addr,
1440                        sizeof(info->range.max_addr));
1441         }
1442         /* Allow missing PROTO_MAX. */
1443         if (info->range.flags & NF_NAT_RANGE_PROTO_SPECIFIED &&
1444             !have_proto_max) {
1445                 info->range.max_proto.all = info->range.min_proto.all;
1446         }
1447         return 0;
1448 }
1449 #endif
1450
1451 static const struct ovs_ct_len_tbl ovs_ct_attr_lens[OVS_CT_ATTR_MAX + 1] = {
1452         [OVS_CT_ATTR_COMMIT]    = { .minlen = 0, .maxlen = 0 },
1453         [OVS_CT_ATTR_FORCE_COMMIT]      = { .minlen = 0, .maxlen = 0 },
1454         [OVS_CT_ATTR_ZONE]      = { .minlen = sizeof(u16),
1455                                     .maxlen = sizeof(u16) },
1456         [OVS_CT_ATTR_MARK]      = { .minlen = sizeof(struct md_mark),
1457                                     .maxlen = sizeof(struct md_mark) },
1458         [OVS_CT_ATTR_LABELS]    = { .minlen = sizeof(struct md_labels),
1459                                     .maxlen = sizeof(struct md_labels) },
1460         [OVS_CT_ATTR_HELPER]    = { .minlen = 1,
1461                                     .maxlen = NF_CT_HELPER_NAME_LEN },
1462 #ifdef CONFIG_NF_NAT_NEEDED
1463         /* NAT length is checked when parsing the nested attributes. */
1464         [OVS_CT_ATTR_NAT]       = { .minlen = 0, .maxlen = INT_MAX },
1465 #endif
1466         [OVS_CT_ATTR_EVENTMASK] = { .minlen = sizeof(u32),
1467                                     .maxlen = sizeof(u32) },
1468 };
1469
1470 static int parse_ct(const struct nlattr *attr, struct ovs_conntrack_info *info,
1471                     const char **helper, bool log)
1472 {
1473         struct nlattr *a;
1474         int rem;
1475
1476         nla_for_each_nested(a, attr, rem) {
1477                 int type = nla_type(a);
1478                 int maxlen;
1479                 int minlen;
1480
1481                 if (type > OVS_CT_ATTR_MAX) {
1482                         OVS_NLERR(log,
1483                                   "Unknown conntrack attr (type=%d, max=%d)",
1484                                   type, OVS_CT_ATTR_MAX);
1485                         return -EINVAL;
1486                 }
1487
1488                 maxlen = ovs_ct_attr_lens[type].maxlen;
1489                 minlen = ovs_ct_attr_lens[type].minlen;
1490                 if (nla_len(a) < minlen || nla_len(a) > maxlen) {
1491                         OVS_NLERR(log,
1492                                   "Conntrack attr type has unexpected length (type=%d, length=%d, expected=%d)",
1493                                   type, nla_len(a), maxlen);
1494                         return -EINVAL;
1495                 }
1496
1497                 switch (type) {
1498                 case OVS_CT_ATTR_FORCE_COMMIT:
1499                         info->force = true;
1500                         /* fall through. */
1501                 case OVS_CT_ATTR_COMMIT:
1502                         info->commit = true;
1503                         break;
1504 #ifdef CONFIG_NF_CONNTRACK_ZONES
1505                 case OVS_CT_ATTR_ZONE:
1506                         info->zone.id = nla_get_u16(a);
1507                         break;
1508 #endif
1509 #ifdef CONFIG_NF_CONNTRACK_MARK
1510                 case OVS_CT_ATTR_MARK: {
1511                         struct md_mark *mark = nla_data(a);
1512
1513                         if (!mark->mask) {
1514                                 OVS_NLERR(log, "ct_mark mask cannot be 0");
1515                                 return -EINVAL;
1516                         }
1517                         info->mark = *mark;
1518                         break;
1519                 }
1520 #endif
1521 #ifdef CONFIG_NF_CONNTRACK_LABELS
1522                 case OVS_CT_ATTR_LABELS: {
1523                         struct md_labels *labels = nla_data(a);
1524
1525                         if (!labels_nonzero(&labels->mask)) {
1526                                 OVS_NLERR(log, "ct_labels mask cannot be 0");
1527                                 return -EINVAL;
1528                         }
1529                         info->labels = *labels;
1530                         break;
1531                 }
1532 #endif
1533                 case OVS_CT_ATTR_HELPER:
1534                         *helper = nla_data(a);
1535                         if (!memchr(*helper, '\0', nla_len(a))) {
1536                                 OVS_NLERR(log, "Invalid conntrack helper");
1537                                 return -EINVAL;
1538                         }
1539                         break;
1540 #ifdef CONFIG_NF_NAT_NEEDED
1541                 case OVS_CT_ATTR_NAT: {
1542                         int err = parse_nat(a, info, log);
1543
1544                         if (err)
1545                                 return err;
1546                         break;
1547                 }
1548 #endif
1549                 case OVS_CT_ATTR_EVENTMASK:
1550                         info->have_eventmask = true;
1551                         info->eventmask = nla_get_u32(a);
1552                         break;
1553
1554                 default:
1555                         OVS_NLERR(log, "Unknown conntrack attr (%d)",
1556                                   type);
1557                         return -EINVAL;
1558                 }
1559         }
1560
1561 #ifdef CONFIG_NF_CONNTRACK_MARK
1562         if (!info->commit && info->mark.mask) {
1563                 OVS_NLERR(log,
1564                           "Setting conntrack mark requires 'commit' flag.");
1565                 return -EINVAL;
1566         }
1567 #endif
1568 #ifdef CONFIG_NF_CONNTRACK_LABELS
1569         if (!info->commit && labels_nonzero(&info->labels.mask)) {
1570                 OVS_NLERR(log,
1571                           "Setting conntrack labels requires 'commit' flag.");
1572                 return -EINVAL;
1573         }
1574 #endif
1575         if (rem > 0) {
1576                 OVS_NLERR(log, "Conntrack attr has %d unknown bytes", rem);
1577                 return -EINVAL;
1578         }
1579
1580         return 0;
1581 }
1582
1583 bool ovs_ct_verify(struct net *net, enum ovs_key_attr attr)
1584 {
1585         if (attr == OVS_KEY_ATTR_CT_STATE)
1586                 return true;
1587         if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) &&
1588             attr == OVS_KEY_ATTR_CT_ZONE)
1589                 return true;
1590         if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) &&
1591             attr == OVS_KEY_ATTR_CT_MARK)
1592                 return true;
1593         if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) &&
1594             attr == OVS_KEY_ATTR_CT_LABELS) {
1595                 struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
1596
1597                 return ovs_net->xt_label;
1598         }
1599
1600         return false;
1601 }
1602
1603 int ovs_ct_copy_action(struct net *net, const struct nlattr *attr,
1604                        const struct sw_flow_key *key,
1605                        struct sw_flow_actions **sfa,  bool log)
1606 {
1607         struct ovs_conntrack_info ct_info;
1608         const char *helper = NULL;
1609         u16 family;
1610         int err;
1611
1612         family = key_to_nfproto(key);
1613         if (family == NFPROTO_UNSPEC) {
1614                 OVS_NLERR(log, "ct family unspecified");
1615                 return -EINVAL;
1616         }
1617
1618         memset(&ct_info, 0, sizeof(ct_info));
1619         ct_info.family = family;
1620
1621         nf_ct_zone_init(&ct_info.zone, NF_CT_DEFAULT_ZONE_ID,
1622                         NF_CT_DEFAULT_ZONE_DIR, 0);
1623
1624         err = parse_ct(attr, &ct_info, &helper, log);
1625         if (err)
1626                 return err;
1627
1628         /* Set up template for tracking connections in specific zones. */
1629         ct_info.ct = nf_ct_tmpl_alloc(net, &ct_info.zone, GFP_KERNEL);
1630         if (!ct_info.ct) {
1631                 OVS_NLERR(log, "Failed to allocate conntrack template");
1632                 return -ENOMEM;
1633         }
1634         if (helper) {
1635                 err = ovs_ct_add_helper(&ct_info, helper, key, log);
1636                 if (err)
1637                         goto err_free_ct;
1638         }
1639
1640         err = ovs_nla_add_action(sfa, OVS_ACTION_ATTR_CT, &ct_info,
1641                                  sizeof(ct_info), log);
1642         if (err)
1643                 goto err_free_ct;
1644
1645         __set_bit(IPS_CONFIRMED_BIT, &ct_info.ct->status);
1646         nf_conntrack_get(&ct_info.ct->ct_general);
1647         return 0;
1648 err_free_ct:
1649         __ovs_ct_free_action(&ct_info);
1650         return err;
1651 }
1652
1653 #ifdef CONFIG_NF_NAT_NEEDED
1654 static bool ovs_ct_nat_to_attr(const struct ovs_conntrack_info *info,
1655                                struct sk_buff *skb)
1656 {
1657         struct nlattr *start;
1658
1659         start = nla_nest_start(skb, OVS_CT_ATTR_NAT);
1660         if (!start)
1661                 return false;
1662
1663         if (info->nat & OVS_CT_SRC_NAT) {
1664                 if (nla_put_flag(skb, OVS_NAT_ATTR_SRC))
1665                         return false;
1666         } else if (info->nat & OVS_CT_DST_NAT) {
1667                 if (nla_put_flag(skb, OVS_NAT_ATTR_DST))
1668                         return false;
1669         } else {
1670                 goto out;
1671         }
1672
1673         if (info->range.flags & NF_NAT_RANGE_MAP_IPS) {
1674                 if (IS_ENABLED(CONFIG_NF_NAT) &&
1675                     info->family == NFPROTO_IPV4) {
1676                         if (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MIN,
1677                                             info->range.min_addr.ip) ||
1678                             (info->range.max_addr.ip
1679                              != info->range.min_addr.ip &&
1680                              (nla_put_in_addr(skb, OVS_NAT_ATTR_IP_MAX,
1681                                               info->range.max_addr.ip))))
1682                                 return false;
1683                 } else if (IS_ENABLED(CONFIG_IPV6) &&
1684                            info->family == NFPROTO_IPV6) {
1685                         if (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MIN,
1686                                              &info->range.min_addr.in6) ||
1687                             (memcmp(&info->range.max_addr.in6,
1688                                     &info->range.min_addr.in6,
1689                                     sizeof(info->range.max_addr.in6)) &&
1690                              (nla_put_in6_addr(skb, OVS_NAT_ATTR_IP_MAX,
1691                                                &info->range.max_addr.in6))))
1692                                 return false;
1693                 } else {
1694                         return false;
1695                 }
1696         }
1697         if (info->range.flags & NF_NAT_RANGE_PROTO_SPECIFIED &&
1698             (nla_put_u16(skb, OVS_NAT_ATTR_PROTO_MIN,
1699                          ntohs(info->range.min_proto.all)) ||
1700              (info->range.max_proto.all != info->range.min_proto.all &&
1701               nla_put_u16(skb, OVS_NAT_ATTR_PROTO_MAX,
1702                           ntohs(info->range.max_proto.all)))))
1703                 return false;
1704
1705         if (info->range.flags & NF_NAT_RANGE_PERSISTENT &&
1706             nla_put_flag(skb, OVS_NAT_ATTR_PERSISTENT))
1707                 return false;
1708         if (info->range.flags & NF_NAT_RANGE_PROTO_RANDOM &&
1709             nla_put_flag(skb, OVS_NAT_ATTR_PROTO_HASH))
1710                 return false;
1711         if (info->range.flags & NF_NAT_RANGE_PROTO_RANDOM_FULLY &&
1712             nla_put_flag(skb, OVS_NAT_ATTR_PROTO_RANDOM))
1713                 return false;
1714 out:
1715         nla_nest_end(skb, start);
1716
1717         return true;
1718 }
1719 #endif
1720
1721 int ovs_ct_action_to_attr(const struct ovs_conntrack_info *ct_info,
1722                           struct sk_buff *skb)
1723 {
1724         struct nlattr *start;
1725
1726         start = nla_nest_start(skb, OVS_ACTION_ATTR_CT);
1727         if (!start)
1728                 return -EMSGSIZE;
1729
1730         if (ct_info->commit && nla_put_flag(skb, ct_info->force
1731                                             ? OVS_CT_ATTR_FORCE_COMMIT
1732                                             : OVS_CT_ATTR_COMMIT))
1733                 return -EMSGSIZE;
1734         if (IS_ENABLED(CONFIG_NF_CONNTRACK_ZONES) &&
1735             nla_put_u16(skb, OVS_CT_ATTR_ZONE, ct_info->zone.id))
1736                 return -EMSGSIZE;
1737         if (IS_ENABLED(CONFIG_NF_CONNTRACK_MARK) && ct_info->mark.mask &&
1738             nla_put(skb, OVS_CT_ATTR_MARK, sizeof(ct_info->mark),
1739                     &ct_info->mark))
1740                 return -EMSGSIZE;
1741         if (IS_ENABLED(CONFIG_NF_CONNTRACK_LABELS) &&
1742             labels_nonzero(&ct_info->labels.mask) &&
1743             nla_put(skb, OVS_CT_ATTR_LABELS, sizeof(ct_info->labels),
1744                     &ct_info->labels))
1745                 return -EMSGSIZE;
1746         if (ct_info->helper) {
1747                 if (nla_put_string(skb, OVS_CT_ATTR_HELPER,
1748                                    ct_info->helper->name))
1749                         return -EMSGSIZE;
1750         }
1751         if (ct_info->have_eventmask &&
1752             nla_put_u32(skb, OVS_CT_ATTR_EVENTMASK, ct_info->eventmask))
1753                 return -EMSGSIZE;
1754
1755 #ifdef CONFIG_NF_NAT_NEEDED
1756         if (ct_info->nat && !ovs_ct_nat_to_attr(ct_info, skb))
1757                 return -EMSGSIZE;
1758 #endif
1759         nla_nest_end(skb, start);
1760
1761         return 0;
1762 }
1763
1764 void ovs_ct_free_action(const struct nlattr *a)
1765 {
1766         struct ovs_conntrack_info *ct_info = nla_data(a);
1767
1768         __ovs_ct_free_action(ct_info);
1769 }
1770
1771 static void __ovs_ct_free_action(struct ovs_conntrack_info *ct_info)
1772 {
1773         if (ct_info->helper)
1774                 nf_conntrack_helper_put(ct_info->helper);
1775         if (ct_info->ct)
1776                 nf_ct_tmpl_free(ct_info->ct);
1777 }
1778
1779 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
1780 static int ovs_ct_limit_init(struct net *net, struct ovs_net *ovs_net)
1781 {
1782         int i, err;
1783
1784         ovs_net->ct_limit_info = kmalloc(sizeof(*ovs_net->ct_limit_info),
1785                                          GFP_KERNEL);
1786         if (!ovs_net->ct_limit_info)
1787                 return -ENOMEM;
1788
1789         ovs_net->ct_limit_info->default_limit = OVS_CT_LIMIT_DEFAULT;
1790         ovs_net->ct_limit_info->limits =
1791                 kmalloc_array(CT_LIMIT_HASH_BUCKETS, sizeof(struct hlist_head),
1792                               GFP_KERNEL);
1793         if (!ovs_net->ct_limit_info->limits) {
1794                 kfree(ovs_net->ct_limit_info);
1795                 return -ENOMEM;
1796         }
1797
1798         for (i = 0; i < CT_LIMIT_HASH_BUCKETS; i++)
1799                 INIT_HLIST_HEAD(&ovs_net->ct_limit_info->limits[i]);
1800
1801         ovs_net->ct_limit_info->data =
1802                 nf_conncount_init(net, NFPROTO_INET, sizeof(u32));
1803
1804         if (IS_ERR(ovs_net->ct_limit_info->data)) {
1805                 err = PTR_ERR(ovs_net->ct_limit_info->data);
1806                 kfree(ovs_net->ct_limit_info->limits);
1807                 kfree(ovs_net->ct_limit_info);
1808                 pr_err("openvswitch: failed to init nf_conncount %d\n", err);
1809                 return err;
1810         }
1811         return 0;
1812 }
1813
1814 static void ovs_ct_limit_exit(struct net *net, struct ovs_net *ovs_net)
1815 {
1816         const struct ovs_ct_limit_info *info = ovs_net->ct_limit_info;
1817         int i;
1818
1819         nf_conncount_destroy(net, NFPROTO_INET, info->data);
1820         for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) {
1821                 struct hlist_head *head = &info->limits[i];
1822                 struct ovs_ct_limit *ct_limit;
1823
1824                 hlist_for_each_entry_rcu(ct_limit, head, hlist_node)
1825                         kfree_rcu(ct_limit, rcu);
1826         }
1827         kfree(ovs_net->ct_limit_info->limits);
1828         kfree(ovs_net->ct_limit_info);
1829 }
1830
1831 static struct sk_buff *
1832 ovs_ct_limit_cmd_reply_start(struct genl_info *info, u8 cmd,
1833                              struct ovs_header **ovs_reply_header)
1834 {
1835         struct ovs_header *ovs_header = info->userhdr;
1836         struct sk_buff *skb;
1837
1838         skb = genlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1839         if (!skb)
1840                 return ERR_PTR(-ENOMEM);
1841
1842         *ovs_reply_header = genlmsg_put(skb, info->snd_portid,
1843                                         info->snd_seq,
1844                                         &dp_ct_limit_genl_family, 0, cmd);
1845
1846         if (!*ovs_reply_header) {
1847                 nlmsg_free(skb);
1848                 return ERR_PTR(-EMSGSIZE);
1849         }
1850         (*ovs_reply_header)->dp_ifindex = ovs_header->dp_ifindex;
1851
1852         return skb;
1853 }
1854
1855 static bool check_zone_id(int zone_id, u16 *pzone)
1856 {
1857         if (zone_id >= 0 && zone_id <= 65535) {
1858                 *pzone = (u16)zone_id;
1859                 return true;
1860         }
1861         return false;
1862 }
1863
1864 static int ovs_ct_limit_set_zone_limit(struct nlattr *nla_zone_limit,
1865                                        struct ovs_ct_limit_info *info)
1866 {
1867         struct ovs_zone_limit *zone_limit;
1868         int rem;
1869         u16 zone;
1870
1871         rem = NLA_ALIGN(nla_len(nla_zone_limit));
1872         zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
1873
1874         while (rem >= sizeof(*zone_limit)) {
1875                 if (unlikely(zone_limit->zone_id ==
1876                                 OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
1877                         ovs_lock();
1878                         info->default_limit = zone_limit->limit;
1879                         ovs_unlock();
1880                 } else if (unlikely(!check_zone_id(
1881                                 zone_limit->zone_id, &zone))) {
1882                         OVS_NLERR(true, "zone id is out of range");
1883                 } else {
1884                         struct ovs_ct_limit *ct_limit;
1885
1886                         ct_limit = kmalloc(sizeof(*ct_limit), GFP_KERNEL);
1887                         if (!ct_limit)
1888                                 return -ENOMEM;
1889
1890                         ct_limit->zone = zone;
1891                         ct_limit->limit = zone_limit->limit;
1892
1893                         ovs_lock();
1894                         ct_limit_set(info, ct_limit);
1895                         ovs_unlock();
1896                 }
1897                 rem -= NLA_ALIGN(sizeof(*zone_limit));
1898                 zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
1899                                 NLA_ALIGN(sizeof(*zone_limit)));
1900         }
1901
1902         if (rem)
1903                 OVS_NLERR(true, "set zone limit has %d unknown bytes", rem);
1904
1905         return 0;
1906 }
1907
1908 static int ovs_ct_limit_del_zone_limit(struct nlattr *nla_zone_limit,
1909                                        struct ovs_ct_limit_info *info)
1910 {
1911         struct ovs_zone_limit *zone_limit;
1912         int rem;
1913         u16 zone;
1914
1915         rem = NLA_ALIGN(nla_len(nla_zone_limit));
1916         zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
1917
1918         while (rem >= sizeof(*zone_limit)) {
1919                 if (unlikely(zone_limit->zone_id ==
1920                                 OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
1921                         ovs_lock();
1922                         info->default_limit = OVS_CT_LIMIT_DEFAULT;
1923                         ovs_unlock();
1924                 } else if (unlikely(!check_zone_id(
1925                                 zone_limit->zone_id, &zone))) {
1926                         OVS_NLERR(true, "zone id is out of range");
1927                 } else {
1928                         ovs_lock();
1929                         ct_limit_del(info, zone);
1930                         ovs_unlock();
1931                 }
1932                 rem -= NLA_ALIGN(sizeof(*zone_limit));
1933                 zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
1934                                 NLA_ALIGN(sizeof(*zone_limit)));
1935         }
1936
1937         if (rem)
1938                 OVS_NLERR(true, "del zone limit has %d unknown bytes", rem);
1939
1940         return 0;
1941 }
1942
1943 static int ovs_ct_limit_get_default_limit(struct ovs_ct_limit_info *info,
1944                                           struct sk_buff *reply)
1945 {
1946         struct ovs_zone_limit zone_limit;
1947         int err;
1948
1949         zone_limit.zone_id = OVS_ZONE_LIMIT_DEFAULT_ZONE;
1950         zone_limit.limit = info->default_limit;
1951         err = nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit);
1952         if (err)
1953                 return err;
1954
1955         return 0;
1956 }
1957
1958 static int __ovs_ct_limit_get_zone_limit(struct net *net,
1959                                          struct nf_conncount_data *data,
1960                                          u16 zone_id, u32 limit,
1961                                          struct sk_buff *reply)
1962 {
1963         struct nf_conntrack_zone ct_zone;
1964         struct ovs_zone_limit zone_limit;
1965         u32 conncount_key = zone_id;
1966
1967         zone_limit.zone_id = zone_id;
1968         zone_limit.limit = limit;
1969         nf_ct_zone_init(&ct_zone, zone_id, NF_CT_DEFAULT_ZONE_DIR, 0);
1970
1971         zone_limit.count = nf_conncount_count(net, data, &conncount_key, NULL,
1972                                               &ct_zone);
1973         return nla_put_nohdr(reply, sizeof(zone_limit), &zone_limit);
1974 }
1975
1976 static int ovs_ct_limit_get_zone_limit(struct net *net,
1977                                        struct nlattr *nla_zone_limit,
1978                                        struct ovs_ct_limit_info *info,
1979                                        struct sk_buff *reply)
1980 {
1981         struct ovs_zone_limit *zone_limit;
1982         int rem, err;
1983         u32 limit;
1984         u16 zone;
1985
1986         rem = NLA_ALIGN(nla_len(nla_zone_limit));
1987         zone_limit = (struct ovs_zone_limit *)nla_data(nla_zone_limit);
1988
1989         while (rem >= sizeof(*zone_limit)) {
1990                 if (unlikely(zone_limit->zone_id ==
1991                                 OVS_ZONE_LIMIT_DEFAULT_ZONE)) {
1992                         err = ovs_ct_limit_get_default_limit(info, reply);
1993                         if (err)
1994                                 return err;
1995                 } else if (unlikely(!check_zone_id(zone_limit->zone_id,
1996                                                         &zone))) {
1997                         OVS_NLERR(true, "zone id is out of range");
1998                 } else {
1999                         rcu_read_lock();
2000                         limit = ct_limit_get(info, zone);
2001                         rcu_read_unlock();
2002
2003                         err = __ovs_ct_limit_get_zone_limit(
2004                                 net, info->data, zone, limit, reply);
2005                         if (err)
2006                                 return err;
2007                 }
2008                 rem -= NLA_ALIGN(sizeof(*zone_limit));
2009                 zone_limit = (struct ovs_zone_limit *)((u8 *)zone_limit +
2010                                 NLA_ALIGN(sizeof(*zone_limit)));
2011         }
2012
2013         if (rem)
2014                 OVS_NLERR(true, "get zone limit has %d unknown bytes", rem);
2015
2016         return 0;
2017 }
2018
2019 static int ovs_ct_limit_get_all_zone_limit(struct net *net,
2020                                            struct ovs_ct_limit_info *info,
2021                                            struct sk_buff *reply)
2022 {
2023         struct ovs_ct_limit *ct_limit;
2024         struct hlist_head *head;
2025         int i, err = 0;
2026
2027         err = ovs_ct_limit_get_default_limit(info, reply);
2028         if (err)
2029                 return err;
2030
2031         rcu_read_lock();
2032         for (i = 0; i < CT_LIMIT_HASH_BUCKETS; ++i) {
2033                 head = &info->limits[i];
2034                 hlist_for_each_entry_rcu(ct_limit, head, hlist_node) {
2035                         err = __ovs_ct_limit_get_zone_limit(net, info->data,
2036                                 ct_limit->zone, ct_limit->limit, reply);
2037                         if (err)
2038                                 goto exit_err;
2039                 }
2040         }
2041
2042 exit_err:
2043         rcu_read_unlock();
2044         return err;
2045 }
2046
2047 static int ovs_ct_limit_cmd_set(struct sk_buff *skb, struct genl_info *info)
2048 {
2049         struct nlattr **a = info->attrs;
2050         struct sk_buff *reply;
2051         struct ovs_header *ovs_reply_header;
2052         struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id);
2053         struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
2054         int err;
2055
2056         reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_SET,
2057                                              &ovs_reply_header);
2058         if (IS_ERR(reply))
2059                 return PTR_ERR(reply);
2060
2061         if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
2062                 err = -EINVAL;
2063                 goto exit_err;
2064         }
2065
2066         err = ovs_ct_limit_set_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT],
2067                                           ct_limit_info);
2068         if (err)
2069                 goto exit_err;
2070
2071         static_branch_enable(&ovs_ct_limit_enabled);
2072
2073         genlmsg_end(reply, ovs_reply_header);
2074         return genlmsg_reply(reply, info);
2075
2076 exit_err:
2077         nlmsg_free(reply);
2078         return err;
2079 }
2080
2081 static int ovs_ct_limit_cmd_del(struct sk_buff *skb, struct genl_info *info)
2082 {
2083         struct nlattr **a = info->attrs;
2084         struct sk_buff *reply;
2085         struct ovs_header *ovs_reply_header;
2086         struct ovs_net *ovs_net = net_generic(sock_net(skb->sk), ovs_net_id);
2087         struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
2088         int err;
2089
2090         reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_DEL,
2091                                              &ovs_reply_header);
2092         if (IS_ERR(reply))
2093                 return PTR_ERR(reply);
2094
2095         if (!a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
2096                 err = -EINVAL;
2097                 goto exit_err;
2098         }
2099
2100         err = ovs_ct_limit_del_zone_limit(a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT],
2101                                           ct_limit_info);
2102         if (err)
2103                 goto exit_err;
2104
2105         genlmsg_end(reply, ovs_reply_header);
2106         return genlmsg_reply(reply, info);
2107
2108 exit_err:
2109         nlmsg_free(reply);
2110         return err;
2111 }
2112
2113 static int ovs_ct_limit_cmd_get(struct sk_buff *skb, struct genl_info *info)
2114 {
2115         struct nlattr **a = info->attrs;
2116         struct nlattr *nla_reply;
2117         struct sk_buff *reply;
2118         struct ovs_header *ovs_reply_header;
2119         struct net *net = sock_net(skb->sk);
2120         struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
2121         struct ovs_ct_limit_info *ct_limit_info = ovs_net->ct_limit_info;
2122         int err;
2123
2124         reply = ovs_ct_limit_cmd_reply_start(info, OVS_CT_LIMIT_CMD_GET,
2125                                              &ovs_reply_header);
2126         if (IS_ERR(reply))
2127                 return PTR_ERR(reply);
2128
2129         nla_reply = nla_nest_start(reply, OVS_CT_LIMIT_ATTR_ZONE_LIMIT);
2130
2131         if (a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT]) {
2132                 err = ovs_ct_limit_get_zone_limit(
2133                         net, a[OVS_CT_LIMIT_ATTR_ZONE_LIMIT], ct_limit_info,
2134                         reply);
2135                 if (err)
2136                         goto exit_err;
2137         } else {
2138                 err = ovs_ct_limit_get_all_zone_limit(net, ct_limit_info,
2139                                                       reply);
2140                 if (err)
2141                         goto exit_err;
2142         }
2143
2144         nla_nest_end(reply, nla_reply);
2145         genlmsg_end(reply, ovs_reply_header);
2146         return genlmsg_reply(reply, info);
2147
2148 exit_err:
2149         nlmsg_free(reply);
2150         return err;
2151 }
2152
2153 static struct genl_ops ct_limit_genl_ops[] = {
2154         { .cmd = OVS_CT_LIMIT_CMD_SET,
2155                 .flags = GENL_ADMIN_PERM, /* Requires CAP_NET_ADMIN
2156                                            * privilege. */
2157                 .policy = ct_limit_policy,
2158                 .doit = ovs_ct_limit_cmd_set,
2159         },
2160         { .cmd = OVS_CT_LIMIT_CMD_DEL,
2161                 .flags = GENL_ADMIN_PERM, /* Requires CAP_NET_ADMIN
2162                                            * privilege. */
2163                 .policy = ct_limit_policy,
2164                 .doit = ovs_ct_limit_cmd_del,
2165         },
2166         { .cmd = OVS_CT_LIMIT_CMD_GET,
2167                 .flags = 0,               /* OK for unprivileged users. */
2168                 .policy = ct_limit_policy,
2169                 .doit = ovs_ct_limit_cmd_get,
2170         },
2171 };
2172
2173 static const struct genl_multicast_group ovs_ct_limit_multicast_group = {
2174         .name = OVS_CT_LIMIT_MCGROUP,
2175 };
2176
2177 struct genl_family dp_ct_limit_genl_family __ro_after_init = {
2178         .hdrsize = sizeof(struct ovs_header),
2179         .name = OVS_CT_LIMIT_FAMILY,
2180         .version = OVS_CT_LIMIT_VERSION,
2181         .maxattr = OVS_CT_LIMIT_ATTR_MAX,
2182         .netnsok = true,
2183         .parallel_ops = true,
2184         .ops = ct_limit_genl_ops,
2185         .n_ops = ARRAY_SIZE(ct_limit_genl_ops),
2186         .mcgrps = &ovs_ct_limit_multicast_group,
2187         .n_mcgrps = 1,
2188         .module = THIS_MODULE,
2189 };
2190 #endif
2191
2192 int ovs_ct_init(struct net *net)
2193 {
2194         unsigned int n_bits = sizeof(struct ovs_key_ct_labels) * BITS_PER_BYTE;
2195         struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
2196
2197         if (nf_connlabels_get(net, n_bits - 1)) {
2198                 ovs_net->xt_label = false;
2199                 OVS_NLERR(true, "Failed to set connlabel length");
2200         } else {
2201                 ovs_net->xt_label = true;
2202         }
2203
2204 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
2205         return ovs_ct_limit_init(net, ovs_net);
2206 #else
2207         return 0;
2208 #endif
2209 }
2210
2211 void ovs_ct_exit(struct net *net)
2212 {
2213         struct ovs_net *ovs_net = net_generic(net, ovs_net_id);
2214
2215 #if     IS_ENABLED(CONFIG_NETFILTER_CONNCOUNT)
2216         ovs_ct_limit_exit(net, ovs_net);
2217 #endif
2218
2219         if (ovs_net->xt_label)
2220                 nf_connlabels_put(net);
2221 }