compiler: implement better shuffle AIR

Runtime `@shuffle` has two cases which backends generally want to handle
differently for efficiency:

* One runtime vector operand; some result elements may be comptime-known
* Two runtime vector operands; some result elements may be undefined

The latter case happens if both vectors given to `@shuffle` are
runtime-known and they are both used (i.e. the mask refers to them).
Otherwise, if the result is not entirely comptime-known, we are in the
former case. `Sema` now diffentiates these two cases in the AIR so that
backends can easily handle them however they want to. Note that this
*doesn't* really involve Sema doing any more work than it would
otherwise need to, so there's not really a negative here!

Most existing backends have their lowerings for `@shuffle` migrated in
this commit. The LLVM backend uses new lowerings suggested by Jacob as
ones which it will handle effectively. The x86_64 backend has not yet
been migrated; for now there's a panic in there. Jacob will implement
that before this is merged anywhere.
This commit is contained in:
mlugg
2025-05-26 05:07:13 +01:00
parent b48d6ff619
commit add2976a9b
18 changed files with 755 additions and 321 deletions

View File

@@ -3374,7 +3374,8 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail,
.error_name => try airErrorName(f, inst),
.splat => try airSplat(f, inst),
.select => try airSelect(f, inst),
.shuffle => try airShuffle(f, inst),
.shuffle_one => try airShuffleOne(f, inst),
.shuffle_two => try airShuffleTwo(f, inst),
.reduce => try airReduce(f, inst),
.aggregate_init => try airAggregateInit(f, inst),
.union_init => try airUnionInit(f, inst),
@@ -7163,34 +7164,73 @@ fn airSelect(f: *Function, inst: Air.Inst.Index) !CValue {
return local;
}
fn airShuffle(f: *Function, inst: Air.Inst.Index) !CValue {
fn airShuffleOne(f: *Function, inst: Air.Inst.Index) !CValue {
const pt = f.object.dg.pt;
const zcu = pt.zcu;
const ty_pl = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = f.air.extraData(Air.Shuffle, ty_pl.payload).data;
const mask = Value.fromInterned(extra.mask);
const lhs = try f.resolveInst(extra.a);
const rhs = try f.resolveInst(extra.b);
const inst_ty = f.typeOfIndex(inst);
const unwrapped = f.air.unwrapShuffleOne(zcu, inst);
const mask = unwrapped.mask;
const operand = try f.resolveInst(unwrapped.operand);
const inst_ty = unwrapped.result_ty;
const writer = f.object.writer();
const local = try f.allocLocal(inst, inst_ty);
try reap(f, inst, &.{ extra.a, extra.b }); // local cannot alias operands
for (0..extra.mask_len) |index| {
try reap(f, inst, &.{unwrapped.operand}); // local cannot alias operand
for (mask, 0..) |mask_elem, out_idx| {
try f.writeCValue(writer, local, .Other);
try writer.writeByte('[');
try f.object.dg.renderValue(writer, try pt.intValue(.usize, index), .Other);
try f.object.dg.renderValue(writer, try pt.intValue(.usize, out_idx), .Other);
try writer.writeAll("] = ");
switch (mask_elem.unwrap()) {
.elem => |src_idx| {
try f.writeCValue(writer, operand, .Other);
try writer.writeByte('[');
try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other);
try writer.writeByte(']');
},
.value => |val| try f.object.dg.renderValue(writer, .fromInterned(val), .Other),
}
try writer.writeAll(";\n");
}
const mask_elem = (try mask.elemValue(pt, index)).toSignedInt(zcu);
const src_val = try pt.intValue(.usize, @as(u64, @intCast(mask_elem ^ mask_elem >> 63)));
return local;
}
try f.writeCValue(writer, if (mask_elem >= 0) lhs else rhs, .Other);
fn airShuffleTwo(f: *Function, inst: Air.Inst.Index) !CValue {
const pt = f.object.dg.pt;
const zcu = pt.zcu;
const unwrapped = f.air.unwrapShuffleTwo(zcu, inst);
const mask = unwrapped.mask;
const operand_a = try f.resolveInst(unwrapped.operand_a);
const operand_b = try f.resolveInst(unwrapped.operand_b);
const inst_ty = unwrapped.result_ty;
const elem_ty = inst_ty.childType(zcu);
const writer = f.object.writer();
const local = try f.allocLocal(inst, inst_ty);
try reap(f, inst, &.{ unwrapped.operand_a, unwrapped.operand_b }); // local cannot alias operands
for (mask, 0..) |mask_elem, out_idx| {
try f.writeCValue(writer, local, .Other);
try writer.writeByte('[');
try f.object.dg.renderValue(writer, src_val, .Other);
try writer.writeAll("];\n");
try f.object.dg.renderValue(writer, try pt.intValue(.usize, out_idx), .Other);
try writer.writeAll("] = ");
switch (mask_elem.unwrap()) {
.a_elem => |src_idx| {
try f.writeCValue(writer, operand_a, .Other);
try writer.writeByte('[');
try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other);
try writer.writeByte(']');
},
.b_elem => |src_idx| {
try f.writeCValue(writer, operand_b, .Other);
try writer.writeByte('[');
try f.object.dg.renderValue(writer, try pt.intValue(.usize, src_idx), .Other);
try writer.writeByte(']');
},
.undef => try f.object.dg.renderUndefValue(writer, elem_ty, .Other),
}
try writer.writeAll(";\n");
}
return local;

View File

@@ -4969,7 +4969,8 @@ pub const FuncGen = struct {
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
.select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.shuffle_one => try self.airShuffleOne(inst),
.shuffle_two => try self.airShuffleTwo(inst),
.aggregate_init => try self.airAggregateInit(inst),
.union_init => try self.airUnionInit(inst),
.prefetch => try self.airPrefetch(inst),
@@ -9666,7 +9667,7 @@ pub const FuncGen = struct {
const zcu = o.pt.zcu;
const ip = &zcu.intern_pool;
for (body_tail[1..]) |body_inst| {
switch (fg.liveness.categorizeOperand(fg.air, body_inst, body_tail[0], ip)) {
switch (fg.liveness.categorizeOperand(fg.air, zcu, body_inst, body_tail[0], ip)) {
.none => continue,
.write, .noret, .complex => return false,
.tomb => return true,
@@ -10421,42 +10422,192 @@ pub const FuncGen = struct {
return self.wip.select(.normal, pred, a, b, "");
}
fn airShuffle(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value {
const o = self.ng.object;
fn airShuffleOne(fg: *FuncGen, inst: Air.Inst.Index) !Builder.Value {
const o = fg.ng.object;
const pt = o.pt;
const zcu = pt.zcu;
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data;
const a = try self.resolveInst(extra.a);
const b = try self.resolveInst(extra.b);
const mask = Value.fromInterned(extra.mask);
const mask_len = extra.mask_len;
const a_len = self.typeOf(extra.a).vectorLen(zcu);
const gpa = zcu.gpa;
// LLVM uses integers larger than the length of the first array to
// index into the second array. This was deemed unnecessarily fragile
// when changing code, so Zig uses negative numbers to index the
// second vector. These start at -1 and go down, and are easiest to use
// with the ~ operator. Here we convert between the two formats.
const values = try self.gpa.alloc(Builder.Constant, mask_len);
defer self.gpa.free(values);
const unwrapped = fg.air.unwrapShuffleOne(zcu, inst);
for (values, 0..) |*val, i| {
const elem = try mask.elemValue(pt, i);
if (elem.isUndef(zcu)) {
val.* = try o.builder.undefConst(.i32);
} else {
const int = elem.toSignedInt(zcu);
const unsigned: u32 = @intCast(if (int >= 0) int else ~int + a_len);
val.* = try o.builder.intConst(.i32, unsigned);
const operand = try fg.resolveInst(unwrapped.operand);
const mask = unwrapped.mask;
const operand_ty = fg.typeOf(unwrapped.operand);
const llvm_operand_ty = try o.lowerType(operand_ty);
const llvm_result_ty = try o.lowerType(unwrapped.result_ty);
const llvm_elem_ty = try o.lowerType(unwrapped.result_ty.childType(zcu));
const llvm_poison_elem = try o.builder.poisonConst(llvm_elem_ty);
const llvm_poison_mask_elem = try o.builder.poisonConst(.i32);
const llvm_mask_ty = try o.builder.vectorType(.normal, @intCast(mask.len), .i32);
// LLVM requires that the two input vectors have the same length, so lowering isn't trivial.
// And, in the words of jacobly0: "llvm sucks at shuffles so we do have to hold its hand at
// least a bit". So, there are two cases here.
//
// If the operand length equals the mask length, we do just the one `shufflevector`, where
// the second operand is a constant vector with comptime-known elements at the right indices
// and poison values elsewhere (in the indices which won't be selected).
//
// Otherwise, we lower to *two* `shufflevector` instructions. The first shuffles the runtime
// operand with an all-poison vector to extract and correctly position all of the runtime
// elements. We also make a constant vector with all of the comptime elements correctly
// positioned. Then, our second instruction selects elements from those "runtime-or-poison"
// and "comptime-or-poison" vectors to compute the result.
// This buffer is used primarily for the mask constants.
const llvm_elem_buf = try gpa.alloc(Builder.Constant, mask.len);
defer gpa.free(llvm_elem_buf);
// ...but first, we'll collect all of the comptime-known values.
var any_defined_comptime_value = false;
for (mask, llvm_elem_buf) |mask_elem, *llvm_elem| {
llvm_elem.* = switch (mask_elem.unwrap()) {
.elem => llvm_poison_elem,
.value => |val| if (!Value.fromInterned(val).isUndef(zcu)) elem: {
any_defined_comptime_value = true;
break :elem try o.lowerValue(val);
} else llvm_poison_elem,
};
}
// This vector is like the result, but runtime elements are replaced with poison.
const comptime_and_poison: Builder.Value = if (any_defined_comptime_value) vec: {
break :vec try o.builder.vectorValue(llvm_result_ty, llvm_elem_buf);
} else try o.builder.poisonValue(llvm_result_ty);
if (operand_ty.vectorLen(zcu) == mask.len) {
// input length equals mask/output length, so we lower to one instruction
for (mask, llvm_elem_buf, 0..) |mask_elem, *llvm_elem, elem_idx| {
llvm_elem.* = switch (mask_elem.unwrap()) {
.elem => |idx| try o.builder.intConst(.i32, idx),
.value => |val| if (!Value.fromInterned(val).isUndef(zcu)) mask_val: {
break :mask_val try o.builder.intConst(.i32, mask.len + elem_idx);
} else llvm_poison_mask_elem,
};
}
return fg.wip.shuffleVector(
operand,
comptime_and_poison,
try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf),
"",
);
}
const llvm_mask_value = try o.builder.vectorValue(
try o.builder.vectorType(.normal, mask_len, .i32),
values,
for (mask, llvm_elem_buf) |mask_elem, *llvm_elem| {
llvm_elem.* = switch (mask_elem.unwrap()) {
.elem => |idx| try o.builder.intConst(.i32, idx),
.value => llvm_poison_mask_elem,
};
}
// This vector is like our result, but all comptime-known elements are poison.
const runtime_and_poison = try fg.wip.shuffleVector(
operand,
try o.builder.poisonValue(llvm_operand_ty),
try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf),
"",
);
if (!any_defined_comptime_value) {
// `comptime_and_poison` is just poison; a second shuffle would be a nop.
return runtime_and_poison;
}
// In this second shuffle, the inputs, the mask, and the output all have the same length.
for (mask, llvm_elem_buf, 0..) |mask_elem, *llvm_elem, elem_idx| {
llvm_elem.* = switch (mask_elem.unwrap()) {
.elem => try o.builder.intConst(.i32, elem_idx),
.value => |val| if (!Value.fromInterned(val).isUndef(zcu)) mask_val: {
break :mask_val try o.builder.intConst(.i32, mask.len + elem_idx);
} else llvm_poison_mask_elem,
};
}
// Merge the runtime and comptime elements with the mask we just built.
return fg.wip.shuffleVector(
runtime_and_poison,
comptime_and_poison,
try o.builder.vectorValue(llvm_mask_ty, llvm_elem_buf),
"",
);
}
fn airShuffleTwo(fg: *FuncGen, inst: Air.Inst.Index) !Builder.Value {
const o = fg.ng.object;
const pt = o.pt;
const zcu = pt.zcu;
const gpa = zcu.gpa;
const unwrapped = fg.air.unwrapShuffleTwo(zcu, inst);
const mask = unwrapped.mask;
const llvm_elem_ty = try o.lowerType(unwrapped.result_ty.childType(zcu));
const llvm_mask_ty = try o.builder.vectorType(.normal, @intCast(mask.len), .i32);
const llvm_poison_mask_elem = try o.builder.poisonConst(.i32);
// This is kind of simpler than in `airShuffleOne`. We extend the shorter vector to the
// length of the longer one with an initial `shufflevector` if necessary, and then do the
// actual computation with a second `shufflevector`.
const operand_a_len = fg.typeOf(unwrapped.operand_a).vectorLen(zcu);
const operand_b_len = fg.typeOf(unwrapped.operand_b).vectorLen(zcu);
const operand_len: u32 = @max(operand_a_len, operand_b_len);
// If we need to extend an operand, this is the type that mask will have.
const llvm_operand_mask_ty = try o.builder.vectorType(.normal, operand_len, .i32);
const llvm_elem_buf = try gpa.alloc(Builder.Constant, @max(mask.len, operand_len));
defer gpa.free(llvm_elem_buf);
const operand_a: Builder.Value = extend: {
const raw = try fg.resolveInst(unwrapped.operand_a);
if (operand_a_len == operand_len) break :extend raw;
// Extend with a `shufflevector`, with a mask `<0, 1, ..., n, poison, poison, ..., poison>`
const mask_elems = llvm_elem_buf[0..operand_len];
for (mask_elems[0..operand_a_len], 0..) |*llvm_elem, elem_idx| {
llvm_elem.* = try o.builder.intConst(.i32, elem_idx);
}
@memset(mask_elems[operand_a_len..], llvm_poison_mask_elem);
const llvm_this_operand_ty = try o.builder.vectorType(.normal, operand_a_len, llvm_elem_ty);
break :extend try fg.wip.shuffleVector(
raw,
try o.builder.poisonValue(llvm_this_operand_ty),
try o.builder.vectorValue(llvm_operand_mask_ty, mask_elems),
"",
);
};
const operand_b: Builder.Value = extend: {
const raw = try fg.resolveInst(unwrapped.operand_b);
if (operand_b_len == operand_len) break :extend raw;
// Extend with a `shufflevector`, with a mask `<0, 1, ..., n, poison, poison, ..., poison>`
const mask_elems = llvm_elem_buf[0..operand_len];
for (mask_elems[0..operand_b_len], 0..) |*llvm_elem, elem_idx| {
llvm_elem.* = try o.builder.intConst(.i32, elem_idx);
}
@memset(mask_elems[operand_b_len..], llvm_poison_mask_elem);
const llvm_this_operand_ty = try o.builder.vectorType(.normal, operand_b_len, llvm_elem_ty);
break :extend try fg.wip.shuffleVector(
raw,
try o.builder.poisonValue(llvm_this_operand_ty),
try o.builder.vectorValue(llvm_operand_mask_ty, mask_elems),
"",
);
};
// `operand_a` and `operand_b` now have the same length (we've extended the shorter one with
// an initial shuffle if necessary). Now for the easy bit.
const mask_elems = llvm_elem_buf[0..mask.len];
for (mask, mask_elems) |mask_elem, *llvm_mask_elem| {
llvm_mask_elem.* = switch (mask_elem.unwrap()) {
.a_elem => |idx| try o.builder.intConst(.i32, idx),
.b_elem => |idx| try o.builder.intConst(.i32, operand_len + idx),
.undef => llvm_poison_mask_elem,
};
}
return fg.wip.shuffleVector(
operand_a,
operand_b,
try o.builder.vectorValue(llvm_mask_ty, mask_elems),
"",
);
return self.wip.shuffleVector(a, b, llvm_mask_value, "");
}
/// Reduce a vector by repeatedly applying `llvm_fn` to produce an accumulated result.

View File

@@ -3252,7 +3252,8 @@ const NavGen = struct {
.splat => try self.airSplat(inst),
.reduce, .reduce_optimized => try self.airReduce(inst),
.shuffle => try self.airShuffle(inst),
.shuffle_one => try self.airShuffleOne(inst),
.shuffle_two => try self.airShuffleTwo(inst),
.ptr_add => try self.airPtrAdd(inst),
.ptr_sub => try self.airPtrSub(inst),
@@ -4047,40 +4048,57 @@ const NavGen = struct {
return result_id;
}
fn airShuffle(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const pt = self.pt;
fn airShuffleOne(ng: *NavGen, inst: Air.Inst.Index) !?IdRef {
const pt = ng.pt;
const zcu = pt.zcu;
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Shuffle, ty_pl.payload).data;
const a = try self.resolve(extra.a);
const b = try self.resolve(extra.b);
const mask = Value.fromInterned(extra.mask);
const gpa = zcu.gpa;
// Note: number of components in the result, a, and b may differ.
const result_ty = self.typeOfIndex(inst);
const scalar_ty = result_ty.scalarType(zcu);
const scalar_ty_id = try self.resolveType(scalar_ty, .direct);
const unwrapped = ng.air.unwrapShuffleOne(zcu, inst);
const mask = unwrapped.mask;
const result_ty = unwrapped.result_ty;
const elem_ty = result_ty.childType(zcu);
const operand = try ng.resolve(unwrapped.operand);
const constituents = try self.gpa.alloc(IdRef, result_ty.vectorLen(zcu));
defer self.gpa.free(constituents);
const constituents = try gpa.alloc(IdRef, mask.len);
defer gpa.free(constituents);
for (constituents, 0..) |*id, i| {
const elem = try mask.elemValue(pt, i);
if (elem.isUndef(zcu)) {
id.* = try self.spv.constUndef(scalar_ty_id);
continue;
}
const index = elem.toSignedInt(zcu);
if (index >= 0) {
id.* = try self.extractVectorComponent(scalar_ty, a, @intCast(index));
} else {
id.* = try self.extractVectorComponent(scalar_ty, b, @intCast(~index));
}
for (constituents, mask) |*id, mask_elem| {
id.* = switch (mask_elem.unwrap()) {
.elem => |idx| try ng.extractVectorComponent(elem_ty, operand, idx),
.value => |val| try ng.constant(elem_ty, .fromInterned(val), .direct),
};
}
const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, constituents);
const result_ty_id = try ng.resolveType(result_ty, .direct);
return try ng.constructComposite(result_ty_id, constituents);
}
fn airShuffleTwo(ng: *NavGen, inst: Air.Inst.Index) !?IdRef {
const pt = ng.pt;
const zcu = pt.zcu;
const gpa = zcu.gpa;
const unwrapped = ng.air.unwrapShuffleTwo(zcu, inst);
const mask = unwrapped.mask;
const result_ty = unwrapped.result_ty;
const elem_ty = result_ty.childType(zcu);
const elem_ty_id = try ng.resolveType(elem_ty, .direct);
const operand_a = try ng.resolve(unwrapped.operand_a);
const operand_b = try ng.resolve(unwrapped.operand_b);
const constituents = try gpa.alloc(IdRef, mask.len);
defer gpa.free(constituents);
for (constituents, mask) |*id, mask_elem| {
id.* = switch (mask_elem.unwrap()) {
.a_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_a, idx),
.b_elem => |idx| try ng.extractVectorComponent(elem_ty, operand_b, idx),
.undef => try ng.spv.constUndef(elem_ty_id),
};
}
const result_ty_id = try ng.resolveType(result_ty, .direct);
return try ng.constructComposite(result_ty_id, constituents);
}
fn indicesToIds(self: *NavGen, indices: []const u32) ![]IdRef {