diff --git a/src/Module.zig b/src/Module.zig index e756cc3dfd..77bde605a2 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -2445,8 +2445,8 @@ pub const SrcLoc = struct { const case_nodes = tree.extra_data[extra.start..extra.end]; for (case_nodes) |case_node| { const case = switch (node_tags[case_node]) { - .switch_case_one => tree.switchCaseOne(case_node), - .switch_case => tree.switchCase(case_node), + .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node), + .switch_case, .switch_case_inline => tree.switchCase(case_node), else => unreachable, }; const is_special = (case.ast.values.len == 0) or @@ -2469,8 +2469,8 @@ pub const SrcLoc = struct { const case_nodes = tree.extra_data[extra.start..extra.end]; for (case_nodes) |case_node| { const case = switch (node_tags[case_node]) { - .switch_case_one => tree.switchCaseOne(case_node), - .switch_case => tree.switchCase(case_node), + .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node), + .switch_case, .switch_case_inline => tree.switchCase(case_node), else => unreachable, }; const is_special = (case.ast.values.len == 0) or @@ -2491,8 +2491,8 @@ pub const SrcLoc = struct { const case_node = src_loc.declRelativeToNodeIndex(node_off); const node_tags = tree.nodes.items(.tag); const case = switch (node_tags[case_node]) { - .switch_case_one => tree.switchCaseOne(case_node), - .switch_case => tree.switchCase(case_node), + .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node), + .switch_case, .switch_case_inline => tree.switchCase(case_node), else => unreachable, }; const start_tok = case.payload_token.?; @@ -5937,8 +5937,8 @@ pub const SwitchProngSrc = union(enum) { var scalar_i: u32 = 0; for (case_nodes) |case_node| { const case = switch (node_tags[case_node]) { - .switch_case_one => tree.switchCaseOne(case_node), - .switch_case => tree.switchCase(case_node), + .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(case_node), + .switch_case, .switch_case_inline => tree.switchCase(case_node), else => unreachable, }; if (case.ast.values.len == 0) diff --git a/src/Sema.zig b/src/Sema.zig index cf7a2a3036..d27c0095ad 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -162,6 +162,9 @@ pub const Block = struct { /// type of `err` in `else => |err|` switch_else_err_ty: ?Type = null, + /// Value for switch_capture in an inline case + inline_case_capture: Air.Inst.Ref = .none, + const Param = struct { /// `noreturn` means `anytype`. ty: Type, @@ -9002,6 +9005,30 @@ fn zirSwitchCapture( const operand_ptr_ty = sema.typeOf(operand_ptr); const operand_ty = if (operand_is_ref) operand_ptr_ty.childType() else operand_ptr_ty; + if (block.inline_case_capture != .none) { + const item_val = sema.resolveConstValue(block, .unneeded, block.inline_case_capture, undefined) catch unreachable; + if (operand_ty.zigTypeTag() == .Union) { + const field_index = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?); + const union_obj = operand_ty.cast(Type.Payload.Union).?.data; + const field_ty = union_obj.fields.values()[field_index].ty; + if (is_ref) { + const ptr_field_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = field_ty, + .mutable = operand_ptr_ty.ptrIsMutable(), + .@"volatile" = operand_ptr_ty.isVolatilePtr(), + .@"addrspace" = operand_ptr_ty.ptrAddressSpace(), + }); + return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty); + } else { + return block.addStructFieldVal(operand_ptr, field_index, field_ty); + } + } else if (is_ref) { + return sema.addConstantMaybeRef(block, operand_src, operand_ty, item_val, true); + } else { + return block.inline_case_capture; + } + } + const operand = if (operand_is_ref) try sema.analyzeLoad(block, operand_src, operand_ptr, operand_src) else @@ -9234,14 +9261,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError } else 0; const special_prong = extra.data.bits.specialProng(); - const special: struct { body: []const Zir.Inst.Index, end: usize } = switch (special_prong) { - .none => .{ .body = &.{}, .end = header_extra_index }, + const special: struct { body: []const Zir.Inst.Index, end: usize, is_inline: bool } = switch (special_prong) { + .none => .{ .body = &.{}, .end = header_extra_index, .is_inline = false }, .under, .@"else" => blk: { const body_len = @truncate(u31, sema.code.extra[header_extra_index]); const extra_body_start = header_extra_index + 1; break :blk .{ .body = sema.code.extra[extra_body_start..][0..body_len], .end = extra_body_start + body_len, + .is_inline = sema.code.extra[header_extra_index] >> 31 != 0, }; }, }; @@ -9901,6 +9929,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError if (special_prong == .none) { return sema.fail(block, src, "switch must handle all possibilities", .{}); } + if (special.is_inline) { + return sema.fail(block, src, "TODO special.is_inline", .{}); + } if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand)) { return Air.Inst.Ref.unreachable_value; } @@ -9927,6 +9958,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const item_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); + const is_inline = sema.code.extra[extra_index] >> 31 != 0; extra_index += 1; const body = sema.code.extra[extra_index..][0..body_len]; extra_index += body_len; @@ -9936,8 +9968,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = wip_captures.scope; + case_block.inline_case_capture = .none; const item = try sema.resolveInst(item_ref); + if (is_inline) case_block.inline_case_capture = item; // `item` is already guaranteed to be constant known. const analyze_body = if (union_originally) blk: { @@ -9989,12 +10023,118 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError const ranges_len = sema.code.extra[extra_index]; extra_index += 1; const body_len = @truncate(u31, sema.code.extra[extra_index]); + const is_inline = sema.code.extra[extra_index] >> 31 != 0; extra_index += 1; const items = sema.code.refSlice(extra_index, items_len); extra_index += items_len; case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = child_block.wip_capture_scope; + case_block.inline_case_capture = .none; + + // Generate all possible cases as scalar prongs. + if (is_inline) { + const body_start = extra_index + 2 * ranges_len; + const body = sema.code.extra[body_start..][0..body_len]; + const case_src = src; // TODO better source location + var emit_bb = false; + + var range_i: usize = 0; + while (range_i < ranges_len) : (range_i += 1) { + const first_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); + extra_index += 1; + const last_ref = @intToEnum(Zir.Inst.Ref, sema.code.extra[extra_index]); + extra_index += 1; + + const item_first_ref = try sema.resolveInst(first_ref); + var item_first = sema.resolveConstValue(block, .unneeded, item_first_ref, undefined) catch unreachable; + const item_last_ref = try sema.resolveInst(last_ref); + const item_last = sema.resolveConstValue(block, .unneeded, item_last_ref, undefined) catch unreachable; + + while (item_first.compare(.lte, item_last, operand_ty, sema.mod)) : ({ + item_first = try sema.intAddScalar(block, case_src, item_first, Value.one); + }) { + cases_len += 1; + + const item_ref = try sema.addConstant(operand_ty, item_first); + case_block.inline_case_capture = item_ref; + + case_block.instructions.shrinkRetainingCapacity(0); + case_block.wip_capture_scope = child_block.wip_capture_scope; + + if (emit_bb) try sema.emitBackwardBranch(block, case_src); + emit_bb = true; + + _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) { + error.ComptimeBreak => { + const zir_datas = sema.code.instructions.items(.data); + const break_data = zir_datas[sema.comptime_break_inst].@"break"; + try sema.addRuntimeBreak(&case_block, .{ + .block_inst = break_data.block_inst, + .operand = break_data.operand, + .inst = sema.comptime_break_inst, + }); + }, + else => |e| return e, + }; + + // try wip_captures.finalize(); + + try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); + cases_extra.appendAssumeCapacity(1); // items_len + cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); + cases_extra.appendAssumeCapacity(@enumToInt(item_ref)); + cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); + } + } + + for (items) |item_ref| { + cases_len += 1; + + const item = try sema.resolveInst(item_ref); + case_block.inline_case_capture = item; + + case_block.instructions.shrinkRetainingCapacity(0); + case_block.wip_capture_scope = child_block.wip_capture_scope; + + const analyze_body = if (union_originally) blk: { + const item_val = sema.resolveConstValue(block, .unneeded, item, undefined) catch unreachable; + const field_ty = maybe_union_ty.unionFieldType(item_val, sema.mod); + break :blk field_ty.zigTypeTag() != .NoReturn; + } else true; + + if (emit_bb) try sema.emitBackwardBranch(block, case_src); + emit_bb = true; + + if (analyze_body) { + _ = sema.analyzeBodyInner(&case_block, body) catch |err| switch (err) { + error.ComptimeBreak => { + const zir_datas = sema.code.instructions.items(.data); + const break_data = zir_datas[sema.comptime_break_inst].@"break"; + try sema.addRuntimeBreak(&case_block, .{ + .block_inst = break_data.block_inst, + .operand = break_data.operand, + .inst = sema.comptime_break_inst, + }); + }, + else => |e| return e, + }; + } else { + _ = try case_block.addNoOp(.unreach); + } + + // try wip_captures.finalize(); + + try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); + cases_extra.appendAssumeCapacity(1); // items_len + cases_extra.appendAssumeCapacity(@intCast(u32, case_block.instructions.items.len)); + cases_extra.appendAssumeCapacity(@enumToInt(item)); + cases_extra.appendSliceAssumeCapacity(case_block.instructions.items); + } + + extra_index += body_len; + continue; + } var any_ok: Air.Inst.Ref = .none; @@ -10158,6 +10298,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError case_block.instructions.shrinkRetainingCapacity(0); case_block.wip_capture_scope = wip_captures.scope; + case_block.inline_case_capture = .none; + if (special.is_inline) { + return sema.fail(block, src, "TODO special.is_inline", .{}); + } const analyze_body = if (union_originally) for (seen_union_fields) |seen_field, index| { diff --git a/test/behavior.zig b/test/behavior.zig index 648757d56f..78029b6dd1 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -180,6 +180,7 @@ test { _ = @import("behavior/decltest.zig"); _ = @import("behavior/packed_struct_explicit_backing_int.zig"); _ = @import("behavior/empty_union.zig"); + _ = @import("behavior/inline_switch.zig"); } if (builtin.os.tag != .wasi) { diff --git a/test/behavior/inline_switch.zig b/test/behavior/inline_switch.zig new file mode 100644 index 0000000000..d7863f8444 --- /dev/null +++ b/test/behavior/inline_switch.zig @@ -0,0 +1,57 @@ +const std = @import("std"); +const expect = std.testing.expect; +const builtin = @import("builtin"); + +test "inline scalar prongs" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + var x: usize = 0; + switch (x) { + 10 => |*item| try expect(@TypeOf(item) == *usize), + inline 11 => |*item| { + try expect(@TypeOf(item) == *const usize); + try expect(item.* == 11); + }, + else => {}, + } +} + +test "inline prong ranges" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + var x: usize = 0; + switch (x) { + inline 0...20, 24 => |item| { + if (item > 25) @compileError("bad"); + }, + else => {}, + } +} + +const E = enum { a, b, c, d }; +test "inline switch enums" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + + var x: E = .a; + switch (x) { + inline .a, .b => |aorb| if (aorb != .a and aorb != .b) @compileError("bad"), + inline .c, .d => |cord| if (cord != .c and cord != .d) @compileError("bad"), + } +} + +const U = union(E) { a: void, b: u2, c: u3, d: u4 }; +test "inline switch unions" { + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + + var x: U = .a; + switch (x) { + inline .a, .b => |aorb| { + try expect(@TypeOf(aorb) == void or @TypeOf(aorb) == u2); + }, + inline .c, .d => |cord| { + try expect(@TypeOf(cord) == u3 or @TypeOf(cord) == u4); + }, + } +}