bpf: fix range propagation on direct packet access
[cascardo/linux.git] / kernel / bpf / verifier.c
index daea765..90493a6 100644 (file)
@@ -930,14 +930,14 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                          enum bpf_arg_type arg_type,
                          struct bpf_call_arg_meta *meta)
 {
-       struct reg_state *reg = env->cur_state.regs + regno;
-       enum bpf_reg_type expected_type;
+       struct reg_state *regs = env->cur_state.regs, *reg = &regs[regno];
+       enum bpf_reg_type expected_type, type = reg->type;
        int err = 0;
 
        if (arg_type == ARG_DONTCARE)
                return 0;
 
-       if (reg->type == NOT_INIT) {
+       if (type == NOT_INIT) {
                verbose("R%d !read_ok\n", regno);
                return -EACCES;
        }
@@ -950,16 +950,29 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                return 0;
        }
 
+       if (type == PTR_TO_PACKET && !may_write_pkt_data(env->prog->type)) {
+               verbose("helper access to the packet is not allowed for clsact\n");
+               return -EACCES;
+       }
+
        if (arg_type == ARG_PTR_TO_MAP_KEY ||
            arg_type == ARG_PTR_TO_MAP_VALUE) {
                expected_type = PTR_TO_STACK;
+               if (type != PTR_TO_PACKET && type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_CONST_STACK_SIZE ||
                   arg_type == ARG_CONST_STACK_SIZE_OR_ZERO) {
                expected_type = CONST_IMM;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_CONST_MAP_PTR) {
                expected_type = CONST_PTR_TO_MAP;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_PTR_TO_CTX) {
                expected_type = PTR_TO_CTX;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_PTR_TO_STACK ||
                   arg_type == ARG_PTR_TO_RAW_STACK) {
                expected_type = PTR_TO_STACK;
@@ -967,20 +980,16 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                 * passed in as argument, it's a CONST_IMM type. Final test
                 * happens during stack boundary checking.
                 */
-               if (reg->type == CONST_IMM && reg->imm == 0)
-                       expected_type = CONST_IMM;
+               if (type == CONST_IMM && reg->imm == 0)
+                       /* final test in check_stack_boundary() */;
+               else if (type != PTR_TO_PACKET && type != expected_type)
+                       goto err_type;
                meta->raw_mode = arg_type == ARG_PTR_TO_RAW_STACK;
        } else {
                verbose("unsupported arg_type %d\n", arg_type);
                return -EFAULT;
        }
 
-       if (reg->type != expected_type) {
-               verbose("R%d type=%s expected=%s\n", regno,
-                       reg_type_str[reg->type], reg_type_str[expected_type]);
-               return -EACCES;
-       }
-
        if (arg_type == ARG_CONST_MAP_PTR) {
                /* bpf_map_xxx(map_ptr) call: remember that map_ptr */
                meta->map_ptr = reg->map_ptr;
@@ -998,8 +1007,13 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("invalid map_ptr to access map->key\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno, meta->map_ptr->key_size,
-                                          false, NULL);
+               if (type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno, 0,
+                                                 meta->map_ptr->key_size);
+               else
+                       err = check_stack_boundary(env, regno,
+                                                  meta->map_ptr->key_size,
+                                                  false, NULL);
        } else if (arg_type == ARG_PTR_TO_MAP_VALUE) {
                /* bpf_map_xxx(..., map_ptr, ..., value) call:
                 * check [value, value + map->value_size) validity
@@ -1009,9 +1023,13 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("invalid map_ptr to access map->value\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno,
-                                          meta->map_ptr->value_size,
-                                          false, NULL);
+               if (type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno, 0,
+                                                 meta->map_ptr->value_size);
+               else
+                       err = check_stack_boundary(env, regno,
+                                                  meta->map_ptr->value_size,
+                                                  false, NULL);
        } else if (arg_type == ARG_CONST_STACK_SIZE ||
                   arg_type == ARG_CONST_STACK_SIZE_OR_ZERO) {
                bool zero_size_allowed = (arg_type == ARG_CONST_STACK_SIZE_OR_ZERO);
@@ -1025,11 +1043,18 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("ARG_CONST_STACK_SIZE cannot be first argument\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno - 1, reg->imm,
-                                          zero_size_allowed, meta);
+               if (regs[regno - 1].type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno - 1, 0, reg->imm);
+               else
+                       err = check_stack_boundary(env, regno - 1, reg->imm,
+                                                  zero_size_allowed, meta);
        }
 
        return err;
+err_type:
+       verbose("R%d type=%s expected=%s\n", regno,
+               reg_type_str[type], reg_type_str[expected_type]);
+       return -EACCES;
 }
 
 static int check_map_func_compatibility(struct bpf_map *map, int func_id)
@@ -1053,7 +1078,8 @@ static int check_map_func_compatibility(struct bpf_map *map, int func_id)
                        goto error;
                break;
        case BPF_MAP_TYPE_CGROUP_ARRAY:
-               if (func_id != BPF_FUNC_skb_under_cgroup)
+               if (func_id != BPF_FUNC_skb_under_cgroup &&
+                   func_id != BPF_FUNC_current_task_under_cgroup)
                        goto error;
                break;
        default:
@@ -1075,6 +1101,7 @@ static int check_map_func_compatibility(struct bpf_map *map, int func_id)
                if (map->map_type != BPF_MAP_TYPE_STACK_TRACE)
                        goto error;
                break;
+       case BPF_FUNC_current_task_under_cgroup:
        case BPF_FUNC_skb_under_cgroup:
                if (map->map_type != BPF_MAP_TYPE_CGROUP_ARRAY)
                        goto error;
@@ -1610,21 +1637,42 @@ static int check_alu_op(struct verifier_env *env, struct bpf_insn *insn)
        return 0;
 }
 
-static void find_good_pkt_pointers(struct verifier_env *env,
-                                  struct reg_state *dst_reg)
+static void find_good_pkt_pointers(struct verifier_state *state,
+                                  const struct reg_state *dst_reg)
 {
-       struct verifier_state *state = &env->cur_state;
        struct reg_state *regs = state->regs, *reg;
        int i;
-       /* r2 = r3;
-        * r2 += 8
-        * if (r2 > pkt_end) goto somewhere
-        * r2 == dst_reg, pkt_end == src_reg,
-        * r2=pkt(id=n,off=8,r=0)
-        * r3=pkt(id=n,off=0,r=0)
-        * find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
-        * so that range of bytes [r3, r3 + 8) is safe to access
+
+       /* LLVM can generate two kind of checks:
+        *
+        * Type 1:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (r2 > pkt_end) goto <handle exception>
+        *   <access okay>
+        *
+        *   Where:
+        *     r2 == dst_reg, pkt_end == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Type 2:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (pkt_end >= r2) goto <access okay>
+        *   <handle exception>
+        *
+        *   Where:
+        *     pkt_end == dst_reg, r2 == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
+        * so that range of bytes [r3, r3 + 8) is safe to access.
         */
+
        for (i = 0; i < MAX_BPF_REG; i++)
                if (regs[i].type == PTR_TO_PACKET && regs[i].id == dst_reg->id)
                        regs[i].range = dst_reg->off;
@@ -1641,8 +1689,8 @@ static void find_good_pkt_pointers(struct verifier_env *env,
 static int check_cond_jmp_op(struct verifier_env *env,
                             struct bpf_insn *insn, int *insn_idx)
 {
-       struct reg_state *regs = env->cur_state.regs, *dst_reg;
-       struct verifier_state *other_branch;
+       struct verifier_state *other_branch, *this_branch = &env->cur_state;
+       struct reg_state *regs = this_branch->regs, *dst_reg;
        u8 opcode = BPF_OP(insn->code);
        int err;
 
@@ -1723,13 +1771,17 @@ static int check_cond_jmp_op(struct verifier_env *env,
        } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGT &&
                   dst_reg->type == PTR_TO_PACKET &&
                   regs[insn->src_reg].type == PTR_TO_PACKET_END) {
-               find_good_pkt_pointers(env, dst_reg);
+               find_good_pkt_pointers(this_branch, dst_reg);
+       } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGE &&
+                  dst_reg->type == PTR_TO_PACKET_END &&
+                  regs[insn->src_reg].type == PTR_TO_PACKET) {
+               find_good_pkt_pointers(other_branch, &regs[insn->src_reg]);
        } else if (is_pointer_value(env, insn->dst_reg)) {
                verbose("R%d pointer comparison prohibited\n", insn->dst_reg);
                return -EACCES;
        }
        if (log_level)
-               print_verifier_state(&env->cur_state);
+               print_verifier_state(this_branch);
        return 0;
 }
 
@@ -2306,7 +2358,8 @@ static int do_check(struct verifier_env *env)
                        if (err)
                                return err;
 
-                       if (BPF_SIZE(insn->code) != BPF_W) {
+                       if (BPF_SIZE(insn->code) != BPF_W &&
+                           BPF_SIZE(insn->code) != BPF_DW) {
                                insn_idx++;
                                continue;
                        }
@@ -2483,6 +2536,20 @@ process_bpf_exit:
        return 0;
 }
 
+static int check_map_prog_compatibility(struct bpf_map *map,
+                                       struct bpf_prog *prog)
+
+{
+       if (prog->type == BPF_PROG_TYPE_PERF_EVENT &&
+           (map->map_type == BPF_MAP_TYPE_HASH ||
+            map->map_type == BPF_MAP_TYPE_PERCPU_HASH) &&
+           (map->map_flags & BPF_F_NO_PREALLOC)) {
+               verbose("perf_event programs can only use preallocated hash map\n");
+               return -EINVAL;
+       }
+       return 0;
+}
+
 /* look for pseudo eBPF instructions that access map FDs and
  * replace them with actual map pointers
  */
@@ -2490,7 +2557,7 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
 {
        struct bpf_insn *insn = env->prog->insnsi;
        int insn_cnt = env->prog->len;
-       int i, j;
+       int i, j, err;
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                if (BPF_CLASS(insn->code) == BPF_LDX &&
@@ -2534,6 +2601,12 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
                                return PTR_ERR(map);
                        }
 
+                       err = check_map_prog_compatibility(map, env->prog);
+                       if (err) {
+                               fdput(f);
+                               return err;
+                       }
+
                        /* store map pointer inside BPF_LD_IMM64 instruction */
                        insn[0].imm = (u32) (unsigned long) map;
                        insn[1].imm = ((u64) (unsigned long) map) >> 32;
@@ -2615,9 +2688,11 @@ static int convert_ctx_accesses(struct verifier_env *env)
        for (i = 0; i < insn_cnt; i++, insn++) {
                u32 insn_delta, cnt;
 
-               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W))
+               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
+                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW))
                        type = BPF_READ;
-               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W))
+               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W) ||
+                        insn->code == (BPF_STX | BPF_MEM | BPF_DW))
                        type = BPF_WRITE;
                else
                        continue;