From d112cd52f36cbb00e18009417044ab1e4496dd80 Mon Sep 17 00:00:00 2001 From: Jakub Konka Date: Wed, 4 May 2022 23:00:41 +0200 Subject: [PATCH] aarch64: fix mul_with_overflow for ints <= 32bits --- src/arch/aarch64/CodeGen.zig | 76 ++++++++++++++++----------- src/arch/aarch64/Emit.zig | 55 ++++++++++++++++++-- src/arch/aarch64/Mir.zig | 29 +++++++++++ src/arch/aarch64/bits.zig | 99 +++++++++++++++++++++++++++++++++++- 4 files changed, 223 insertions(+), 36 deletions(-) diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index fdf7eadb73..d146724188 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -1294,29 +1294,23 @@ fn binOpRegister( }; defer self.register_manager.unfreezeRegs(&.{rhs_reg}); - const dest_reg: Register = reg: { - const dest_reg = switch (mir_tag) { - .cmp_shifted_register => undefined, // cmp has no destination register - else => if (maybe_inst) |inst| blk: { - const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const dest_reg = switch (mir_tag) { + .cmp_shifted_register => undefined, // cmp has no destination register + else => if (maybe_inst) |inst| blk: { + const bin_op = self.air.instructions.items(.data)[inst].bin_op; - if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) { - break :blk lhs_reg; - } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) { - break :blk rhs_reg; - } else { - const raw_reg = try self.register_manager.allocReg(inst); - break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*)); - } - } else blk: { - const raw_reg = try self.register_manager.allocReg(null); + if (lhs_is_register and self.reuseOperand(inst, bin_op.lhs, 0, lhs)) { + break :blk lhs_reg; + } else if (rhs_is_register and self.reuseOperand(inst, bin_op.rhs, 1, rhs)) { + break :blk rhs_reg; + } else { + const raw_reg = try self.register_manager.allocReg(inst); break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*)); - }, - }; - break :reg switch (mir_tag) { - .smull, .umull => dest_reg.to64(), - else => dest_reg, - }; + } + } else blk: { + const raw_reg = try self.register_manager.allocReg(null); + break :blk registerAlias(raw_reg, lhs_ty.abiSize(self.target.*)); + }, }; if (!lhs_is_register) try self.genSetReg(lhs_ty, lhs_reg, lhs); @@ -1341,9 +1335,7 @@ fn binOpRegister( .shift = .lsl, } }, .mul, - .smulh, .smull, - .umulh, .umull, .lsl_register, .asr_register, @@ -1932,16 +1924,38 @@ fn airMulWithOverflow(self: *Self, inst: Air.Inst.Index) !void { self.register_manager.freezeRegs(&.{truncated_reg}); defer self.register_manager.unfreezeRegs(&.{truncated_reg}); - try self.truncRegister(dest_reg, truncated_reg, int_info.signedness, int_info.bits); - _ = try self.binOp( - .cmp_eq, - null, - dest, - .{ .register = truncated_reg }, - Type.usize, - Type.usize, + try self.truncRegister( + dest_reg.to32(), + truncated_reg.to32(), + int_info.signedness, + int_info.bits, ); + switch (int_info.signedness) { + .signed => { + _ = try self.addInst(.{ + .tag = .cmp_extended_register, + .data = .{ .rr_extend_shift = .{ + .rn = dest_reg.to64(), + .rm = truncated_reg.to32(), + .ext_type = .sxtw, + .imm3 = 0, + } }, + }); + }, + .unsigned => { + _ = try self.addInst(.{ + .tag = .cmp_extended_register, + .data = .{ .rr_extend_shift = .{ + .rn = dest_reg.to64(), + .rm = truncated_reg.to32(), + .ext_type = .uxtw, + .imm3 = 0, + } }, + }); + }, + } + try self.genSetStack(lhs_ty, stack_offset, .{ .register = truncated_reg }); try self.genSetStack(Type.initTag(.u1), stack_offset - overflow_bit_offset, .{ .compare_flags_unsigned = .neq, diff --git a/src/arch/aarch64/Emit.zig b/src/arch/aarch64/Emit.zig index 1393533a7f..959ca4037c 100644 --- a/src/arch/aarch64/Emit.zig +++ b/src/arch/aarch64/Emit.zig @@ -114,6 +114,12 @@ pub fn emitMir( .sub_shifted_register => try emit.mirAddSubtractShiftedRegister(inst), .subs_shifted_register => try emit.mirAddSubtractShiftedRegister(inst), + .add_extended_register => try emit.mirAddSubtractExtendedRegister(inst), + .adds_extended_register => try emit.mirAddSubtractExtendedRegister(inst), + .sub_extended_register => try emit.mirAddSubtractExtendedRegister(inst), + .subs_extended_register => try emit.mirAddSubtractExtendedRegister(inst), + .cmp_extended_register => try emit.mirAddSubtractExtendedRegister(inst), + .cset => try emit.mirConditionalSelect(inst), .dbg_line => try emit.mirDbgLine(inst), @@ -732,6 +738,47 @@ fn mirAddSubtractShiftedRegister(emit: *Emit, inst: Mir.Inst.Index) !void { } } +fn mirAddSubtractExtendedRegister(emit: *Emit, inst: Mir.Inst.Index) !void { + const tag = emit.mir.instructions.items(.tag)[inst]; + switch (tag) { + .add_extended_register, + .adds_extended_register, + .sub_extended_register, + .subs_extended_register, + => { + const rrr_extend_shift = emit.mir.instructions.items(.data)[inst].rrr_extend_shift; + const rd = rrr_extend_shift.rd; + const rn = rrr_extend_shift.rn; + const rm = rrr_extend_shift.rm; + const ext_type = rrr_extend_shift.ext_type; + const imm3 = rrr_extend_shift.imm3; + + switch (tag) { + .add_extended_register => try emit.writeInstruction(Instruction.addExtendedRegister(rd, rn, rm, ext_type, imm3)), + .adds_extended_register => try emit.writeInstruction(Instruction.addsExtendedRegister(rd, rn, rm, ext_type, imm3)), + .sub_extended_register => try emit.writeInstruction(Instruction.subExtendedRegister(rd, rn, rm, ext_type, imm3)), + .subs_extended_register => try emit.writeInstruction(Instruction.subsExtendedRegister(rd, rn, rm, ext_type, imm3)), + else => unreachable, + } + }, + .cmp_extended_register => { + const rr_extend_shift = emit.mir.instructions.items(.data)[inst].rr_extend_shift; + const rn = rr_extend_shift.rn; + const rm = rr_extend_shift.rm; + const ext_type = rr_extend_shift.ext_type; + const imm3 = rr_extend_shift.imm3; + const zr: Register = switch (rn.size()) { + 32 => .wzr, + 64 => .xzr, + else => unreachable, + }; + + try emit.writeInstruction(Instruction.subsExtendedRegister(zr, rn, rm, ext_type, imm3)); + }, + else => unreachable, + } +} + fn mirConditionalSelect(emit: *Emit, inst: Mir.Inst.Index) !void { const tag = emit.mir.instructions.items(.tag)[inst]; switch (tag) { @@ -1013,10 +1060,10 @@ fn mirDataProcessing3Source(emit: *Emit, inst: Mir.Inst.Index) !void { switch (tag) { .mul => try emit.writeInstruction(Instruction.mul(rrr.rd, rrr.rn, rrr.rm)), - .smulh => try emit.writeInstruction(Instruction.smulh(rrr.rd, rrr.rn, rrr.rm)), - .smull => try emit.writeInstruction(Instruction.smull(rrr.rd, rrr.rn, rrr.rm)), - .umulh => try emit.writeInstruction(Instruction.umulh(rrr.rd, rrr.rn, rrr.rm)), - .umull => try emit.writeInstruction(Instruction.umull(rrr.rd, rrr.rn, rrr.rm)), + .smulh => try emit.writeInstruction(Instruction.smulh(rrr.rd.to64(), rrr.rn.to64(), rrr.rm.to64())), + .smull => try emit.writeInstruction(Instruction.smull(rrr.rd.to64(), rrr.rn.to32(), rrr.rm.to32())), + .umulh => try emit.writeInstruction(Instruction.umulh(rrr.rd.to64(), rrr.rn.to64(), rrr.rm.to64())), + .umull => try emit.writeInstruction(Instruction.umull(rrr.rd.to64(), rrr.rn.to32(), rrr.rm.to32())), else => unreachable, } } diff --git a/src/arch/aarch64/Mir.zig b/src/arch/aarch64/Mir.zig index 1b27303419..1d66a69c8e 100644 --- a/src/arch/aarch64/Mir.zig +++ b/src/arch/aarch64/Mir.zig @@ -32,6 +32,10 @@ pub const Inst = struct { add_shifted_register, /// Add, update condition flags (shifted register) adds_shifted_register, + /// Add (extended register) + add_extended_register, + /// Add, update condition flags (extended register) + adds_extended_register, /// Bitwise AND (shifted register) and_shifted_register, /// Arithmetic Shift Right (immediate) @@ -56,6 +60,8 @@ pub const Inst = struct { cmp_immediate, /// Compare (shifted register) cmp_shifted_register, + /// Compare (extended register) + cmp_extended_register, /// Conditional set cset, /// Pseudo-instruction: End of prologue @@ -184,6 +190,10 @@ pub const Inst = struct { sub_shifted_register, /// Subtract, update condition flags (shifted register) subs_shifted_register, + /// Subtract (extended register) + sub_extended_register, + /// Subtract, update condition flags (extended register) + subs_extended_register, /// Supervisor Call svc, /// Test bits (immediate) @@ -300,6 +310,15 @@ pub const Inst = struct { imm6: u6, shift: bits.Instruction.AddSubtractShiftedRegisterShift, }, + /// Two registers with sign-extension (extension type and 3-bit shift amount) + /// + /// Used by e.g. cmp_extended_register + rr_extend_shift: struct { + rn: Register, + rm: Register, + ext_type: bits.Instruction.AddSubtractExtendedRegisterOption, + imm3: u3, + }, /// Two registers and a shift (logical instruction version) /// (shift type and 6-bit amount) /// @@ -356,6 +375,16 @@ pub const Inst = struct { imm6: u6, shift: bits.Instruction.AddSubtractShiftedRegisterShift, }, + /// Three registers with sign-extension (extension type and 3-bit shift amount) + /// + /// Used by e.g. add_extended_register + rrr_extend_shift: struct { + rd: Register, + rn: Register, + rm: Register, + ext_type: bits.Instruction.AddSubtractExtendedRegisterOption, + imm3: u3, + }, /// Three registers and a shift (logical instruction version) /// (shift type and 6-bit amount) /// diff --git a/src/arch/aarch64/bits.zig b/src/arch/aarch64/bits.zig index d8cb868d66..a3f5fbac51 100644 --- a/src/arch/aarch64/bits.zig +++ b/src/arch/aarch64/bits.zig @@ -330,6 +330,17 @@ pub const Instruction = union(enum) { op: u1, sf: u1, }, + add_subtract_extended_register: packed struct { + rd: u5, + rn: u5, + imm3: u3, + option: u3, + rm: u5, + fixed: u8 = 0b01011_00_1, + s: u1, + op: u1, + sf: u1, + }, conditional_branch: struct { cond: u4, o0: u1, @@ -495,6 +506,7 @@ pub const Instruction = union(enum) { .logical_immediate => |v| @bitCast(u32, v), .bitfield => |v| @bitCast(u32, v), .add_subtract_shifted_register => |v| @bitCast(u32, v), + .add_subtract_extended_register => |v| @bitCast(u32, v), // TODO once packed structs work, this can be refactored .conditional_branch => |v| @as(u32, v.cond) | (@as(u32, v.o0) << 4) | (@as(u32, v.imm19) << 5) | (@as(u32, v.o1) << 24) | (@as(u32, v.fixed) << 25), .compare_and_branch => |v| @as(u32, v.rt) | (@as(u32, v.imm19) << 5) | (@as(u32, v.op) << 24) | (@as(u32, v.fixed) << 25) | (@as(u32, v.sf) << 31), @@ -1006,6 +1018,44 @@ pub const Instruction = union(enum) { }; } + pub const AddSubtractExtendedRegisterOption = enum(u3) { + uxtb, + uxth, + uxtw, + uxtx, // serves also as lsl + sxtb, + sxth, + sxtw, + sxtx, + }; + + fn addSubtractExtendedRegister( + op: u1, + s: u1, + rd: Register, + rn: Register, + rm: Register, + extend: AddSubtractExtendedRegisterOption, + imm3: u3, + ) Instruction { + return Instruction{ + .add_subtract_extended_register = .{ + .rd = rd.enc(), + .rn = rn.enc(), + .imm3 = imm3, + .option = @enumToInt(extend), + .rm = rm.enc(), + .s = s, + .op = op, + .sf = switch (rd.size()) { + 32 => 0b0, + 64 => 0b1, + else => unreachable, // unexpected register size + }, + }, + }; + } + fn conditionalBranch( o0: u1, o1: u1, @@ -1524,6 +1574,48 @@ pub const Instruction = union(enum) { return addSubtractShiftedRegister(0b1, 0b1, shift, rd, rn, rm, imm6); } + // Add/subtract (extended register) + + pub fn addExtendedRegister( + rd: Register, + rn: Register, + rm: Register, + extend: AddSubtractExtendedRegisterOption, + imm3: u3, + ) Instruction { + return addSubtractExtendedRegister(0b0, 0b0, rd, rn, rm, extend, imm3); + } + + pub fn addsExtendedRegister( + rd: Register, + rn: Register, + rm: Register, + extend: AddSubtractExtendedRegisterOption, + imm3: u3, + ) Instruction { + return addSubtractExtendedRegister(0b0, 0b1, rd, rn, rm, extend, imm3); + } + + pub fn subExtendedRegister( + rd: Register, + rn: Register, + rm: Register, + extend: AddSubtractExtendedRegisterOption, + imm3: u3, + ) Instruction { + return addSubtractExtendedRegister(0b1, 0b0, rd, rn, rm, extend, imm3); + } + + pub fn subsExtendedRegister( + rd: Register, + rn: Register, + rm: Register, + extend: AddSubtractExtendedRegisterOption, + imm3: u3, + ) Instruction { + return addSubtractExtendedRegister(0b1, 0b1, rd, rn, rm, extend, imm3); + } + // Conditional branch pub fn bCond(cond: Condition, offset: i21) Instruction { @@ -1565,11 +1657,12 @@ pub const Instruction = union(enum) { } pub fn smaddl(rd: Register, rn: Register, rm: Register, ra: Register) Instruction { + assert(rd.size() == 64 and rn.size() == 32 and rm.size() == 32 and ra.size() == 64); return dataProcessing3Source(0b00, 0b001, 0b0, rd, rn, rm, ra); } pub fn umaddl(rd: Register, rn: Register, rm: Register, ra: Register) Instruction { - assert(rd.size() == 64); + assert(rd.size() == 64 and rn.size() == 32 and rm.size() == 32 and ra.size() == 64); return dataProcessing3Source(0b00, 0b101, 0b0, rd, rn, rm, ra); } @@ -1837,6 +1930,10 @@ test "serialize instructions" { .inst = Instruction.smulh(.x0, .x1, .x2), .expected = 0b1_00_11011_0_10_00010_0_11111_00001_00000, }, + .{ // adds x0, x1, x2, sxtx + .inst = Instruction.addsExtendedRegister(.x0, .x1, .x2, .sxtx, 0), + .expected = 0b1_0_1_01011_00_1_00010_111_000_00001_00000, + }, }; for (testcases) |case| {