add array bounds checking in debug mode

closes #27
This commit is contained in:
Andrew Kelley
2016-04-26 11:35:56 -07:00
parent 61e6c49bc5
commit d1fa5692c6
5 changed files with 216 additions and 55 deletions

View File

@@ -330,6 +330,46 @@ static LLVMValueRef get_handle_value(CodeGen *g, AstNode *source_node, LLVMValue
}
}
static bool want_debug_safety(CodeGen *g, AstNode *node) {
return !g->is_release_build && !node->block_context->safety_off;
}
static void add_bounds_check(CodeGen *g, AstNode *source_node, LLVMValueRef target_val,
LLVMIntPredicate lower_pred, LLVMValueRef lower_value,
LLVMIntPredicate upper_pred, LLVMValueRef upper_value)
{
if (!lower_value && !upper_value) {
return;
}
if (upper_value && !lower_value) {
lower_value = upper_value;
lower_pred = upper_pred;
upper_value = nullptr;
}
add_debug_source_node(g, source_node);
LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail");
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk");
LLVMBasicBlockRef lower_ok_block = upper_value ?
LLVMAppendBasicBlock(g->cur_fn->fn_value, "FirstBoundsCheckOk") : ok_block;
LLVMValueRef lower_ok_val = LLVMBuildICmp(g->builder, lower_pred, target_val, lower_value, "");
LLVMBuildCondBr(g->builder, lower_ok_val, lower_ok_block, bounds_check_fail_block);
LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block);
LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, "");
LLVMBuildUnreachable(g->builder);
if (upper_value) {
LLVMPositionBuilderAtEnd(g->builder, lower_ok_block);
LLVMValueRef upper_ok_val = LLVMBuildICmp(g->builder, upper_pred, target_val, upper_value, "");
LLVMBuildCondBr(g->builder, upper_ok_val, ok_block, bounds_check_fail_block);
}
LLVMPositionBuilderAtEnd(g->builder, ok_block);
}
static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
assert(g->generate_error_name_table);
@@ -344,25 +384,10 @@ static LLVMValueRef gen_err_name(CodeGen *g, AstNode *node) {
LLVMValueRef err_val = gen_expr(g, err_val_node);
add_debug_source_node(g, node);
if (!g->is_release_build) {
LLVMBasicBlockRef bounds_check_fail_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckFail");
LLVMBasicBlockRef lower_ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "LowerBoundsCheckOk");
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoundsCheckOk");
if (want_debug_safety(g, node)) {
LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(err_val));
LLVMValueRef is_zero_val = LLVMBuildICmp(g->builder, LLVMIntEQ, err_val, zero, "");
LLVMBuildCondBr(g->builder, is_zero_val, bounds_check_fail_block, lower_ok_block);
LLVMPositionBuilderAtEnd(g->builder, bounds_check_fail_block);
LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, "");
LLVMBuildUnreachable(g->builder);
LLVMPositionBuilderAtEnd(g->builder, lower_ok_block);
LLVMValueRef end_val = LLVMConstInt(LLVMTypeOf(err_val), g->error_decls.length, false);
LLVMValueRef is_too_big_val = LLVMBuildICmp(g->builder, LLVMIntUGE, err_val, end_val, "");
LLVMBuildCondBr(g->builder, is_too_big_val, bounds_check_fail_block, ok_block);
LLVMPositionBuilderAtEnd(g->builder, ok_block);
add_bounds_check(g, node, err_val, LLVMIntNE, zero, LLVMIntULT, end_val);
}
LLVMValueRef indices[] = {
@@ -869,6 +894,11 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal
}
if (array_type->id == TypeTableEntryIdArray) {
if (want_debug_safety(g, source_node)) {
LLVMValueRef end = LLVMConstInt(g->builtin_types.entry_isize->type_ref,
array_type->data.array.len, false);
add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, end);
}
LLVMValueRef indices[] = {
LLVMConstNull(g->builtin_types.entry_isize->type_ref),
subscript_value
@@ -887,6 +917,15 @@ static LLVMValueRef gen_array_elem_ptr(CodeGen *g, AstNode *source_node, LLVMVal
assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind);
assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind);
if (want_debug_safety(g, source_node)) {
add_debug_source_node(g, source_node);
int len_index = array_type->data.structure.fields[1].gen_index;
assert(len_index >= 0);
LLVMValueRef len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, "");
LLVMValueRef len = LLVMBuildLoad(g->builder, len_ptr, "");
add_bounds_check(g, source_node, subscript_value, LLVMIntEQ, nullptr, LLVMIntULT, len);
}
add_debug_source_node(g, source_node);
int ptr_index = array_type->data.structure.fields[0].gen_index;
assert(ptr_index >= 0);
@@ -907,7 +946,6 @@ static LLVMValueRef gen_array_ptr(CodeGen *g, AstNode *node) {
LLVMValueRef array_ptr = gen_array_base_ptr(g, array_expr_node);
LLVMValueRef subscript_value = gen_expr(g, node->data.array_access_expr.subscript);
return gen_array_elem_ptr(g, node, array_ptr, array_type, subscript_value);
}
@@ -969,6 +1007,15 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
end_val = LLVMConstInt(g->builtin_types.entry_isize->type_ref, array_type->data.array.len, false);
}
if (want_debug_safety(g, node)) {
add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
if (node->data.slice_expr.end) {
LLVMValueRef array_end = LLVMConstInt(g->builtin_types.entry_isize->type_ref,
array_type->data.array.len, false);
add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, array_end);
}
}
add_debug_source_node(g, node);
LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, "");
LLVMValueRef indices[] = {
@@ -987,6 +1034,10 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start);
LLVMValueRef end_val = gen_expr(g, node->data.slice_expr.end);
if (want_debug_safety(g, node)) {
add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
}
add_debug_source_node(g, node);
LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, tmp_struct_ptr, 0, "");
LLVMValueRef slice_start_ptr = LLVMBuildInBoundsGEP(g->builder, array_ptr, &start_val, 1, "");
@@ -1002,22 +1053,33 @@ static LLVMValueRef gen_slice_expr(CodeGen *g, AstNode *node) {
assert(LLVMGetTypeKind(LLVMTypeOf(array_ptr)) == LLVMPointerTypeKind);
assert(LLVMGetTypeKind(LLVMGetElementType(LLVMTypeOf(array_ptr))) == LLVMStructTypeKind);
int ptr_index = array_type->data.structure.fields[0].gen_index;
assert(ptr_index >= 0);
int len_index = array_type->data.structure.fields[1].gen_index;
assert(len_index >= 0);
LLVMValueRef prev_end = nullptr;
if (!node->data.slice_expr.end || want_debug_safety(g, node)) {
add_debug_source_node(g, node);
LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, "");
prev_end = LLVMBuildLoad(g->builder, src_len_ptr, "");
}
LLVMValueRef start_val = gen_expr(g, node->data.slice_expr.start);
LLVMValueRef end_val;
if (node->data.slice_expr.end) {
end_val = gen_expr(g, node->data.slice_expr.end);
} else {
add_debug_source_node(g, node);
int len_index = array_type->data.structure.fields[1].gen_index;
assert(len_index >= 0);
LLVMValueRef src_len_ptr = LLVMBuildStructGEP(g->builder, array_ptr, len_index, "");
end_val = LLVMBuildLoad(g->builder, src_len_ptr, "");
end_val = prev_end;
}
int ptr_index = array_type->data.structure.fields[0].gen_index;
assert(ptr_index >= 0);
int len_index = array_type->data.structure.fields[1].gen_index;
assert(len_index >= 0);
if (want_debug_safety(g, node)) {
assert(prev_end);
add_bounds_check(g, node, start_val, LLVMIntEQ, nullptr, LLVMIntULE, end_val);
if (node->data.slice_expr.end) {
add_bounds_check(g, node, end_val, LLVMIntEQ, nullptr, LLVMIntULE, prev_end);
}
}
add_debug_source_node(g, node);
LLVMValueRef src_ptr_ptr = LLVMBuildStructGEP(g->builder, array_ptr, ptr_index, "");
@@ -1225,7 +1287,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
assert(expr_type->id == TypeTableEntryIdErrorUnion);
TypeTableEntry *child_type = expr_type->data.error.child_type;
if (!g->is_release_build) {
if (want_debug_safety(g, node)) {
LLVMValueRef err_val;
if (type_has_bits(child_type)) {
add_debug_source_node(g, node);
@@ -1263,7 +1325,7 @@ static LLVMValueRef gen_prefix_op_expr(CodeGen *g, AstNode *node) {
assert(expr_type->id == TypeTableEntryIdMaybe);
TypeTableEntry *child_type = expr_type->data.maybe.child_type;
if (!g->is_release_build) {
if (want_debug_safety(g, node)) {
add_debug_source_node(g, node);
LLVMValueRef cond_val;
if (child_type->id == TypeTableEntryIdPointer ||
@@ -2261,7 +2323,7 @@ static LLVMValueRef gen_container_init_expr(CodeGen *g, AstNode *node) {
} else if (type_entry->id == TypeTableEntryIdUnreachable) {
assert(node->data.container_init_expr.entries.length == 0);
add_debug_source_node(g, node);
if (!g->is_release_build) {
if (want_debug_safety(g, node)) {
LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, "");
}
LLVMBuildUnreachable(g->builder);
@@ -2575,7 +2637,7 @@ static LLVMValueRef gen_var_decl_raw(CodeGen *g, AstNode *source_node, AstNodeVa
}
}
}
if (!ignore_uninit && !g->is_release_build) {
if (!ignore_uninit && want_debug_safety(g, source_node)) {
TypeTableEntry *isize = g->builtin_types.entry_isize;
uint64_t size_bytes = LLVMStoreSizeOfType(g->target_data_ref, variable->type->type_ref);
uint64_t align_bytes = get_memcpy_align(g, variable->type);
@@ -2790,7 +2852,7 @@ static LLVMValueRef gen_switch_expr(CodeGen *g, AstNode *node) {
if (!else_prong) {
LLVMPositionBuilderAtEnd(g->builder, else_block);
add_debug_source_node(g, node);
if (!g->is_release_build) {
if (want_debug_safety(g, node)) {
LLVMBuildCall(g->builder, g->trap_fn_val, nullptr, 0, "");
}
LLVMBuildUnreachable(g->builder);
@@ -3383,6 +3445,10 @@ static void do_code_gen(CodeGen *g) {
// Generate the list of test function pointers.
if (g->is_test_build) {
if (g->test_fn_count == 0) {
fprintf(stderr, "No tests to run.\n");
exit(0);
}
assert(g->test_fn_count > 0);
assert(next_test_index == g->test_fn_count);