stage2: implement @popCount for SIMD vectors

This commit is contained in:
Andrew Kelley
2022-02-12 20:44:30 -07:00
parent 16ec848d2a
commit a005ac9d3c
6 changed files with 61 additions and 81 deletions

View File

@@ -720,7 +720,6 @@ fn analyzeBodyInner(
.align_cast => try sema.zirAlignCast(block, inst),
.has_decl => try sema.zirHasDecl(block, inst),
.has_field => try sema.zirHasField(block, inst),
.pop_count => try sema.zirPopCount(block, inst),
.byte_swap => try sema.zirByteSwap(block, inst),
.bit_reverse => try sema.zirBitReverse(block, inst),
.bit_offset_of => try sema.zirBitOffsetOf(block, inst),
@@ -743,8 +742,9 @@ fn analyzeBodyInner(
.await_nosuspend => try sema.zirAwait(block, inst, true),
.extended => try sema.zirExtended(block, inst),
.clz => try sema.zirClzCtz(block, inst, .clz, Value.clz),
.ctz => try sema.zirClzCtz(block, inst, .ctz, Value.ctz),
.clz => try sema.zirBitCount(block, inst, .clz, Value.clz),
.ctz => try sema.zirBitCount(block, inst, .ctz, Value.ctz),
.pop_count => try sema.zirBitCount(block, inst, .popcount, Value.popCount),
.sqrt => try sema.zirUnaryMath(block, inst, .sqrt, Value.sqrt),
.sin => try sema.zirUnaryMath(block, inst, .sin, Value.sin),
@@ -11487,7 +11487,7 @@ fn zirAlignCast(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!A
return sema.coerceCompatiblePtrs(block, dest_ty, ptr, ptr_src);
}
fn zirClzCtz(
fn zirBitCount(
sema: *Sema,
block: *Block,
inst: Zir.Inst.Index,
@@ -11550,34 +11550,6 @@ fn zirClzCtz(
}
}
fn zirPopCount(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[inst].un_node;
const ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
const operand_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
const operand = sema.resolveInst(inst_data.operand);
const operand_ty = sema.typeOf(operand);
// TODO implement support for vectors
if (operand_ty.zigTypeTag() != .Int) {
return sema.fail(block, ty_src, "expected integer type, found '{}'", .{
operand_ty,
});
}
const target = sema.mod.getTarget();
const bits = operand_ty.intInfo(target).bits;
if (bits == 0) return Air.Inst.Ref.zero;
const result_ty = try Type.smallestUnsignedInt(sema.arena, bits);
const runtime_src = if (try sema.resolveMaybeUndefVal(block, operand_src, operand)) |val| {
if (val.isUndef()) return sema.addConstUndef(result_ty);
const result_val = try val.popCount(operand_ty, target, sema.arena);
return sema.addConstant(result_ty, result_val);
} else operand_src;
try sema.requireRuntimeBlock(block, runtime_src);
return block.addTyOp(.popcount, result_ty, operand);
}
fn zirByteSwap(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[inst].un_node;
const src = inst_data.src();

View File

@@ -2205,7 +2205,7 @@ pub const FuncGen = struct {
.get_union_tag => try self.airGetUnionTag(inst),
.clz => try self.airClzCtz(inst, "ctlz"),
.ctz => try self.airClzCtz(inst, "cttz"),
.popcount => try self.airPopCount(inst, "ctpop"),
.popcount => try self.airPopCount(inst),
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
@@ -4364,7 +4364,7 @@ pub const FuncGen = struct {
}
}
fn airPopCount(self: *FuncGen, inst: Air.Inst.Index, prefix: [*:0]const u8) !?*const llvm.Value {
fn airPopCount(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst)) return null;
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
@@ -4372,11 +4372,16 @@ pub const FuncGen = struct {
const operand = try self.resolveInst(ty_op.operand);
const target = self.dg.module.getTarget();
const bits = operand_ty.intInfo(target).bits;
const vec_len: ?u32 = switch (operand_ty.zigTypeTag()) {
.Vector => operand_ty.vectorLen(),
else => null,
};
var fn_name_buf: [100]u8 = undefined;
const llvm_fn_name = std.fmt.bufPrintZ(&fn_name_buf, "llvm.{s}.i{d}", .{
prefix, bits,
}) catch unreachable;
const llvm_fn_name = if (vec_len) |len|
std.fmt.bufPrintZ(&fn_name_buf, "llvm.ctpop.v{d}i{d}", .{ len, bits }) catch unreachable
else
std.fmt.bufPrintZ(&fn_name_buf, "llvm.ctpop.i{d}", .{bits}) catch unreachable;
const fn_val = self.dg.object.llvm_module.getNamedFunction(llvm_fn_name) orelse blk: {
const operand_llvm_ty = try self.dg.llvmType(operand_ty);
const param_types = [_]*const llvm.Type{operand_llvm_ty};

View File

@@ -1303,6 +1303,33 @@ pub const Value = extern union {
}
}
pub fn popCount(val: Value, ty: Type, target: Target) u64 {
assert(!val.isUndef());
switch (val.tag()) {
.zero, .bool_false => return 0,
.one, .bool_true => return 1,
.int_u64 => return @popCount(u64, val.castTag(.int_u64).?.data),
else => {
const info = ty.intInfo(target);
var buffer: Value.BigIntSpace = undefined;
const operand_bigint = val.toBigInt(&buffer);
var limbs_buffer: [4]std.math.big.Limb = undefined;
var result_bigint = BigIntMutable{
.limbs = &limbs_buffer,
.positive = undefined,
.len = undefined,
};
result_bigint.popCount(operand_bigint, info.bits);
return result_bigint.toConst().to(u64) catch unreachable;
},
}
}
/// Asserts the value is an integer and not undefined.
/// Returns the number of bits the value requires to represent stored in twos complement form.
pub fn intBitCountTwosComp(self: Value, target: Target) usize {
@@ -1340,24 +1367,6 @@ pub const Value = extern union {
}
}
pub fn popCount(val: Value, ty: Type, target: Target, arena: Allocator) !Value {
assert(!val.isUndef());
const info = ty.intInfo(target);
var buffer: Value.BigIntSpace = undefined;
const operand_bigint = val.toBigInt(&buffer);
const limbs = try arena.alloc(
std.math.big.Limb,
std.math.big.int.calcTwosCompLimbCount(info.bits),
);
var result_bigint = BigIntMutable{ .limbs = limbs, .positive = undefined, .len = undefined };
result_bigint.popCount(operand_bigint, info.bits);
return fromBigInt(arena, result_bigint.toConst());
}
/// Asserts the value is an integer, and the destination type is ComptimeInt or Int.
pub fn intFitsInType(self: Value, ty: Type, target: Target) bool {
switch (self.tag()) {

View File

@@ -153,7 +153,6 @@ test {
_ = @import("behavior/ir_block_deps.zig");
_ = @import("behavior/misc.zig");
_ = @import("behavior/muladd.zig");
_ = @import("behavior/popcount_stage1.zig");
_ = @import("behavior/reflection.zig");
_ = @import("behavior/select.zig");
_ = @import("behavior/shuffle.zig");

View File

@@ -1,7 +1,6 @@
const std = @import("std");
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;
const Vector = std.meta.Vector;
test "@popCount integers" {
comptime try testPopCountIntegers();
@@ -44,3 +43,23 @@ fn testPopCountIntegers() !void {
try expect(@popCount(i128, @as(i128, 0b11111111000110001100010000100001000011000011100101010001)) == 24);
}
}
test "@popCount vectors" {
comptime try testPopCountVectors();
try testPopCountVectors();
}
fn testPopCountVectors() !void {
{
var x: @Vector(8, u32) = [1]u32{0xffffffff} ** 8;
const expected = [1]u6{32} ** 8;
const result: [8]u6 = @popCount(u32, x);
try expect(std.mem.eql(u6, &expected, &result));
}
{
var x: @Vector(8, i16) = [1]i16{-1} ** 8;
const expected = [1]u5{16} ** 8;
const result: [8]u5 = @popCount(i16, x);
try expect(std.mem.eql(u5, &expected, &result));
}
}

View File

@@ -1,24 +0,0 @@
const std = @import("std");
const expect = std.testing.expect;
const expectEqual = std.testing.expectEqual;
const Vector = std.meta.Vector;
test "@popCount vectors" {
comptime try testPopCountVectors();
try testPopCountVectors();
}
fn testPopCountVectors() !void {
{
var x: Vector(8, u32) = [1]u32{0xffffffff} ** 8;
const expected = [1]u6{32} ** 8;
const result: [8]u6 = @popCount(u32, x);
try expect(std.mem.eql(u6, &expected, &result));
}
{
var x: Vector(8, i16) = [1]i16{-1} ** 8;
const expected = [1]u5{16} ** 8;
const result: [8]u5 = @popCount(i16, x);
try expect(std.mem.eql(u5, &expected, &result));
}
}