#include #include #include #include typedef unsigned int u32; struct sock; #define MAX_MSG_FRAGS 5 #define WARN_ON(cond) fprintf(stderr, "WARNING %s:%d\n", __func__, __LINE__) #define sk_msg_iter_var_prev(var) \ do { \ if (var == 0) \ var = MAX_MSG_FRAGS - 1; \ else \ var--; \ } while (0) #define sk_msg_iter_var_next(var) \ do { \ var++; \ if (var == MAX_MSG_FRAGS) \ var = 0; \ } while (0) struct scatterlist { u32 length; }; struct sk_msg_sg { u32 start; u32 curr; u32 end; u32 size; u32 copybreak; /* The extra element is used for chaining the front and sections when * the list becomes partitioned (e.g. end < start). The crypto APIs * require the chaining. */ struct scatterlist data[MAX_MSG_FRAGS + 1]; }; struct sk_msg { struct sk_msg_sg sg; }; static void sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, int i, bool a) { msg->sg.data[i].length = 0; } static inline void sk_mem_uncharge() {} static inline u32 sk_msg_iter_dist(u32 start, u32 end) { return end >= start ? end - start : end + (MAX_MSG_FRAGS - start); } void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) { int trim = msg->sg.size - len; u32 i = msg->sg.end; if (trim <= 0) { WARN_ON(trim < 0); return; } sk_msg_iter_var_prev(i); msg->sg.size = len; while (msg->sg.data[i].length && trim >= msg->sg.data[i].length) { trim -= msg->sg.data[i].length; sk_msg_free_elem(sk, msg, i, true); sk_msg_iter_var_prev(i); if (!trim) goto out; } msg->sg.data[i].length -= trim; sk_mem_uncharge(sk, trim); /* Adjust copy break if it falls into the trimmed part of last buf */ if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length) msg->sg.copybreak = msg->sg.data[i].length; out: sk_msg_iter_var_next(i); msg->sg.end = i; /* If we trim data a full sg elem before curr pointer update * copybreak and current so that any future copy operations * start at new copy location. * However trimed data that has not yet been used in a copy op * does not require an update. */ if (!msg->sg.size) { msg->sg.curr = msg->sg.start; msg->sg.copybreak = 0; } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) > sk_msg_iter_dist(msg->sg.end, msg->sg.curr)) { sk_msg_iter_var_prev(i); msg->sg.curr = i; msg->sg.copybreak = msg->sg.data[i].length; } } #define NOPAREN(...) __VA_ARGS__ static void dump_msg(const char *str, struct sk_msg *msg, const char *end) { int i; fprintf(stderr, "%s start:%u curr:%u end:%u cb:%3u size:%4u ", str, msg->sg.start, msg->sg.curr, msg->sg.end, msg->sg.copybreak, msg->sg.size); for (i = 0; i < MAX_MSG_FRAGS; i++) fprintf(stderr, " %3u", msg->sg.data[i].length); fprintf(stderr, "%s", end); } #define test_one(_as, _ac, _acb, _len, _bc, _bcb, _be, _ad...) \ do { \ struct sk_msg msg = { \ .sg = { \ .start = (_as), \ .curr = (_ac), \ .end = 0, \ .copybreak = (_acb), \ }, \ }; \ int in_lens[] = { _ad }; \ struct sk_msg omsg = { \ .sg = { \ .start = (_as), \ .curr = (_bc), \ .end = (_be), \ .copybreak = (_bcb), \ }, \ }; \ int i; \ \ for (i = 0; i < sizeof(in_lens)/sizeof(in_lens[0]); i++) { \ int var = (i + msg.sg.start) % MAX_MSG_FRAGS; \ \ msg.sg.data[var].length = in_lens[i]; \ msg.sg.size += in_lens[i]; \ } \ msg.sg.end = (msg.sg.start + i) % MAX_MSG_FRAGS; \ \ for (i = 0; i < sizeof(in_lens)/sizeof(in_lens[0]); i++) { \ int var = (i + msg.sg.start) % MAX_MSG_FRAGS; \ int len = in_lens[i]; \ \ if ((_len) < omsg.sg.size + len) \ len = (_len) - omsg.sg.size; \ omsg.sg.data[var].length = len; \ omsg.sg.size += len; \ } \ \ fprintf(stderr, "test #%2u ", test_ID); \ dump_msg("", &msg, ""); \ sk_msg_trim(NULL, &msg, (_len)); \ \ if (memcmp(&msg, &omsg, sizeof(msg))) { \ fprintf(stderr, "\tfailed\n"); \ dump_msg(" result", &msg, "\n"); \ dump_msg(" expect", &omsg, "\n"); \ } else { \ fprintf(stderr, "\tOKAY\n"); \ } \ test_ID++; \ } while (0) static unsigned int test_ID; int main() { test_one(/* start */ 0, /* curr */ 0, /* copybreak */ 200, /* trim */ 100, /* curr */ 0, /* copybreak */ 100, /* end */ 1, /* data */ 200); test_one(/* start */ 1, /* curr */ 1, /* copybreak */ 200, /* trim */ 100, /* curr */ 1, /* copybreak */ 100, /* end */ 2, /* data */ 200); test_one(/* start */ 1, /* curr */ 1, /* copybreak */ 200, /* trim */ 300, /* curr */ 1, /* copybreak */ 200, /* end */ 3, /* data */ 200, 200); test_one(/* start */ 1, /* curr */ 2, /* copybreak */ 200, /* trim */ 200, /* curr */ 1, /* copybreak */ 200, /* end */ 2, /* data */ 200, 200, 200); test_one(/* start */ 1, /* curr */ 3, /* copybreak */ 200, /* trim */ 200, /* curr */ 1, /* copybreak */ 200, /* end */ 2, /* data */ 200, 200, 200); test_one(/* start */ 1, /* curr */ 2, /* copybreak */ 200, /* trim */ 0, /* curr */ 1, /* copybreak */ 0, /* end */ 1, /* data */ 200, 200, 200); test_one(/* start */ 1, /* curr */ 1, /* copybreak */ 200, /* trim */ 0, /* curr */ 1, /* copybreak */ 0, /* end */ 1, /* data */ 200, 200, 200, 200, 200); test_one(/* start */ 1, /* curr */ 3, /* copybreak */ 100, /* trim */ 0, /* curr */ 1, /* copybreak */ 0, /* end */ 1, /* data */ 200, 200, 200, 200, 200); test_one(/* start */ 1, /* curr */ 0, /* copybreak */ 200, /* trim */ 0, /* curr */ 1, /* copybreak */ 0, /* end */ 1, /* data */ 200, 200, 200, 200, 200); test_one(/* start */ 1, /* curr */ 0, /* copybreak */ 200, /* trim */ 900, /* curr */ 0, /* copybreak */ 100, /* end */ 1, /* data */ 200, 200, 200, 200, 200); test_one(/* start */ 1, /* curr */ 0, /* copybreak */ 200, /* trim */ 900, /* curr */ 0, /* copybreak */ 100, /* end */ 1, /* data */ 200, 200, 200, 200, 200); test_one(/* start */ 0, /* curr */ 1, /* copybreak */ 0, /* trim */ 100, /* curr */ 0, /* copybreak */ 100, /* end */ 1, /* data */ 200); test_one(/* start */ 1, /* curr */ 2, /* copybreak */ 0, /* trim */ 100, /* curr */ 1, /* copybreak */ 100, /* end */ 2, /* data */ 200); return 0; }