bpf: fix range propagation on direct packet access
[cascardo/linux.git] / kernel / bpf / verifier.c
index eec9f90..90493a6 100644 (file)
@@ -194,6 +194,7 @@ struct verifier_env {
        struct verifier_state_list **explored_states; /* search pruning optimization */
        struct bpf_map *used_maps[MAX_USED_MAPS]; /* array of map's used by eBPF program */
        u32 used_map_cnt;               /* number of used maps */
+       u32 id_gen;                     /* used to generate unique reg IDs */
        bool allow_ptr_leaks;
 };
 
@@ -653,6 +654,16 @@ static int check_map_access(struct verifier_env *env, u32 regno, int off,
 
 #define MAX_PACKET_OFF 0xffff
 
+static bool may_write_pkt_data(enum bpf_prog_type type)
+{
+       switch (type) {
+       case BPF_PROG_TYPE_XDP:
+               return true;
+       default:
+               return false;
+       }
+}
+
 static int check_packet_access(struct verifier_env *env, u32 regno, int off,
                               int size)
 {
@@ -713,6 +724,7 @@ static int check_ptr_alignment(struct verifier_env *env, struct reg_state *reg,
        switch (env->prog->type) {
        case BPF_PROG_TYPE_SCHED_CLS:
        case BPF_PROG_TYPE_SCHED_ACT:
+       case BPF_PROG_TYPE_XDP:
                break;
        default:
                verbose("verifier is misconfigured\n");
@@ -805,10 +817,15 @@ static int check_mem_access(struct verifier_env *env, u32 regno, int off,
                        err = check_stack_read(state, off, size, value_regno);
                }
        } else if (state->regs[regno].type == PTR_TO_PACKET) {
-               if (t == BPF_WRITE) {
+               if (t == BPF_WRITE && !may_write_pkt_data(env->prog->type)) {
                        verbose("cannot write into packet\n");
                        return -EACCES;
                }
+               if (t == BPF_WRITE && value_regno >= 0 &&
+                   is_pointer_value(env, value_regno)) {
+                       verbose("R%d leaks addr into packet\n", value_regno);
+                       return -EACCES;
+               }
                err = check_packet_access(env, regno, off, size);
                if (!err && t == BPF_READ && value_regno >= 0)
                        mark_reg_unknown_value(state->regs, value_regno);
@@ -913,14 +930,14 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                          enum bpf_arg_type arg_type,
                          struct bpf_call_arg_meta *meta)
 {
-       struct reg_state *reg = env->cur_state.regs + regno;
-       enum bpf_reg_type expected_type;
+       struct reg_state *regs = env->cur_state.regs, *reg = &regs[regno];
+       enum bpf_reg_type expected_type, type = reg->type;
        int err = 0;
 
        if (arg_type == ARG_DONTCARE)
                return 0;
 
-       if (reg->type == NOT_INIT) {
+       if (type == NOT_INIT) {
                verbose("R%d !read_ok\n", regno);
                return -EACCES;
        }
@@ -933,16 +950,29 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                return 0;
        }
 
+       if (type == PTR_TO_PACKET && !may_write_pkt_data(env->prog->type)) {
+               verbose("helper access to the packet is not allowed for clsact\n");
+               return -EACCES;
+       }
+
        if (arg_type == ARG_PTR_TO_MAP_KEY ||
            arg_type == ARG_PTR_TO_MAP_VALUE) {
                expected_type = PTR_TO_STACK;
+               if (type != PTR_TO_PACKET && type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_CONST_STACK_SIZE ||
                   arg_type == ARG_CONST_STACK_SIZE_OR_ZERO) {
                expected_type = CONST_IMM;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_CONST_MAP_PTR) {
                expected_type = CONST_PTR_TO_MAP;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_PTR_TO_CTX) {
                expected_type = PTR_TO_CTX;
+               if (type != expected_type)
+                       goto err_type;
        } else if (arg_type == ARG_PTR_TO_STACK ||
                   arg_type == ARG_PTR_TO_RAW_STACK) {
                expected_type = PTR_TO_STACK;
@@ -950,20 +980,16 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                 * passed in as argument, it's a CONST_IMM type. Final test
                 * happens during stack boundary checking.
                 */
-               if (reg->type == CONST_IMM && reg->imm == 0)
-                       expected_type = CONST_IMM;
+               if (type == CONST_IMM && reg->imm == 0)
+                       /* final test in check_stack_boundary() */;
+               else if (type != PTR_TO_PACKET && type != expected_type)
+                       goto err_type;
                meta->raw_mode = arg_type == ARG_PTR_TO_RAW_STACK;
        } else {
                verbose("unsupported arg_type %d\n", arg_type);
                return -EFAULT;
        }
 
-       if (reg->type != expected_type) {
-               verbose("R%d type=%s expected=%s\n", regno,
-                       reg_type_str[reg->type], reg_type_str[expected_type]);
-               return -EACCES;
-       }
-
        if (arg_type == ARG_CONST_MAP_PTR) {
                /* bpf_map_xxx(map_ptr) call: remember that map_ptr */
                meta->map_ptr = reg->map_ptr;
@@ -981,8 +1007,13 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("invalid map_ptr to access map->key\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno, meta->map_ptr->key_size,
-                                          false, NULL);
+               if (type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno, 0,
+                                                 meta->map_ptr->key_size);
+               else
+                       err = check_stack_boundary(env, regno,
+                                                  meta->map_ptr->key_size,
+                                                  false, NULL);
        } else if (arg_type == ARG_PTR_TO_MAP_VALUE) {
                /* bpf_map_xxx(..., map_ptr, ..., value) call:
                 * check [value, value + map->value_size) validity
@@ -992,9 +1023,13 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("invalid map_ptr to access map->value\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno,
-                                          meta->map_ptr->value_size,
-                                          false, NULL);
+               if (type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno, 0,
+                                                 meta->map_ptr->value_size);
+               else
+                       err = check_stack_boundary(env, regno,
+                                                  meta->map_ptr->value_size,
+                                                  false, NULL);
        } else if (arg_type == ARG_CONST_STACK_SIZE ||
                   arg_type == ARG_CONST_STACK_SIZE_OR_ZERO) {
                bool zero_size_allowed = (arg_type == ARG_CONST_STACK_SIZE_OR_ZERO);
@@ -1008,11 +1043,18 @@ static int check_func_arg(struct verifier_env *env, u32 regno,
                        verbose("ARG_CONST_STACK_SIZE cannot be first argument\n");
                        return -EACCES;
                }
-               err = check_stack_boundary(env, regno - 1, reg->imm,
-                                          zero_size_allowed, meta);
+               if (regs[regno - 1].type == PTR_TO_PACKET)
+                       err = check_packet_access(env, regno - 1, 0, reg->imm);
+               else
+                       err = check_stack_boundary(env, regno - 1, reg->imm,
+                                                  zero_size_allowed, meta);
        }
 
        return err;
+err_type:
+       verbose("R%d type=%s expected=%s\n", regno,
+               reg_type_str[type], reg_type_str[expected_type]);
+       return -EACCES;
 }
 
 static int check_map_func_compatibility(struct bpf_map *map, int func_id)
@@ -1035,6 +1077,11 @@ static int check_map_func_compatibility(struct bpf_map *map, int func_id)
                if (func_id != BPF_FUNC_get_stackid)
                        goto error;
                break;
+       case BPF_MAP_TYPE_CGROUP_ARRAY:
+               if (func_id != BPF_FUNC_skb_under_cgroup &&
+                   func_id != BPF_FUNC_current_task_under_cgroup)
+                       goto error;
+               break;
        default:
                break;
        }
@@ -1054,6 +1101,11 @@ static int check_map_func_compatibility(struct bpf_map *map, int func_id)
                if (map->map_type != BPF_MAP_TYPE_STACK_TRACE)
                        goto error;
                break;
+       case BPF_FUNC_current_task_under_cgroup:
+       case BPF_FUNC_skb_under_cgroup:
+               if (map->map_type != BPF_MAP_TYPE_CGROUP_ARRAY)
+                       goto error;
+               break;
        default:
                break;
        }
@@ -1277,7 +1329,7 @@ add_imm:
                /* dst_reg stays as pkt_ptr type and since some positive
                 * integer value was added to the pointer, increment its 'id'
                 */
-               dst_reg->id++;
+               dst_reg->id = ++env->id_gen;
 
                /* something was added to pkt_ptr, set range and off to zero */
                dst_reg->off = 0;
@@ -1585,21 +1637,42 @@ static int check_alu_op(struct verifier_env *env, struct bpf_insn *insn)
        return 0;
 }
 
-static void find_good_pkt_pointers(struct verifier_env *env,
-                                  struct reg_state *dst_reg)
+static void find_good_pkt_pointers(struct verifier_state *state,
+                                  const struct reg_state *dst_reg)
 {
-       struct verifier_state *state = &env->cur_state;
        struct reg_state *regs = state->regs, *reg;
        int i;
-       /* r2 = r3;
-        * r2 += 8
-        * if (r2 > pkt_end) goto somewhere
-        * r2 == dst_reg, pkt_end == src_reg,
-        * r2=pkt(id=n,off=8,r=0)
-        * r3=pkt(id=n,off=0,r=0)
-        * find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
-        * so that range of bytes [r3, r3 + 8) is safe to access
+
+       /* LLVM can generate two kind of checks:
+        *
+        * Type 1:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (r2 > pkt_end) goto <handle exception>
+        *   <access okay>
+        *
+        *   Where:
+        *     r2 == dst_reg, pkt_end == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Type 2:
+        *
+        *   r2 = r3;
+        *   r2 += 8;
+        *   if (pkt_end >= r2) goto <access okay>
+        *   <handle exception>
+        *
+        *   Where:
+        *     pkt_end == dst_reg, r2 == src_reg
+        *     r2=pkt(id=n,off=8,r=0)
+        *     r3=pkt(id=n,off=0,r=0)
+        *
+        * Find register r3 and mark its range as r3=pkt(id=n,off=0,r=8)
+        * so that range of bytes [r3, r3 + 8) is safe to access.
         */
+
        for (i = 0; i < MAX_BPF_REG; i++)
                if (regs[i].type == PTR_TO_PACKET && regs[i].id == dst_reg->id)
                        regs[i].range = dst_reg->off;
@@ -1616,8 +1689,8 @@ static void find_good_pkt_pointers(struct verifier_env *env,
 static int check_cond_jmp_op(struct verifier_env *env,
                             struct bpf_insn *insn, int *insn_idx)
 {
-       struct reg_state *regs = env->cur_state.regs, *dst_reg;
-       struct verifier_state *other_branch;
+       struct verifier_state *other_branch, *this_branch = &env->cur_state;
+       struct reg_state *regs = this_branch->regs, *dst_reg;
        u8 opcode = BPF_OP(insn->code);
        int err;
 
@@ -1698,13 +1771,17 @@ static int check_cond_jmp_op(struct verifier_env *env,
        } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGT &&
                   dst_reg->type == PTR_TO_PACKET &&
                   regs[insn->src_reg].type == PTR_TO_PACKET_END) {
-               find_good_pkt_pointers(env, dst_reg);
+               find_good_pkt_pointers(this_branch, dst_reg);
+       } else if (BPF_SRC(insn->code) == BPF_X && opcode == BPF_JGE &&
+                  dst_reg->type == PTR_TO_PACKET_END &&
+                  regs[insn->src_reg].type == PTR_TO_PACKET) {
+               find_good_pkt_pointers(other_branch, &regs[insn->src_reg]);
        } else if (is_pointer_value(env, insn->dst_reg)) {
                verbose("R%d pointer comparison prohibited\n", insn->dst_reg);
                return -EACCES;
        }
        if (log_level)
-               print_verifier_state(&env->cur_state);
+               print_verifier_state(this_branch);
        return 0;
 }
 
@@ -2281,7 +2358,8 @@ static int do_check(struct verifier_env *env)
                        if (err)
                                return err;
 
-                       if (BPF_SIZE(insn->code) != BPF_W) {
+                       if (BPF_SIZE(insn->code) != BPF_W &&
+                           BPF_SIZE(insn->code) != BPF_DW) {
                                insn_idx++;
                                continue;
                        }
@@ -2458,6 +2536,20 @@ process_bpf_exit:
        return 0;
 }
 
+static int check_map_prog_compatibility(struct bpf_map *map,
+                                       struct bpf_prog *prog)
+
+{
+       if (prog->type == BPF_PROG_TYPE_PERF_EVENT &&
+           (map->map_type == BPF_MAP_TYPE_HASH ||
+            map->map_type == BPF_MAP_TYPE_PERCPU_HASH) &&
+           (map->map_flags & BPF_F_NO_PREALLOC)) {
+               verbose("perf_event programs can only use preallocated hash map\n");
+               return -EINVAL;
+       }
+       return 0;
+}
+
 /* look for pseudo eBPF instructions that access map FDs and
  * replace them with actual map pointers
  */
@@ -2465,7 +2557,7 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
 {
        struct bpf_insn *insn = env->prog->insnsi;
        int insn_cnt = env->prog->len;
-       int i, j;
+       int i, j, err;
 
        for (i = 0; i < insn_cnt; i++, insn++) {
                if (BPF_CLASS(insn->code) == BPF_LDX &&
@@ -2509,6 +2601,12 @@ static int replace_map_fd_with_map_ptr(struct verifier_env *env)
                                return PTR_ERR(map);
                        }
 
+                       err = check_map_prog_compatibility(map, env->prog);
+                       if (err) {
+                               fdput(f);
+                               return err;
+                       }
+
                        /* store map pointer inside BPF_LD_IMM64 instruction */
                        insn[0].imm = (u32) (unsigned long) map;
                        insn[1].imm = ((u64) (unsigned long) map) >> 32;
@@ -2590,9 +2688,11 @@ static int convert_ctx_accesses(struct verifier_env *env)
        for (i = 0; i < insn_cnt; i++, insn++) {
                u32 insn_delta, cnt;
 
-               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W))
+               if (insn->code == (BPF_LDX | BPF_MEM | BPF_W) ||
+                   insn->code == (BPF_LDX | BPF_MEM | BPF_DW))
                        type = BPF_READ;
-               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W))
+               else if (insn->code == (BPF_STX | BPF_MEM | BPF_W) ||
+                        insn->code == (BPF_STX | BPF_MEM | BPF_DW))
                        type = BPF_WRITE;
                else
                        continue;