From 5c5da179fb930c9d8be9366a851eb4a36f4044f1 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Sun, 7 May 2023 03:47:56 -0400 Subject: [PATCH] x86_64: implement `@sqrt` for vectors --- src/arch/x86_64/CodeGen.zig | 225 +++++++++++++++++++++------------- src/arch/x86_64/Encoding.zig | 1 + src/arch/x86_64/Lower.zig | 4 + src/arch/x86_64/Mir.zig | 8 ++ src/arch/x86_64/encodings.zig | 18 ++- 5 files changed, 166 insertions(+), 90 deletions(-) diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 38497400f2..19878bae17 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -4520,25 +4520,69 @@ fn airRound(self: *Self, inst: Air.Inst.Index, mode: Immediate) !void { fn airSqrt(self: *Self, inst: Air.Inst.Index) !void { const un_op = self.air.instructions.items(.data)[inst].un_op; const ty = self.air.typeOf(un_op); + const abi_size = @intCast(u32, ty.abiSize(self.target.*)); const src_mcv = try self.resolveInst(un_op); const dst_mcv = if (src_mcv.isRegister() and self.reuseOperand(inst, un_op, 0, src_mcv)) src_mcv else try self.copyToRegisterWithInstTracking(inst, ty, src_mcv); + const dst_reg = registerAlias(dst_mcv.getReg().?, abi_size); + const dst_lock = self.register_manager.lockReg(dst_reg); + defer if (dst_lock) |lock| self.register_manager.unlockReg(lock); - try self.genBinOpMir(switch (ty.zigTypeTag()) { - .Float => switch (ty.floatBits(self.target.*)) { - 32 => .sqrtss, - 64 => .sqrtsd, - else => return self.fail("TODO implement airSqrt for {}", .{ - ty.fmt(self.bin_file.options.module.?), - }), + const tag = if (@as(?Mir.Inst.Tag, switch (ty.zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 32 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, + 64 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, + 16, 80, 128 => null, + else => unreachable, }, - else => return self.fail("TODO implement airSqrt for {}", .{ - ty.fmt(self.bin_file.options.module.?), - }), - }, ty, dst_mcv, src_mcv); + .Vector => switch (ty.childType().zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 32 => switch (ty.vectorLen()) { + 1 => if (self.hasFeature(.avx)) .vsqrtss else .sqrtss, + 2...4 => if (self.hasFeature(.avx)) .vsqrtps else .sqrtps, + 5...8 => if (self.hasFeature(.avx)) .vsqrtps else null, + else => null, + }, + 64 => switch (ty.vectorLen()) { + 1 => if (self.hasFeature(.avx)) .vsqrtsd else .sqrtsd, + 2 => if (self.hasFeature(.avx)) .vsqrtpd else .sqrtpd, + 3...4 => if (self.hasFeature(.avx)) .vsqrtpd else null, + else => null, + }, + 16, 80, 128 => null, + else => unreachable, + }, + else => unreachable, + }, + else => unreachable, + })) |tag| tag else return self.fail("TODO implement airSqrt for {}", .{ + ty.fmt(self.bin_file.options.module.?), + }); + switch (tag) { + .vsqrtss, .vsqrtsd => if (src_mcv.isRegister()) try self.asmRegisterRegisterRegister( + tag, + dst_reg, + dst_reg, + registerAlias(src_mcv.getReg().?, abi_size), + ) else try self.asmRegisterRegisterMemory( + tag, + dst_reg, + dst_reg, + src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), + ), + else => if (src_mcv.isRegister()) try self.asmRegisterRegister( + tag, + dst_reg, + registerAlias(src_mcv.getReg().?, abi_size), + ) else try self.asmRegisterMemory( + tag, + dst_reg, + src_mcv.mem(Memory.PtrSize.fromSize(abi_size)), + ), + } return self.finishAir(inst, dst_mcv, .{ un_op, .none, .none }); } @@ -9544,85 +9588,92 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { lock.* = self.register_manager.lockRegAssumeUnused(reg); } - const tag: ?Mir.Inst.Tag = + const tag = if (@as( + ?Mir.Inst.Tag, if (mem.eql(u2, &order, &.{ 1, 3, 2 }) or mem.eql(u2, &order, &.{ 3, 1, 2 })) - switch (ty.zigTypeTag()) { - .Float => switch (ty.floatBits(self.target.*)) { - 32 => .vfmadd132ss, - 64 => .vfmadd132sd, - else => null, - }, - .Vector => switch (ty.childType().zigTypeTag()) { - .Float => switch (ty.childType().floatBits(self.target.*)) { - 32 => switch (ty.vectorLen()) { - 1 => .vfmadd132ss, - 2...8 => .vfmadd132ps, - else => null, - }, - 64 => switch (ty.vectorLen()) { - 1 => .vfmadd132sd, - 2...4 => .vfmadd132pd, - else => null, - }, - else => null, + switch (ty.zigTypeTag()) { + .Float => switch (ty.floatBits(self.target.*)) { + 32 => .vfmadd132ss, + 64 => .vfmadd132sd, + 16, 80, 128 => null, + else => unreachable, }, - else => null, - }, - else => unreachable, - } - else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 })) - switch (ty.zigTypeTag()) { - .Float => switch (ty.floatBits(self.target.*)) { - 32 => .vfmadd213ss, - 64 => .vfmadd213sd, - else => null, - }, - .Vector => switch (ty.childType().zigTypeTag()) { - .Float => switch (ty.childType().floatBits(self.target.*)) { - 32 => switch (ty.vectorLen()) { - 1 => .vfmadd213ss, - 2...8 => .vfmadd213ps, - else => null, + .Vector => switch (ty.childType().zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 32 => switch (ty.vectorLen()) { + 1 => .vfmadd132ss, + 2...8 => .vfmadd132ps, + else => null, + }, + 64 => switch (ty.vectorLen()) { + 1 => .vfmadd132sd, + 2...4 => .vfmadd132pd, + else => null, + }, + 16, 80, 128 => null, + else => unreachable, }, - 64 => switch (ty.vectorLen()) { - 1 => .vfmadd213sd, - 2...4 => .vfmadd213pd, - else => null, - }, - else => null, + else => unreachable, }, - else => null, - }, - else => unreachable, - } - else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 })) - switch (ty.zigTypeTag()) { - .Float => switch (ty.floatBits(self.target.*)) { - 32 => .vfmadd231ss, - 64 => .vfmadd231sd, - else => null, - }, - .Vector => switch (ty.childType().zigTypeTag()) { - .Float => switch (ty.childType().floatBits(self.target.*)) { - 32 => switch (ty.vectorLen()) { - 1 => .vfmadd231ss, - 2...8 => .vfmadd231ps, - else => null, - }, - 64 => switch (ty.vectorLen()) { - 1 => .vfmadd231sd, - 2...4 => .vfmadd231pd, - else => null, - }, - else => null, + else => unreachable, + } + else if (mem.eql(u2, &order, &.{ 2, 1, 3 }) or mem.eql(u2, &order, &.{ 1, 2, 3 })) + switch (ty.zigTypeTag()) { + .Float => switch (ty.floatBits(self.target.*)) { + 32 => .vfmadd213ss, + 64 => .vfmadd213sd, + 16, 80, 128 => null, + else => unreachable, }, - else => null, - }, - else => null, - } - else - unreachable; - if (tag == null) return self.fail("TODO implement airMulAdd for {}", .{ + .Vector => switch (ty.childType().zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 32 => switch (ty.vectorLen()) { + 1 => .vfmadd213ss, + 2...8 => .vfmadd213ps, + else => null, + }, + 64 => switch (ty.vectorLen()) { + 1 => .vfmadd213sd, + 2...4 => .vfmadd213pd, + else => null, + }, + 16, 80, 128 => null, + else => unreachable, + }, + else => unreachable, + }, + else => unreachable, + } + else if (mem.eql(u2, &order, &.{ 2, 3, 1 }) or mem.eql(u2, &order, &.{ 3, 2, 1 })) + switch (ty.zigTypeTag()) { + .Float => switch (ty.floatBits(self.target.*)) { + 32 => .vfmadd231ss, + 64 => .vfmadd231sd, + 16, 80, 128 => null, + else => unreachable, + }, + .Vector => switch (ty.childType().zigTypeTag()) { + .Float => switch (ty.childType().floatBits(self.target.*)) { + 32 => switch (ty.vectorLen()) { + 1 => .vfmadd231ss, + 2...8 => .vfmadd231ps, + else => null, + }, + 64 => switch (ty.vectorLen()) { + 1 => .vfmadd231sd, + 2...4 => .vfmadd231pd, + else => null, + }, + 16, 80, 128 => null, + else => unreachable, + }, + else => unreachable, + }, + else => unreachable, + } + else + unreachable, + )) |tag| tag else return self.fail("TODO implement airMulAdd for {}", .{ ty.fmt(self.bin_file.options.module.?), }); @@ -9634,14 +9685,14 @@ fn airMulAdd(self: *Self, inst: Air.Inst.Index) !void { const mop2_reg = registerAlias(mops[1].getReg().?, abi_size); if (mops[2].isRegister()) try self.asmRegisterRegisterRegister( - tag.?, + tag, mop1_reg, mop2_reg, registerAlias(mops[2].getReg().?, abi_size), ) else try self.asmRegisterRegisterMemory( - tag.?, + tag, mop1_reg, mop2_reg, mops[2].mem(Memory.PtrSize.fromSize(abi_size)), diff --git a/src/arch/x86_64/Encoding.zig b/src/arch/x86_64/Encoding.zig index bd6e70c975..b242c98bdc 100644 --- a/src/arch/x86_64/Encoding.zig +++ b/src/arch/x86_64/Encoding.zig @@ -316,6 +316,7 @@ pub const Mnemonic = enum { vpsrld, vpsrlq, vpsrlw, vpunpckhbw, vpunpckhdq, vpunpckhqdq, vpunpckhwd, vpunpcklbw, vpunpckldq, vpunpcklqdq, vpunpcklwd, + vsqrtpd, vsqrtps, vsqrtsd, vsqrtss, // F16C vcvtph2ps, vcvtps2ph, // FMA diff --git a/src/arch/x86_64/Lower.zig b/src/arch/x86_64/Lower.zig index 40a5ccdb10..39ad2313e7 100644 --- a/src/arch/x86_64/Lower.zig +++ b/src/arch/x86_64/Lower.zig @@ -212,6 +212,10 @@ pub fn lowerMir(lower: *Lower, index: Mir.Inst.Index) Error!struct { .vpunpckldq, .vpunpcklqdq, .vpunpcklwd, + .vsqrtpd, + .vsqrtps, + .vsqrtsd, + .vsqrtss, .vcvtph2ps, .vcvtps2ph, diff --git a/src/arch/x86_64/Mir.zig b/src/arch/x86_64/Mir.zig index cb1a578bb6..b6df0fff09 100644 --- a/src/arch/x86_64/Mir.zig +++ b/src/arch/x86_64/Mir.zig @@ -338,6 +338,14 @@ pub const Inst = struct { vpunpcklqdq, /// Unpack low data vpunpcklwd, + /// Square root of packed double-precision floating-point value + vsqrtpd, + /// Square root of packed single-precision floating-point value + vsqrtps, + /// Square root of scalar double-precision floating-point value + vsqrtsd, + /// Square root of scalar single-precision floating-point value + vsqrtss, /// Convert 16-bit floating-point values to single-precision floating-point values vcvtph2ps, diff --git a/src/arch/x86_64/encodings.zig b/src/arch/x86_64/encodings.zig index 5e4dc2f04b..49ebc344fd 100644 --- a/src/arch/x86_64/encodings.zig +++ b/src/arch/x86_64/encodings.zig @@ -869,8 +869,9 @@ pub const table = [_]Entry{ .{ .subss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x5c }, 0, .none, .sse }, - .{ .sqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .none, .sse }, - .{ .sqrtss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x51 }, 0, .none, .sse }, + .{ .sqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .none, .sse }, + + .{ .sqrtss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f, 0x51 }, 0, .none, .sse }, .{ .ucomiss, .rm, &.{ .xmm, .xmm_m32 }, &.{ 0x0f, 0x2e }, 0, .none, .sse }, @@ -943,7 +944,8 @@ pub const table = [_]Entry{ .{ .punpcklqdq, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6c }, 0, .none, .sse2 }, .{ .sqrtpd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x51 }, 0, .none, .sse2 }, - .{ .sqrtsd, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x51 }, 0, .none, .sse2 }, + + .{ .sqrtsd, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x51 }, 0, .none, .sse2 }, .{ .subsd, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f, 0x5c }, 0, .none, .sse2 }, @@ -1039,6 +1041,16 @@ pub const table = [_]Entry{ .{ .vpunpckldq, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x62 }, 0, .vex_128_wig, .avx }, .{ .vpunpcklqdq, .rvm, &.{ .xmm, .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x6c }, 0, .vex_128_wig, .avx }, + .{ .vsqrtpd, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x51 }, 0, .vex_128_wig, .avx }, + .{ .vsqrtpd, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x66, 0x0f, 0x51 }, 0, .vex_256_wig, .avx }, + + .{ .vsqrtps, .rm, &.{ .xmm, .xmm_m128 }, &.{ 0x0f, 0x51 }, 0, .vex_128_wig, .avx }, + .{ .vsqrtps, .rm, &.{ .ymm, .ymm_m256 }, &.{ 0x0f, 0x51 }, 0, .vex_256_wig, .avx }, + + .{ .vsqrtsd, .rvm, &.{ .xmm, .xmm, .xmm_m64 }, &.{ 0xf2, 0x0f }, 0, .vex_lig_wig, .avx }, + + .{ .vsqrtss, .rvm, &.{ .xmm, .xmm, .xmm_m32 }, &.{ 0xf3, 0x0f }, 0, .vex_lig_wig, .avx }, + // F16C .{ .vcvtph2ps, .rm, &.{ .xmm, .xmm_m64 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_128_w0, .f16c }, .{ .vcvtph2ps, .rm, &.{ .ymm, .xmm_m128 }, &.{ 0x66, 0x0f, 0x38, 0x13 }, 0, .vex_256_w0, .f16c },