implement safety for resuming non-suspended function

closes #3469
This commit is contained in:
Andrew Kelley
2019-10-22 23:43:27 -04:00
parent 1dcf540426
commit e98e5dda52
3 changed files with 74 additions and 0 deletions

View File

@@ -933,6 +933,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
return buf_create_from_str("resumed an async function which can only be awaited");
case PanicMsgIdBadNoAsyncCall:
return buf_create_from_str("async function called with noasync suspended");
case PanicMsgIdResumeNotSuspendedFn:
return buf_create_from_str("resumed a non-suspended function");
}
zig_unreachable();
}
@@ -2234,6 +2236,12 @@ static void gen_assert_resume_id(CodeGen *g, IrInstruction *source_instr, Resume
LLVMBasicBlockRef end_bb)
{
LLVMTypeRef usize_type_ref = g->builtin_types.entry_usize->llvm_type;
if (ir_want_runtime_safety(g, source_instr)) {
// Write a value to the resume index which indicates the function was resumed while not suspended.
LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
}
LLVMBasicBlockRef bad_resume_block = LLVMAppendBasicBlock(g->cur_fn_val, "BadResume");
if (end_bb == nullptr) end_bb = LLVMAppendBasicBlock(g->cur_fn_val, "OkResume");
LLVMValueRef expected_value = LLVMConstSub(LLVMConstAllOnes(usize_type_ref),
@@ -5764,6 +5772,9 @@ static LLVMValueRef ir_render_suspend_finish(CodeGen *g, IrExecutable *executabl
LLVMBuildRetVoid(g->builder);
LLVMPositionBuilderAtEnd(g->builder, instruction->begin->resume_bb);
if (ir_want_runtime_safety(g, &instruction->base)) {
LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
}
render_async_var_decls(g, instruction->base.scope);
return nullptr;
}
@@ -7542,7 +7553,20 @@ static void do_code_gen(CodeGen *g) {
IrBasicBlock *entry_block = executable->basic_block_list.at(0);
LLVMAddCase(switch_instr, zero, entry_block->llvm_block);
g->cur_resume_block_count += 1;
{
LLVMBasicBlockRef bad_not_suspended_bb = LLVMAppendBasicBlock(g->cur_fn_val, "NotSuspended");
size_t new_block_index = g->cur_resume_block_count;
g->cur_resume_block_count += 1;
g->cur_bad_not_suspended_index = LLVMConstInt(usize_type_ref, new_block_index, false);
LLVMAddCase(g->cur_async_switch_instr, g->cur_bad_not_suspended_index, bad_not_suspended_bb);
LLVMPositionBuilderAtEnd(g->builder, bad_not_suspended_bb);
gen_assertion_scope(g, PanicMsgIdResumeNotSuspendedFn, fn_table_entry->child_scope);
}
LLVMPositionBuilderAtEnd(g->builder, entry_block->llvm_block);
LLVMBuildStore(g->builder, g->cur_bad_not_suspended_index, g->cur_async_resume_index_ptr);
if (trace_field_index_stack != UINT32_MAX) {
if (codegen_fn_has_err_ret_tracing_arg(g, fn_type_id->return_type)) {
LLVMValueRef trace_ptr_ptr = LLVMBuildStructGEP(g->builder, g->cur_frame_ptr,