diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 99d0aadc99..d0f6d62ad7 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1931,11 +1931,93 @@ pub const FuncGen = struct { const lhs = try self.resolveInst(bin_op.lhs); const rhs = try self.resolveInst(bin_op.rhs); const operand_ty = self.air.typeOf(bin_op.lhs); - var buffer: Type.Payload.Bits = undefined; + + return self.cmp(lhs, rhs, operand_ty, op); + } + + fn cmp( + self: *FuncGen, + lhs: *const llvm.Value, + rhs: *const llvm.Value, + operand_ty: Type, + op: math.CompareOperator, + ) *const llvm.Value { + var int_buffer: Type.Payload.Bits = undefined; + var opt_buffer: Type.Payload.ElemType = undefined; const int_ty = switch (operand_ty.zigTypeTag()) { - .Enum => operand_ty.intTagType(&buffer), - .Int, .Bool, .Pointer, .Optional, .ErrorSet => operand_ty, + .Enum => operand_ty.intTagType(&int_buffer), + .Int, .Bool, .Pointer, .ErrorSet => operand_ty, + .Optional => blk: { + const payload_ty = operand_ty.optionalChild(&opt_buffer); + if (!payload_ty.hasCodeGenBits() or operand_ty.isPtrLikeOptional()) { + break :blk operand_ty; + } + // We need to emit instructions to check for equality/inequality + // of optionals that are not pointers. + const is_by_ref = isByRef(operand_ty); + const lhs_non_null = self.optIsNonNull(lhs, is_by_ref); + const rhs_non_null = self.optIsNonNull(rhs, is_by_ref); + const llvm_i2 = self.context.intType(2); + const lhs_non_null_i2 = self.builder.buildZExt(lhs_non_null, llvm_i2, ""); + const rhs_non_null_i2 = self.builder.buildZExt(rhs_non_null, llvm_i2, ""); + const lhs_shifted = self.builder.buildShl(lhs_non_null_i2, llvm_i2.constInt(1, .False), ""); + const lhs_rhs_ored = self.builder.buildOr(lhs_shifted, rhs_non_null_i2, ""); + const both_null_block = self.context.appendBasicBlock(self.llvm_func, "BothNull"); + const mixed_block = self.context.appendBasicBlock(self.llvm_func, "Mixed"); + const both_pl_block = self.context.appendBasicBlock(self.llvm_func, "BothNonNull"); + const end_block = self.context.appendBasicBlock(self.llvm_func, "End"); + const llvm_switch = self.builder.buildSwitch(lhs_rhs_ored, mixed_block, 2); + const llvm_i2_00 = llvm_i2.constInt(0b00, .False); + const llvm_i2_11 = llvm_i2.constInt(0b11, .False); + llvm_switch.addCase(llvm_i2_00, both_null_block); + llvm_switch.addCase(llvm_i2_11, both_pl_block); + + self.builder.positionBuilderAtEnd(both_null_block); + _ = self.builder.buildBr(end_block); + + self.builder.positionBuilderAtEnd(mixed_block); + _ = self.builder.buildBr(end_block); + + self.builder.positionBuilderAtEnd(both_pl_block); + const lhs_payload = self.optPayloadHandle(lhs, is_by_ref); + const rhs_payload = self.optPayloadHandle(rhs, is_by_ref); + const payload_cmp = self.cmp(lhs_payload, rhs_payload, payload_ty, op); + _ = self.builder.buildBr(end_block); + const both_pl_block_end = self.builder.getInsertBlock(); + + self.builder.positionBuilderAtEnd(end_block); + const incoming_blocks: [3]*const llvm.BasicBlock = .{ + both_null_block, + mixed_block, + both_pl_block_end, + }; + const llvm_i1 = self.context.intType(1); + const llvm_i1_0 = llvm_i1.constInt(0, .False); + const llvm_i1_1 = llvm_i1.constInt(1, .False); + const incoming_values: [3]*const llvm.Value = .{ + switch (op) { + .eq => llvm_i1_1, + .neq => llvm_i1_0, + else => unreachable, + }, + switch (op) { + .eq => llvm_i1_0, + .neq => llvm_i1_1, + else => unreachable, + }, + payload_cmp, + }; + + const phi_node = self.builder.buildPhi(llvm_i1, ""); + comptime assert(incoming_values.len == incoming_blocks.len); + phi_node.addIncoming( + &incoming_values, + &incoming_blocks, + incoming_values.len, + ); + return phi_node; + }, .Float => { const operation: llvm.RealPredicate = switch (op) { .eq => .OEQ, @@ -2493,24 +2575,8 @@ pub const FuncGen = struct { } } - if (operand_is_ptr or isByRef(optional_ty)) { - const index_type = self.context.intType(32); - - const indices: [2]*const llvm.Value = .{ - index_type.constNull(), - index_type.constInt(1, .False), - }; - - const field_ptr = self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""); - const non_null_bit = self.builder.buildLoad(field_ptr, ""); - if (invert) { - return self.builder.buildNot(non_null_bit, ""); - } else { - return non_null_bit; - } - } - - const non_null_bit = self.builder.buildExtractValue(operand, 1, ""); + const is_by_ref = operand_is_ptr or isByRef(optional_ty); + const non_null_bit = self.optIsNonNull(operand, is_by_ref); if (invert) { return self.builder.buildNot(non_null_bit, ""); } else { @@ -2622,17 +2688,7 @@ pub const FuncGen = struct { return operand; } - if (isByRef(payload_ty)) { - // We have a pointer and we need to return a pointer to the first field. - const index_type = self.context.intType(32); - const indices: [2]*const llvm.Value = .{ - index_type.constNull(), // dereference the pointer - index_type.constNull(), // first field is the payload - }; - return self.builder.buildInBoundsGEP(operand, &indices, indices.len, ""); - } - - return self.builder.buildExtractValue(operand, 0, ""); + return self.optPayloadHandle(operand, isByRef(payload_ty)); } fn airErrUnionPayload( @@ -3748,6 +3804,38 @@ pub const FuncGen = struct { } } + /// Assumes the optional is not pointer-like and payload has bits. + fn optIsNonNull(self: *FuncGen, opt_handle: *const llvm.Value, is_by_ref: bool) *const llvm.Value { + if (is_by_ref) { + const index_type = self.context.intType(32); + + const indices: [2]*const llvm.Value = .{ + index_type.constNull(), + index_type.constInt(1, .False), + }; + + const field_ptr = self.builder.buildInBoundsGEP(opt_handle, &indices, indices.len, ""); + return self.builder.buildLoad(field_ptr, ""); + } + + return self.builder.buildExtractValue(opt_handle, 1, ""); + } + + /// Assumes the optional is not pointer-like and payload has bits. + fn optPayloadHandle(self: *FuncGen, opt_handle: *const llvm.Value, is_by_ref: bool) *const llvm.Value { + if (is_by_ref) { + // We have a pointer and we need to return a pointer to the first field. + const index_type = self.context.intType(32); + const indices: [2]*const llvm.Value = .{ + index_type.constNull(), // dereference the pointer + index_type.constNull(), // first field is the payload + }; + return self.builder.buildInBoundsGEP(opt_handle, &indices, indices.len, ""); + } + + return self.builder.buildExtractValue(opt_handle, 0, ""); + } + fn callFloor(self: *FuncGen, arg: *const llvm.Value, ty: Type) !*const llvm.Value { return self.callFloatUnary(arg, ty, "floor"); } diff --git a/src/codegen/llvm/bindings.zig b/src/codegen/llvm/bindings.zig index 43aca87532..4a837df9cd 100644 --- a/src/codegen/llvm/bindings.zig +++ b/src/codegen/llvm/bindings.zig @@ -98,7 +98,12 @@ pub const Value = opaque { extern fn LLVMAppendExistingBasicBlock(Fn: *const Value, BB: *const BasicBlock) void; pub const addIncoming = LLVMAddIncoming; - extern fn LLVMAddIncoming(PhiNode: *const Value, IncomingValues: [*]*const Value, IncomingBlocks: [*]*const BasicBlock, Count: c_uint) void; + extern fn LLVMAddIncoming( + PhiNode: *const Value, + IncomingValues: [*]const *const Value, + IncomingBlocks: [*]const *const BasicBlock, + Count: c_uint, + ) void; pub const getNextInstruction = LLVMGetNextInstruction; extern fn LLVMGetNextInstruction(Inst: *const Value) ?*const Value; diff --git a/src/type.zig b/src/type.zig index e02ec051cf..728ba8ef5f 100644 --- a/src/type.zig +++ b/src/type.zig @@ -175,7 +175,11 @@ pub const Type = extern union { => false, .Pointer => is_equality_cmp or ty.isCPtr(), - .Optional => is_equality_cmp and ty.isPtrLikeOptional(), + .Optional => { + if (!is_equality_cmp) return false; + var buf: Payload.ElemType = undefined; + return ty.optionalChild(&buf).isSelfComparable(is_equality_cmp); + }, }; } diff --git a/test/behavior/optional.zig b/test/behavior/optional.zig index 821e31eae2..7bd88ee4a0 100644 --- a/test/behavior/optional.zig +++ b/test/behavior/optional.zig @@ -103,3 +103,37 @@ test "nested optional field in struct" { }; try expect(s.x.?.y == 127); } + +test "equality compare optional with non-optional" { + try test_cmp_optional_non_optional(); + comptime try test_cmp_optional_non_optional(); +} + +fn test_cmp_optional_non_optional() !void { + var ten: i32 = 10; + var opt_ten: ?i32 = 10; + var five: i32 = 5; + var int_n: ?i32 = null; + + try expect(int_n != ten); + try expect(opt_ten == ten); + try expect(opt_ten != five); + + // test evaluation is always lexical + // ensure that the optional isn't always computed before the non-optional + var mutable_state: i32 = 0; + _ = blk1: { + mutable_state += 1; + break :blk1 @as(?f64, 10.0); + } != blk2: { + try expect(mutable_state == 1); + break :blk2 @as(f64, 5.0); + }; + _ = blk1: { + mutable_state += 1; + break :blk1 @as(f64, 10.0); + } != blk2: { + try expect(mutable_state == 2); + break :blk2 @as(?f64, 5.0); + }; +} diff --git a/test/behavior/optional_stage1.zig b/test/behavior/optional_stage1.zig index 14ab731a81..e46f9a847a 100644 --- a/test/behavior/optional_stage1.zig +++ b/test/behavior/optional_stage1.zig @@ -3,40 +3,6 @@ const testing = std.testing; const expect = testing.expect; const expectEqual = testing.expectEqual; -test "equality compare optional with non-optional" { - try test_cmp_optional_non_optional(); - comptime try test_cmp_optional_non_optional(); -} - -fn test_cmp_optional_non_optional() !void { - var ten: i32 = 10; - var opt_ten: ?i32 = 10; - var five: i32 = 5; - var int_n: ?i32 = null; - - try expect(int_n != ten); - try expect(opt_ten == ten); - try expect(opt_ten != five); - - // test evaluation is always lexical - // ensure that the optional isn't always computed before the non-optional - var mutable_state: i32 = 0; - _ = blk1: { - mutable_state += 1; - break :blk1 @as(?f64, 10.0); - } != blk2: { - try expect(mutable_state == 1); - break :blk2 @as(f64, 5.0); - }; - _ = blk1: { - mutable_state += 1; - break :blk1 @as(f64, 10.0); - } != blk2: { - try expect(mutable_state == 2); - break :blk2 @as(?f64, 5.0); - }; -} - test "unwrap function call with optional pointer return value" { const S = struct { fn entry() !void {