]> asedeno.scripts.mit.edu Git - linux.git/blobdiff - net/core/filter.c
Merge tag 'riscv-for-linus-4.20-rc4' of git://git.kernel.org/pub/scm/linux/kernel...
[linux.git] / net / core / filter.c
index 1a3ac6c468735a9319ac81b8b607f809f184d188..e521c5ebc7d11cdfdcc10307ad973bcac2d1602a 100644 (file)
@@ -2297,6 +2297,137 @@ static const struct bpf_func_proto bpf_msg_pull_data_proto = {
        .arg4_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_4(bpf_msg_push_data, struct sk_msg *, msg, u32, start,
+          u32, len, u64, flags)
+{
+       struct scatterlist sge, nsge, nnsge, rsge = {0}, *psge;
+       u32 new, i = 0, l, space, copy = 0, offset = 0;
+       u8 *raw, *to, *from;
+       struct page *page;
+
+       if (unlikely(flags))
+               return -EINVAL;
+
+       /* First find the starting scatterlist element */
+       i = msg->sg.start;
+       do {
+               l = sk_msg_elem(msg, i)->length;
+
+               if (start < offset + l)
+                       break;
+               offset += l;
+               sk_msg_iter_var_next(i);
+       } while (i != msg->sg.end);
+
+       if (start >= offset + l)
+               return -EINVAL;
+
+       space = MAX_MSG_FRAGS - sk_msg_elem_used(msg);
+
+       /* If no space available will fallback to copy, we need at
+        * least one scatterlist elem available to push data into
+        * when start aligns to the beginning of an element or two
+        * when it falls inside an element. We handle the start equals
+        * offset case because its the common case for inserting a
+        * header.
+        */
+       if (!space || (space == 1 && start != offset))
+               copy = msg->sg.data[i].length;
+
+       page = alloc_pages(__GFP_NOWARN | GFP_ATOMIC | __GFP_COMP,
+                          get_order(copy + len));
+       if (unlikely(!page))
+               return -ENOMEM;
+
+       if (copy) {
+               int front, back;
+
+               raw = page_address(page);
+
+               psge = sk_msg_elem(msg, i);
+               front = start - offset;
+               back = psge->length - front;
+               from = sg_virt(psge);
+
+               if (front)
+                       memcpy(raw, from, front);
+
+               if (back) {
+                       from += front;
+                       to = raw + front + len;
+
+                       memcpy(to, from, back);
+               }
+
+               put_page(sg_page(psge));
+       } else if (start - offset) {
+               psge = sk_msg_elem(msg, i);
+               rsge = sk_msg_elem_cpy(msg, i);
+
+               psge->length = start - offset;
+               rsge.length -= psge->length;
+               rsge.offset += start;
+
+               sk_msg_iter_var_next(i);
+               sg_unmark_end(psge);
+               sk_msg_iter_next(msg, end);
+       }
+
+       /* Slot(s) to place newly allocated data */
+       new = i;
+
+       /* Shift one or two slots as needed */
+       if (!copy) {
+               sge = sk_msg_elem_cpy(msg, i);
+
+               sk_msg_iter_var_next(i);
+               sg_unmark_end(&sge);
+               sk_msg_iter_next(msg, end);
+
+               nsge = sk_msg_elem_cpy(msg, i);
+               if (rsge.length) {
+                       sk_msg_iter_var_next(i);
+                       nnsge = sk_msg_elem_cpy(msg, i);
+               }
+
+               while (i != msg->sg.end) {
+                       msg->sg.data[i] = sge;
+                       sge = nsge;
+                       sk_msg_iter_var_next(i);
+                       if (rsge.length) {
+                               nsge = nnsge;
+                               nnsge = sk_msg_elem_cpy(msg, i);
+                       } else {
+                               nsge = sk_msg_elem_cpy(msg, i);
+                       }
+               }
+       }
+
+       /* Place newly allocated data buffer */
+       sk_mem_charge(msg->sk, len);
+       msg->sg.size += len;
+       msg->sg.copy[new] = false;
+       sg_set_page(&msg->sg.data[new], page, len + copy, 0);
+       if (rsge.length) {
+               get_page(sg_page(&rsge));
+               sk_msg_iter_var_next(new);
+               msg->sg.data[new] = rsge;
+       }
+
+       sk_msg_compute_data_pointers(msg);
+       return 0;
+}
+
+static const struct bpf_func_proto bpf_msg_push_data_proto = {
+       .func           = bpf_msg_push_data,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 {
        return task_get_classid(skb);
@@ -4854,6 +4985,7 @@ bool bpf_helper_changes_pkt_data(void *func)
            func == bpf_xdp_adjust_head ||
            func == bpf_xdp_adjust_meta ||
            func == bpf_msg_pull_data ||
+           func == bpf_msg_push_data ||
            func == bpf_xdp_adjust_tail ||
 #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
            func == bpf_lwt_seg6_store_bytes ||
@@ -4876,6 +5008,12 @@ bpf_base_func_proto(enum bpf_func_id func_id)
                return &bpf_map_update_elem_proto;
        case BPF_FUNC_map_delete_elem:
                return &bpf_map_delete_elem_proto;
+       case BPF_FUNC_map_push_elem:
+               return &bpf_map_push_elem_proto;
+       case BPF_FUNC_map_pop_elem:
+               return &bpf_map_pop_elem_proto;
+       case BPF_FUNC_map_peek_elem:
+               return &bpf_map_peek_elem_proto;
        case BPF_FUNC_get_prandom_u32:
                return &bpf_get_prandom_u32_proto;
        case BPF_FUNC_get_smp_processor_id:
@@ -5124,8 +5262,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_msg_cork_bytes_proto;
        case BPF_FUNC_msg_pull_data:
                return &bpf_msg_pull_data_proto;
-       case BPF_FUNC_get_local_storage:
-               return &bpf_get_local_storage_proto;
+       case BPF_FUNC_msg_push_data:
+               return &bpf_msg_push_data_proto;
        default:
                return bpf_base_func_proto(func_id);
        }
@@ -5156,8 +5294,6 @@ sk_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_sk_redirect_map_proto;
        case BPF_FUNC_sk_redirect_hash:
                return &bpf_sk_redirect_hash_proto;
-       case BPF_FUNC_get_local_storage:
-               return &bpf_get_local_storage_proto;
 #ifdef CONFIG_INET
        case BPF_FUNC_sk_lookup_tcp:
                return &bpf_sk_lookup_tcp_proto;
@@ -5346,6 +5482,46 @@ static bool sk_filter_is_valid_access(int off, int size,
        return bpf_skb_is_valid_access(off, size, type, prog, info);
 }
 
+static bool cg_skb_is_valid_access(int off, int size,
+                                  enum bpf_access_type type,
+                                  const struct bpf_prog *prog,
+                                  struct bpf_insn_access_aux *info)
+{
+       switch (off) {
+       case bpf_ctx_range(struct __sk_buff, tc_classid):
+       case bpf_ctx_range(struct __sk_buff, data_meta):
+       case bpf_ctx_range(struct __sk_buff, flow_keys):
+               return false;
+       case bpf_ctx_range(struct __sk_buff, data):
+       case bpf_ctx_range(struct __sk_buff, data_end):
+               if (!capable(CAP_SYS_ADMIN))
+                       return false;
+               break;
+       }
+
+       if (type == BPF_WRITE) {
+               switch (off) {
+               case bpf_ctx_range(struct __sk_buff, mark):
+               case bpf_ctx_range(struct __sk_buff, priority):
+               case bpf_ctx_range_till(struct __sk_buff, cb[0], cb[4]):
+                       break;
+               default:
+                       return false;
+               }
+       }
+
+       switch (off) {
+       case bpf_ctx_range(struct __sk_buff, data):
+               info->reg_type = PTR_TO_PACKET;
+               break;
+       case bpf_ctx_range(struct __sk_buff, data_end):
+               info->reg_type = PTR_TO_PACKET_END;
+               break;
+       }
+
+       return bpf_skb_is_valid_access(off, size, type, prog, info);
+}
+
 static bool lwt_is_valid_access(int off, int size,
                                enum bpf_access_type type,
                                const struct bpf_prog *prog,
@@ -5464,6 +5640,15 @@ static bool sock_filter_is_valid_access(int off, int size,
                                               prog->expected_attach_type);
 }
 
+static int bpf_noop_prologue(struct bpf_insn *insn_buf, bool direct_write,
+                            const struct bpf_prog *prog)
+{
+       /* Neither direct read nor direct write requires any preliminary
+        * action.
+        */
+       return 0;
+}
+
 static int bpf_unclone_prologue(struct bpf_insn *insn_buf, bool direct_write,
                                const struct bpf_prog *prog, int drop_verdict)
 {
@@ -7030,6 +7215,7 @@ const struct bpf_verifier_ops xdp_verifier_ops = {
        .get_func_proto         = xdp_func_proto,
        .is_valid_access        = xdp_is_valid_access,
        .convert_ctx_access     = xdp_convert_ctx_access,
+       .gen_prologue           = bpf_noop_prologue,
 };
 
 const struct bpf_prog_ops xdp_prog_ops = {
@@ -7038,7 +7224,7 @@ const struct bpf_prog_ops xdp_prog_ops = {
 
 const struct bpf_verifier_ops cg_skb_verifier_ops = {
        .get_func_proto         = cg_skb_func_proto,
-       .is_valid_access        = sk_filter_is_valid_access,
+       .is_valid_access        = cg_skb_is_valid_access,
        .convert_ctx_access     = bpf_convert_ctx_access,
 };
 
@@ -7128,6 +7314,7 @@ const struct bpf_verifier_ops sk_msg_verifier_ops = {
        .get_func_proto         = sk_msg_func_proto,
        .is_valid_access        = sk_msg_is_valid_access,
        .convert_ctx_access     = sk_msg_convert_ctx_access,
+       .gen_prologue           = bpf_noop_prologue,
 };
 
 const struct bpf_prog_ops sk_msg_prog_ops = {