diff options
author | David S. Miller <davem@davemloft.net> | 2020-09-23 22:11:11 +0200 |
---|---|---|
committer | David S. Miller <davem@davemloft.net> | 2020-09-23 22:11:11 +0200 |
commit | 6d772f328d6ad3e4fb64385784571be4be25e63d (patch) | |
tree | de6e9d5b1aac58a7e1fd9502f0baa45b5c12b296 /arch | |
parent | Merge tag 'linux-can-next-for-5.10-20200923' of git://git.kernel.org/pub/scm/... (diff) | |
parent | tools resolve_btfids: Always force HOSTARCH (diff) | |
download | linux-6d772f328d6ad3e4fb64385784571be4be25e63d.tar.xz linux-6d772f328d6ad3e4fb64385784571be4be25e63d.zip |
Merge git://git.kernel.org/pub/scm/linux/kernel/git/bpf/bpf-next
Alexei Starovoitov says:
====================
pull-request: bpf-next 2020-09-23
The following pull-request contains BPF updates for your *net-next* tree.
We've added 95 non-merge commits during the last 22 day(s) which contain
a total of 124 files changed, 4211 insertions(+), 2040 deletions(-).
The main changes are:
1) Full multi function support in libbpf, from Andrii.
2) Refactoring of function argument checks, from Lorenz.
3) Make bpf_tail_call compatible with functions (subprograms), from Maciej.
4) Program metadata support, from YiFei.
5) bpf iterator optimizations, from Yonghong.
====================
Signed-off-by: David S. Miller <davem@davemloft.net>
Diffstat (limited to 'arch')
-rw-r--r-- | arch/s390/net/bpf_jit_comp.c | 61 | ||||
-rw-r--r-- | arch/x86/include/asm/nospec-branch.h | 16 | ||||
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 265 |
3 files changed, 239 insertions, 103 deletions
diff --git a/arch/s390/net/bpf_jit_comp.c b/arch/s390/net/bpf_jit_comp.c index be4b8532dd3c..0a4182792876 100644 --- a/arch/s390/net/bpf_jit_comp.c +++ b/arch/s390/net/bpf_jit_comp.c @@ -50,7 +50,6 @@ struct bpf_jit { int r14_thunk_ip; /* Address of expoline thunk for 'br %r14' */ int tail_call_start; /* Tail call start offset */ int excnt; /* Number of exception table entries */ - int labels[1]; /* Labels for local jumps */ }; #define SEEN_MEM BIT(0) /* use mem[] for temporary storage */ @@ -229,18 +228,18 @@ static inline void reg_set_seen(struct bpf_jit *jit, u32 b1) REG_SET_SEEN(b3); \ }) -#define EMIT6_PCREL_LABEL(op1, op2, b1, b2, label, mask) \ +#define EMIT6_PCREL_RIEB(op1, op2, b1, b2, mask, target) \ ({ \ - int rel = (jit->labels[label] - jit->prg) >> 1; \ + unsigned int rel = (int)((target) - jit->prg) / 2; \ _EMIT6((op1) | reg(b1, b2) << 16 | (rel & 0xffff), \ (op2) | (mask) << 12); \ REG_SET_SEEN(b1); \ REG_SET_SEEN(b2); \ }) -#define EMIT6_PCREL_IMM_LABEL(op1, op2, b1, imm, label, mask) \ +#define EMIT6_PCREL_RIEC(op1, op2, b1, imm, mask, target) \ ({ \ - int rel = (jit->labels[label] - jit->prg) >> 1; \ + unsigned int rel = (int)((target) - jit->prg) / 2; \ _EMIT6((op1) | (reg_high(b1) | (mask)) << 16 | \ (rel & 0xffff), (op2) | ((imm) & 0xff) << 8); \ REG_SET_SEEN(b1); \ @@ -1282,7 +1281,9 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, EMIT4(0xb9040000, BPF_REG_0, REG_2); break; } - case BPF_JMP | BPF_TAIL_CALL: + case BPF_JMP | BPF_TAIL_CALL: { + int patch_1_clrj, patch_2_clij, patch_3_brc; + /* * Implicit input: * B1: pointer to ctx @@ -1300,16 +1301,10 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, EMIT6_DISP_LH(0xe3000000, 0x0016, REG_W1, REG_0, BPF_REG_2, offsetof(struct bpf_array, map.max_entries)); /* if ((u32)%b3 >= (u32)%w1) goto out; */ - if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { - /* clrj %b3,%w1,0xa,label0 */ - EMIT6_PCREL_LABEL(0xec000000, 0x0077, BPF_REG_3, - REG_W1, 0, 0xa); - } else { - /* clr %b3,%w1 */ - EMIT2(0x1500, BPF_REG_3, REG_W1); - /* brcl 0xa,label0 */ - EMIT6_PCREL_RILC(0xc0040000, 0xa, jit->labels[0]); - } + /* clrj %b3,%w1,0xa,out */ + patch_1_clrj = jit->prg; + EMIT6_PCREL_RIEB(0xec000000, 0x0077, BPF_REG_3, REG_W1, 0xa, + jit->prg); /* * if (tail_call_cnt++ > MAX_TAIL_CALL_CNT) @@ -1324,16 +1319,10 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, EMIT4_IMM(0xa7080000, REG_W0, 1); /* laal %w1,%w0,off(%r15) */ EMIT6_DISP_LH(0xeb000000, 0x00fa, REG_W1, REG_W0, REG_15, off); - if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { - /* clij %w1,MAX_TAIL_CALL_CNT,0x2,label0 */ - EMIT6_PCREL_IMM_LABEL(0xec000000, 0x007f, REG_W1, - MAX_TAIL_CALL_CNT, 0, 0x2); - } else { - /* clfi %w1,MAX_TAIL_CALL_CNT */ - EMIT6_IMM(0xc20f0000, REG_W1, MAX_TAIL_CALL_CNT); - /* brcl 0x2,label0 */ - EMIT6_PCREL_RILC(0xc0040000, 0x2, jit->labels[0]); - } + /* clij %w1,MAX_TAIL_CALL_CNT,0x2,out */ + patch_2_clij = jit->prg; + EMIT6_PCREL_RIEC(0xec000000, 0x007f, REG_W1, MAX_TAIL_CALL_CNT, + 2, jit->prg); /* * prog = array->ptrs[index]; @@ -1348,13 +1337,9 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, /* ltg %r1,prog(%b2,%r1) */ EMIT6_DISP_LH(0xe3000000, 0x0002, REG_1, BPF_REG_2, REG_1, offsetof(struct bpf_array, ptrs)); - if (!is_first_pass(jit) && can_use_rel(jit, jit->labels[0])) { - /* brc 0x8,label0 */ - EMIT4_PCREL_RIC(0xa7040000, 0x8, jit->labels[0]); - } else { - /* brcl 0x8,label0 */ - EMIT6_PCREL_RILC(0xc0040000, 0x8, jit->labels[0]); - } + /* brc 0x8,out */ + patch_3_brc = jit->prg; + EMIT4_PCREL_RIC(0xa7040000, 8, jit->prg); /* * Restore registers before calling function @@ -1371,8 +1356,16 @@ static noinline int bpf_jit_insn(struct bpf_jit *jit, struct bpf_prog *fp, /* bc 0xf,tail_call_start(%r1) */ _EMIT4(0x47f01000 + jit->tail_call_start); /* out: */ - jit->labels[0] = jit->prg; + if (jit->prg_buf) { + *(u16 *)(jit->prg_buf + patch_1_clrj + 2) = + (jit->prg - patch_1_clrj) >> 1; + *(u16 *)(jit->prg_buf + patch_2_clij + 2) = + (jit->prg - patch_2_clij) >> 1; + *(u16 *)(jit->prg_buf + patch_3_brc + 2) = + (jit->prg - patch_3_brc) >> 1; + } break; + } case BPF_JMP | BPF_EXIT: /* return b0 */ last = (i == fp->len - 1) ? 1 : 0; if (last) diff --git a/arch/x86/include/asm/nospec-branch.h b/arch/x86/include/asm/nospec-branch.h index e7752b4038ff..e491c3d9f227 100644 --- a/arch/x86/include/asm/nospec-branch.h +++ b/arch/x86/include/asm/nospec-branch.h @@ -314,19 +314,19 @@ static inline void mds_idle_clear_cpu_buffers(void) * lfence * jmp spec_trap * do_rop: - * mov %rax,(%rsp) for x86_64 + * mov %rcx,(%rsp) for x86_64 * mov %edx,(%esp) for x86_32 * retq * * Without retpolines configured: * - * jmp *%rax for x86_64 + * jmp *%rcx for x86_64 * jmp *%edx for x86_32 */ #ifdef CONFIG_RETPOLINE # ifdef CONFIG_X86_64 -# define RETPOLINE_RAX_BPF_JIT_SIZE 17 -# define RETPOLINE_RAX_BPF_JIT() \ +# define RETPOLINE_RCX_BPF_JIT_SIZE 17 +# define RETPOLINE_RCX_BPF_JIT() \ do { \ EMIT1_off32(0xE8, 7); /* callq do_rop */ \ /* spec_trap: */ \ @@ -334,7 +334,7 @@ do { \ EMIT3(0x0F, 0xAE, 0xE8); /* lfence */ \ EMIT2(0xEB, 0xF9); /* jmp spec_trap */ \ /* do_rop: */ \ - EMIT4(0x48, 0x89, 0x04, 0x24); /* mov %rax,(%rsp) */ \ + EMIT4(0x48, 0x89, 0x0C, 0x24); /* mov %rcx,(%rsp) */ \ EMIT1(0xC3); /* retq */ \ } while (0) # else /* !CONFIG_X86_64 */ @@ -352,9 +352,9 @@ do { \ # endif #else /* !CONFIG_RETPOLINE */ # ifdef CONFIG_X86_64 -# define RETPOLINE_RAX_BPF_JIT_SIZE 2 -# define RETPOLINE_RAX_BPF_JIT() \ - EMIT2(0xFF, 0xE0); /* jmp *%rax */ +# define RETPOLINE_RCX_BPF_JIT_SIZE 2 +# define RETPOLINE_RCX_BPF_JIT() \ + EMIT2(0xFF, 0xE1); /* jmp *%rcx */ # else /* !CONFIG_X86_64 */ # define RETPOLINE_EDX_BPF_JIT() \ EMIT2(0xFF, 0xE2) /* jmp *%edx */ diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 7d9ea7b41c71..26f43279b78b 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -221,14 +221,48 @@ struct jit_context { /* Number of bytes emit_patch() needs to generate instructions */ #define X86_PATCH_SIZE 5 +/* Number of bytes that will be skipped on tailcall */ +#define X86_TAIL_CALL_OFFSET 11 -#define PROLOGUE_SIZE 25 +static void push_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + int cnt = 0; + + if (callee_regs_used[0]) + EMIT1(0x53); /* push rbx */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x55); /* push r13 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x56); /* push r14 */ + if (callee_regs_used[3]) + EMIT2(0x41, 0x57); /* push r15 */ + *pprog = prog; +} + +static void pop_callee_regs(u8 **pprog, bool *callee_regs_used) +{ + u8 *prog = *pprog; + int cnt = 0; + + if (callee_regs_used[3]) + EMIT2(0x41, 0x5F); /* pop r15 */ + if (callee_regs_used[2]) + EMIT2(0x41, 0x5E); /* pop r14 */ + if (callee_regs_used[1]) + EMIT2(0x41, 0x5D); /* pop r13 */ + if (callee_regs_used[0]) + EMIT1(0x5B); /* pop rbx */ + *pprog = prog; +} /* - * Emit x86-64 prologue code for BPF program and check its size. - * bpf_tail_call helper will skip it while jumping into another program + * Emit x86-64 prologue code for BPF program. + * bpf_tail_call helper will skip the first X86_TAIL_CALL_OFFSET bytes + * while jumping to another program */ -static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) +static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf, + bool tail_call_reachable, bool is_subprog) { u8 *prog = *pprog; int cnt = X86_PATCH_SIZE; @@ -238,19 +272,18 @@ static void emit_prologue(u8 **pprog, u32 stack_depth, bool ebpf_from_cbpf) */ memcpy(prog, ideal_nops[NOP_ATOMIC5], cnt); prog += cnt; + if (!ebpf_from_cbpf) { + if (tail_call_reachable && !is_subprog) + EMIT2(0x31, 0xC0); /* xor eax, eax */ + else + EMIT2(0x66, 0x90); /* nop2 */ + } EMIT1(0x55); /* push rbp */ EMIT3(0x48, 0x89, 0xE5); /* mov rbp, rsp */ /* sub rsp, rounded_stack_depth */ EMIT3_off32(0x48, 0x81, 0xEC, round_up(stack_depth, 8)); - EMIT1(0x53); /* push rbx */ - EMIT2(0x41, 0x55); /* push r13 */ - EMIT2(0x41, 0x56); /* push r14 */ - EMIT2(0x41, 0x57); /* push r15 */ - if (!ebpf_from_cbpf) { - /* zero init tail_call_cnt */ - EMIT2(0x6a, 0x00); - BUILD_BUG_ON(cnt != PROLOGUE_SIZE); - } + if (tail_call_reachable) + EMIT1(0x50); /* push rax */ *pprog = prog; } @@ -314,13 +347,14 @@ static int __bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, mutex_lock(&text_mutex); if (memcmp(ip, old_insn, X86_PATCH_SIZE)) goto out; + ret = 1; if (memcmp(ip, new_insn, X86_PATCH_SIZE)) { if (text_live) text_poke_bp(ip, new_insn, X86_PATCH_SIZE, NULL); else memcpy(ip, new_insn, X86_PATCH_SIZE); + ret = 0; } - ret = 0; out: mutex_unlock(&text_mutex); return ret; @@ -337,6 +371,22 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, return __bpf_arch_text_poke(ip, t, old_addr, new_addr, true); } +static int get_pop_bytes(bool *callee_regs_used) +{ + int bytes = 0; + + if (callee_regs_used[3]) + bytes += 2; + if (callee_regs_used[2]) + bytes += 2; + if (callee_regs_used[1]) + bytes += 2; + if (callee_regs_used[0]) + bytes += 1; + + return bytes; +} + /* * Generate the following code: * @@ -351,12 +401,26 @@ int bpf_arch_text_poke(void *ip, enum bpf_text_poke_type t, * goto *(prog->bpf_func + prologue_size); * out: */ -static void emit_bpf_tail_call_indirect(u8 **pprog) +static void emit_bpf_tail_call_indirect(u8 **pprog, bool *callee_regs_used, + u32 stack_depth) { + int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog; - int label1, label2, label3; + int pop_bytes = 0; + int off1 = 49; + int off2 = 38; + int off3 = 16; int cnt = 0; + /* count the additional bytes used for popping callee regs from stack + * that need to be taken into account for each of the offsets that + * are used for bailing out of the tail call + */ + pop_bytes = get_pop_bytes(callee_regs_used); + off1 += pop_bytes; + off2 += pop_bytes; + off3 += pop_bytes; + /* * rdi - pointer to ctx * rsi - pointer to bpf_array @@ -370,72 +434,106 @@ static void emit_bpf_tail_call_indirect(u8 **pprog) EMIT2(0x89, 0xD2); /* mov edx, edx */ EMIT3(0x39, 0x56, /* cmp dword ptr [rsi + 16], edx */ offsetof(struct bpf_array, map.max_entries)); -#define OFFSET1 (41 + RETPOLINE_RAX_BPF_JIT_SIZE) /* Number of bytes to jump */ +#define OFFSET1 (off1 + RETPOLINE_RCX_BPF_JIT_SIZE) /* Number of bytes to jump */ EMIT2(X86_JBE, OFFSET1); /* jbe out */ - label1 = cnt; /* * if (tail_call_cnt > MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ -#define OFFSET2 (30 + RETPOLINE_RAX_BPF_JIT_SIZE) +#define OFFSET2 (off2 + RETPOLINE_RCX_BPF_JIT_SIZE) EMIT2(X86_JA, OFFSET2); /* ja out */ - label2 = cnt; EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ /* prog = array->ptrs[index]; */ - EMIT4_off32(0x48, 0x8B, 0x84, 0xD6, /* mov rax, [rsi + rdx * 8 + offsetof(...)] */ + EMIT4_off32(0x48, 0x8B, 0x8C, 0xD6, /* mov rcx, [rsi + rdx * 8 + offsetof(...)] */ offsetof(struct bpf_array, ptrs)); /* * if (prog == NULL) * goto out; */ - EMIT3(0x48, 0x85, 0xC0); /* test rax,rax */ -#define OFFSET3 (8 + RETPOLINE_RAX_BPF_JIT_SIZE) + EMIT3(0x48, 0x85, 0xC9); /* test rcx,rcx */ +#define OFFSET3 (off3 + RETPOLINE_RCX_BPF_JIT_SIZE) EMIT2(X86_JE, OFFSET3); /* je out */ - label3 = cnt; - /* goto *(prog->bpf_func + prologue_size); */ - EMIT4(0x48, 0x8B, 0x40, /* mov rax, qword ptr [rax + 32] */ - offsetof(struct bpf_prog, bpf_func)); - EMIT4(0x48, 0x83, 0xC0, PROLOGUE_SIZE); /* add rax, prologue_size */ + *pprog = prog; + pop_callee_regs(pprog, callee_regs_used); + prog = *pprog; + + EMIT1(0x58); /* pop rax */ + EMIT3_off32(0x48, 0x81, 0xC4, /* add rsp, sd */ + round_up(stack_depth, 8)); + /* goto *(prog->bpf_func + X86_TAIL_CALL_OFFSET); */ + EMIT4(0x48, 0x8B, 0x49, /* mov rcx, qword ptr [rcx + 32] */ + offsetof(struct bpf_prog, bpf_func)); + EMIT4(0x48, 0x83, 0xC1, /* add rcx, X86_TAIL_CALL_OFFSET */ + X86_TAIL_CALL_OFFSET); /* - * Wow we're ready to jump into next BPF program + * Now we're ready to jump into next BPF program * rdi == ctx (1st arg) - * rax == prog->bpf_func + prologue_size + * rcx == prog->bpf_func + X86_TAIL_CALL_OFFSET */ - RETPOLINE_RAX_BPF_JIT(); + RETPOLINE_RCX_BPF_JIT(); /* out: */ - BUILD_BUG_ON(cnt - label1 != OFFSET1); - BUILD_BUG_ON(cnt - label2 != OFFSET2); - BUILD_BUG_ON(cnt - label3 != OFFSET3); *pprog = prog; } static void emit_bpf_tail_call_direct(struct bpf_jit_poke_descriptor *poke, - u8 **pprog, int addr, u8 *image) + u8 **pprog, int addr, u8 *image, + bool *callee_regs_used, u32 stack_depth) { + int tcc_off = -4 - round_up(stack_depth, 8); u8 *prog = *pprog; + int pop_bytes = 0; + int off1 = 27; + int poke_off; int cnt = 0; + /* count the additional bytes used for popping callee regs to stack + * that need to be taken into account for jump offset that is used for + * bailing out from of the tail call when limit is reached + */ + pop_bytes = get_pop_bytes(callee_regs_used); + off1 += pop_bytes; + + /* + * total bytes for: + * - nop5/ jmpq $off + * - pop callee regs + * - sub rsp, $val + * - pop rax + */ + poke_off = X86_PATCH_SIZE + pop_bytes + 7 + 1; + /* * if (tail_call_cnt > MAX_TAIL_CALL_CNT) * goto out; */ - EMIT2_off32(0x8B, 0x85, -36 - MAX_BPF_STACK); /* mov eax, dword ptr [rbp - 548] */ + EMIT2_off32(0x8B, 0x85, tcc_off); /* mov eax, dword ptr [rbp - tcc_off] */ EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT); /* cmp eax, MAX_TAIL_CALL_CNT */ - EMIT2(X86_JA, 14); /* ja out */ + EMIT2(X86_JA, off1); /* ja out */ EMIT3(0x83, 0xC0, 0x01); /* add eax, 1 */ - EMIT2_off32(0x89, 0x85, -36 - MAX_BPF_STACK); /* mov dword ptr [rbp -548], eax */ + EMIT2_off32(0x89, 0x85, tcc_off); /* mov dword ptr [rbp - tcc_off], eax */ + + poke->tailcall_bypass = image + (addr - poke_off - X86_PATCH_SIZE); + poke->adj_off = X86_TAIL_CALL_OFFSET; + poke->tailcall_target = image + (addr - X86_PATCH_SIZE); + poke->bypass_addr = (u8 *)poke->tailcall_target + X86_PATCH_SIZE; - poke->ip = image + (addr - X86_PATCH_SIZE); - poke->adj_off = PROLOGUE_SIZE; + emit_jump(&prog, (u8 *)poke->tailcall_target + X86_PATCH_SIZE, + poke->tailcall_bypass); + + *pprog = prog; + pop_callee_regs(pprog, callee_regs_used); + prog = *pprog; + EMIT1(0x58); /* pop rax */ + EMIT3_off32(0x48, 0x81, 0xC4, round_up(stack_depth, 8)); memcpy(prog, ideal_nops[NOP_ATOMIC5], X86_PATCH_SIZE); prog += X86_PATCH_SIZE; @@ -453,7 +551,7 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) for (i = 0; i < prog->aux->size_poke_tab; i++) { poke = &prog->aux->poke_tab[i]; - WARN_ON_ONCE(READ_ONCE(poke->ip_stable)); + WARN_ON_ONCE(READ_ONCE(poke->tailcall_target_stable)); if (poke->reason != BPF_POKE_REASON_TAIL_CALL) continue; @@ -464,18 +562,25 @@ static void bpf_tail_call_direct_fixup(struct bpf_prog *prog) if (target) { /* Plain memcpy is used when image is not live yet * and still not locked as read-only. Once poke - * location is active (poke->ip_stable), any parallel - * bpf_arch_text_poke() might occur still on the - * read-write image until we finally locked it as - * read-only. Both modifications on the given image - * are under text_mutex to avoid interference. + * location is active (poke->tailcall_target_stable), + * any parallel bpf_arch_text_poke() might occur + * still on the read-write image until we finally + * locked it as read-only. Both modifications on + * the given image are under text_mutex to avoid + * interference. */ - ret = __bpf_arch_text_poke(poke->ip, BPF_MOD_JUMP, NULL, + ret = __bpf_arch_text_poke(poke->tailcall_target, + BPF_MOD_JUMP, NULL, (u8 *)target->bpf_func + poke->adj_off, false); BUG_ON(ret < 0); + ret = __bpf_arch_text_poke(poke->tailcall_bypass, + BPF_MOD_JUMP, + (u8 *)poke->tailcall_target + + X86_PATCH_SIZE, NULL, false); + BUG_ON(ret < 0); } - WRITE_ONCE(poke->ip_stable, true); + WRITE_ONCE(poke->tailcall_target_stable, true); mutex_unlock(&array->aux->poke_mutex); } } @@ -652,19 +757,49 @@ static bool ex_handler_bpf(const struct exception_table_entry *x, return true; } +static void detect_reg_usage(struct bpf_insn *insn, int insn_cnt, + bool *regs_used, bool *tail_call_seen) +{ + int i; + + for (i = 1; i <= insn_cnt; i++, insn++) { + if (insn->code == (BPF_JMP | BPF_TAIL_CALL)) + *tail_call_seen = true; + if (insn->dst_reg == BPF_REG_6 || insn->src_reg == BPF_REG_6) + regs_used[0] = true; + if (insn->dst_reg == BPF_REG_7 || insn->src_reg == BPF_REG_7) + regs_used[1] = true; + if (insn->dst_reg == BPF_REG_8 || insn->src_reg == BPF_REG_8) + regs_used[2] = true; + if (insn->dst_reg == BPF_REG_9 || insn->src_reg == BPF_REG_9) + regs_used[3] = true; + } +} + static int do_jit(struct bpf_prog *bpf_prog, int *addrs, u8 *image, int oldproglen, struct jit_context *ctx) { + bool tail_call_reachable = bpf_prog->aux->tail_call_reachable; struct bpf_insn *insn = bpf_prog->insnsi; + bool callee_regs_used[4] = {}; int insn_cnt = bpf_prog->len; + bool tail_call_seen = false; bool seen_exit = false; u8 temp[BPF_MAX_INSN_SIZE + BPF_INSN_SAFETY]; int i, cnt = 0, excnt = 0; int proglen = 0; u8 *prog = temp; + detect_reg_usage(insn, insn_cnt, callee_regs_used, + &tail_call_seen); + + /* tail call's presence in current prog implies it is reachable */ + tail_call_reachable |= tail_call_seen; + emit_prologue(&prog, bpf_prog->aux->stack_depth, - bpf_prog_was_classic(bpf_prog)); + bpf_prog_was_classic(bpf_prog), tail_call_reachable, + bpf_prog->aux->func_idx != 0); + push_callee_regs(&prog, callee_regs_used); addrs[0] = prog - temp; for (i = 1; i <= insn_cnt; i++, insn++) { @@ -1102,16 +1237,27 @@ xadd: if (is_imm8(insn->off)) /* call */ case BPF_JMP | BPF_CALL: func = (u8 *) __bpf_call_base + imm32; - if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) - return -EINVAL; + if (tail_call_reachable) { + EMIT3_off32(0x48, 0x8B, 0x85, + -(bpf_prog->aux->stack_depth + 8)); + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1] + 7)) + return -EINVAL; + } else { + if (!imm32 || emit_call(&prog, func, image + addrs[i - 1])) + return -EINVAL; + } break; case BPF_JMP | BPF_TAIL_CALL: if (imm32) emit_bpf_tail_call_direct(&bpf_prog->aux->poke_tab[imm32 - 1], - &prog, addrs[i], image); + &prog, addrs[i], image, + callee_regs_used, + bpf_prog->aux->stack_depth); else - emit_bpf_tail_call_indirect(&prog); + emit_bpf_tail_call_indirect(&prog, + callee_regs_used, + bpf_prog->aux->stack_depth); break; /* cond jump */ @@ -1294,12 +1440,9 @@ emit_jmp: seen_exit = true; /* Update cleanup_addr */ ctx->cleanup_addr = proglen; - if (!bpf_prog_was_classic(bpf_prog)) - EMIT1(0x5B); /* get rid of tail_call_cnt */ - EMIT2(0x41, 0x5F); /* pop r15 */ - EMIT2(0x41, 0x5E); /* pop r14 */ - EMIT2(0x41, 0x5D); /* pop r13 */ - EMIT1(0x5B); /* pop rbx */ + pop_callee_regs(&prog, callee_regs_used); + if (tail_call_reachable) + EMIT1(0x59); /* pop rcx, get rid of tail_call_cnt */ EMIT1(0xC9); /* leave */ EMIT1(0xC3); /* ret */ break; |