diff options
Diffstat (limited to 'arch/x86/net')
-rw-r--r-- | arch/x86/net/bpf_jit_comp.c | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c index 308241187582..1d4d50199293 100644 --- a/arch/x86/net/bpf_jit_comp.c +++ b/arch/x86/net/bpf_jit_comp.c @@ -808,6 +808,10 @@ static int emit_atomic(u8 **pprog, u8 atomic_op, /* emit opcode */ switch (atomic_op) { case BPF_ADD: + case BPF_SUB: + case BPF_AND: + case BPF_OR: + case BPF_XOR: /* lock *(u32/u64*)(dst_reg + off) <op>= src_reg */ EMIT1(simple_alu_opcodes[atomic_op]); break; @@ -1292,8 +1296,52 @@ st: if (is_imm8(insn->off)) case BPF_STX | BPF_ATOMIC | BPF_W: case BPF_STX | BPF_ATOMIC | BPF_DW: + if (insn->imm == (BPF_AND | BPF_FETCH) || + insn->imm == (BPF_OR | BPF_FETCH) || + insn->imm == (BPF_XOR | BPF_FETCH)) { + u8 *branch_target; + bool is64 = BPF_SIZE(insn->code) == BPF_DW; + + /* + * Can't be implemented with a single x86 insn. + * Need to do a CMPXCHG loop. + */ + + /* Will need RAX as a CMPXCHG operand so save R0 */ + emit_mov_reg(&prog, true, BPF_REG_AX, BPF_REG_0); + branch_target = prog; + /* Load old value */ + emit_ldx(&prog, BPF_SIZE(insn->code), + BPF_REG_0, dst_reg, insn->off); + /* + * Perform the (commutative) operation locally, + * put the result in the AUX_REG. + */ + emit_mov_reg(&prog, is64, AUX_REG, BPF_REG_0); + maybe_emit_mod(&prog, AUX_REG, src_reg, is64); + EMIT2(simple_alu_opcodes[BPF_OP(insn->imm)], + add_2reg(0xC0, AUX_REG, src_reg)); + /* Attempt to swap in new value */ + err = emit_atomic(&prog, BPF_CMPXCHG, + dst_reg, AUX_REG, insn->off, + BPF_SIZE(insn->code)); + if (WARN_ON(err)) + return err; + /* + * ZF tells us whether we won the race. If it's + * cleared we need to try again. + */ + EMIT2(X86_JNE, -(prog - branch_target) - 2); + /* Return the pre-modification value */ + emit_mov_reg(&prog, is64, src_reg, BPF_REG_0); + /* Restore R0 after clobbering RAX */ + emit_mov_reg(&prog, true, BPF_REG_0, BPF_REG_AX); + break; + + } + err = emit_atomic(&prog, insn->imm, dst_reg, src_reg, - insn->off, BPF_SIZE(insn->code)); + insn->off, BPF_SIZE(insn->code)); if (err) return err; break; |