[<prev] [next>] [<thread-prev] [thread-next>] [day] [month] [year] [list]
Message-Id: <1545259460-13376-12-git-send-email-jiong.wang@netronome.com>
Date: Wed, 19 Dec 2018 17:44:18 -0500
From: Jiong Wang <jiong.wang@...ronome.com>
To: ast@...nel.org, daniel@...earbox.net
Cc: netdev@...r.kernel.org, oss-drivers@...ronome.com,
Jiong Wang <jiong.wang@...ronome.com>
Subject: [PATH bpf-next 11/13] bpf: verifier support JMP32
Verifier is doing some runtime optimizations based on the extra info
conditional jump instruction could offer, especially when the comparison
is between constant and register for which case the value range of the
register could be improved.
is_branch_taken/reg_set_min_max/reg_set_min_max_inv
are the three functions that needs updating.
There are some other conditional jump related optimizations but they
are with pointer types comparison which JMP32 won't be generated for.
Signed-off-by: Jiong Wang <jiong.wang@...ronome.com>
---
kernel/bpf/verifier.c | 178 ++++++++++++++++++++++++++++++++++++++------------
1 file changed, 137 insertions(+), 41 deletions(-)
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index e0e77ff..3123c91 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -3919,7 +3919,7 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
*/
static void reg_set_min_max(struct bpf_reg_state *true_reg,
struct bpf_reg_state *false_reg, u64 val,
- u8 opcode)
+ u8 opcode, bool is_jmp32)
{
/* If the dst_reg is a pointer, we can't learn anything about its
* variable offset from the compare (unless src_reg were a pointer into
@@ -3935,45 +3935,69 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
/* If this is false then we know nothing Jon Snow, but if it is
* true then we know for sure.
*/
- __mark_reg_known(true_reg, val);
+ if (is_jmp32)
+ true_reg->var_off = tnum_or(true_reg->var_off,
+ tnum_const(val));
+ else
+ __mark_reg_known(true_reg, val);
break;
case BPF_JNE:
/* If this is true we know nothing Jon Snow, but if it is false
- * we know the value for sure;
+ * we know the value for sure.
*/
- __mark_reg_known(false_reg, val);
+ if (is_jmp32)
+ false_reg->var_off = tnum_or(false_reg->var_off,
+ tnum_const(val));
+ else
+ __mark_reg_known(false_reg, val);
break;
case BPF_JGT:
- false_reg->umax_value = min(false_reg->umax_value, val);
+ if (!is_jmp32)
+ false_reg->umax_value = min(false_reg->umax_value, val);
true_reg->umin_value = max(true_reg->umin_value, val + 1);
break;
case BPF_JSGT:
- false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
+ if (!is_jmp32)
+ false_reg->smax_value = min_t(s64,
+ false_reg->smax_value,
+ val);
true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1);
break;
case BPF_JLT:
false_reg->umin_value = max(false_reg->umin_value, val);
- true_reg->umax_value = min(true_reg->umax_value, val - 1);
+ if (!is_jmp32)
+ true_reg->umax_value = min(true_reg->umax_value,
+ val - 1);
break;
case BPF_JSLT:
false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
- true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1);
+ if (!is_jmp32)
+ true_reg->smax_value = min_t(s64, true_reg->smax_value,
+ val - 1);
break;
case BPF_JGE:
- false_reg->umax_value = min(false_reg->umax_value, val - 1);
+ if (!is_jmp32)
+ false_reg->umax_value = min(false_reg->umax_value,
+ val - 1);
true_reg->umin_value = max(true_reg->umin_value, val);
break;
case BPF_JSGE:
- false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1);
+ if (!is_jmp32)
+ false_reg->smax_value = min_t(s64,
+ false_reg->smax_value,
+ val - 1);
true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
break;
case BPF_JLE:
false_reg->umin_value = max(false_reg->umin_value, val + 1);
- true_reg->umax_value = min(true_reg->umax_value, val);
+ if (!is_jmp32)
+ true_reg->umax_value = min(true_reg->umax_value, val);
break;
case BPF_JSLE:
false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1);
- true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
+ if (!is_jmp32)
+ true_reg->smax_value = min_t(s64, true_reg->smax_value,
+ val);
break;
default:
break;
@@ -3997,7 +4021,7 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
*/
static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
struct bpf_reg_state *false_reg, u64 val,
- u8 opcode)
+ u8 opcode, bool is_jmp32)
{
if (__is_pointer_value(false, false_reg))
return;
@@ -4007,45 +4031,69 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
/* If this is false then we know nothing Jon Snow, but if it is
* true then we know for sure.
*/
- __mark_reg_known(true_reg, val);
+ if (is_jmp32)
+ true_reg->var_off = tnum_or(true_reg->var_off,
+ tnum_const(val));
+ else
+ __mark_reg_known(true_reg, val);
break;
case BPF_JNE:
/* If this is true we know nothing Jon Snow, but if it is false
- * we know the value for sure;
+ * we know the value for sure.
*/
- __mark_reg_known(false_reg, val);
+ if (is_jmp32)
+ false_reg->var_off = tnum_or(false_reg->var_off,
+ tnum_const(val));
+ else
+ __mark_reg_known(false_reg, val);
break;
case BPF_JGT:
- true_reg->umax_value = min(true_reg->umax_value, val - 1);
+ if (!is_jmp32)
+ true_reg->umax_value = min(true_reg->umax_value,
+ val - 1);
false_reg->umin_value = max(false_reg->umin_value, val);
break;
case BPF_JSGT:
- true_reg->smax_value = min_t(s64, true_reg->smax_value, val - 1);
+ if (!is_jmp32)
+ true_reg->smax_value = min_t(s64, true_reg->smax_value,
+ val - 1);
false_reg->smin_value = max_t(s64, false_reg->smin_value, val);
break;
case BPF_JLT:
true_reg->umin_value = max(true_reg->umin_value, val + 1);
- false_reg->umax_value = min(false_reg->umax_value, val);
+ if (!is_jmp32)
+ false_reg->umax_value = min(false_reg->umax_value, val);
break;
case BPF_JSLT:
true_reg->smin_value = max_t(s64, true_reg->smin_value, val + 1);
- false_reg->smax_value = min_t(s64, false_reg->smax_value, val);
+ if (!is_jmp32)
+ false_reg->smax_value = min_t(s64,
+ false_reg->smax_value,
+ val);
break;
case BPF_JGE:
- true_reg->umax_value = min(true_reg->umax_value, val);
+ if (!is_jmp32)
+ true_reg->umax_value = min(true_reg->umax_value, val);
false_reg->umin_value = max(false_reg->umin_value, val + 1);
break;
case BPF_JSGE:
- true_reg->smax_value = min_t(s64, true_reg->smax_value, val);
+ if (!is_jmp32)
+ true_reg->smax_value = min_t(s64, true_reg->smax_value,
+ val);
false_reg->smin_value = max_t(s64, false_reg->smin_value, val + 1);
break;
case BPF_JLE:
true_reg->umin_value = max(true_reg->umin_value, val);
- false_reg->umax_value = min(false_reg->umax_value, val - 1);
+ if (!is_jmp32)
+ false_reg->umax_value = min(false_reg->umax_value,
+ val - 1);
break;
case BPF_JSLE:
true_reg->smin_value = max_t(s64, true_reg->smin_value, val);
- false_reg->smax_value = min_t(s64, false_reg->smax_value, val - 1);
+ if (!is_jmp32)
+ false_reg->smax_value = min_t(s64,
+ false_reg->smax_value,
+ val - 1);
break;
default:
break;
@@ -4276,6 +4324,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
struct bpf_reg_state *dst_reg, *other_branch_regs;
u8 opcode = BPF_OP(insn->code);
+ bool is_jmp32;
int err;
if (opcode > BPF_JSLE) {
@@ -4284,11 +4333,13 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
}
if (BPF_SRC(insn->code) == BPF_X) {
- if (insn->imm != 0) {
+ if (insn->imm & ~BPF_JMP_SUBOP_MASK) {
verbose(env, "BPF_JMP uses reserved fields\n");
return -EINVAL;
}
+ is_jmp32 = insn->imm & BPF_JMP_SUBOP_32BIT;
+
/* check src1 operand */
err = check_reg_arg(env, insn->src_reg, SRC_OP);
if (err)
@@ -4300,10 +4351,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
return -EACCES;
}
} else {
- if (insn->src_reg != BPF_REG_0) {
+ if (insn->src_reg & ~BPF_JMP_SUBOP_MASK) {
verbose(env, "BPF_JMP uses reserved fields\n");
return -EINVAL;
}
+
+ is_jmp32 = insn->src_reg & BPF_JMP_SUBOP_32BIT;
}
/* check src2 operand */
@@ -4314,8 +4367,29 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
dst_reg = ®s[insn->dst_reg];
if (BPF_SRC(insn->code) == BPF_K) {
- int pred = is_branch_taken(dst_reg, insn->imm, opcode);
+ struct bpf_reg_state *reg = dst_reg;
+ struct bpf_reg_state reg_lo;
+ int pred;
+
+ if (is_jmp32) {
+ reg_lo = *dst_reg;
+ reg = ®_lo;
+ coerce_reg_to_size(reg, 4);
+ /* If s32 min/max has the same sign bit, then min/max
+ * relationship still maintain after copying value from
+ * unsigned range inside coerce_reg_to_size.
+ */
+
+ if (reg->umax_value > (u32)S32_MAX &&
+ reg->umin_value <= (u32)S32_MAX) {
+ reg->smin_value = S32_MIN;
+ reg->smax_value = S32_MAX;
+ }
+ reg->smin_value = (s64)(s32)reg->smin_value;
+ reg->smax_value = (s64)(s32)reg->smax_value;
+ }
+ pred = is_branch_taken(reg, insn->imm, opcode);
if (pred == 1) {
/* only follow the goto, ignore fall-through */
*insn_idx += insn->off;
@@ -4341,30 +4415,51 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
* comparable.
*/
if (BPF_SRC(insn->code) == BPF_X) {
+ struct bpf_reg_state *src_reg = ®s[insn->src_reg];
+ struct bpf_reg_state lo_reg0 = *dst_reg;
+ struct bpf_reg_state lo_reg1 = *src_reg;
+ struct bpf_reg_state *src_lo, *dst_lo;
+
+ dst_lo = &lo_reg0;
+ src_lo = &lo_reg1;
+ coerce_reg_to_size(dst_lo, 4);
+ coerce_reg_to_size(src_lo, 4);
+
if (dst_reg->type == SCALAR_VALUE &&
- regs[insn->src_reg].type == SCALAR_VALUE) {
- if (tnum_is_const(regs[insn->src_reg].var_off))
+ src_reg->type == SCALAR_VALUE) {
+ if (tnum_is_const(src_reg->var_off) ||
+ (is_jmp32 && tnum_is_const(src_lo->var_off)))
reg_set_min_max(&other_branch_regs[insn->dst_reg],
- dst_reg, regs[insn->src_reg].var_off.value,
- opcode);
- else if (tnum_is_const(dst_reg->var_off))
+ dst_reg,
+ is_jmp32
+ ? src_lo->var_off.value
+ : src_reg->var_off.value,
+ opcode, is_jmp32);
+ else if (tnum_is_const(dst_reg->var_off) ||
+ (is_jmp32 && tnum_is_const(dst_lo->var_off)))
reg_set_min_max_inv(&other_branch_regs[insn->src_reg],
- ®s[insn->src_reg],
- dst_reg->var_off.value, opcode);
- else if (opcode == BPF_JEQ || opcode == BPF_JNE)
+ src_reg,
+ is_jmp32
+ ? dst_lo->var_off.value
+ : dst_reg->var_off.value,
+ opcode, is_jmp32);
+ else if (!is_jmp32 &&
+ (opcode == BPF_JEQ || opcode == BPF_JNE))
/* Comparing for equality, we can combine knowledge */
reg_combine_min_max(&other_branch_regs[insn->src_reg],
&other_branch_regs[insn->dst_reg],
- ®s[insn->src_reg],
- ®s[insn->dst_reg], opcode);
+ src_reg, dst_reg, opcode);
}
} else if (dst_reg->type == SCALAR_VALUE) {
reg_set_min_max(&other_branch_regs[insn->dst_reg],
- dst_reg, insn->imm, opcode);
+ dst_reg, insn->imm, opcode, is_jmp32);
}
- /* detect if R == 0 where R is returned from bpf_map_lookup_elem() */
- if (BPF_SRC(insn->code) == BPF_K &&
+ /* detect if R == 0 where R is returned from bpf_map_lookup_elem().
+ * NOTE: these optimizations below are related with pointer comparison
+ * which will always be JMP32.
+ */
+ if (!is_jmp32 && BPF_SRC(insn->code) == BPF_K &&
insn->imm == 0 && (opcode == BPF_JEQ || opcode == BPF_JNE) &&
reg_type_may_be_null(dst_reg->type)) {
/* Mark all identical registers in each branch as either
@@ -4374,7 +4469,8 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
opcode == BPF_JNE);
mark_ptr_or_null_regs(other_branch, insn->dst_reg,
opcode == BPF_JEQ);
- } else if (!try_match_pkt_pointers(insn, dst_reg, ®s[insn->src_reg],
+ } else if (!is_jmp32 &&
+ !try_match_pkt_pointers(insn, dst_reg, ®s[insn->src_reg],
this_branch, other_branch) &&
is_pointer_value(env, insn->dst_reg)) {
verbose(env, "R%d pointer comparison prohibited\n",
--
2.7.4
Powered by blists - more mailing lists