x86_64: implement @mulAdd

This commit is contained in:
Jacob Young
2023-05-06 20:31:48 -04:00
parent 0bd92da0e2
commit 3a5e3c52e0
7 changed files with 277 additions and 14 deletions

View File

@@ -1200,6 +1200,32 @@ fn asmRegisterRegisterImmediate(
});
}
fn asmRegisterRegisterMemory(
self: *Self,
tag: Mir.Inst.Tag,
reg1: Register,
reg2: Register,
m: Memory,
) !void {
_ = try self.addInst(.{
.tag = tag,
.ops = switch (m) {
.sib => .rrm_sib,
.rip => .rrm_rip,
else => unreachable,
},
.data = .{ .rrx = .{
.r1 = reg1,
.r2 = reg2,
.payload = switch (m) {
.sib => try self.addExtra(Mir.MemorySib.encode(m)),
.rip => try self.addExtra(Mir.MemoryRip.encode(m)),
else => unreachable,
},
} },
});
}
fn asmMemory(self: *Self, tag: Mir.Inst.Tag, m: Memory) !void {
_ = try self.addInst(.{
.tag = tag,
@@ -9369,9 +9395,146 @@ fn airPrefetch(self: *Self, inst: Air.Inst.Index) !void {
fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void {
const pl_op = self.air.instructions.items(.data)[inst].pl_op;
const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
_ = extra;
return self.fail("TODO implement airMulAdd for x86_64", .{});
//return self.finishAir(inst, result, .{ extra.lhs, extra.rhs, pl_op.operand });
const ty = self.air.typeOfIndex(inst);
if (!self.hasFeature(.fma)) return self.fail("TODO implement airMulAdd for {}", .{
ty.fmt(self.bin_file.options.module.?),
});
const ops = [3]Air.Inst.Ref{ extra.lhs, extra.rhs, pl_op.operand };
var mcvs: [3]MCValue = undefined;
var locks = [1]?RegisterManager.RegisterLock{null} ** 3;
defer for (locks) |reg_lock| if (reg_lock) |lock| self.register_manager.unlockReg(lock);
var order = [1]u2{0} ** 3;
var unused = std.StaticBitSet(3).initFull();
for (ops, &mcvs, &locks, 0..) |op, *mcv, *lock, op_i| {
const op_index = @intCast(u2, op_i);
mcv.* = try self.resolveInst(op);
if (unused.isSet(0) and mcv.isRegister() and self.reuseOperand(inst, op, op_index, mcv.*)) {
order[op_index] = 1;
unused.unset(0);
} else if (unused.isSet(2) and mcv.isMemory()) {
order[op_index] = 3;
unused.unset(2);
}
switch (mcv.*) {
.register => |reg| lock.* = self.register_manager.lockReg(reg),
else => {},
}
}
for (&order, &mcvs, &locks) |*mop_index, *mcv, *lock| {
if (mop_index.* != 0) continue;
mop_index.* = 1 + @intCast(u2, unused.toggleFirstSet().?);
if (mop_index.* > 1 and mcv.isRegister()) continue;
const reg = try self.copyToTmpRegister(ty, mcv.*);
mcv.* = .{ .register = reg };
if (lock.*) |old_lock| self.register_manager.unlockReg(old_lock);
lock.* = self.register_manager.lockRegAssumeUnused(reg);
}
const tag: ?Mir.Inst.Tag =
if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or mem.eql(u2, &order, &.{ 3, 1, 2 }))
switch (ty.zigTypeTag()) {
.Float => switch (ty.floatBits(self.target.*)) {
32 => .vfmadd132ss,
64 => .vfmadd132sd,
else => null,
},
.Vector => switch (ty.childType().zigTypeTag()) {
.Float => switch (ty.childType().floatBits(self.target.*)) {
32 => switch (ty.vectorLen()) {
1 => .vfmadd132ss,
2...8 => .vfmadd132ps,
else => null,
},
64 => switch (ty.vectorLen()) {
1 => .vfmadd132sd,
2...4 => .vfmadd132pd,
else => null,
},
else => null,
},
else => null,
},
else => unreachable,
}
else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 }))
switch (ty.zigTypeTag()) {
.Float => switch (ty.floatBits(self.target.*)) {
32 => .vfmadd213ss,
64 => .vfmadd213sd,
else => null,
},
.Vector => switch (ty.childType().zigTypeTag()) {
.Float => switch (ty.childType().floatBits(self.target.*)) {
32 => switch (ty.vectorLen()) {
1 => .vfmadd213ss,
2...8 => .vfmadd213ps,
else => null,
},
64 => switch (ty.vectorLen()) {
1 => .vfmadd213sd,
2...4 => .vfmadd213pd,
else => null,
},
else => null,
},
else => null,
},
else => unreachable,
}
else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 }))
switch (ty.zigTypeTag()) {
.Float => switch (ty.floatBits(self.target.*)) {
32 => .vfmadd231ss,
64 => .vfmadd231sd,
else => null,
},
.Vector => switch (ty.childType().zigTypeTag()) {
.Float => switch (ty.childType().floatBits(self.target.*)) {
32 => switch (ty.vectorLen()) {
1 => .vfmadd231ss,
2...8 => .vfmadd231ps,
else => null,
},
64 => switch (ty.vectorLen()) {
1 => .vfmadd231sd,
2...4 => .vfmadd231pd,
else => null,
},
else => null,
},
else => null,
},
else => null,
}
else
unreachable;
if (tag == null) return self.fail("TODO implement airMulAdd for {}", .{
ty.fmt(self.bin_file.options.module.?),
});
var mops: [3]MCValue = undefined;
for (order, mcvs) |mop_index, mcv| mops[mop_index - 1] = mcv;
const abi_size = @intCast(u32, ty.abiSize(self.target.*));
const mop1_reg = registerAlias(mops[0].getReg().?, abi_size);
const mop2_reg = registerAlias(mops[1].getReg().?, abi_size);
if (mops[2].isRegister())
try self.asmRegisterRegisterRegister(
tag.?,
mop1_reg,
mop2_reg,
registerAlias(mops[2].getReg().?, abi_size),
)
else
try self.asmRegisterRegisterMemory(
tag.?,
mop1_reg,
mop2_reg,
mops[2].mem(Memory.PtrSize.fromSize(abi_size)),
);
return self.finishAir(inst, mops[0], ops);
}
fn resolveInst(self: *Self, ref: Air.Inst.Ref) InnerError!MCValue {

View File

@@ -340,6 +340,11 @@ pub const Mnemonic = enum {
vpunpcklbw, vpunpckldq, vpunpcklqdq, vpunpcklwd,
// F16C
vcvtph2ps, vcvtps2ph,
// FMA
vfmadd132pd, vfmadd213pd, vfmadd231pd,
vfmadd132ps, vfmadd213ps, vfmadd231ps,
vfmadd132sd, vfmadd213sd, vfmadd231sd,
vfmadd132ss, vfmadd213ss, vfmadd231ss,
// zig fmt: on
};
@@ -368,12 +373,13 @@ pub const Op = enum {
r8, r16, r32, r64,
rm8, rm16, rm32, rm64,
r32_m16, r64_m16,
m8, m16, m32, m64, m80, m128,
m8, m16, m32, m64, m80, m128, m256,
rel8, rel16, rel32,
m,
moffs,
sreg,
xmm, xmm_m32, xmm_m64, xmm_m128,
ymm, ymm_m256,
// zig fmt: on
pub fn fromOperand(operand: Instruction.Operand) Op {
@@ -385,6 +391,7 @@ pub const Op = enum {
.segment => return .sreg,
.floating_point => return switch (reg.bitSize()) {
128 => .xmm,
256 => .ymm,
else => unreachable,
},
.general_purpose => {
@@ -418,6 +425,7 @@ pub const Op = enum {
64 => .m64,
80 => .m80,
128 => .m128,
256 => .m256,
else => unreachable,
};
},
@@ -454,7 +462,8 @@ pub const Op = enum {
.eax, .r32, .rm32, .r32_m16 => unreachable,
.rax, .r64, .rm64, .r64_m16 => unreachable,
.xmm, .xmm_m32, .xmm_m64, .xmm_m128 => unreachable,
.m8, .m16, .m32, .m64, .m80, .m128 => unreachable,
.ymm, .ymm_m256 => unreachable,
.m8, .m16, .m32, .m64, .m80, .m128, .m256 => unreachable,
.unity => 1,
.imm8, .imm8s, .rel8 => 8,
.imm16, .imm16s, .rel16 => 16,
@@ -468,12 +477,13 @@ pub const Op = enum {
.none, .o16, .o32, .o64, .moffs, .m, .sreg => unreachable,
.unity, .imm8, .imm8s, .imm16, .imm16s, .imm32, .imm32s, .imm64 => unreachable,
.rel8, .rel16, .rel32 => unreachable,
.m8, .m16, .m32, .m64, .m80, .m128 => unreachable,
.m8, .m16, .m32, .m64, .m80, .m128, .m256 => unreachable,
.al, .cl, .r8, .rm8 => 8,
.ax, .r16, .rm16 => 16,
.eax, .r32, .rm32, .r32_m16 => 32,
.rax, .r64, .rm64, .r64_m16 => 64,
.xmm, .xmm_m32, .xmm_m64, .xmm_m128 => 128,
.ymm, .ymm_m256 => 256,
};
}
@@ -482,13 +492,14 @@ pub const Op = enum {
.none, .o16, .o32, .o64, .moffs, .m, .sreg => unreachable,
.unity, .imm8, .imm8s, .imm16, .imm16s, .imm32, .imm32s, .imm64 => unreachable,
.rel8, .rel16, .rel32 => unreachable,
.al, .cl, .r8, .ax, .r16, .eax, .r32, .rax, .r64, .xmm => unreachable,
.al, .cl, .r8, .ax, .r16, .eax, .r32, .rax, .r64, .xmm, .ymm => unreachable,
.m8, .rm8 => 8,
.m16, .rm16, .r32_m16, .r64_m16 => 16,
.m32, .rm32, .xmm_m32 => 32,
.m64, .rm64, .xmm_m64 => 64,
.m80 => 80,
.m128, .xmm_m128 => 128,
.m256, .ymm_m256 => 256,
};
}
@@ -513,6 +524,7 @@ pub const Op = enum {
.rm8, .rm16, .rm32, .rm64,
.r32_m16, .r64_m16,
.xmm, .xmm_m32, .xmm_m64, .xmm_m128,
.ymm, .ymm_m256,
=> true,
else => false,
};
@@ -539,7 +551,7 @@ pub const Op = enum {
.r32_m16, .r64_m16,
.m8, .m16, .m32, .m64, .m80, .m128,
.m,
.xmm_m32, .xmm_m64, .xmm_m128,
.xmm_m32, .xmm_m64, .xmm_m128, .ymm_m256,
=> true,
else => false,
};
@@ -562,6 +574,7 @@ pub const Op = enum {
.r32_m16, .r64_m16 => .general_purpose,
.sreg => .segment,
.xmm, .xmm_m32, .xmm_m64, .xmm_m128 => .floating_point,
.ymm, .ymm_m256 => .floating_point,
};
}
@@ -625,6 +638,7 @@ pub const Feature = enum {
none,
avx,
f16c,
fma,
sse,
sse2,
sse3,

View File

@@ -205,6 +205,19 @@ pub fn lowerMir(lower: *Lower, index: Mir.Inst.Index) Error!struct {
.vcvtph2ps,
.vcvtps2ph,
.vfmadd132pd,
.vfmadd213pd,
.vfmadd231pd,
.vfmadd132ps,
.vfmadd213ps,
.vfmadd231ps,
.vfmadd132sd,
.vfmadd213sd,
.vfmadd231sd,
.vfmadd132ss,
.vfmadd213ss,
.vfmadd231ss,
=> try lower.mirGeneric(inst),
.cmps,
@@ -288,6 +301,8 @@ fn imm(lower: Lower, ops: Mir.Inst.Ops, i: u32) Immediate {
.rmi_rip,
.mri_sib,
.mri_rip,
.rrm_sib,
.rrm_rip,
.rrmi_sib,
.rrmi_rip,
=> Immediate.u(i),
@@ -310,6 +325,7 @@ fn mem(lower: Lower, ops: Mir.Inst.Ops, payload: u32) Memory {
.mr_sib,
.mrr_sib,
.mri_sib,
.rrm_sib,
.rrmi_sib,
.lock_m_sib,
.lock_mi_sib_u,
@@ -327,6 +343,7 @@ fn mem(lower: Lower, ops: Mir.Inst.Ops, payload: u32) Memory {
.mr_rip,
.mrr_rip,
.mri_rip,
.rrm_rip,
.rrmi_rip,
.lock_m_rip,
.lock_mi_rip_u,
@@ -449,6 +466,11 @@ fn mirGeneric(lower: *Lower, inst: Mir.Inst) Error!void {
.{ .reg = inst.data.rix.r },
.{ .imm = lower.imm(inst.ops, inst.data.rix.i) },
},
.rrm_sib, .rrm_rip => &.{
.{ .reg = inst.data.rrx.r1 },
.{ .reg = inst.data.rrx.r2 },
.{ .mem = lower.mem(inst.ops, inst.data.rrx.payload) },
},
.rrmi_sib, .rrmi_rip => &.{
.{ .reg = inst.data.rrix.r1 },
.{ .reg = inst.data.rrix.r2 },

View File

@@ -324,6 +324,31 @@ pub const Inst = struct {
/// Convert single-precision floating-point values to 16-bit floating-point values
vcvtps2ph,
/// Fused multiply-add of packed double-precision floating-point values
vfmadd132pd,
/// Fused multiply-add of packed double-precision floating-point values
vfmadd213pd,
/// Fused multiply-add of packed double-precision floating-point values
vfmadd231pd,
/// Fused multiply-add of packed single-precision floating-point values
vfmadd132ps,
/// Fused multiply-add of packed single-precision floating-point values
vfmadd213ps,
/// Fused multiply-add of packed single-precision floating-point values
vfmadd231ps,
/// Fused multiply-add of scalar double-precision floating-point values
vfmadd132sd,
/// Fused multiply-add of scalar double-precision floating-point values
vfmadd213sd,
/// Fused multiply-add of scalar double-precision floating-point values
vfmadd231sd,
/// Fused multiply-add of scalar single-precision floating-point values
vfmadd132ss,
/// Fused multiply-add of scalar single-precision floating-point values
vfmadd213ss,
/// Fused multiply-add of scalar single-precision floating-point values
vfmadd231ss,
/// Compare string operands
cmps,
/// Load string
@@ -434,6 +459,12 @@ pub const Inst = struct {
/// Register, memory (SIB), immediate (byte) operands.
/// Uses `rix` payload with extra data of type `MemorySib`.
rmi_sib,
/// Register, register, memory (RIP).
/// Uses `rrix` payload with extra data of type `MemoryRip`.
rrm_rip,
/// Register, register, memory (SIB).
/// Uses `rrix` payload with extra data of type `MemorySib`.
rrm_sib,
/// Register, register, memory (RIP), immediate (byte) operands.
/// Uses `rrix` payload with extra data of type `MemoryRip`.
rrmi_rip,

View File

@@ -485,7 +485,9 @@ pub const Memory = union(enum) {
dword,
qword,
tbyte,
dqword,
xword,
yword,
zword,
pub fn fromSize(size: u32) PtrSize {
return switch (size) {
@@ -493,7 +495,9 @@ pub const Memory = union(enum) {
2...2 => .word,
3...4 => .dword,
5...8 => .qword,
9...16 => .dqword,
9...16 => .xword,
17...32 => .yword,
33...64 => .zword,
else => unreachable,
};
}
@@ -505,7 +509,9 @@ pub const Memory = union(enum) {
32 => .dword,
64 => .qword,
80 => .tbyte,
128 => .dqword,
128 => .xword,
256 => .yword,
512 => .zword,
else => unreachable,
};
}
@@ -517,7 +523,9 @@ pub const Memory = union(enum) {
.dword => 32,
.qword => 64,
.tbyte => 80,
.dqword => 128,
.xword => 128,
.yword => 256,
.zword => 512,
};
}
};

View File

@@ -1016,5 +1016,28 @@ pub const table = [_]Entry{
.{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128, .f16c },
.{ .vcvtps2ph, .mri, &.{ .xmm_m64, .xmm, .imm8 }, &.{ 0x66, 0x0f, 0x3a, 0x1d }, 0, .vex_128, .f16c },
// FMA
.{ .vfmadd132pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_128_long, .fma },
.{ .vfmadd132pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_256_long, .fma },
.{ .vfmadd213pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_128_long, .fma },
.{ .vfmadd213pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_256_long, .fma },
.{ .vfmadd231pd, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_128_long, .fma },
.{ .vfmadd231pd, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_256_long, .fma },
.{ .vfmadd132ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_128, .fma },
.{ .vfmadd132ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0x98 }, 0, .vex_256, .fma },
.{ .vfmadd213ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_128, .fma },
.{ .vfmadd213ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xa8 }, 0, .vex_256, .fma },
.{ .vfmadd231ps, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_128, .fma },
.{ .vfmadd231ps, .rvm, &.{ .ymm, .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x38, 0xb8 }, 0, .vex_256, .fma },
.{ .vfmadd132sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x99 }, 0, .vex_128_long, .fma },
.{ .vfmadd213sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0xa9 }, 0, .vex_128_long, .fma },
.{ .vfmadd231sd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0xb9 }, 0, .vex_128_long, .fma },
.{ .vfmadd132ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0x99 }, 0, .vex_128, .fma },
.{ .vfmadd213ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0xa9 }, 0, .vex_128, .fma },
.{ .vfmadd231ss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0x66, 0x0f, 0x38, 0xb9 }, 0, .vex_128, .fma },
};
// zig fmt: on

View File

@@ -1,8 +1,10 @@
const std = @import("std");
const builtin = @import("builtin");
const expect = @import("std").testing.expect;
const expect = std.testing.expect;
test "@mulAdd" {
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64 and
!comptime std.Target.x86.featureSetHas(builtin.cpu.features, .fma)) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO