add inline switch union tag captures

This commit is contained in:
Veikka Tuominen
2022-09-26 17:30:24 +03:00
parent 5baaf90e3c
commit 0e77259f44
12 changed files with 186 additions and 59 deletions

View File

@@ -3100,7 +3100,7 @@ const Parser = struct {
return identifier;
}
/// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr
/// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr
/// SwitchCase
/// <- SwitchItem (COMMA SwitchItem)* COMMA?
/// / KEYWORD_else
@@ -3123,7 +3123,7 @@ const Parser = struct {
}
}
const arrow_token = try p.expectToken(.equal_angle_bracket_right);
_ = try p.parsePtrPayload();
_ = try p.parsePtrIndexPayload();
const items = p.scratch.items[scratch_top..];
switch (items.len) {

View File

@@ -3276,6 +3276,8 @@ test "zig fmt: switch" {
\\ switch (u) {
\\ Union.Int => |int| {},
\\ Union.Float => |*float| unreachable,
\\ 1 => |a, b| unreachable,
\\ 2 => |*a, b| unreachable,
\\ }
\\}
\\

View File

@@ -1541,13 +1541,17 @@ fn renderSwitchCase(
if (switch_case.payload_token) |payload_token| {
try renderToken(ais, tree, payload_token - 1, .none); // pipe
const ident = payload_token + @boolToInt(token_tags[payload_token] == .asterisk);
if (token_tags[payload_token] == .asterisk) {
try renderToken(ais, tree, payload_token, .none); // asterisk
try renderToken(ais, tree, payload_token + 1, .none); // identifier
try renderToken(ais, tree, payload_token + 2, pre_target_space); // pipe
}
try renderToken(ais, tree, ident, .none); // identifier
if (token_tags[ident + 1] == .comma) {
try renderToken(ais, tree, ident + 1, .space); // ,
try renderToken(ais, tree, ident + 2, .none); // identifier
try renderToken(ais, tree, ident + 3, pre_target_space); // pipe
} else {
try renderToken(ais, tree, payload_token, .none); // identifier
try renderToken(ais, tree, payload_token + 1, pre_target_space); // pipe
try renderToken(ais, tree, ident + 1, pre_target_space); // pipe
}
}

View File

@@ -2373,6 +2373,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
.switch_capture_ref,
.switch_capture_multi,
.switch_capture_multi_ref,
.switch_capture_tag,
.struct_init_empty,
.struct_init,
.struct_init_ref,
@@ -6378,8 +6379,12 @@ fn switchExpr(
var dbg_var_name: ?u32 = null;
var dbg_var_inst: Zir.Inst.Ref = undefined;
var dbg_var_tag_name: ?u32 = null;
var dbg_var_tag_inst: Zir.Inst.Ref = undefined;
var capture_inst: Zir.Inst.Index = 0;
var tag_inst: Zir.Inst.Index = 0;
var capture_val_scope: Scope.LocalVal = undefined;
var tag_scope: Scope.LocalVal = undefined;
const sub_scope = blk: {
const payload_token = case.payload_token orelse break :blk &case_scope.base;
const ident = if (token_tags[payload_token] == .asterisk)
@@ -6387,59 +6392,96 @@ fn switchExpr(
else
payload_token;
const is_ptr = ident != payload_token;
if (mem.eql(u8, tree.tokenSlice(ident), "_")) {
const ident_slice = tree.tokenSlice(ident);
var payload_sub_scope: *Scope = undefined;
if (mem.eql(u8, ident_slice, "_")) {
if (is_ptr) {
return astgen.failTok(payload_token, "pointer modifier invalid on discard", .{});
}
break :blk &case_scope.base;
}
if (case_node == special_node) {
const capture_tag: Zir.Inst.Tag = if (is_ptr)
.switch_capture_ref
else
.switch_capture;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{
.switch_capture = .{
.switch_inst = switch_block,
// Max int communicates that this is the else/underscore prong.
.prong_index = std.math.maxInt(u32),
},
},
});
payload_sub_scope = &case_scope.base;
} else {
const is_multi_case_bits: u2 = @boolToInt(is_multi_case);
const is_ptr_bits: u2 = @boolToInt(is_ptr);
const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) {
0b00 => .switch_capture,
0b01 => .switch_capture_ref,
0b10 => .switch_capture_multi,
0b11 => .switch_capture_multi_ref,
if (case_node == special_node) {
const capture_tag: Zir.Inst.Tag = if (is_ptr)
.switch_capture_ref
else
.switch_capture;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{
.switch_capture = .{
.switch_inst = switch_block,
// Max int communicates that this is the else/underscore prong.
.prong_index = std.math.maxInt(u32),
},
},
});
} else {
const is_multi_case_bits: u2 = @boolToInt(is_multi_case);
const is_ptr_bits: u2 = @boolToInt(is_ptr);
const capture_tag: Zir.Inst.Tag = switch ((is_multi_case_bits << 1) | is_ptr_bits) {
0b00 => .switch_capture,
0b01 => .switch_capture_ref,
0b10 => .switch_capture_multi,
0b11 => .switch_capture_multi_ref,
};
const capture_index = if (is_multi_case) multi_case_index else scalar_case_index;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{ .switch_capture = .{
.switch_inst = switch_block,
.prong_index = capture_index,
} },
});
}
const capture_name = try astgen.identAsString(ident);
try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice);
capture_val_scope = .{
.parent = &case_scope.base,
.gen_zir = &case_scope,
.name = capture_name,
.inst = indexToRef(capture_inst),
.token_src = payload_token,
.id_cat = .@"capture",
};
const capture_index = if (is_multi_case) multi_case_index else scalar_case_index;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{ .switch_capture = .{
.switch_inst = switch_block,
.prong_index = capture_index,
} },
});
dbg_var_name = capture_name;
dbg_var_inst = indexToRef(capture_inst);
payload_sub_scope = &capture_val_scope.base;
}
const capture_name = try astgen.identAsString(ident);
capture_val_scope = .{
.parent = &case_scope.base,
const tag_token = if (token_tags[ident + 1] == .comma)
ident + 2
else
break :blk payload_sub_scope;
const tag_slice = tree.tokenSlice(tag_token);
if (mem.eql(u8, tag_slice, "_")) {
return astgen.failTok(tag_token, "discard of tag capture; omit it instead", .{});
} else if (case.inline_token == null) {
return astgen.failTok(tag_token, "tag capture on non-inline prong", .{});
}
const tag_name = try astgen.identAsString(tag_token);
try astgen.detectLocalShadowing(payload_sub_scope, tag_name, tag_token, tag_slice);
tag_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = .switch_capture_tag,
.data = .{ .un_tok = .{
.operand = cond,
.src_tok = case_scope.tokenIndexToRelative(tag_token),
} },
});
tag_scope = .{
.parent = payload_sub_scope,
.gen_zir = &case_scope,
.name = capture_name,
.inst = indexToRef(capture_inst),
.token_src = payload_token,
.id_cat = .@"capture",
.name = tag_name,
.inst = indexToRef(tag_inst),
.token_src = tag_token,
.id_cat = .@"switch tag capture",
};
dbg_var_name = capture_name;
dbg_var_inst = indexToRef(capture_inst);
break :blk &capture_val_scope.base;
dbg_var_tag_name = tag_name;
dbg_var_tag_inst = indexToRef(tag_inst);
break :blk &tag_scope.base;
};
const header_index = @intCast(u32, payloads.items.len);
@@ -6494,10 +6536,14 @@ fn switchExpr(
defer case_scope.unstack();
if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst);
if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst);
try case_scope.addDbgBlockBegin();
if (dbg_var_name) |some| {
try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_inst);
}
if (dbg_var_tag_name) |some| {
try case_scope.addDbgVar(.dbg_var_val, some, dbg_var_tag_inst);
}
const case_result = try expr(&case_scope, sub_scope, block_scope.break_result_loc, case.ast.target_expr);
try checkUsed(parent_gz, &case_scope.base, sub_scope);
try case_scope.addDbgBlockEnd();
@@ -10073,6 +10119,7 @@ const Scope = struct {
@"local constant",
@"local variable",
@"loop index capture",
@"switch tag capture",
@"capture",
};

View File

@@ -799,6 +799,7 @@ fn analyzeBodyInner(
.switch_capture_ref => try sema.zirSwitchCapture(block, inst, false, true),
.switch_capture_multi => try sema.zirSwitchCapture(block, inst, true, false),
.switch_capture_multi_ref => try sema.zirSwitchCapture(block, inst, true, true),
.switch_capture_tag => try sema.zirSwitchCaptureTag(block, inst),
.type_info => try sema.zirTypeInfo(block, inst),
.size_of => try sema.zirSizeOf(block, inst),
.bit_size_of => try sema.zirBitSizeOf(block, inst),
@@ -9164,6 +9165,33 @@ fn zirSwitchCapture(
}
}
fn zirSwitchCaptureTag(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const zir_datas = sema.code.instructions.items(.data);
const inst_data = zir_datas[inst].un_tok;
const src = inst_data.src();
const switch_tag = sema.code.instructions.items(.tag)[Zir.refToIndex(inst_data.operand).?];
const is_ref = switch_tag == .switch_cond_ref;
const cond_data = zir_datas[Zir.refToIndex(inst_data.operand).?].un_node;
const operand_ptr = try sema.resolveInst(cond_data.operand);
const operand_ptr_ty = sema.typeOf(operand_ptr);
const operand_ty = if (is_ref) operand_ptr_ty.childType() else operand_ptr_ty;
if (operand_ty.zigTypeTag() != .Union) {
const msg = msg: {
const msg = try sema.errMsg(block, src, "cannot capture tag of non-union type '{}'", .{
operand_ty.fmt(sema.mod),
});
errdefer msg.destroy(sema.gpa);
try sema.addDeclaredHereNote(msg, operand_ty);
break :msg msg;
};
return sema.failWithOwnedErrorMsg(msg);
}
return block.inline_case_capture;
}
fn zirSwitchCond(
sema: *Sema,
block: *Block,

View File

@@ -683,6 +683,9 @@ pub const Inst = struct {
/// Result is a pointer to the value.
/// Uses the `switch_capture` field.
switch_capture_multi_ref,
/// Produces the capture value for an inline switch prong tag capture.
/// Uses the `un_tok` field.
switch_capture_tag,
/// Given a
/// *A returns *A
/// *E!A returns *A
@@ -1128,6 +1131,7 @@ pub const Inst = struct {
.switch_capture_ref,
.switch_capture_multi,
.switch_capture_multi_ref,
.switch_capture_tag,
.switch_block,
.switch_cond,
.switch_cond_ref,
@@ -1422,6 +1426,7 @@ pub const Inst = struct {
.switch_capture_ref,
.switch_capture_multi,
.switch_capture_multi_ref,
.switch_capture_tag,
.switch_block,
.switch_cond,
.switch_cond_ref,
@@ -1681,6 +1686,7 @@ pub const Inst = struct {
.switch_capture_ref = .switch_capture,
.switch_capture_multi = .switch_capture,
.switch_capture_multi_ref = .switch_capture,
.switch_capture_tag = .un_tok,
.array_base_ptr = .un_node,
.field_base_ptr = .un_node,
.validate_array_init_ty = .pl_node,

View File

@@ -2159,7 +2159,7 @@ const RegisterOrMemory = union(enum) {
/// Returns size in bits.
fn size(reg_or_mem: RegisterOrMemory) u64 {
return switch (reg_or_mem) {
.register => |reg| reg.size(),
.register => |register| register.size(),
.memory => |memory| memory.size(),
};
}

View File

@@ -237,6 +237,7 @@ const Writer = struct {
.ret_tok,
.ensure_err_payload_void,
.closure_capture,
.switch_capture_tag,
=> try self.writeUnTok(stream, inst),
.bool_br_and,

View File

@@ -2306,17 +2306,17 @@ static Optional<PtrIndexPayload> ast_parse_ptr_index_payload(ParseContext *pc) {
return Optional<PtrIndexPayload>::some(res);
}
// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrPayload? AssignExpr
// SwitchProng <- KEYWORD_inline? SwitchCase EQUALRARROW PtrIndexPayload? AssignExpr
static AstNode *ast_parse_switch_prong(ParseContext *pc) {
AstNode *res = ast_parse_switch_case(pc);
if (res == nullptr)
return nullptr;
expect_token(pc, TokenIdFatArrow);
Optional<PtrPayload> opt_payload = ast_parse_ptr_payload(pc);
Optional<PtrIndexPayload> opt_payload = ast_parse_ptr_index_payload(pc);
AstNode *expr = ast_expect(pc, ast_parse_assign_expr);
PtrPayload payload;
PtrIndexPayload payload;
assert(res->type == NodeTypeSwitchProng);
res->data.switch_prong.expr = expr;
if (opt_payload.unwrap(&payload)) {

View File

@@ -47,11 +47,21 @@ test "inline switch unions" {
var x: U = .a;
switch (x) {
inline .a, .b => |aorb| {
try expect(@TypeOf(aorb) == void or @TypeOf(aorb) == u2);
inline .a, .b => |aorb, tag| {
if (tag == .a) {
try expect(@TypeOf(aorb) == void);
} else {
try expect(tag == .b);
try expect(@TypeOf(aorb) == u2);
}
},
inline .c, .d => |cord| {
try expect(@TypeOf(cord) == u3 or @TypeOf(cord) == u4);
inline .c, .d => |cord, tag| {
if (tag == .c) {
try expect(@TypeOf(cord) == u3);
} else {
try expect(tag == .d);
try expect(@TypeOf(cord) == u4);
}
},
}
}

View File

@@ -0,0 +1,15 @@
const E = enum { a, b, c, d };
pub export fn entry() void {
var x: E = .a;
switch (x) {
inline .a, .b => |aorb, d| @compileLog(aorb, d),
inline .c, .d => |*cord| @compileLog(cord),
}
}
// error
// backend=stage2
// target=native
//
// :5:33: error: cannot capture tag of non-union type 'tmp.E'
// :1:11: note: enum declared here

View File

@@ -0,0 +1,14 @@
const E = enum { a, b, c, d };
pub export fn entry() void {
var x: E = .a;
switch (x) {
.a, .b => |aorb, d| @compileLog(aorb, d),
inline .c, .d => |*cord| @compileLog(cord),
}
}
// error
// backend=stage2
// target=native
//
// :5:26: error: tag capture on non-inline prong