]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/sched/cls_flower.c
net: sched: flower: introduce reference counting for filters
[linux.git] / net / sched / cls_flower.c
index 12ca9d13db83320b6fe3381b1874206dec433a74..9ed7c9b804a7598222a5bbd01f6cd8be06c33e2e 100644 (file)
@@ -14,6 +14,7 @@
 #include <linux/module.h>
 #include <linux/rhashtable.h>
 #include <linux/workqueue.h>
+#include <linux/refcount.h>
 
 #include <linux/if_ether.h>
 #include <linux/in6.h>
@@ -104,6 +105,11 @@ struct cls_fl_filter {
        u32 in_hw_count;
        struct rcu_work rwork;
        struct net_device *hw_dev;
+       /* Flower classifier is unlocked, which means that its reference counter
+        * can be changed concurrently without any kind of external
+        * synchronization. Use atomic reference counter to be concurrency-safe.
+        */
+       refcount_t refcnt;
 };
 
 static const struct rhashtable_params mask_ht_params = {
@@ -381,16 +387,31 @@ static int fl_hw_replace_filter(struct tcf_proto *tp,
        bool skip_sw = tc_skip_sw(f->flags);
        int err;
 
+       cls_flower.rule = flow_rule_alloc(tcf_exts_num_actions(&f->exts));
+       if (!cls_flower.rule)
+               return -ENOMEM;
+
        tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, extack);
        cls_flower.command = TC_CLSFLOWER_REPLACE;
        cls_flower.cookie = (unsigned long) f;
-       cls_flower.dissector = &f->mask->dissector;
-       cls_flower.mask = &f->mask->key;
-       cls_flower.key = &f->mkey;
-       cls_flower.exts = &f->exts;
+       cls_flower.rule->match.dissector = &f->mask->dissector;
+       cls_flower.rule->match.mask = &f->mask->key;
+       cls_flower.rule->match.key = &f->mkey;
        cls_flower.classid = f->res.classid;
 
+       err = tc_setup_flow_action(&cls_flower.rule->action, &f->exts);
+       if (err) {
+               kfree(cls_flower.rule);
+               if (skip_sw) {
+                       NL_SET_ERR_MSG_MOD(extack, "Failed to setup flow action");
+                       return err;
+               }
+               return 0;
+       }
+
        err = tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, skip_sw);
+       kfree(cls_flower.rule);
+
        if (err < 0) {
                fl_hw_destroy_filter(tp, f, NULL);
                return err;
@@ -413,16 +434,71 @@ static void fl_hw_update_stats(struct tcf_proto *tp, struct cls_fl_filter *f)
        tc_cls_common_offload_init(&cls_flower.common, tp, f->flags, NULL);
        cls_flower.command = TC_CLSFLOWER_STATS;
        cls_flower.cookie = (unsigned long) f;
-       cls_flower.exts = &f->exts;
        cls_flower.classid = f->res.classid;
 
        tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false);
+
+       tcf_exts_stats_update(&f->exts, cls_flower.stats.bytes,
+                             cls_flower.stats.pkts,
+                             cls_flower.stats.lastused);
+}
+
+static struct cls_fl_head *fl_head_dereference(struct tcf_proto *tp)
+{
+       /* Flower classifier only changes root pointer during init and destroy.
+        * Users must obtain reference to tcf_proto instance before calling its
+        * API, so tp->root pointer is protected from concurrent call to
+        * fl_destroy() by reference counting.
+        */
+       return rcu_dereference_raw(tp->root);
+}
+
+static void __fl_put(struct cls_fl_filter *f)
+{
+       if (!refcount_dec_and_test(&f->refcnt))
+               return;
+
+       if (tcf_exts_get_net(&f->exts))
+               tcf_queue_work(&f->rwork, fl_destroy_filter_work);
+       else
+               __fl_destroy_filter(f);
+}
+
+static struct cls_fl_filter *__fl_get(struct cls_fl_head *head, u32 handle)
+{
+       struct cls_fl_filter *f;
+
+       rcu_read_lock();
+       f = idr_find(&head->handle_idr, handle);
+       if (f && !refcount_inc_not_zero(&f->refcnt))
+               f = NULL;
+       rcu_read_unlock();
+
+       return f;
+}
+
+static struct cls_fl_filter *fl_get_next_filter(struct tcf_proto *tp,
+                                               unsigned long *handle)
+{
+       struct cls_fl_head *head = fl_head_dereference(tp);
+       struct cls_fl_filter *f;
+
+       rcu_read_lock();
+       while ((f = idr_get_next_ul(&head->handle_idr, handle))) {
+               /* don't return filters that are being deleted */
+               if (refcount_inc_not_zero(&f->refcnt))
+                       break;
+               ++(*handle);
+       }
+       rcu_read_unlock();
+
+       return f;
 }
 
 static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
                        struct netlink_ext_ack *extack)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
        bool async = tcf_exts_get_net(&f->exts);
        bool last;
 
@@ -432,10 +508,7 @@ static bool __fl_delete(struct tcf_proto *tp, struct cls_fl_filter *f,
        if (!tc_skip_hw(f->flags))
                fl_hw_destroy_filter(tp, f, extack);
        tcf_unbind_filter(tp, &f->res);
-       if (async)
-               tcf_queue_work(&f->rwork, fl_destroy_filter_work);
-       else
-               __fl_destroy_filter(f);
+       __fl_put(f);
 
        return last;
 }
@@ -451,9 +524,10 @@ static void fl_destroy_sleepable(struct work_struct *work)
        module_put(THIS_MODULE);
 }
 
-static void fl_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack)
+static void fl_destroy(struct tcf_proto *tp, bool rtnl_held,
+                      struct netlink_ext_ack *extack)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
        struct fl_flow_mask *mask, *next_mask;
        struct cls_fl_filter *f, *next;
 
@@ -469,11 +543,18 @@ static void fl_destroy(struct tcf_proto *tp, struct netlink_ext_ack *extack)
        tcf_queue_work(&head->rwork, fl_destroy_sleepable);
 }
 
+static void fl_put(struct tcf_proto *tp, void *arg)
+{
+       struct cls_fl_filter *f = arg;
+
+       __fl_put(f);
+}
+
 static void *fl_get(struct tcf_proto *tp, u32 handle)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
 
-       return idr_find(&head->handle_idr, handle);
+       return __fl_get(head, handle);
 }
 
 static const struct nla_policy fl_policy[TCA_FLOWER_MAX + 1] = {
@@ -1258,7 +1339,8 @@ static int fl_set_parms(struct net *net, struct tcf_proto *tp,
 {
        int err;
 
-       err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, extack);
+       err = tcf_exts_validate(net, tp, tb, est, &f->exts, ovr, true,
+                               extack);
        if (err < 0)
                return err;
 
@@ -1285,21 +1367,26 @@ static int fl_set_parms(struct net *net, struct tcf_proto *tp,
 static int fl_change(struct net *net, struct sk_buff *in_skb,
                     struct tcf_proto *tp, unsigned long base,
                     u32 handle, struct nlattr **tca,
-                    void **arg, bool ovr, struct netlink_ext_ack *extack)
+                    void **arg, bool ovr, bool rtnl_held,
+                    struct netlink_ext_ack *extack)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
        struct cls_fl_filter *fold = *arg;
        struct cls_fl_filter *fnew;
        struct fl_flow_mask *mask;
        struct nlattr **tb;
        int err;
 
-       if (!tca[TCA_OPTIONS])
-               return -EINVAL;
+       if (!tca[TCA_OPTIONS]) {
+               err = -EINVAL;
+               goto errout_fold;
+       }
 
        mask = kzalloc(sizeof(struct fl_flow_mask), GFP_KERNEL);
-       if (!mask)
-               return -ENOBUFS;
+       if (!mask) {
+               err = -ENOBUFS;
+               goto errout_fold;
+       }
 
        tb = kcalloc(TCA_FLOWER_MAX + 1, sizeof(struct nlattr *), GFP_KERNEL);
        if (!tb) {
@@ -1322,95 +1409,104 @@ static int fl_change(struct net *net, struct sk_buff *in_skb,
                err = -ENOBUFS;
                goto errout_tb;
        }
+       refcount_set(&fnew->refcnt, 1);
 
-       err = tcf_exts_init(&fnew->exts, TCA_FLOWER_ACT, 0);
+       err = tcf_exts_init(&fnew->exts, net, TCA_FLOWER_ACT, 0);
        if (err < 0)
                goto errout;
 
-       if (!handle) {
-               handle = 1;
-               err = idr_alloc_u32(&head->handle_idr, fnew, &handle,
-                                   INT_MAX, GFP_KERNEL);
-       } else if (!fold) {
-               /* user specifies a handle and it doesn't exist */
-               err = idr_alloc_u32(&head->handle_idr, fnew, &handle,
-                                   handle, GFP_KERNEL);
-       }
-       if (err)
-               goto errout;
-       fnew->handle = handle;
-
        if (tb[TCA_FLOWER_FLAGS]) {
                fnew->flags = nla_get_u32(tb[TCA_FLOWER_FLAGS]);
 
                if (!tc_flags_valid(fnew->flags)) {
                        err = -EINVAL;
-                       goto errout_idr;
+                       goto errout;
                }
        }
 
        err = fl_set_parms(net, tp, fnew, mask, base, tb, tca[TCA_RATE], ovr,
                           tp->chain->tmplt_priv, extack);
        if (err)
-               goto errout_idr;
+               goto errout;
 
        err = fl_check_assign_mask(head, fnew, fold, mask);
        if (err)
-               goto errout_idr;
-
-       if (!fold && __fl_lookup(fnew->mask, &fnew->mkey)) {
-               err = -EEXIST;
-               goto errout_mask;
-       }
-
-       err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node,
-                                    fnew->mask->filter_ht_params);
-       if (err)
-               goto errout_mask;
+               goto errout;
 
        if (!tc_skip_hw(fnew->flags)) {
                err = fl_hw_replace_filter(tp, fnew, extack);
                if (err)
-                       goto errout_mask_ht;
+                       goto errout_mask;
        }
 
        if (!tc_in_hw(fnew->flags))
                fnew->flags |= TCA_CLS_FLAGS_NOT_IN_HW;
 
+       refcount_inc(&fnew->refcnt);
        if (fold) {
+               fnew->handle = handle;
+
+               err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node,
+                                            fnew->mask->filter_ht_params);
+               if (err)
+                       goto errout_hw;
+
                rhashtable_remove_fast(&fold->mask->ht,
                                       &fold->ht_node,
                                       fold->mask->filter_ht_params);
-               if (!tc_skip_hw(fold->flags))
-                       fl_hw_destroy_filter(tp, fold, NULL);
-       }
-
-       *arg = fnew;
-
-       if (fold) {
                idr_replace(&head->handle_idr, fnew, fnew->handle);
                list_replace_rcu(&fold->list, &fnew->list);
+
+               if (!tc_skip_hw(fold->flags))
+                       fl_hw_destroy_filter(tp, fold, NULL);
                tcf_unbind_filter(tp, &fold->res);
                tcf_exts_get_net(&fold->exts);
-               tcf_queue_work(&fold->rwork, fl_destroy_filter_work);
+               /* Caller holds reference to fold, so refcnt is always > 0
+                * after this.
+                */
+               refcount_dec(&fold->refcnt);
+               __fl_put(fold);
        } else {
+               if (__fl_lookup(fnew->mask, &fnew->mkey)) {
+                       err = -EEXIST;
+                       goto errout_hw;
+               }
+
+               if (handle) {
+                       /* user specifies a handle and it doesn't exist */
+                       err = idr_alloc_u32(&head->handle_idr, fnew, &handle,
+                                           handle, GFP_ATOMIC);
+               } else {
+                       handle = 1;
+                       err = idr_alloc_u32(&head->handle_idr, fnew, &handle,
+                                           INT_MAX, GFP_ATOMIC);
+               }
+               if (err)
+                       goto errout_hw;
+
+               fnew->handle = handle;
+
+               err = rhashtable_insert_fast(&fnew->mask->ht, &fnew->ht_node,
+                                            fnew->mask->filter_ht_params);
+               if (err)
+                       goto errout_idr;
+
                list_add_tail_rcu(&fnew->list, &fnew->mask->filters);
        }
 
+       *arg = fnew;
+
        kfree(tb);
        kfree(mask);
        return 0;
 
-errout_mask_ht:
-       rhashtable_remove_fast(&fnew->mask->ht, &fnew->ht_node,
-                              fnew->mask->filter_ht_params);
-
+errout_idr:
+       idr_remove(&head->handle_idr, fnew->handle);
+errout_hw:
+       if (!tc_skip_hw(fnew->flags))
+               fl_hw_destroy_filter(tp, fnew, NULL);
 errout_mask:
        fl_mask_put(head, fnew->mask, false);
-
-errout_idr:
-       if (!fold)
-               idr_remove(&head->handle_idr, fnew->handle);
 errout:
        tcf_exts_destroy(&fnew->exts);
        kfree(fnew);
@@ -1418,36 +1514,42 @@ static int fl_change(struct net *net, struct sk_buff *in_skb,
        kfree(tb);
 errout_mask_alloc:
        kfree(mask);
+errout_fold:
+       if (fold)
+               __fl_put(fold);
        return err;
 }
 
 static int fl_delete(struct tcf_proto *tp, void *arg, bool *last,
-                    struct netlink_ext_ack *extack)
+                    bool rtnl_held, struct netlink_ext_ack *extack)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
        struct cls_fl_filter *f = arg;
 
        rhashtable_remove_fast(&f->mask->ht, &f->ht_node,
                               f->mask->filter_ht_params);
        __fl_delete(tp, f, extack);
        *last = list_empty(&head->masks);
+       __fl_put(f);
+
        return 0;
 }
 
-static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg)
+static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg,
+                   bool rtnl_held)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
        struct cls_fl_filter *f;
 
        arg->count = arg->skip;
 
-       while ((f = idr_get_next_ul(&head->handle_idr,
-                                   &arg->cookie)) != NULL) {
+       while ((f = fl_get_next_filter(tp, &arg->cookie)) != NULL) {
                if (arg->fn(tp, f, arg) < 0) {
+                       __fl_put(f);
                        arg->stop = 1;
                        break;
                }
-               arg->cookie = f->handle + 1;
+               __fl_put(f);
+               arg->cookie++;
                arg->count++;
        }
 }
@@ -1455,7 +1557,7 @@ static void fl_walk(struct tcf_proto *tp, struct tcf_walker *arg)
 static int fl_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb,
                        void *cb_priv, struct netlink_ext_ack *extack)
 {
-       struct cls_fl_head *head = rtnl_dereference(tp->root);
+       struct cls_fl_head *head = fl_head_dereference(tp);
        struct tc_cls_flower_offload cls_flower = {};
        struct tcf_block *block = tp->chain->block;
        struct fl_flow_mask *mask;
@@ -1467,18 +1569,36 @@ static int fl_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb,
                        if (tc_skip_hw(f->flags))
                                continue;
 
+                       cls_flower.rule =
+                               flow_rule_alloc(tcf_exts_num_actions(&f->exts));
+                       if (!cls_flower.rule)
+                               return -ENOMEM;
+
                        tc_cls_common_offload_init(&cls_flower.common, tp,
                                                   f->flags, extack);
                        cls_flower.command = add ?
                                TC_CLSFLOWER_REPLACE : TC_CLSFLOWER_DESTROY;
                        cls_flower.cookie = (unsigned long)f;
-                       cls_flower.dissector = &mask->dissector;
-                       cls_flower.mask = &mask->key;
-                       cls_flower.key = &f->mkey;
-                       cls_flower.exts = &f->exts;
+                       cls_flower.rule->match.dissector = &mask->dissector;
+                       cls_flower.rule->match.mask = &mask->key;
+                       cls_flower.rule->match.key = &f->mkey;
+
+                       err = tc_setup_flow_action(&cls_flower.rule->action,
+                                                  &f->exts);
+                       if (err) {
+                               kfree(cls_flower.rule);
+                               if (tc_skip_sw(f->flags)) {
+                                       NL_SET_ERR_MSG_MOD(extack, "Failed to setup flow action");
+                                       return err;
+                               }
+                               continue;
+                       }
+
                        cls_flower.classid = f->res.classid;
 
                        err = cb(TC_SETUP_CLSFLOWER, &cls_flower, cb_priv);
+                       kfree(cls_flower.rule);
+
                        if (err) {
                                if (add && tc_skip_sw(f->flags))
                                        return err;
@@ -1493,25 +1613,30 @@ static int fl_reoffload(struct tcf_proto *tp, bool add, tc_setup_cb_t *cb,
        return 0;
 }
 
-static void fl_hw_create_tmplt(struct tcf_chain *chain,
-                              struct fl_flow_tmplt *tmplt)
+static int fl_hw_create_tmplt(struct tcf_chain *chain,
+                             struct fl_flow_tmplt *tmplt)
 {
        struct tc_cls_flower_offload cls_flower = {};
        struct tcf_block *block = chain->block;
-       struct tcf_exts dummy_exts = { 0, };
+
+       cls_flower.rule = flow_rule_alloc(0);
+       if (!cls_flower.rule)
+               return -ENOMEM;
 
        cls_flower.common.chain_index = chain->index;
        cls_flower.command = TC_CLSFLOWER_TMPLT_CREATE;
        cls_flower.cookie = (unsigned long) tmplt;
-       cls_flower.dissector = &tmplt->dissector;
-       cls_flower.mask = &tmplt->mask;
-       cls_flower.key = &tmplt->dummy_key;
-       cls_flower.exts = &dummy_exts;
+       cls_flower.rule->match.dissector = &tmplt->dissector;
+       cls_flower.rule->match.mask = &tmplt->mask;
+       cls_flower.rule->match.key = &tmplt->dummy_key;
 
        /* We don't care if driver (any of them) fails to handle this
         * call. It serves just as a hint for it.
         */
        tc_setup_cb_call(block, TC_SETUP_CLSFLOWER, &cls_flower, false);
+       kfree(cls_flower.rule);
+
+       return 0;
 }
 
 static void fl_hw_destroy_tmplt(struct tcf_chain *chain,
@@ -1555,12 +1680,14 @@ static void *fl_tmplt_create(struct net *net, struct tcf_chain *chain,
        err = fl_set_key(net, tb, &tmplt->dummy_key, &tmplt->mask, extack);
        if (err)
                goto errout_tmplt;
-       kfree(tb);
 
        fl_init_dissector(&tmplt->dissector, &tmplt->mask);
 
-       fl_hw_create_tmplt(chain, tmplt);
+       err = fl_hw_create_tmplt(chain, tmplt);
+       if (err)
+               goto errout_tmplt;
 
+       kfree(tb);
        return tmplt;
 
 errout_tmplt:
@@ -2008,7 +2135,7 @@ static int fl_dump_key(struct sk_buff *skb, struct net *net,
 }
 
 static int fl_dump(struct net *net, struct tcf_proto *tp, void *fh,
-                  struct sk_buff *skb, struct tcmsg *t)
+                  struct sk_buff *skb, struct tcmsg *t, bool rtnl_held)
 {
        struct cls_fl_filter *f = fh;
        struct nlattr *nest;
@@ -2096,6 +2223,7 @@ static struct tcf_proto_ops cls_fl_ops __read_mostly = {
        .init           = fl_init,
        .destroy        = fl_destroy,
        .get            = fl_get,
+       .put            = fl_put,
        .change         = fl_change,
        .delete         = fl_delete,
        .walk           = fl_walk,