From b9d1d45dfd0f704bc762732c23aa2844f1d14e8d Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Fri, 9 Aug 2019 21:49:40 -0400 Subject: [PATCH] fix combining try with errdefer cancel --- src/all_types.hpp | 13 +++++++++ src/codegen.cpp | 45 +++++++++++++++++++++++------ src/ir.cpp | 33 +++++++++++++++++++++ src/ir_print.cpp | 9 ++++++ test/stage1/behavior/coroutines.zig | 29 +++++++++++++++++++ 5 files changed, 120 insertions(+), 9 deletions(-) diff --git a/src/all_types.hpp b/src/all_types.hpp index 0b03388502..45182f3db3 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2366,6 +2366,7 @@ enum IrInstructionId { IrInstructionIdAwaitGen, IrInstructionIdCoroResume, IrInstructionIdTestCancelRequested, + IrInstructionIdSpill, }; struct IrInstruction { @@ -3643,6 +3644,18 @@ struct IrInstructionTestCancelRequested { IrInstruction base; }; +enum SpillId { + SpillIdInvalid, + SpillIdRetErrCode, +}; + +struct IrInstructionSpill { + IrInstruction base; + + SpillId spill_id; + IrInstruction *operand; +}; + enum ResultLocId { ResultLocIdInvalid, ResultLocIdNone, diff --git a/src/codegen.cpp b/src/codegen.cpp index 5a8fd3e9ca..976ee4181e 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -5113,6 +5113,18 @@ static LLVMValueRef ir_render_test_err(CodeGen *g, IrExecutable *executable, IrI return LLVMBuildICmp(g->builder, LLVMIntNE, err_val, zero, ""); } +static LLVMValueRef gen_unwrap_err_code(CodeGen *g, LLVMValueRef err_union_ptr, ZigType *ptr_type) { + ZigType *err_union_type = ptr_type->data.pointer.child_type; + ZigType *payload_type = err_union_type->data.error_union.payload_type; + if (!type_has_bits(payload_type)) { + return err_union_ptr; + } else { + // TODO assign undef to the payload + LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type); + return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, ""); + } +} + static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executable, IrInstructionUnwrapErrCode *instruction) { @@ -5121,16 +5133,8 @@ static LLVMValueRef ir_render_unwrap_err_code(CodeGen *g, IrExecutable *executab ZigType *ptr_type = instruction->err_union_ptr->value.type; assert(ptr_type->id == ZigTypeIdPointer); - ZigType *err_union_type = ptr_type->data.pointer.child_type; - ZigType *payload_type = err_union_type->data.error_union.payload_type; LLVMValueRef err_union_ptr = ir_llvm_value(g, instruction->err_union_ptr); - if (!type_has_bits(payload_type)) { - return err_union_ptr; - } else { - // TODO assign undef to the payload - LLVMValueRef err_union_handle = get_handle_value(g, err_union_ptr, err_union_type, ptr_type); - return LLVMBuildStructGEP(g->builder, err_union_handle, err_union_err_index, ""); - } + return gen_unwrap_err_code(g, err_union_ptr, ptr_type); } static LLVMValueRef ir_render_unwrap_err_payload(CodeGen *g, IrExecutable *executable, @@ -5611,6 +5615,27 @@ static LLVMValueRef ir_render_test_cancel_requested(CodeGen *g, IrExecutable *ex } } +static LLVMValueRef ir_render_spill(CodeGen *g, IrExecutable *executable, IrInstructionSpill *instruction) { + if (!fn_is_async(g->cur_fn)) + return ir_llvm_value(g, instruction->operand); + + switch (instruction->spill_id) { + case SpillIdInvalid: + zig_unreachable(); + case SpillIdRetErrCode: { + LLVMValueRef ret_ptr = LLVMBuildLoad(g->builder, g->cur_ret_ptr, ""); + ZigType *ret_type = g->cur_fn->type_entry->data.fn.fn_type_id.return_type; + if (ret_type->id == ZigTypeIdErrorUnion) { + return gen_unwrap_err_code(g, ret_ptr, get_pointer_to_type(g, ret_type, true)); + } else { + zig_unreachable(); + } + } + + } + zig_unreachable(); +} + static void set_debug_location(CodeGen *g, IrInstruction *instruction) { AstNode *source_node = instruction->source_node; Scope *scope = instruction->scope; @@ -5866,6 +5891,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_await(g, executable, (IrInstructionAwaitGen *)instruction); case IrInstructionIdTestCancelRequested: return ir_render_test_cancel_requested(g, executable, (IrInstructionTestCancelRequested *)instruction); + case IrInstructionIdSpill: + return ir_render_spill(g, executable, (IrInstructionSpill *)instruction); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 4dcfaa6cce..845ee03757 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -1066,6 +1066,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionTestCancelReques return IrInstructionIdTestCancelRequested; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionSpill *) { + return IrInstructionIdSpill; +} + template static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate(1); @@ -3332,6 +3336,18 @@ static IrInstruction *ir_build_test_cancel_requested(IrBuilder *irb, Scope *scop return &instruction->base; } +static IrInstruction *ir_build_spill(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *operand, SpillId spill_id) +{ + IrInstructionSpill *instruction = ir_build_instruction(irb, scope, source_node); + instruction->operand = operand; + instruction->spill_id = spill_id; + + ir_ref_instruction(operand, irb->current_basic_block); + + return &instruction->base; +} + static void ir_count_defers(IrBuilder *irb, Scope *inner_scope, Scope *outer_scope, size_t *results) { results[ReturnKindUnconditional] = 0; results[ReturnKindError] = 0; @@ -3591,6 +3607,7 @@ static IrInstruction *ir_gen_return(IrBuilder *irb, Scope *scope, AstNode *node, ResultLocReturn *result_loc_ret = allocate(1); result_loc_ret->base.id = ResultLocIdReturn; ir_build_reset_result(irb, scope, node, &result_loc_ret->base); + err_val = ir_build_spill(irb, scope, node, err_val, SpillIdRetErrCode); ir_build_end_expr(irb, scope, node, err_val, &result_loc_ret->base); if (irb->codegen->have_err_ret_tracing && !should_inline) { @@ -24725,6 +24742,19 @@ static IrInstruction *ir_analyze_instruction_test_cancel_requested(IrAnalyze *ir return ir_build_test_cancel_requested(&ira->new_irb, instruction->base.scope, instruction->base.source_node); } +static IrInstruction *ir_analyze_instruction_spill(IrAnalyze *ira, IrInstructionSpill *instruction) { + IrInstruction *operand = instruction->operand->child; + if (type_is_invalid(operand->value.type)) + return ira->codegen->invalid_instruction; + if (ir_should_inline(ira->new_irb.exec, instruction->base.scope)) { + return operand; + } + IrInstruction *result = ir_build_spill(&ira->new_irb, instruction->base.scope, instruction->base.source_node, + operand, instruction->spill_id); + result->value.type = operand->value.type; + return result; +} + static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) { switch (instruction->id) { case IrInstructionIdInvalid: @@ -25024,6 +25054,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_await(ira, (IrInstructionAwaitSrc *)instruction); case IrInstructionIdTestCancelRequested: return ir_analyze_instruction_test_cancel_requested(ira, (IrInstructionTestCancelRequested *)instruction); + case IrInstructionIdSpill: + return ir_analyze_instruction_spill(ira, (IrInstructionSpill *)instruction); } zig_unreachable(); } @@ -25259,6 +25291,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdAllocaSrc: case IrInstructionIdAllocaGen: case IrInstructionIdTestCancelRequested: + case IrInstructionIdSpill: return false; case IrInstructionIdAsm: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 8c90eb02f3..39e781e4f0 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -1554,6 +1554,12 @@ static void ir_print_test_cancel_requested(IrPrint *irp, IrInstructionTestCancel fprintf(irp->f, "@testCancelRequested()"); } +static void ir_print_spill(IrPrint *irp, IrInstructionSpill *instruction) { + fprintf(irp->f, "@spill("); + ir_print_other_instruction(irp, instruction->operand); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -2039,6 +2045,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdTestCancelRequested: ir_print_test_cancel_requested(irp, (IrInstructionTestCancelRequested *)instruction); break; + case IrInstructionIdSpill: + ir_print_spill(irp, (IrInstructionSpill *)instruction); + break; } fprintf(irp->f, "\n"); } diff --git a/test/stage1/behavior/coroutines.zig b/test/stage1/behavior/coroutines.zig index c2b95e8559..c92cca9573 100644 --- a/test/stage1/behavior/coroutines.zig +++ b/test/stage1/behavior/coroutines.zig @@ -613,3 +613,32 @@ test "cancel inside an errdefer" { }; S.doTheTest(); } + +test "combining try with errdefer cancel" { + const S = struct { + var frame: anyframe = undefined; + var ok = false; + + fn doTheTest() void { + _ = async amain(); + resume frame; + expect(ok); + } + + fn amain() !void { + var f = async func("https://example.com/"); + errdefer cancel f; + + _ = try await f; + } + + fn func(url: []const u8) ![]u8 { + errdefer ok = true; + frame = @frame(); + suspend; + return error.Bad; + } + + }; + S.doTheTest(); +}