summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--include/linux/skmsg.h28
-rw-r--r--net/core/filter.c8
-rw-r--r--net/core/skmsg.c2
-rw-r--r--net/ipv4/tcp_bpf.c2
4 files changed, 20 insertions, 20 deletions
diff --git a/include/linux/skmsg.h b/include/linux/skmsg.h
index 6cb077b646a5..ef7031f8a304 100644
--- a/include/linux/skmsg.h
+++ b/include/linux/skmsg.h
@@ -14,6 +14,7 @@
#include <net/strparser.h>
#define MAX_MSG_FRAGS MAX_SKB_FRAGS
+#define NR_MSG_FRAG_IDS (MAX_MSG_FRAGS + 1)
enum __sk_action {
__SK_DROP = 0,
@@ -29,13 +30,15 @@ struct sk_msg_sg {
u32 size;
u32 copybreak;
unsigned long copy;
- /* The extra element is used for chaining the front and sections when
- * the list becomes partitioned (e.g. end < start). The crypto APIs
- * require the chaining.
+ /* The extra two elements:
+ * 1) used for chaining the front and sections when the list becomes
+ * partitioned (e.g. end < start). The crypto APIs require the
+ * chaining;
+ * 2) to chain tailer SG entries after the message.
*/
- struct scatterlist data[MAX_MSG_FRAGS + 1];
+ struct scatterlist data[MAX_MSG_FRAGS + 2];
};
-static_assert(BITS_PER_LONG >= MAX_MSG_FRAGS);
+static_assert(BITS_PER_LONG >= NR_MSG_FRAG_IDS);
/* UAPI in filter.c depends on struct sk_msg_sg being first element. */
struct sk_msg {
@@ -142,13 +145,13 @@ static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes)
static inline u32 sk_msg_iter_dist(u32 start, u32 end)
{
- return end >= start ? end - start : end + (MAX_MSG_FRAGS - start);
+ return end >= start ? end - start : end + (NR_MSG_FRAG_IDS - start);
}
#define sk_msg_iter_var_prev(var) \
do { \
if (var == 0) \
- var = MAX_MSG_FRAGS - 1; \
+ var = NR_MSG_FRAG_IDS - 1; \
else \
var--; \
} while (0)
@@ -156,7 +159,7 @@ static inline u32 sk_msg_iter_dist(u32 start, u32 end)
#define sk_msg_iter_var_next(var) \
do { \
var++; \
- if (var == MAX_MSG_FRAGS) \
+ if (var == NR_MSG_FRAG_IDS) \
var = 0; \
} while (0)
@@ -173,9 +176,9 @@ static inline void sk_msg_clear_meta(struct sk_msg *msg)
static inline void sk_msg_init(struct sk_msg *msg)
{
- BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != MAX_MSG_FRAGS);
+ BUILD_BUG_ON(ARRAY_SIZE(msg->sg.data) - 1 != NR_MSG_FRAG_IDS);
memset(msg, 0, sizeof(*msg));
- sg_init_marker(msg->sg.data, MAX_MSG_FRAGS);
+ sg_init_marker(msg->sg.data, NR_MSG_FRAG_IDS);
}
static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src,
@@ -196,14 +199,11 @@ static inline void sk_msg_xfer_full(struct sk_msg *dst, struct sk_msg *src)
static inline bool sk_msg_full(const struct sk_msg *msg)
{
- return (msg->sg.end == msg->sg.start) && msg->sg.size;
+ return sk_msg_iter_dist(msg->sg.start, msg->sg.end) == MAX_MSG_FRAGS;
}
static inline u32 sk_msg_elem_used(const struct sk_msg *msg)
{
- if (sk_msg_full(msg))
- return MAX_MSG_FRAGS;
-
return sk_msg_iter_dist(msg->sg.start, msg->sg.end);
}
diff --git a/net/core/filter.c b/net/core/filter.c
index b0ed048585ba..f1e703eed3d2 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -2299,7 +2299,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
WARN_ON_ONCE(last_sge == first_sge);
shift = last_sge > first_sge ?
last_sge - first_sge - 1 :
- MAX_SKB_FRAGS - first_sge + last_sge - 1;
+ NR_MSG_FRAG_IDS - first_sge + last_sge - 1;
if (!shift)
goto out;
@@ -2308,8 +2308,8 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
do {
u32 move_from;
- if (i + shift >= MAX_MSG_FRAGS)
- move_from = i + shift - MAX_MSG_FRAGS;
+ if (i + shift >= NR_MSG_FRAG_IDS)
+ move_from = i + shift - NR_MSG_FRAG_IDS;
else
move_from = i + shift;
if (move_from == msg->sg.end)
@@ -2323,7 +2323,7 @@ BPF_CALL_4(bpf_msg_pull_data, struct sk_msg *, msg, u32, start,
} while (1);
msg->sg.end = msg->sg.end - shift > msg->sg.end ?
- msg->sg.end - shift + MAX_MSG_FRAGS :
+ msg->sg.end - shift + NR_MSG_FRAG_IDS :
msg->sg.end - shift;
out:
msg->data = sg_virt(&msg->sg.data[first_sge]) + start - offset;
diff --git a/net/core/skmsg.c b/net/core/skmsg.c
index a469d2124f3f..ded2d5227678 100644
--- a/net/core/skmsg.c
+++ b/net/core/skmsg.c
@@ -421,7 +421,7 @@ static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb)
copied = skb->len;
msg->sg.start = 0;
msg->sg.size = copied;
- msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge;
+ msg->sg.end = num_sge;
msg->skb = skb;
sk_psock_queue_msg(psock, msg);
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 8a56e09cfb0e..e38705165ac9 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -301,7 +301,7 @@ EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
struct sk_msg *msg, int *copied, int flags)
{
- bool cork = false, enospc = msg->sg.start == msg->sg.end;
+ bool cork = false, enospc = sk_msg_full(msg);
struct sock *sk_redir;
u32 tosend, delta = 0;
int ret;