From 623d5f442c832ec0ea2a07aba73b8e2eae57191c Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Sun, 28 Mar 2021 23:12:26 -0700 Subject: [PATCH] stage2: guidance on how to implement switch expressions Here's what I think the ZIR should be. AstGen is not yet implemented to match this, and the main implementation of analyzeSwitch in Sema is not yet implemented to match it either. Here are some example byte size reductions from master branch, with the ZIR memory layout from this commit: ``` switch (foo) { a => 1, b => 2, c => 3, d => 4, } ``` 184 bytes (master) => 40 bytes (this branch) ``` switch (foo) { a, b => 1, c..d, e, f => 2, g => 3, else => 4, } ``` 240 bytes (master) => 80 bytes (this branch) --- src/AstGen.zig | 19 ++++-- src/Module.zig | 9 +++ src/Sema.zig | 150 +++++++++++++++++++++++++++++------------------ src/zir.zig | 154 ++++++++++++++++++++++++++++++++++++------------- 4 files changed, 230 insertions(+), 102 deletions(-) diff --git a/src/AstGen.zig b/src/AstGen.zig index b904d58cd5..d91a0966ef 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -1257,6 +1257,18 @@ fn blockExprStmts( .break_inline, .condbr, .condbr_inline, + .switch_br, + .switch_br_range, + .switch_br_else, + .switch_br_else_range, + .switch_br_underscore, + .switch_br_underscore_range, + .switch_br_ref, + .switch_br_ref_range, + .switch_br_ref_else, + .switch_br_ref_else_range, + .switch_br_ref_underscore, + .switch_br_ref_underscore_range, .compile_error, .ret_node, .ret_tok, @@ -2536,20 +2548,19 @@ fn switchExpr( rl: ResultLoc, switch_node: ast.Node.Index, ) InnerError!zir.Inst.Ref { - if (true) @panic("TODO update for zir-memory-layout"); const tree = parent_gz.tree(); const node_datas = tree.nodes.items(.data); const main_tokens = tree.nodes.items(.main_token); const token_tags = tree.tokens.items(.tag); const node_tags = tree.nodes.items(.tag); + if (true) @panic("TODO rework for zir-memory-layout branch"); + const switch_token = main_tokens[switch_node]; const target_node = node_datas[switch_node].lhs; const extra = tree.extraData(node_datas[switch_node].rhs, ast.Node.SubRange); const case_nodes = tree.extra_data[extra.start..extra.end]; - const switch_src = token_starts[switch_token]; - var block_scope: GenZir = .{ .parent = scope, .decl = scope.ownerDecl().?, @@ -2627,7 +2638,7 @@ fn switchExpr( const msg = msg: { const msg = try mod.errMsg( scope, - switch_src, + parent_gz.nodeSrcLoc(switch_node), "else and '_' prong in switch expression", .{}, ); diff --git a/src/Module.zig b/src/Module.zig index c92df1aaae..3309a10b30 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -1532,6 +1532,7 @@ pub const SrcLoc = struct { .node_offset_bin_op, .node_offset_bin_lhs, .node_offset_bin_rhs, + .node_offset_switch_operand, => src_loc.container.decl.container.file_scope, }; } @@ -1663,6 +1664,7 @@ pub const SrcLoc = struct { const token_starts = tree.tokens.items(.start); return token_starts[tok_index]; }, + .node_offset_switch_operand => @panic("TODO"), } } }; @@ -1795,6 +1797,11 @@ pub const LazySrcLoc = union(enum) { /// which points to a binary expression AST node. Next, nagivate to the RHS. /// The Decl is determined contextually. node_offset_bin_rhs: i32, + /// The source location points to the operand of a switch expression, found + /// by taking this AST node index offset from the containing Decl AST node, + /// which points to a switch expression AST node. Next, nagivate to the operand. + /// The Decl is determined contextually. + node_offset_switch_operand: i32, /// Upgrade to a `SrcLoc` based on the `Decl` or file in the provided scope. pub fn toSrcLoc(lazy: LazySrcLoc, scope: *Scope) SrcLoc { @@ -1828,6 +1835,7 @@ pub const LazySrcLoc = union(enum) { .node_offset_bin_op, .node_offset_bin_lhs, .node_offset_bin_rhs, + .node_offset_switch_operand, => .{ .container = .{ .decl = scope.srcDecl().? }, .lazy = lazy, @@ -1867,6 +1875,7 @@ pub const LazySrcLoc = union(enum) { .node_offset_bin_op, .node_offset_bin_lhs, .node_offset_bin_rhs, + .node_offset_switch_operand, => .{ .container = .{ .decl = decl }, .lazy = lazy, diff --git a/src/Sema.zig b/src/Sema.zig index c18c472930..20f7c9d9ca 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -229,10 +229,6 @@ pub fn analyzeBody( .typeof => try sema.zirTypeof(block, inst), .typeof_peer => try sema.zirTypeofPeer(block, inst), .xor => try sema.zirBitwise(block, inst, .xor), - // TODO - //.switchbr => try sema.zirSwitchBr(block, inst, false), - //.switchbr_ref => try sema.zirSwitchBr(block, inst, true), - //.switch_range => try sema.zirSwitchRange(block, inst), // Instructions that we know to *always* be noreturn based solely on their tag. // These functions match the return type of analyzeBody so that we can @@ -246,6 +242,18 @@ pub fn analyzeBody( .ret_tok => return sema.zirRetTok(block, inst, false), .@"unreachable" => return sema.zirUnreachable(block, inst), .repeat => return sema.zirRepeat(block, inst), + .switch_br => return sema.zirSwitchBr(block, inst, false, .full), + .switch_br_range => return sema.zirSwitchBrRange(block, inst, false, .full), + .switch_br_else => return sema.zirSwitchBr(block, inst, false, .@"else"), + .switch_br_else_range => return sema.zirSwitchBrRange(block, inst, false, .@"else"), + .switch_br_underscore => return sema.zirSwitchBr(block, inst, false, .under), + .switch_br_underscore_range => return sema.zirSwitchBrRange(block, inst, false, .under), + .switch_br_ref => return sema.zirSwitchBr(block, inst, true, .full), + .switch_br_ref_range => return sema.zirSwitchBrRange(block, inst, true, .full), + .switch_br_ref_else => return sema.zirSwitchBr(block, inst, true, .@"else"), + .switch_br_ref_else_range => return sema.zirSwitchBrRange(block, inst, true, .@"else"), + .switch_br_ref_underscore => return sema.zirSwitchBr(block, inst, true, .under), + .switch_br_ref_underscore_range => return sema.zirSwitchBrRange(block, inst, true, .under), // Instructions that we know can *never* be noreturn based solely on // their tag. We avoid needlessly checking if they are noreturn and @@ -2197,54 +2205,82 @@ fn zirSliceSentinel(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) Inne return sema.analyzeSlice(block, src, array_ptr, start, end, sentinel, sentinel_src); } -fn zirSwitchRange(sema: *Sema, block: *Scope.Block, inst: zir.Inst.Index) InnerError!*Inst { - const tracy = trace(@src()); - defer tracy.end(); - - const src: LazySrcLoc = .todo; - const bin_inst = sema.code.instructions.items(.data)[inst].bin; - const start = try sema.resolveInst(bin_inst.lhs); - const end = try sema.resolveInst(bin_inst.rhs); - - switch (start.ty.zigTypeTag()) { - .Int, .ComptimeInt => {}, - else => return sema.mod.constVoid(sema.arena, .unneeded), - } - switch (end.ty.zigTypeTag()) { - .Int, .ComptimeInt => {}, - else => return sema.mod.constVoid(sema.arena, .unneeded), - } - // .switch_range must be inside a comptime scope - const start_val = start.value().?; - const end_val = end.value().?; - if (start_val.compare(.gte, end_val)) { - return sema.mod.fail(&block.base, src, "range start value must be smaller than the end value", .{}); - } - return sema.mod.constVoid(sema.arena, .unneeded); -} +const ElseProng = enum { full, @"else", under }; fn zirSwitchBr( sema: *Sema, - parent_block: *Scope.Block, + block: *Scope.Block, inst: zir.Inst.Index, - ref: bool, -) InnerError!zir.Inst.Ref { + is_ref: bool, + else_prong: ElseProng, +) InnerError!zir.Inst.Index { const tracy = trace(@src()); defer tracy.end(); - if (true) @panic("TODO rework with zir-memory-layout in mind"); + const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const src = inst_data.src(); + const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = inst_data.src_node }; + const extra = sema.code.extraData(zir.Inst.SwitchBr, inst_data.payload_index); - const target_ptr = try sema.resolveInst(inst.positionals.target); - const target = if (ref) - try sema.analyzeLoad(parent_block, inst.base.src, target_ptr, inst.positionals.target.src) + const operand_ptr = try sema.resolveInst(extra.data.operand); + const operand = if (is_ref) + try sema.analyzeLoad(block, src, operand_ptr, operand_src) else - target_ptr; - try sema.validateSwitch(parent_block, target, inst); + operand_ptr; - if (try sema.resolveDefinedValue(parent_block, inst.base.src, target)) |target_val| { + return sema.analyzeSwitch(block, operand, extra.end, else_prong, extra.data.cases_len, 0, 0); +} + +fn zirSwitchBrRange( + sema: *Sema, + block: *Scope.Block, + inst: zir.Inst.Index, + is_ref: bool, + else_prong: ElseProng, +) InnerError!zir.Inst.Index { + const tracy = trace(@src()); + defer tracy.end(); + + const inst_data = sema.code.instructions.items(.data)[inst].pl_node; + const src = inst_data.src(); + const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = inst_data.src_node }; + const extra = sema.code.extraData(zir.Inst.SwitchBrRange, inst_data.payload_index); + + const operand_ptr = try sema.resolveInst(extra.data.operand); + const operand = if (is_ref) + try sema.analyzeLoad(block, src, operand_ptr, operand_src) + else + operand_ptr; + + return sema.analyzeSwitch( + block, + operand, + extra.end, + else_prong, + extra.data.scalar_cases_len, + extra.data.multi_cases_len, + extra.data.range_cases_len, + ); +} + +fn analyzeSwitch( + sema: *Sema, + parent_block: *Scope.Block, + operand: *Inst, + extra_end: usize, + else_prong: ElseProng, + scalar_cases_len: usize, + multi_cases_len: usize, + range_cases_len: usize, +) InnerError!zir.Inst.Index { + if (true) @panic("TODO rework for zir-memory-layout branch"); + + try sema.validateSwitch(parent_block, operand, inst); + + if (try sema.resolveDefinedValue(parent_block, inst.base.src, operand)) |target_val| { for (inst.positionals.cases) |case| { const resolved = try sema.resolveInst(case.item); - const casted = try sema.coerce(block, target.ty, resolved, resolved_src); + const casted = try sema.coerce(block, operand.ty, resolved, resolved_src); const item = try sema.resolveConstValue(parent_block, case_src, casted); if (target_val.eql(item)) { @@ -2280,7 +2316,7 @@ fn zirSwitchBr( case_block.instructions.items.len = 0; const resolved = try sema.resolveInst(case.item); - const casted = try sema.coerce(block, target.ty, resolved, resolved_src); + const casted = try sema.coerce(block, operand.ty, resolved, resolved_src); const item = try sema.resolveConstValue(parent_block, case_src, casted); _ = try sema.analyzeBody(&case_block, case.body); @@ -2298,29 +2334,29 @@ fn zirSwitchBr( .instructions = try sema.arena.dupe(*Inst, case_block.instructions.items), }; - return mod.addSwitchBr(parent_block, inst.base.src, target, cases, else_body); + return mod.addSwitchBr(parent_block, inst.base.src, operand, cases, else_body); } -fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Inst.Index) InnerError!void { +fn validateSwitch(sema: *Sema, block: *Scope.Block, operand: *Inst, inst: zir.Inst.Index) InnerError!void { // validate usage of '_' prongs - if (inst.positionals.special_prong == .underscore and target.ty.zigTypeTag() != .Enum) { + if (inst.positionals.special_prong == .underscore and operand.ty.zigTypeTag() != .Enum) { return sema.mod.fail(&block.base, inst.base.src, "'_' prong only allowed when switching on non-exhaustive enums", .{}); // TODO notes "'_' prong here" inst.positionals.cases[last].src } - // check that target type supports ranges + // check that operand type supports ranges if (inst.positionals.range) |range_inst| { - switch (target.ty.zigTypeTag()) { + switch (operand.ty.zigTypeTag()) { .Int, .ComptimeInt => {}, else => { - return sema.mod.fail(&block.base, target.src, "ranges not allowed when switching on type {}", .{target.ty}); + return sema.mod.fail(&block.base, operand.src, "ranges not allowed when switching on type {}", .{operand.ty}); // TODO notes "range used here" range_inst.src }, } } // validate for duplicate items/missing else prong - switch (target.ty.zigTypeTag()) { + switch (operand.ty.zigTypeTag()) { .Enum => return sema.mod.fail(&block.base, inst.base.src, "TODO validateSwitch .Enum", .{}), .ErrorSet => return sema.mod.fail(&block.base, inst.base.src, "TODO validateSwitch .ErrorSet", .{}), .Union => return sema.mod.fail(&block.base, inst.base.src, "TODO validateSwitch .Union", .{}), @@ -2331,9 +2367,9 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins for (inst.positionals.items) |item| { const maybe_src = if (item.castTag(.switch_range)) |range| blk: { const start_resolved = try sema.resolveInst(range.positionals.lhs); - const start_casted = try sema.coerce(block, target.ty, start_resolved); + const start_casted = try sema.coerce(block, operand.ty, start_resolved); const end_resolved = try sema.resolveInst(range.positionals.rhs); - const end_casted = try sema.coerce(block, target.ty, end_resolved); + const end_casted = try sema.coerce(block, operand.ty, end_resolved); break :blk try range_set.add( try sema.resolveConstValue(block, range_start_src, start_casted), @@ -2342,7 +2378,7 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins ); } else blk: { const resolved = try sema.resolveInst(item); - const casted = try sema.coerce(block, target.ty, resolved); + const casted = try sema.coerce(block, operand.ty, resolved); const value = try sema.resolveConstValue(block, item_src, casted); break :blk try range_set.add(value, value, item.src); }; @@ -2353,12 +2389,12 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins } } - if (target.ty.zigTypeTag() == .Int) { + if (operand.ty.zigTypeTag() == .Int) { var arena = std.heap.ArenaAllocator.init(sema.gpa); defer arena.deinit(); - const start = try target.ty.minInt(&arena, mod.getTarget()); - const end = try target.ty.maxInt(&arena, mod.getTarget()); + const start = try operand.ty.minInt(&arena, mod.getTarget()); + const end = try operand.ty.maxInt(&arena, mod.getTarget()); if (try range_set.spans(start, end)) { if (inst.positionals.special_prong == .@"else") { return sema.mod.fail(&block.base, inst.base.src, "unreachable else prong, all cases already handled", .{}); @@ -2396,7 +2432,7 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins }, .EnumLiteral, .Void, .Fn, .Pointer, .Type => { if (inst.positionals.special_prong != .@"else") { - return sema.mod.fail(&block.base, inst.base.src, "else prong required when switching on type '{}'", .{target.ty}); + return sema.mod.fail(&block.base, inst.base.src, "else prong required when switching on type '{}'", .{operand.ty}); } var seen_values = std.HashMap(Value, usize, Value.hash, Value.eql, std.hash_map.DefaultMaxLoadPercentage).init(sema.gpa); @@ -2404,7 +2440,7 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins for (inst.positionals.items) |item| { const resolved = try sema.resolveInst(item); - const casted = try sema.coerce(block, target.ty, resolved); + const casted = try sema.coerce(block, operand.ty, resolved); const val = try sema.resolveConstValue(block, item_src, casted); if (try seen_values.fetchPut(val, item.src)) |prev| { @@ -2429,7 +2465,7 @@ fn validateSwitch(sema: *Sema, block: *Scope.Block, target: *Inst, inst: zir.Ins .ComptimeFloat, .Float, => { - return sema.mod.fail(&block.base, target.src, "invalid switch target type '{}'", .{target.ty}); + return sema.mod.fail(&block.base, operand.src, "invalid switch operand type '{}'", .{operand.ty}); }, } } diff --git a/src/zir.zig b/src/zir.zig index 928548ca83..09ba091e81 100644 --- a/src/zir.zig +++ b/src/zir.zig @@ -585,39 +585,35 @@ pub const Inst = struct { /// An enum literal 8 or fewer bytes. No source location. /// Uses the `small_str` field. enum_literal_small, - // /// A switch expression. - // /// lhs is target, SwitchBr[rhs] - // /// All prongs of target handled. - // switch_br, - // /// Same as switch_br, except has a range field. - // switch_br_range, - // /// Same as switch_br, except has an else prong. - // switch_br_else, - // /// Same as switch_br_else, except has a range field. - // switch_br_else_range, - // /// Same as switch_br, except has an underscore prong. - // switch_br_underscore, - // /// Same as switch_br, except has a range field. - // switch_br_underscore_range, - // /// Same as `switch_br` but the target is a pointer to the value being switched on. - // switch_br_ref, - // /// Same as `switch_br_range` but the target is a pointer to the value being switched on. - // switch_br_ref_range, - // /// Same as `switch_br_else` but the target is a pointer to the value being switched on. - // switch_br_ref_else, - // /// Same as `switch_br_else_range` but the target is a pointer to the - // /// value being switched on. - // switch_br_ref_else_range, - // /// Same as `switch_br_underscore` but the target is a pointer to the value - // /// being switched on. - // switch_br_ref_underscore, - // /// Same as `switch_br_underscore_range` but the target is a pointer to - // /// the value being switched on. - // switch_br_ref_underscore_range, - // /// A range in a switch case, `lhs...rhs`. - // /// Only checks that `lhs >= rhs` if they are ints, everything else is - // /// validated by the switch_br instruction. - // switch_range, + /// A switch expression. Uses the `pl_node` union field. + /// AST node is the switch, payload is `SwitchBr`. + /// All prongs of target handled. + switch_br, + /// Same as switch_br, except has a range field. + switch_br_range, + /// Same as switch_br, except has an else prong. + switch_br_else, + /// Same as switch_br_else, except has a range field. + switch_br_else_range, + /// Same as switch_br, except has an underscore prong. + switch_br_underscore, + /// Same as switch_br, except has a range field. + switch_br_underscore_range, + /// Same as `switch_br` but the target is a pointer to the value being switched on. + switch_br_ref, + /// Same as `switch_br_range` but the target is a pointer to the value being switched on. + switch_br_ref_range, + /// Same as `switch_br_else` but the target is a pointer to the value being switched on. + switch_br_ref_else, + /// Same as `switch_br_else_range` but the target is a pointer to the + /// value being switched on. + switch_br_ref_else_range, + /// Same as `switch_br_underscore` but the target is a pointer to the value + /// being switched on. + switch_br_ref_underscore, + /// Same as `switch_br_underscore_range` but the target is a pointer to + /// the value being switched on. + switch_br_ref_underscore_range, /// Returns whether the instruction is one of the control flow "noreturn" types. /// Function calls do not count. @@ -760,6 +756,18 @@ pub const Inst = struct { .@"unreachable", .repeat, .repeat_inline, + .switch_br, + .switch_br_range, + .switch_br_else, + .switch_br_else_range, + .switch_br_underscore, + .switch_br_underscore_range, + .switch_br_ref, + .switch_br_ref_range, + .switch_br_ref_else, + .switch_br_ref_else_range, + .switch_br_ref_underscore, + .switch_br_ref_underscore_range, => true, }; } @@ -1322,22 +1330,53 @@ pub const Inst = struct { rhs: Ref, }; - /// Stored in extra. Depending on zir tag and len fields, extra fields trail + /// This form is supported when there are no ranges, and exactly 1 item per block. + /// Depending on zir tag and len fields, extra fields trail /// this one in the extra array. - /// 0. range: Ref // If the tag has "_range" in it. - /// 1. else_body: Ref // If the tag has "_else" or "_underscore" in it. - /// 2. items: list of all individual items and ranges. - /// 3. cases: { + /// 0. else_body { // If the tag has "_else" or "_underscore" in it. + /// body_len: u32, + /// body member Index for every body_len + /// } + /// 1. cases: { /// item: Ref, /// body_len: u32, - /// body member Ref for every body_len + /// body member Index for every body_len /// } for every cases_len pub const SwitchBr = struct { - /// TODO investigate, why do we need to store this? is it redundant? - items_len: u32, + operand: Ref, cases_len: u32, }; + /// This form is required when there exists a block which has more than one item, + /// or a range. + /// Depending on zir tag and len fields, extra fields trail + /// this one in the extra array. + /// 0. else_body { // If the tag has "_else" or "_underscore" in it. + /// body_len: u32, + /// body member Index for every body_len + /// } + /// 1. scalar_cases: { // for every scalar_cases_len + /// item: Ref, + /// body_len: u32, + /// body member Index for every body_len + /// } + /// 2. multi_cases: { // for every multi_cases_len + /// items_len: u32, + /// item: Ref for every items_len + /// block_index: u32, // index in extra to a `Block` + /// } + /// 3. range_cases: { // for every range_cases_len + /// item_start: Ref, + /// item_end: Ref, + /// block_index: u32, // index in extra to a `Block` + /// } + pub const SwitchBrRange = struct { + operand: Ref, + scalar_cases_len: u32, + multi_cases_len: u32, + range_cases_len: u32, + }; + pub const Field = struct { lhs: Ref, /// Offset into `string_bytes`. @@ -1503,6 +1542,22 @@ const Writer = struct { .condbr_inline, => try self.writePlNodeCondBr(stream, inst), + .switch_br, + .switch_br_else, + .switch_br_underscore, + .switch_br_ref, + .switch_br_ref_else, + .switch_br_ref_underscore, + => try self.writePlNodeSwitchBr(stream, inst), + + .switch_br_range, + .switch_br_else_range, + .switch_br_underscore_range, + .switch_br_ref_range, + .switch_br_ref_else_range, + .switch_br_ref_underscore_range, + => try self.writePlNodeSwitchBrRange(stream, inst), + .compile_log, .typeof_peer, => try self.writePlNodeMultiOp(stream, inst), @@ -1708,6 +1763,23 @@ const Writer = struct { try self.writeSrc(stream, inst_data.src()); } + fn writePlNodeSwitchBr(self: *Writer, stream: anytype, inst: Inst.Index) !void { + const inst_data = self.code.instructions.items(.data)[inst].pl_node; + const extra = self.code.extraData(Inst.SwitchBr, inst_data.payload_index); + + try self.writeInstRef(stream, extra.data.operand); + try stream.writeAll(", TODO) "); + try self.writeSrc(stream, inst_data.src()); + } + + fn writePlNodeSwitchBrRange(self: *Writer, stream: anytype, inst: Inst.Index) !void { + const inst_data = self.code.instructions.items(.data)[inst].pl_node; + const extra = self.code.extraData(Inst.SwitchBrRange, inst_data.payload_index); + try self.writeInstRef(stream, extra.data.operand); + try stream.writeAll(", TODO) "); + try self.writeSrc(stream, inst_data.src()); + } + fn writePlNodeMultiOp(self: *Writer, stream: anytype, inst: Inst.Index) !void { const inst_data = self.code.instructions.items(.data)[inst].pl_node; const extra = self.code.extraData(Inst.MultiOp, inst_data.payload_index);