commit 88e98a0611b9fb41c1da026febac2467548bb129 (tree)
parent bae35bdf2d8919b60dee9a0af3afbdd93dd72b59
Author: Andrew Kelley <andrew@ziglang.org>
Date: Sat, 26 Mar 2022 00:33:22 -0400
Merge pull request #11289 from schmee/stage2-select
stage2: implement `@select`
Diffstat:
12 files changed, 239 insertions(+), 23 deletions(-)
diff --git a/src/Air.zig b/src/Air.zig
@@ -344,7 +344,7 @@ pub const Inst = struct {
/// to the storage for the variable. The local may be a const or a var.
/// Result type is always void.
/// Uses `pl_op`. The payload index is the variable name. It points to the extra
- /// array, reinterpreting the bytes there as a null-terminated string.
+ /// array, reinterpreting the bytes there as a null-terminated string.
dbg_var_ptr,
/// Same as `dbg_var_ptr` except the local is a const, not a var, and the
/// operand is the local's value.
@@ -553,6 +553,9 @@ pub const Inst = struct {
/// Constructs a vector by selecting elements from `a` and `b` based on `mask`.
/// Uses the `ty_pl` field with payload `Shuffle`.
shuffle,
+ /// Constructs a vector element-wise from `a` or `b` based on `pred`.
+ /// Uses the `pl_op` field with `pred` as operand, and payload `Bin`.
+ select,
/// Given dest ptr, value, and len, set all elements at dest to value.
/// Result type is always void.
@@ -1067,6 +1070,10 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type {
.reduce => return air.typeOf(datas[inst].reduce.operand).childType(),
.mul_add => return air.typeOf(datas[inst].pl_op.operand),
+ .select => {
+ const extra = air.extraData(Air.Bin, datas[inst].pl_op.payload).data;
+ return air.typeOf(extra.lhs);
+ },
.add_with_overflow,
.sub_with_overflow,
diff --git a/src/Liveness.zig b/src/Liveness.zig
@@ -433,6 +433,11 @@ fn analyzeInst(
}
return extra_tombs.finish();
},
+ .select => {
+ const pl_op = inst_datas[inst].pl_op;
+ const extra = a.air.extraData(Air.Bin, pl_op.payload).data;
+ return trackOperands(a, new_set, inst, main_tomb, .{ pl_op.operand, extra.lhs, extra.rhs });
+ },
.shuffle => {
const extra = a.air.extraData(Air.Shuffle, inst_datas[inst].ty_pl.payload).data;
return trackOperands(a, new_set, inst, main_tomb, .{ extra.a, extra.b, .none });
diff --git a/src/Sema.zig b/src/Sema.zig
@@ -14890,8 +14890,91 @@ fn analyzeShuffle(
fn zirSelect(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[inst].pl_node;
- const src = inst_data.src();
- return sema.fail(block, src, "TODO: Sema.zirSelect", .{});
+ const extra = sema.code.extraData(Zir.Inst.Select, inst_data.payload_index).data;
+ const target = sema.mod.getTarget();
+
+ const elem_ty_src: LazySrcLoc = .{ .node_offset_builtin_call_arg0 = inst_data.src_node };
+ const pred_src: LazySrcLoc = .{ .node_offset_builtin_call_arg1 = inst_data.src_node };
+ const a_src: LazySrcLoc = .{ .node_offset_builtin_call_arg2 = inst_data.src_node };
+ const b_src: LazySrcLoc = .{ .node_offset_builtin_call_arg3 = inst_data.src_node };
+
+ const elem_ty = try sema.resolveType(block, elem_ty_src, extra.elem_type);
+ try sema.checkVectorElemType(block, elem_ty_src, elem_ty);
+ const pred_uncoerced = sema.resolveInst(extra.pred);
+ const pred_ty = sema.typeOf(pred_uncoerced);
+
+ const vec_len_u64 = switch (try pred_ty.zigTypeTagOrPoison()) {
+ .Vector, .Array => pred_ty.arrayLen(),
+ else => return sema.fail(block, pred_src, "expected vector or array, found '{}'", .{pred_ty.fmt(target)}),
+ };
+ const vec_len = try sema.usizeCast(block, pred_src, vec_len_u64);
+
+ const bool_vec_ty = try Type.vector(sema.arena, vec_len, Type.bool);
+ const pred = try sema.coerce(block, bool_vec_ty, pred_uncoerced, pred_src);
+
+ const vec_ty = try Type.vector(sema.arena, vec_len, elem_ty);
+ const a = try sema.coerce(block, vec_ty, sema.resolveInst(extra.a), a_src);
+ const b = try sema.coerce(block, vec_ty, sema.resolveInst(extra.b), b_src);
+
+ const maybe_pred = try sema.resolveMaybeUndefVal(block, pred_src, pred);
+ const maybe_a = try sema.resolveMaybeUndefVal(block, a_src, a);
+ const maybe_b = try sema.resolveMaybeUndefVal(block, b_src, b);
+
+ const runtime_src = if (maybe_pred) |pred_val| rs: {
+ if (pred_val.isUndef()) return sema.addConstUndef(vec_ty);
+
+ if (maybe_a) |a_val| {
+ if (a_val.isUndef()) return sema.addConstUndef(vec_ty);
+
+ if (maybe_b) |b_val| {
+ if (b_val.isUndef()) return sema.addConstUndef(vec_ty);
+
+ var buf: Value.ElemValueBuffer = undefined;
+ const elems = try sema.gpa.alloc(Value, vec_len);
+ for (elems) |*elem, i| {
+ const pred_elem_val = pred_val.elemValueBuffer(i, &buf);
+ const should_choose_a = pred_elem_val.toBool();
+ if (should_choose_a) {
+ elem.* = a_val.elemValueBuffer(i, &buf);
+ } else {
+ elem.* = b_val.elemValueBuffer(i, &buf);
+ }
+ }
+
+ return sema.addConstant(
+ vec_ty,
+ try Value.Tag.aggregate.create(sema.arena, elems),
+ );
+ } else {
+ break :rs b_src;
+ }
+ } else {
+ if (maybe_b) |b_val| {
+ if (b_val.isUndef()) return sema.addConstUndef(vec_ty);
+ }
+ break :rs a_src;
+ }
+ } else rs: {
+ if (maybe_a) |a_val| {
+ if (a_val.isUndef()) return sema.addConstUndef(vec_ty);
+ }
+ if (maybe_b) |b_val| {
+ if (b_val.isUndef()) return sema.addConstUndef(vec_ty);
+ }
+ break :rs pred_src;
+ };
+
+ try sema.requireRuntimeBlock(block, runtime_src);
+ return block.addInst(.{
+ .tag = .select,
+ .data = .{ .pl_op = .{
+ .operand = pred,
+ .payload = try block.sema.addExtra(Air.Bin{
+ .lhs = a,
+ .rhs = b,
+ }),
+ } },
+ });
}
fn zirAtomicLoad(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig
@@ -633,6 +633,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
+ .select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.reduce => try self.airReduce(inst),
.aggregate_init => try self.airAggregateInit(inst),
@@ -3666,6 +3667,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void {
return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
}
+fn airSelect(self: *Self, inst: Air.Inst.Index) !void {
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+ const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for {}", .{self.target.cpu.arch});
+ return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs });
+}
+
fn airShuffle(self: *Self, inst: Air.Inst.Index) !void {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for {}", .{self.target.cpu.arch});
diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig
@@ -630,6 +630,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
+ .select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.reduce => try self.airReduce(inst),
.aggregate_init => try self.airAggregateInit(inst),
@@ -4323,6 +4324,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void {
return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
}
+fn airSelect(self: *Self, inst: Air.Inst.Index) !void {
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+ const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for arm", .{});
+ return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs });
+}
+
fn airShuffle(self: *Self, inst: Air.Inst.Index) !void {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for arm", .{});
diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig
@@ -600,6 +600,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
+ .select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.reduce => try self.airReduce(inst),
.aggregate_init => try self.airAggregateInit(inst),
@@ -2396,6 +2397,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void {
return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
}
+fn airSelect(self: *Self, inst: Air.Inst.Index) !void {
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+ const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for riscv64", .{});
+ return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs });
+}
+
fn airShuffle(self: *Self, inst: Air.Inst.Index) !void {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for riscv64", .{});
diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig
@@ -1371,6 +1371,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue {
.ret_ptr => self.airRetPtr(inst),
.ret_load => self.airRetLoad(inst),
.splat => self.airSplat(inst),
+ .select => self.airSelect(inst),
.shuffle => self.airShuffle(inst),
.reduce => self.airReduce(inst),
.aggregate_init => self.airAggregateInit(inst),
@@ -3265,6 +3266,16 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
return self.fail("TODO: Implement wasm airSplat", .{});
}
+fn airSelect(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
+ if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
+
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const operand = try self.resolveInst(pl_op.operand);
+
+ _ = operand;
+ return self.fail("TODO: Implement wasm airSelect", .{});
+}
+
fn airShuffle(self: *Self, inst: Air.Inst.Index) InnerError!WValue {
if (self.liveness.isUnused(inst)) return WValue{ .none = {} };
diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig
@@ -714,6 +714,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void {
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
+ .select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.reduce => try self.airReduce(inst),
.aggregate_init => try self.airAggregateInit(inst),
@@ -5624,6 +5625,13 @@ fn airSplat(self: *Self, inst: Air.Inst.Index) !void {
return self.finishAir(inst, result, .{ ty_op.operand, .none, .none });
}
+fn airSelect(self: *Self, inst: Air.Inst.Index) !void {
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+ const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airSelect for x86_64", .{});
+ return self.finishAir(inst, result, .{ pl_op.operand, extra.lhs, extra.rhs });
+}
+
fn airShuffle(self: *Self, inst: Air.Inst.Index) !void {
const ty_op = self.air.instructions.items(.data)[inst].ty_op;
const result: MCValue = if (self.liveness.isUnused(inst)) .dead else return self.fail("TODO implement airShuffle for x86_64", .{});
diff --git a/src/codegen/c.zig b/src/codegen/c.zig
@@ -1825,6 +1825,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO
.tag_name => try airTagName(f, inst),
.error_name => try airErrorName(f, inst),
.splat => try airSplat(f, inst),
+ .select => try airSelect(f, inst),
.shuffle => try airShuffle(f, inst),
.reduce => try airReduce(f, inst),
.aggregate_init => try airAggregateInit(f, inst),
@@ -3794,6 +3795,21 @@ fn airSplat(f: *Function, inst: Air.Inst.Index) !CValue {
return f.fail("TODO: C backend: implement airSplat", .{});
}
+fn airSelect(f: *Function, inst: Air.Inst.Index) !CValue {
+ if (f.liveness.isUnused(inst)) return CValue.none;
+
+ const inst_ty = f.air.typeOfIndex(inst);
+ const ty_pl = f.air.instructions.items(.data)[inst].ty_pl;
+
+ const writer = f.object.writer();
+ const local = try f.allocLocal(inst_ty, .Const);
+ try writer.writeAll(" = ");
+
+ _ = local;
+ _ = ty_pl;
+ return f.fail("TODO: C backend: implement airSelect", .{});
+}
+
fn airShuffle(f: *Function, inst: Air.Inst.Index) !CValue {
if (f.liveness.isUnused(inst)) return CValue.none;
diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig
@@ -3444,6 +3444,7 @@ pub const FuncGen = struct {
.tag_name => try self.airTagName(inst),
.error_name => try self.airErrorName(inst),
.splat => try self.airSplat(inst),
+ .select => try self.airSelect(inst),
.shuffle => try self.airShuffle(inst),
.reduce => try self.airReduce(inst),
.aggregate_init => try self.airAggregateInit(inst),
@@ -6355,6 +6356,18 @@ pub const FuncGen = struct {
return self.builder.buildShuffleVector(op_vector, undef_vector, mask_llvm_ty.constNull(), "");
}
+ fn airSelect(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
+ if (self.liveness.isUnused(inst)) return null;
+
+ const pl_op = self.air.instructions.items(.data)[inst].pl_op;
+ const extra = self.air.extraData(Air.Bin, pl_op.payload).data;
+ const pred = try self.resolveInst(pl_op.operand);
+ const a = try self.resolveInst(extra.lhs);
+ const b = try self.resolveInst(extra.rhs);
+
+ return self.builder.buildSelect(pred, a, b, "");
+ }
+
fn airShuffle(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value {
if (self.liveness.isUnused(inst)) return null;
diff --git a/src/print_air.zig b/src/print_air.zig
@@ -264,6 +264,7 @@ const Writer = struct {
.wasm_memory_size => try w.writeWasmMemorySize(s, inst),
.wasm_memory_grow => try w.writeWasmMemoryGrow(s, inst),
.mul_add => try w.writeMulAdd(s, inst),
+ .select => try w.writeSelect(s, inst),
.shuffle => try w.writeShuffle(s, inst),
.reduce => try w.writeReduce(s, inst),
.cmp_vector => try w.writeCmpVector(s, inst),
@@ -396,6 +397,19 @@ const Writer = struct {
try s.print(", mask {d}, len {d}", .{ extra.mask, extra.mask_len });
}
+ fn writeSelect(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void {
+ const pl_op = w.air.instructions.items(.data)[inst].pl_op;
+ const extra = w.air.extraData(Air.Bin, pl_op.payload).data;
+
+ const elem_ty = w.air.typeOfIndex(inst).childType();
+ try s.print("{}, ", .{elem_ty.fmtDebug()});
+ try w.writeOperand(s, inst, 0, pl_op.operand);
+ try s.writeAll(", ");
+ try w.writeOperand(s, inst, 1, extra.lhs);
+ try s.writeAll(", ");
+ try w.writeOperand(s, inst, 2, extra.rhs);
+ }
+
fn writeReduce(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void {
const reduce = w.air.instructions.items(.data)[inst].reduce;
diff --git a/test/behavior/select.zig b/test/behavior/select.zig
@@ -3,24 +3,59 @@ const builtin = @import("builtin");
const mem = std.mem;
const expect = std.testing.expect;
-test "@select" {
- if (@import("builtin").zig_backend != .stage1) return error.SkipZigTest; // TODO
-
- const S = struct {
- fn doTheTest() !void {
- var a: @Vector(4, bool) = [4]bool{ true, false, true, false };
- var b: @Vector(4, i32) = [4]i32{ -1, 4, 999, -31 };
- var c: @Vector(4, i32) = [4]i32{ -5, 1, 0, 1234 };
- var abc = @select(i32, a, b, c);
- try expect(mem.eql(i32, &@as([4]i32, abc), &[4]i32{ -1, 1, 999, 1234 }));
-
- var x: @Vector(4, bool) = [4]bool{ false, false, false, true };
- var y: @Vector(4, f32) = [4]f32{ 0.001, 33.4, 836, -3381.233 };
- var z: @Vector(4, f32) = [4]f32{ 0.0, 312.1, -145.9, 9993.55 };
- var xyz = @select(f32, x, y, z);
- try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
- }
- };
- try S.doTheTest();
- comptime try S.doTheTest();
+test "@select vectors" {
+ if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+ comptime try selectVectors();
+ try selectVectors();
+}
+
+fn selectVectors() !void {
+ var a = @Vector(4, bool){ true, false, true, false };
+ var b = @Vector(4, i32){ -1, 4, 999, -31 };
+ var c = @Vector(4, i32){ -5, 1, 0, 1234 };
+ var abc = @select(i32, a, b, c);
+ try expect(abc[0] == -1);
+ try expect(abc[1] == 1);
+ try expect(abc[2] == 999);
+ try expect(abc[3] == 1234);
+
+ var x = @Vector(4, bool){ false, false, false, true };
+ var y = @Vector(4, f32){ 0.001, 33.4, 836, -3381.233 };
+ var z = @Vector(4, f32){ 0.0, 312.1, -145.9, 9993.55 };
+ var xyz = @select(f32, x, y, z);
+ try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
+}
+
+test "@select arrays" {
+ if (builtin.zig_backend == .stage1) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
+ if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
+
+ comptime try selectArrays();
+ try selectArrays();
+}
+
+fn selectArrays() !void {
+ var a = [4]bool{ false, true, false, true };
+ var b = [4]usize{ 0, 1, 2, 3 };
+ var c = [4]usize{ 4, 5, 6, 7 };
+ var abc = @select(usize, a, b, c);
+ try expect(abc[0] == 4);
+ try expect(abc[1] == 1);
+ try expect(abc[2] == 6);
+ try expect(abc[3] == 3);
+
+ var x = [4]bool{ false, false, false, true };
+ var y = [4]f32{ 0.001, 33.4, 836, -3381.233 };
+ var z = [4]f32{ 0.0, 312.1, -145.9, 9993.55 };
+ var xyz = @select(f32, x, y, z);
+ try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
}