zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit cac814cf58ca65ffd8081dca6f2f5b26d822ef5d (tree)
parent 9ba9f457be13c0c32981e94cf75ba7ea6619d92f
Author: mlugg <mlugg@mlugg.co.uk>
Date:   Wed,  5 Feb 2025 20:13:56 +0000

Sema: fix comparison between error set and comptime-known error union

Resolves: #20613

Diffstat:
Msrc/Sema.zig | 9++++++++-
Mtest/behavior/error.zig | 11+++++++++++
2 files changed, 19 insertions(+), 1 deletion(-)

diff --git a/src/Sema.zig b/src/Sema.zig @@ -9145,6 +9145,7 @@ fn zirErrUnionCode(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro return sema.analyzeErrUnionCode(block, src, operand); } +/// If `operand` is comptime-known, asserts that it is an error value rather than a payload value. fn analyzeErrUnionCode(sema: *Sema, block: *Block, src: LazySrcLoc, operand: Air.Inst.Ref) CompileError!Air.Inst.Ref { const pt = sema.pt; const zcu = pt.zcu; @@ -17599,10 +17600,16 @@ fn analyzeCmp( return sema.cmpNumeric(block, src, lhs, rhs, op, lhs_src, rhs_src); } if (is_equality_cmp and lhs_ty.zigTypeTag(zcu) == .error_union and rhs_ty.zigTypeTag(zcu) == .error_set) { + if (try sema.resolveDefinedValue(block, lhs_src, lhs)) |lhs_val| { + if (lhs_val.errorUnionIsPayload(zcu)) return .bool_false; + } const casted_lhs = try sema.analyzeErrUnionCode(block, lhs_src, lhs); return sema.cmpSelf(block, src, casted_lhs, rhs, op, lhs_src, rhs_src); } if (is_equality_cmp and lhs_ty.zigTypeTag(zcu) == .error_set and rhs_ty.zigTypeTag(zcu) == .error_union) { + if (try sema.resolveDefinedValue(block, rhs_src, rhs)) |rhs_val| { + if (rhs_val.errorUnionIsPayload(zcu)) return .bool_false; + } const casted_rhs = try sema.analyzeErrUnionCode(block, rhs_src, rhs); return sema.cmpSelf(block, src, lhs, casted_rhs, op, lhs_src, rhs_src); } @@ -23254,7 +23261,7 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData zcu.backendSupportsFeature(.error_set_has_value)) { if (dest_tag == .error_union) { - const err_code = try sema.analyzeErrUnionCode(block, operand_src, operand); + const err_code = try block.addTyOp(.unwrap_errunion_err, operand_ty, operand); const err_int = try block.addBitCast(err_int_ty, err_code); const zero_err = try pt.intRef(try pt.errorIntType(), 0); diff --git a/test/behavior/error.zig b/test/behavior/error.zig @@ -1100,3 +1100,14 @@ test "return error union with i65" { fn add(x: i65, y: i65) anyerror!i65 { return x + y; } + +test "compare error union to error set" { + const S = struct { + fn doTheTest(val: error{Foo}!i32) !void { + if (error.Foo == val) return error.Unexpected; + if (val == error.Foo) return error.Unexpected; + } + }; + try S.doTheTest(0); + try comptime S.doTheTest(0); +}