bpf: verifier support JMP32
authorJiong Wang <jiong.wang@netronome.com>
Sat, 26 Jan 2019 17:26:01 +0000 (12:26 -0500)
committerAlexei Starovoitov <ast@kernel.org>
Sat, 26 Jan 2019 21:33:01 +0000 (13:33 -0800)
This patch teach verifier about the new BPF_JMP32 instruction class.
Verifier need to treat it similar as the existing BPF_JMP class.
A BPF_JMP32 insn needs to go through all checks that have been done on
BPF_JMP.

Also, verifier is doing runtime optimizations based on the extra info
conditional jump instruction could offer, especially when the comparison is
between constant and register that the value range of the register could be
improved based on the comparison results. These code are updated
accordingly.

Acked-by: Jakub Kicinski <jakub.kicinski@netronome.com>
Signed-off-by: Jiong Wang <jiong.wang@netronome.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
kernel/bpf/core.c
kernel/bpf/verifier.c

index 2a81b8af37482144df5b85f0998671e6e03b7c41..1e443ba97310ab6f5c86a69a49e76bf54316c1e7 100644 (file)
@@ -362,7 +362,8 @@ static int bpf_adj_branches(struct bpf_prog *prog, u32 pos, s32 end_old,
                        insn = prog->insnsi + end_old;
                }
                code = insn->code;
-               if (BPF_CLASS(code) != BPF_JMP ||
+               if ((BPF_CLASS(code) != BPF_JMP &&
+                    BPF_CLASS(code) != BPF_JMP32) ||
                    BPF_OP(code) == BPF_EXIT)
                        continue;
                /* Adjust offset of jmps if we cross patch boundaries. */
index eae6cb1fe6536ab047a31c13f7b13b02424d99c0..8c1c21cd50b4eb446392f1f4a9f479053fc7876c 100644 (file)
@@ -1095,7 +1095,7 @@ static int check_subprogs(struct bpf_verifier_env *env)
        for (i = 0; i < insn_cnt; i++) {
                u8 code = insn[i].code;
 
-               if (BPF_CLASS(code) != BPF_JMP)
+               if (BPF_CLASS(code) != BPF_JMP && BPF_CLASS(code) != BPF_JMP32)
                        goto next;
                if (BPF_OP(code) == BPF_EXIT || BPF_OP(code) == BPF_CALL)
                        goto next;
@@ -4031,14 +4031,49 @@ static void find_good_pkt_pointers(struct bpf_verifier_state *vstate,
  *  0 - branch will not be taken and fall-through to next insn
  * -1 - unknown. Example: "if (reg < 5)" is unknown when register value range [0,10]
  */
-static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
+static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode,
+                          bool is_jmp32)
 {
+       struct bpf_reg_state reg_lo;
        s64 sval;
 
        if (__is_pointer_value(false, reg))
                return -1;
 
-       sval = (s64)val;
+       if (is_jmp32) {
+               reg_lo = *reg;
+               reg = &reg_lo;
+               /* For JMP32, only low 32 bits are compared, coerce_reg_to_size
+                * could truncate high bits and update umin/umax according to
+                * information of low bits.
+                */
+               coerce_reg_to_size(reg, 4);
+               /* smin/smax need special handling. For example, after coerce,
+                * if smin_value is 0x00000000ffffffffLL, the value is -1 when
+                * used as operand to JMP32. It is a negative number from s32's
+                * point of view, while it is a positive number when seen as
+                * s64. The smin/smax are kept as s64, therefore, when used with
+                * JMP32, they need to be transformed into s32, then sign
+                * extended back to s64.
+                *
+                * Also, smin/smax were copied from umin/umax. If umin/umax has
+                * different sign bit, then min/max relationship doesn't
+                * maintain after casting into s32, for this case, set smin/smax
+                * to safest range.
+                */
+               if ((reg->umax_value ^ reg->umin_value) &
+                   (1ULL << 31)) {
+                       reg->smin_value = S32_MIN;
+                       reg->smax_value = S32_MAX;
+               }
+               reg->smin_value = (s64)(s32)reg->smin_value;
+               reg->smax_value = (s64)(s32)reg->smax_value;
+
+               val = (u32)val;
+               sval = (s64)(s32)val;
+       } else {
+               sval = (s64)val;
+       }
 
        switch (opcode) {
        case BPF_JEQ:
@@ -4108,6 +4143,29 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
        return -1;
 }
 
+/* Generate min value of the high 32-bit from TNUM info. */
+static u64 gen_hi_min(struct tnum var)
+{
+       return var.value & ~0xffffffffULL;
+}
+
+/* Generate max value of the high 32-bit from TNUM info. */
+static u64 gen_hi_max(struct tnum var)
+{
+       return (var.value | var.mask) & ~0xffffffffULL;
+}
+
+/* Return true if VAL is compared with a s64 sign extended from s32, and they
+ * are with the same signedness.
+ */
+static bool cmp_val_with_extended_s64(s64 sval, struct bpf_reg_state *reg)
+{
+       return ((s32)sval >= 0 &&
+               reg->smin_value >= 0 && reg->smax_value <= S32_MAX) ||
+              ((s32)sval < 0 &&
+               reg->smax_value <= 0 && reg->smin_value >= S32_MIN);
+}
+
 /* Adjusts the register min/max values in the case that the dst_reg is the
  * variable register that we are working on, and src_reg is a constant or we're
  * simply doing a BPF_K check.
@@ -4115,7 +4173,7 @@ static int is_branch_taken(struct bpf_reg_state *reg, u64 val, u8 opcode)
  */
 static void reg_set_min_max(struct bpf_reg_state *true_reg,
                            struct bpf_reg_state *false_reg, u64 val,
-                           u8 opcode)
+                           u8 opcode, bool is_jmp32)
 {
        s64 sval;
 
@@ -4128,7 +4186,8 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
        if (__is_pointer_value(false, false_reg))
                return;
 
-       sval = (s64)val;
+       val = is_jmp32 ? (u32)val : val;
+       sval = is_jmp32 ? (s64)(s32)val : (s64)val;
 
        switch (opcode) {
        case BPF_JEQ:
@@ -4141,7 +4200,15 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
                 * if it is true we know the value for sure. Likewise for
                 * BPF_JNE.
                 */
-               __mark_reg_known(reg, val);
+               if (is_jmp32) {
+                       u64 old_v = reg->var_off.value;
+                       u64 hi_mask = ~0xffffffffULL;
+
+                       reg->var_off.value = (old_v & hi_mask) | val;
+                       reg->var_off.mask &= hi_mask;
+               } else {
+                       __mark_reg_known(reg, val);
+               }
                break;
        }
        case BPF_JSET:
@@ -4157,6 +4224,10 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
                u64 false_umax = opcode == BPF_JGT ? val    : val - 1;
                u64 true_umin = opcode == BPF_JGT ? val + 1 : val;
 
+               if (is_jmp32) {
+                       false_umax += gen_hi_max(false_reg->var_off);
+                       true_umin += gen_hi_min(true_reg->var_off);
+               }
                false_reg->umax_value = min(false_reg->umax_value, false_umax);
                true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
@@ -4167,6 +4238,11 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
                s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
                s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
 
+               /* If the full s64 was not sign-extended from s32 then don't
+                * deduct further info.
+                */
+               if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+                       break;
                false_reg->smax_value = min(false_reg->smax_value, false_smax);
                true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
@@ -4177,6 +4253,10 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
                u64 false_umin = opcode == BPF_JLT ? val    : val + 1;
                u64 true_umax = opcode == BPF_JLT ? val - 1 : val;
 
+               if (is_jmp32) {
+                       false_umin += gen_hi_min(false_reg->var_off);
+                       true_umax += gen_hi_max(true_reg->var_off);
+               }
                false_reg->umin_value = max(false_reg->umin_value, false_umin);
                true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
@@ -4187,6 +4267,8 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
                s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
                s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
 
+               if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+                       break;
                false_reg->smin_value = max(false_reg->smin_value, false_smin);
                true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
@@ -4213,14 +4295,15 @@ static void reg_set_min_max(struct bpf_reg_state *true_reg,
  */
 static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                                struct bpf_reg_state *false_reg, u64 val,
-                               u8 opcode)
+                               u8 opcode, bool is_jmp32)
 {
        s64 sval;
 
        if (__is_pointer_value(false, false_reg))
                return;
 
-       sval = (s64)val;
+       val = is_jmp32 ? (u32)val : val;
+       sval = is_jmp32 ? (s64)(s32)val : (s64)val;
 
        switch (opcode) {
        case BPF_JEQ:
@@ -4229,7 +4312,15 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                struct bpf_reg_state *reg =
                        opcode == BPF_JEQ ? true_reg : false_reg;
 
-               __mark_reg_known(reg, val);
+               if (is_jmp32) {
+                       u64 old_v = reg->var_off.value;
+                       u64 hi_mask = ~0xffffffffULL;
+
+                       reg->var_off.value = (old_v & hi_mask) | val;
+                       reg->var_off.mask &= hi_mask;
+               } else {
+                       __mark_reg_known(reg, val);
+               }
                break;
        }
        case BPF_JSET:
@@ -4245,6 +4336,10 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                u64 false_umin = opcode == BPF_JGT ? val    : val + 1;
                u64 true_umax = opcode == BPF_JGT ? val - 1 : val;
 
+               if (is_jmp32) {
+                       false_umin += gen_hi_min(false_reg->var_off);
+                       true_umax += gen_hi_max(true_reg->var_off);
+               }
                false_reg->umin_value = max(false_reg->umin_value, false_umin);
                true_reg->umax_value = min(true_reg->umax_value, true_umax);
                break;
@@ -4255,6 +4350,8 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                s64 false_smin = opcode == BPF_JSGT ? sval    : sval + 1;
                s64 true_smax = opcode == BPF_JSGT ? sval - 1 : sval;
 
+               if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+                       break;
                false_reg->smin_value = max(false_reg->smin_value, false_smin);
                true_reg->smax_value = min(true_reg->smax_value, true_smax);
                break;
@@ -4265,6 +4362,10 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                u64 false_umax = opcode == BPF_JLT ? val    : val - 1;
                u64 true_umin = opcode == BPF_JLT ? val + 1 : val;
 
+               if (is_jmp32) {
+                       false_umax += gen_hi_max(false_reg->var_off);
+                       true_umin += gen_hi_min(true_reg->var_off);
+               }
                false_reg->umax_value = min(false_reg->umax_value, false_umax);
                true_reg->umin_value = max(true_reg->umin_value, true_umin);
                break;
@@ -4275,6 +4376,8 @@ static void reg_set_min_max_inv(struct bpf_reg_state *true_reg,
                s64 false_smax = opcode == BPF_JSLT ? sval    : sval - 1;
                s64 true_smin = opcode == BPF_JSLT ? sval + 1 : sval;
 
+               if (is_jmp32 && !cmp_val_with_extended_s64(sval, false_reg))
+                       break;
                false_reg->smax_value = min(false_reg->smax_value, false_smax);
                true_reg->smin_value = max(true_reg->smin_value, true_smin);
                break;
@@ -4416,6 +4519,10 @@ static bool try_match_pkt_pointers(const struct bpf_insn *insn,
        if (BPF_SRC(insn->code) != BPF_X)
                return false;
 
+       /* Pointers are always 64-bit. */
+       if (BPF_CLASS(insn->code) == BPF_JMP32)
+               return false;
+
        switch (BPF_OP(insn->code)) {
        case BPF_JGT:
                if ((dst_reg->type == PTR_TO_PACKET &&
@@ -4508,16 +4615,18 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
        struct bpf_reg_state *regs = this_branch->frame[this_branch->curframe]->regs;
        struct bpf_reg_state *dst_reg, *other_branch_regs;
        u8 opcode = BPF_OP(insn->code);
+       bool is_jmp32;
        int err;
 
-       if (opcode > BPF_JSLE) {
-               verbose(env, "invalid BPF_JMP opcode %x\n", opcode);
+       /* Only conditional jumps are expected to reach here. */
+       if (opcode == BPF_JA || opcode > BPF_JSLE) {
+               verbose(env, "invalid BPF_JMP/JMP32 opcode %x\n", opcode);
                return -EINVAL;
        }
 
        if (BPF_SRC(insn->code) == BPF_X) {
                if (insn->imm != 0) {
-                       verbose(env, "BPF_JMP uses reserved fields\n");
+                       verbose(env, "BPF_JMP/JMP32 uses reserved fields\n");
                        return -EINVAL;
                }
 
@@ -4533,7 +4642,7 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
                }
        } else {
                if (insn->src_reg != BPF_REG_0) {
-                       verbose(env, "BPF_JMP uses reserved fields\n");
+                       verbose(env, "BPF_JMP/JMP32 uses reserved fields\n");
                        return -EINVAL;
                }
        }
@@ -4544,9 +4653,11 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
                return err;
 
        dst_reg = &regs[insn->dst_reg];
+       is_jmp32 = BPF_CLASS(insn->code) == BPF_JMP32;
 
        if (BPF_SRC(insn->code) == BPF_K) {
-               int pred = is_branch_taken(dst_reg, insn->imm, opcode);
+               int pred = is_branch_taken(dst_reg, insn->imm, opcode,
+                                          is_jmp32);
 
                if (pred == 1) {
                         /* only follow the goto, ignore fall-through */
@@ -4574,30 +4685,51 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
         * comparable.
         */
        if (BPF_SRC(insn->code) == BPF_X) {
+               struct bpf_reg_state *src_reg = &regs[insn->src_reg];
+               struct bpf_reg_state lo_reg0 = *dst_reg;
+               struct bpf_reg_state lo_reg1 = *src_reg;
+               struct bpf_reg_state *src_lo, *dst_lo;
+
+               dst_lo = &lo_reg0;
+               src_lo = &lo_reg1;
+               coerce_reg_to_size(dst_lo, 4);
+               coerce_reg_to_size(src_lo, 4);
+
                if (dst_reg->type == SCALAR_VALUE &&
-                   regs[insn->src_reg].type == SCALAR_VALUE) {
-                       if (tnum_is_const(regs[insn->src_reg].var_off))
+                   src_reg->type == SCALAR_VALUE) {
+                       if (tnum_is_const(src_reg->var_off) ||
+                           (is_jmp32 && tnum_is_const(src_lo->var_off)))
                                reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                                               dst_reg, regs[insn->src_reg].var_off.value,
-                                               opcode);
-                       else if (tnum_is_const(dst_reg->var_off))
+                                               dst_reg,
+                                               is_jmp32
+                                               ? src_lo->var_off.value
+                                               : src_reg->var_off.value,
+                                               opcode, is_jmp32);
+                       else if (tnum_is_const(dst_reg->var_off) ||
+                                (is_jmp32 && tnum_is_const(dst_lo->var_off)))
                                reg_set_min_max_inv(&other_branch_regs[insn->src_reg],
-                                                   &regs[insn->src_reg],
-                                                   dst_reg->var_off.value, opcode);
-                       else if (opcode == BPF_JEQ || opcode == BPF_JNE)
+                                                   src_reg,
+                                                   is_jmp32
+                                                   ? dst_lo->var_off.value
+                                                   : dst_reg->var_off.value,
+                                                   opcode, is_jmp32);
+                       else if (!is_jmp32 &&
+                                (opcode == BPF_JEQ || opcode == BPF_JNE))
                                /* Comparing for equality, we can combine knowledge */
                                reg_combine_min_max(&other_branch_regs[insn->src_reg],
                                                    &other_branch_regs[insn->dst_reg],
-                                                   &regs[insn->src_reg],
-                                                   &regs[insn->dst_reg], opcode);
+                                                   src_reg, dst_reg, opcode);
                }
        } else if (dst_reg->type == SCALAR_VALUE) {
                reg_set_min_max(&other_branch_regs[insn->dst_reg],
-                                       dst_reg, insn->imm, opcode);
+                                       dst_reg, insn->imm, opcode, is_jmp32);
        }
 
-       /* detect if R == 0 where R is returned from bpf_map_lookup_elem() */
-       if (BPF_SRC(insn->code) == BPF_K &&
+       /* detect if R == 0 where R is returned from bpf_map_lookup_elem().
+        * NOTE: these optimizations below are related with pointer comparison
+        *       which will never be JMP32.
+        */
+       if (!is_jmp32 && BPF_SRC(insn->code) == BPF_K &&
            insn->imm == 0 && (opcode == BPF_JEQ || opcode == BPF_JNE) &&
            reg_type_may_be_null(dst_reg->type)) {
                /* Mark all identical registers in each branch as either
@@ -4926,7 +5058,8 @@ peek_stack:
                goto check_state;
        t = insn_stack[cur_stack - 1];
 
-       if (BPF_CLASS(insns[t].code) == BPF_JMP) {
+       if (BPF_CLASS(insns[t].code) == BPF_JMP ||
+           BPF_CLASS(insns[t].code) == BPF_JMP32) {
                u8 opcode = BPF_OP(insns[t].code);
 
                if (opcode == BPF_EXIT) {
@@ -6082,7 +6215,7 @@ static int do_check(struct bpf_verifier_env *env)
                        if (err)
                                return err;
 
-               } else if (class == BPF_JMP) {
+               } else if (class == BPF_JMP || class == BPF_JMP32) {
                        u8 opcode = BPF_OP(insn->code);
 
                        if (opcode == BPF_CALL) {
@@ -6090,7 +6223,8 @@ static int do_check(struct bpf_verifier_env *env)
                                    insn->off != 0 ||
                                    (insn->src_reg != BPF_REG_0 &&
                                     insn->src_reg != BPF_PSEUDO_CALL) ||
-                                   insn->dst_reg != BPF_REG_0) {
+                                   insn->dst_reg != BPF_REG_0 ||
+                                   class == BPF_JMP32) {
                                        verbose(env, "BPF_CALL uses reserved fields\n");
                                        return -EINVAL;
                                }
@@ -6106,7 +6240,8 @@ static int do_check(struct bpf_verifier_env *env)
                                if (BPF_SRC(insn->code) != BPF_K ||
                                    insn->imm != 0 ||
                                    insn->src_reg != BPF_REG_0 ||
-                                   insn->dst_reg != BPF_REG_0) {
+                                   insn->dst_reg != BPF_REG_0 ||
+                                   class == BPF_JMP32) {
                                        verbose(env, "BPF_JA uses reserved fields\n");
                                        return -EINVAL;
                                }
@@ -6118,7 +6253,8 @@ static int do_check(struct bpf_verifier_env *env)
                                if (BPF_SRC(insn->code) != BPF_K ||
                                    insn->imm != 0 ||
                                    insn->src_reg != BPF_REG_0 ||
-                                   insn->dst_reg != BPF_REG_0) {
+                                   insn->dst_reg != BPF_REG_0 ||
+                                   class == BPF_JMP32) {
                                        verbose(env, "BPF_EXIT uses reserved fields\n");
                                        return -EINVAL;
                                }
@@ -6635,6 +6771,9 @@ static bool insn_is_cond_jump(u8 code)
 {
        u8 op;
 
+       if (BPF_CLASS(code) == BPF_JMP32)
+               return true;
+
        if (BPF_CLASS(code) != BPF_JMP)
                return false;