diff --git a/src/Air.zig b/src/Air.zig index 40070dccfb..b4552f9d7b 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -290,6 +290,9 @@ pub const Inst = struct { /// Result type is always void. /// Uses the `bin_op` field. LHS is union pointer, RHS is new tag value. set_union_tag, + /// Given a tagged union value, get its tag value. + /// Uses the `ty_op` field. + get_union_tag, /// Given a slice value, return the length. /// Result type is always usize. /// Uses the `ty_op` field. @@ -630,6 +633,7 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index) Type { .array_to_slice, .float_to_int, .int_to_float, + .get_union_tag, => return air.getRefType(datas[inst].ty_op.ty), .loop, diff --git a/src/Liveness.zig b/src/Liveness.zig index 9a7126d135..a9ff586aeb 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -297,6 +297,7 @@ fn analyzeInst( .array_to_slice, .float_to_int, .int_to_float, + .get_union_tag, => { const o = inst_datas[inst].ty_op; return trackOperands(a, new_set, inst, main_tomb, .{ o.operand, .none, .none }); diff --git a/src/Sema.zig b/src/Sema.zig index f076389797..b669cdb979 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1349,7 +1349,13 @@ fn zirUnionDecl( errdefer new_decl_arena.deinit(); const union_obj = try new_decl_arena.allocator.create(Module.Union); - const union_ty = try Type.Tag.@"union".create(&new_decl_arena.allocator, union_obj); + const type_tag: Type.Tag = if (small.has_tag_type or small.auto_enum_tag) .union_tagged else .@"union"; + const union_payload = try new_decl_arena.allocator.create(Type.Payload.Union); + union_payload.* = .{ + .base = .{ .tag = type_tag }, + .data = union_obj, + }; + const union_ty = Type.initPayload(&union_payload.base); const union_val = try Value.Tag.ty.create(&new_decl_arena.allocator, union_ty); const type_name = try sema.createTypeName(block, small.name_strategy); const new_decl = try sema.mod.createAnonymousDeclNamed(&block.base, .{ @@ -6477,10 +6483,11 @@ fn zirCmpEq( const non_null_type = if (lhs_ty_tag == .Null) rhs_ty else lhs_ty; return mod.fail(&block.base, src, "comparison of '{}' with null", .{non_null_type}); } - if (((lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) or - (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union))) - { - return mod.fail(&block.base, src, "TODO implement equality comparison between a union's tag value and an enum literal", .{}); + if (lhs_ty_tag == .EnumLiteral and rhs_ty_tag == .Union) { + return sema.analyzeCmpUnionTag(block, rhs, rhs_src, lhs, lhs_src, op); + } + if (rhs_ty_tag == .EnumLiteral and lhs_ty_tag == .Union) { + return sema.analyzeCmpUnionTag(block, lhs, lhs_src, rhs, rhs_src, op); } if (lhs_ty_tag == .ErrorSet and rhs_ty_tag == .ErrorSet) { const runtime_src: LazySrcLoc = src: { @@ -6521,6 +6528,28 @@ fn zirCmpEq( return sema.analyzeCmp(block, src, lhs, rhs, op, lhs_src, rhs_src, true); } +fn analyzeCmpUnionTag( + sema: *Sema, + block: *Scope.Block, + un: Air.Inst.Ref, + un_src: LazySrcLoc, + tag: Air.Inst.Ref, + tag_src: LazySrcLoc, + op: std.math.CompareOperator, +) CompileError!Air.Inst.Ref { + const union_ty = sema.typeOf(un); + const union_tag_ty = union_ty.unionTagType() orelse { + // TODO note at declaration site that says "union foo is not tagged" + return sema.mod.fail(&block.base, un_src, "comparison of union and enum literal is only valid for tagged union types", .{}); + }; + // Coerce both the union and the tag to the union's tag type, and then execute the + // enum comparison codepath. + const coerced_tag = try sema.coerce(block, union_tag_ty, tag, tag_src); + const coerced_union = try sema.coerce(block, union_tag_ty, un, un_src); + + return sema.cmpSelf(block, coerced_union, coerced_tag, op, un_src, tag_src); +} + /// Only called for non-equality operators. See also `zirCmpEq`. fn zirCmp( sema: *Sema, @@ -6567,10 +6596,21 @@ fn analyzeCmp( @tagName(op), resolved_type, }); } - const casted_lhs = try sema.coerce(block, resolved_type, lhs, lhs_src); const casted_rhs = try sema.coerce(block, resolved_type, rhs, rhs_src); + return sema.cmpSelf(block, casted_lhs, casted_rhs, op, lhs_src, rhs_src); +} +fn cmpSelf( + sema: *Sema, + block: *Scope.Block, + casted_lhs: Air.Inst.Ref, + casted_rhs: Air.Inst.Ref, + op: std.math.CompareOperator, + lhs_src: LazySrcLoc, + rhs_src: LazySrcLoc, +) CompileError!Air.Inst.Ref { + const resolved_type = sema.typeOf(casted_lhs); const runtime_src: LazySrcLoc = src: { if (try sema.resolveMaybeUndefVal(block, lhs_src, casted_lhs)) |lhs_val| { if (lhs_val.isUndef()) return sema.addConstUndef(resolved_type); @@ -9919,9 +9959,9 @@ fn coerce( } } }, - .Enum => { - // enum literal to enum - if (inst_ty.zigTypeTag() == .EnumLiteral) { + .Enum => switch (inst_ty.zigTypeTag()) { + .EnumLiteral => { + // enum literal to enum const val = try sema.resolveConstValue(block, inst_src, inst); const bytes = val.castTag(.enum_literal).?.data; const resolved_dest_type = try sema.resolveTypeFields(block, inst_src, dest_type); @@ -9948,7 +9988,15 @@ fn coerce( resolved_dest_type, try Value.Tag.enum_field_index.create(arena, @intCast(u32, field_index)), ); - } + }, + .Union => blk: { + // union to its own tag type + const union_tag_ty = inst_ty.unionTagType() orelse break :blk; + if (union_tag_ty.eql(dest_type)) { + return sema.unionToTag(block, dest_type, inst, inst_src); + } + }, + else => {}, }, .ErrorUnion => { // T to E!T or E to E!T @@ -10802,6 +10850,20 @@ fn wrapErrorUnion( } } +fn unionToTag( + sema: *Sema, + block: *Scope.Block, + dest_type: Type, + un: Air.Inst.Ref, + un_src: LazySrcLoc, +) !Air.Inst.Ref { + if (try sema.resolveMaybeUndefVal(block, un_src, un)) |un_val| { + return sema.addConstant(dest_type, un_val.unionTag()); + } + try sema.requireRuntimeBlock(block, un_src); + return block.addTyOp(.get_union_tag, dest_type, un); +} + fn resolvePeerTypes( sema: *Sema, block: *Scope.Block, diff --git a/src/codegen.zig b/src/codegen.zig index 6a605edca9..4eda3f2594 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -890,6 +890,7 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .memcpy => try self.airMemcpy(inst), .memset => try self.airMemset(inst), .set_union_tag => try self.airSetUnionTag(inst), + .get_union_tag => try self.airGetUnionTag(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -1552,6 +1553,14 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { return self.finishAir(inst, result, .{ bin_op.lhs, bin_op.rhs, .none }); } + fn airGetUnionTag(self: *Self, inst: Air.Inst.Index) !void { + const ty_op = self.air.instructions.items(.data)[inst].ty_op; + const result: MCValue = if (self.liveness.isUnused(inst)) .dead else switch (arch) { + else => return self.fail("TODO implement airGetUnionTag for {}", .{self.target.cpu.arch}), + }; + return self.finishAir(inst, result, .{ ty_op.operand, .none, .none }); + } + fn reuseOperand(self: *Self, inst: Air.Inst.Index, operand: Air.Inst.Ref, op_index: Liveness.OperandInt, mcv: MCValue) bool { if (!self.liveness.operandDies(inst, op_index)) return false; diff --git a/src/codegen/c.zig b/src/codegen/c.zig index fc0c86b8f1..a6534b1eba 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -956,6 +956,7 @@ fn genBody(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, OutO .memset => try airMemset(f, inst), .memcpy => try airMemcpy(f, inst), .set_union_tag => try airSetUnionTag(f, inst), + .get_union_tag => try airGetUnionTag(f, inst), .int_to_float, .float_to_int, @@ -2096,6 +2097,22 @@ fn airSetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue { return CValue.none; } +fn airGetUnionTag(f: *Function, inst: Air.Inst.Index) !CValue { + if (f.liveness.isUnused(inst)) + return CValue.none; + + const inst_ty = f.air.typeOfIndex(inst); + const local = try f.allocLocal(inst_ty, .Const); + const ty_op = f.air.instructions.items(.data)[inst].ty_op; + const writer = f.object.writer(); + const operand = try f.resolveInst(ty_op.operand); + + try writer.writeAll("get_union_tag("); + try f.writeCValue(writer, operand); + try writer.writeAll(");\n"); + return local; +} + fn toMemoryOrder(order: std.builtin.AtomicOrder) [:0]const u8 { return switch (order) { .Unordered => "memory_order_relaxed", diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index ab164b5d91..4a0d218ead 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1304,6 +1304,7 @@ pub const FuncGen = struct { .memset => try self.airMemset(inst), .memcpy => try self.airMemcpy(inst), .set_union_tag => try self.airSetUnionTag(inst), + .get_union_tag => try self.airGetUnionTag(inst), .atomic_store_unordered => try self.airAtomicStore(inst, .Unordered), .atomic_store_monotonic => try self.airAtomicStore(inst, .Monotonic), @@ -2557,6 +2558,18 @@ pub const FuncGen = struct { return null; } + fn airGetUnionTag(self: *FuncGen, inst: Air.Inst.Index) !?*const llvm.Value { + if (self.liveness.isUnused(inst)) + return null; + + const ty_op = self.air.instructions.items(.data)[inst].ty_op; + const un_ty = self.air.typeOf(ty_op.operand); + const un = try self.resolveInst(ty_op.operand); + + _ = un_ty; // TODO handle when onlyTagHasCodegenBits() == true and other union forms + return self.builder.buildExtractValue(un, 1, ""); + } + fn fieldPtr( self: *FuncGen, inst: Air.Inst.Index, diff --git a/src/print_air.zig b/src/print_air.zig index e735d03bd3..2a7538f81a 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -179,6 +179,7 @@ const Writer = struct { .array_to_slice, .int_to_float, .float_to_int, + .get_union_tag, => try w.writeTyOp(s, inst), .block, diff --git a/src/type.zig b/src/type.zig index bb798959f4..781fe74d45 100644 --- a/src/type.zig +++ b/src/type.zig @@ -2487,6 +2487,12 @@ pub const Type = extern union { }; } + pub fn unionFieldType(ty: Type, enum_tag: Value) Type { + const union_obj = ty.cast(Payload.Union).?.data; + const index = union_obj.tag_ty.enumTagFieldIndex(enum_tag).?; + return union_obj.fields.values()[index].ty; + } + /// Asserts that the type is an error union. pub fn errorUnionPayload(self: Type) Type { return switch (self.tag()) { @@ -3801,6 +3807,8 @@ pub const Type = extern union { }; }; + pub const @"bool" = initTag(.bool); + pub fn ptr(arena: *Allocator, d: Payload.Pointer.Data) !Type { assert(d.host_size == 0 or d.bit_offset < d.host_size * 8); diff --git a/src/value.zig b/src/value.zig index cb5d211b1e..69f8945e01 100644 --- a/src/value.zig +++ b/src/value.zig @@ -1275,7 +1275,12 @@ pub const Value = extern union { } }, .Union => { - @panic("TODO implement hashing union values"); + const union_obj = val.castTag(.@"union").?.data; + if (ty.unionTagType()) |tag_ty| { + union_obj.tag.hash(tag_ty, hasher); + } + const active_field_ty = ty.unionFieldType(union_obj.tag); + union_obj.val.hash(active_field_ty, hasher); }, .Fn => { @panic("TODO implement hashing function values"); @@ -1431,6 +1436,14 @@ pub const Value = extern union { } } + pub fn unionTag(val: Value) Value { + switch (val.tag()) { + .undef => return val, + .@"union" => return val.castTag(.@"union").?.data.tag, + else => unreachable, + } + } + /// Returns a pointer to the element value at the index. pub fn elemPtr(self: Value, allocator: *Allocator, index: usize) !Value { if (self.castTag(.elem_ptr)) |elem_ptr| { diff --git a/test/behavior/union.zig b/test/behavior/union.zig index 6b8705e044..afefa7cf85 100644 --- a/test/behavior/union.zig +++ b/test/behavior/union.zig @@ -14,3 +14,21 @@ test "basic unions" { foo = Foo{ .float = 12.34 }; try expect(foo.float == 12.34); } + +test "init union with runtime value" { + var foo: Foo = undefined; + + setFloat(&foo, 12.34); + try expect(foo.float == 12.34); + + setInt(&foo, 42); + try expect(foo.int == 42); +} + +fn setFloat(foo: *Foo, x: f64) void { + foo.* = Foo{ .float = x }; +} + +fn setInt(foo: *Foo, x: i32) void { + foo.* = Foo{ .int = x }; +} diff --git a/test/behavior/union_stage1.zig b/test/behavior/union_stage1.zig index 5741858d51..725d7bd028 100644 --- a/test/behavior/union_stage1.zig +++ b/test/behavior/union_stage1.zig @@ -49,24 +49,6 @@ test "comptime union field access" { } } -test "init union with runtime value" { - var foo: Foo = undefined; - - setFloat(&foo, 12.34); - try expect(foo.float == 12.34); - - setInt(&foo, 42); - try expect(foo.int == 42); -} - -fn setFloat(foo: *Foo, x: f64) void { - foo.* = Foo{ .float = x }; -} - -fn setInt(foo: *Foo, x: i32) void { - foo.* = Foo{ .int = x }; -} - const FooExtern = extern union { float: f64, int: i32, @@ -185,12 +167,13 @@ test "union field access gives the enum values" { } test "cast union to tag type of union" { - try testCastUnionToTag(TheUnion{ .B = 1234 }); - comptime try testCastUnionToTag(TheUnion{ .B = 1234 }); + try testCastUnionToTag(); + comptime try testCastUnionToTag(); } -fn testCastUnionToTag(x: TheUnion) !void { - try expect(@as(TheTag, x) == TheTag.B); +fn testCastUnionToTag() !void { + var u = TheUnion{ .B = 1234 }; + try expect(@as(TheTag, u) == TheTag.B); } test "cast tag type of union to union" {