diff options
-rw-r--r-- | include/net/pkt_cls.h | 2 | ||||
-rw-r--r-- | net/sched/cls_api.c | 70 | ||||
-rw-r--r-- | net/sched/sch_api.c | 4 |
3 files changed, 64 insertions, 12 deletions
diff --git a/include/net/pkt_cls.h b/include/net/pkt_cls.h index 38bee7dd21d1..e5dafa5ee1b2 100644 --- a/include/net/pkt_cls.h +++ b/include/net/pkt_cls.h @@ -46,6 +46,8 @@ struct tcf_chain *tcf_chain_get_by_act(struct tcf_block *block, void tcf_chain_put_by_act(struct tcf_chain *chain); struct tcf_chain *tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain); +struct tcf_proto *tcf_get_next_proto(struct tcf_chain *chain, + struct tcf_proto *tp); void tcf_block_netif_keep_dst(struct tcf_block *block); int tcf_block_get(struct tcf_block **p_block, struct tcf_proto __rcu **p_filter_chain, struct Qdisc *q, diff --git a/net/sched/cls_api.c b/net/sched/cls_api.c index 37c05b96898f..dca8a3bee9c2 100644 --- a/net/sched/cls_api.c +++ b/net/sched/cls_api.c @@ -980,6 +980,45 @@ tcf_get_next_chain(struct tcf_block *block, struct tcf_chain *chain) } EXPORT_SYMBOL(tcf_get_next_chain); +static struct tcf_proto * +__tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp) +{ + ASSERT_RTNL(); + mutex_lock(&chain->filter_chain_lock); + + if (!tp) + tp = tcf_chain_dereference(chain->filter_chain, chain); + else + tp = tcf_chain_dereference(tp->next, chain); + + if (tp) + tcf_proto_get(tp); + + mutex_unlock(&chain->filter_chain_lock); + + return tp; +} + +/* Function to be used by all clients that want to iterate over all tp's on + * chain. Users of this function must be tolerant to concurrent tp + * insertion/deletion or ensure that no concurrent chain modification is + * possible. Note that all netlink dump callbacks cannot guarantee to provide + * consistent dump because rtnl lock is released each time skb is filled with + * data and sent to user-space. + */ + +struct tcf_proto * +tcf_get_next_proto(struct tcf_chain *chain, struct tcf_proto *tp) +{ + struct tcf_proto *tp_next = __tcf_get_next_proto(chain, tp); + + if (tp) + tcf_proto_put(tp, NULL); + + return tp_next; +} +EXPORT_SYMBOL(tcf_get_next_proto); + static void tcf_block_flush_all_chains(struct tcf_block *block) { struct tcf_chain *chain; @@ -1352,7 +1391,7 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, struct netlink_ext_ack *extack) { struct tcf_chain *chain, *chain_prev; - struct tcf_proto *tp; + struct tcf_proto *tp, *tp_prev; int err; for (chain = __tcf_get_next_chain(block, NULL); @@ -1360,8 +1399,10 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, chain_prev = chain, chain = __tcf_get_next_chain(block, chain), tcf_chain_put(chain_prev)) { - for (tp = rtnl_dereference(chain->filter_chain); tp; - tp = rtnl_dereference(tp->next)) { + for (tp = __tcf_get_next_proto(chain, NULL); tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, NULL)) { if (tp->ops->reoffload) { err = tp->ops->reoffload(tp, add, cb, cb_priv, extack); @@ -1378,6 +1419,7 @@ tcf_block_playback_offloads(struct tcf_block *block, tc_setup_cb_t *cb, return 0; err_playback_remove: + tcf_proto_put(tp, NULL); tcf_chain_put(chain); tcf_block_playback_offloads(block, cb, cb_priv, false, offload_in_use, extack); @@ -1677,8 +1719,8 @@ static void tfilter_notify_chain(struct net *net, struct sk_buff *oskb, { struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next)) + for (tp = tcf_get_next_proto(chain, NULL); + tp; tp = tcf_get_next_proto(chain, tp)) tfilter_notify(net, oskb, n, tp, block, q, parent, NULL, event, false); } @@ -2104,11 +2146,15 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, struct net *net = sock_net(skb->sk); struct tcf_block *block = chain->block; struct tcmsg *tcm = nlmsg_data(cb->nlh); + struct tcf_proto *tp, *tp_prev; struct tcf_dump_args arg; - struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next), (*p_index)++) { + for (tp = __tcf_get_next_proto(chain, NULL); + tp; + tp_prev = tp, + tp = __tcf_get_next_proto(chain, tp), + tcf_proto_put(tp_prev, NULL), + (*p_index)++) { if (*p_index < index_start) continue; if (TC_H_MAJ(tcm->tcm_info) && @@ -2125,7 +2171,7 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, NETLINK_CB(cb->skb).portid, cb->nlh->nlmsg_seq, NLM_F_MULTI, RTM_NEWTFILTER) <= 0) - return false; + goto errout; cb->args[1] = 1; } @@ -2145,9 +2191,13 @@ static bool tcf_chain_dump(struct tcf_chain *chain, struct Qdisc *q, u32 parent, cb->args[2] = arg.w.cookie; cb->args[1] = arg.w.count + 1; if (arg.w.stop) - return false; + goto errout; } return true; + +errout: + tcf_proto_put(tp, NULL); + return false; } /* called with RTNL */ diff --git a/net/sched/sch_api.c b/net/sched/sch_api.c index 80058abc729f..9a530cad2759 100644 --- a/net/sched/sch_api.c +++ b/net/sched/sch_api.c @@ -1914,8 +1914,8 @@ static void tc_bind_tclass(struct Qdisc *q, u32 portid, u32 clid, chain = tcf_get_next_chain(block, chain)) { struct tcf_proto *tp; - for (tp = rtnl_dereference(chain->filter_chain); - tp; tp = rtnl_dereference(tp->next)) { + for (tp = tcf_get_next_proto(chain, NULL); + tp; tp = tcf_get_next_proto(chain, tp)) { struct tcf_bind_args arg = {}; arg.w.fn = tcf_node_bind; |