bpf: fix range propagation on direct packet access
[cascardo/linux.git] / kernel / bpf / verifier.c
index abb61f3..90493a6 100644 (file)
@@ -1637,21 +1637,42 @@ static int check_alu_op(struct verifier_env *env, struct bpf_insn *insn)
        return 0;
 }
 
-static void find_good_pkt_pointers(struct verifier_env *env,
-                                  struct reg_state *dst_reg)
+static void find_good_pkt_pointers(struct verifier_state *state,
+                                  const struct reg_state *dst_reg)
 {
-       struct verifier_state *state = &env->cur_state;
        struct reg_state *regs = state->regs, *reg;
        int i;
-       /* r2 = r3;
-        * r2 += 8
-        * if (r2 > pkt_end) goto somewhere
-        * r2 == dst_reg, pkt_end == src_reg,
-        * r2=pkt(id=n,off=8,r=0)
-        * r3=pkt(id=n,off=0,r=0)
-        * find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
-        * so that range of bytes [r3, r3 + 8) is safe to access
+
+       /* LLVM can generate two kind of checks:
+        *
+        * Type 1:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (r2 > pkt_end) goto <handle exception>
+        *   <access okay>
+        *
+        *   Where:
+        *     r2 == dst_reg, pkt_end == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Type 2:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (pkt_end >= r2) goto <access okay>
+        *   <handle exception>
+        *
+        *   Where:
+        *     pkt_end == dst_reg, r2 == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
+        * so that range of bytes [r3, r3 + 8) is safe to access.
         */
+
        for (i = 0; i < MAX_BPF_REG; i++)
                if (regs[i].type == PTR_TO_PACKET && regs[i].id == dst_reg->id)
                        regs[i].range = dst_reg->off;
@@ -1668,8 +1689,8 @@ static void find_good_pkt_pointers(struct verifier_env *env,
 static int check_cond_jmp_op(struct verifier_env *env,
                             struct bpf_insn *insn, int *insn_idx)
 {
-       struct reg_state *regs = env->cur_state.regs, *dst_reg;
-       struct verifier_state *other_branch;
+       struct verifier_state *other_branch, *this_branch = &env->cur_state;
+       struct reg_state *regs = this_branch->regs, *dst_reg;
        u8 opcode = BPF_OP(insn->code);
        int err;
 
@@ -1750,13 +1771,17 @@ static int check_cond_jmp_op(struct verifier_env *env,
        } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGT &&
                   dst_reg->type == PTR_TO_PACKET &&
                   regs[insn->src_reg].type == PTR_TO_PACKET_END) {
-               find_good_pkt_pointers(env, dst_reg);
+               find_good_pkt_pointers(this_branch, dst_reg);
+       } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGE &&
+                  dst_reg->type == PTR_TO_PACKET_END &&
+                  regs[insn->src_reg].type == PTR_TO_PACKET) {
+               find_good_pkt_pointers(other_branch, &regs[insn->src_reg]);
        } else if (is_pointer_value(env, insn->dst_reg)) {
                verbose("R%d pointer comparison prohibited\n", insn->dst_reg);
                return -EACCES;
        }
        if (log_level)
-               print_verifier_state(&env->cur_state);
+               print_verifier_state(this_branch);
        return 0;
 }
 
@@ -2333,7 +2358,8 @@ static int do_check(struct verifier_env *env)
                        if (err)
                                return err;
 
-                       if (BPF_SIZE(insn->code) != BPF_W) {
+                       if (BPF_SIZE(insn->code) != BPF_W &&
+                           BPF_SIZE(insn->code) != BPF_DW) {
                                insn_idx++;
                                continue;
                        }
@@ -2510,6 +2536,20 @@ process_bpf_exit:
        return 0;
 }
 
+static int check_map_prog_compatibility(struct bpf_map *map,
+                                       struct bpf_prog *prog)
+
+{
+       if (prog->type == BPF_PROG_TYPE_PERF_EVENT &&
+           (map->map_type == BPF_MAP_TYPE_HASH ||
+            map->map_type == BPF_MAP_TYPE_PERCPU_HASH) &&
+           (map->map_flags & BPF_F_NO_PREALLOC)) {
+               verbose("perf_event programs can only use preallocated hash map\n");
+               return -EINVAL;
+       }
+       return 0;
+}
+
 /* look for pseudo eBPF instructions that access map FDs and
  * replace them with actual map pointers
  */
@@ -2517,7 +2557,7 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
 {
        struct bpf_insn *insn = env->prog->insnsi;
        int insn_cnt = env->prog->len;
-       int i, j;
+       int i, j, err;
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                if (BPF_CLASS(insn->code) == BPF_LDX &&
@@ -2561,6 +2601,12 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
                                return PTR_ERR(map);
                        }
 
+                       err = check_map_prog_compatibility(map, env->prog);
+                       if (err) {
+                               fdput(f);
+                               return err;
+                       }
+
                        /* store map pointer inside BPF_LD_IMM64 instruction */
                        insn[0].imm = (u32) (unsigned long) map;
                        insn[1].imm = ((u64) (unsigned long) map) >> 32;
@@ -2642,9 +2688,11 @@ static int convert_ctx_accesses(struct verifier_env *env)
        for (i = 0; i < insn_cnt; i++, insn++) {
                u32 insn_delta, cnt;
 
-               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W))
+               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
+                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW))
                        type = BPF_READ;
-               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W))
+               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W) ||
+                        insn->code == (BPF_STX | BPF_MEM | BPF_DW))
                        type = BPF_WRITE;
                else
                        continue;