diff --git a/src/all_types.hpp b/src/all_types.hpp index c386587a68..ea52be51cc 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -1045,6 +1045,13 @@ enum FnAnalState { FnAnalStateSkipped, }; + +enum WantPure { + WantPureAuto, + WantPureFalse, + WantPureTrue, +}; + struct FnTableEntry { LLVMValueRef fn_value; AstNode *proto_node; @@ -1060,6 +1067,7 @@ struct FnTableEntry { bool is_extern; bool is_test; bool is_pure; + WantPure want_pure; bool safety_off; bool is_noinline; BlockContext *parent_block_context; diff --git a/src/analyze.cpp b/src/analyze.cpp index b7f325bde3..5eb1be4381 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1060,7 +1060,7 @@ static void resolve_function_proto(CodeGen *g, AstNode *node, FnTableEntry *fn_t bool ok = resolve_const_expr_bool(g, import, import->block_context, &directive_node->data.directive.expr, &enable); if (!enable || !ok) { - fn_table_entry->is_pure = false; + fn_table_entry->want_pure = WantPureFalse; } // TODO cause compile error if enable is true and impure fn } @@ -5153,7 +5153,7 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, FnTableEntry *fn_table_entry = node->data.fn_call_expr.fn_entry; ConstExprValue *result_val = &get_resolved_expr(node)->const_val; - if (ok_invocation && fn_table_entry && fn_table_entry->is_pure) { + if (ok_invocation && fn_table_entry && fn_table_entry->is_pure && fn_table_entry->want_pure != WantPureFalse) { if (fn_table_entry->anal_state == FnAnalStateReady) { analyze_fn_body(g, fn_table_entry); } @@ -5167,7 +5167,7 @@ static TypeTableEntry *analyze_fn_call_ptr(CodeGen *g, ImportTableEntry *import, } } } - if (!ok_invocation || !fn_table_entry || !fn_table_entry->is_pure) { + if (!ok_invocation || !fn_table_entry || !fn_table_entry->is_pure || fn_table_entry->want_pure == WantPureFalse) { // calling an impure fn is impure mark_impure_fn(context); } diff --git a/src/codegen.cpp b/src/codegen.cpp index dc66680ecf..d9a581b570 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3885,6 +3885,7 @@ static void do_code_gen(CodeGen *g) { TypeTableEntry *fn_type = fn_table_entry->type_entry; + bool is_sret = false; if (!type_has_bits(fn_type->data.fn.fn_type_id.return_type)) { // nothing to do } else if (fn_type->data.fn.fn_type_id.return_type->id == TypeTableEntryIdPointer) { @@ -3893,7 +3894,12 @@ static void do_code_gen(CodeGen *g) { LLVMValueRef first_arg = LLVMGetParam(fn_table_entry->fn_value, 0); LLVMAddAttribute(first_arg, LLVMStructRetAttribute); LLVMZigAddNonNullAttr(fn_table_entry->fn_value, 1); + is_sret = true; } + if (fn_table_entry->is_pure && !is_sret) { + LLVMAddFunctionAttr(fn_table_entry->fn_value, LLVMReadOnlyAttribute); + } + // set parameter attributes for (int param_decl_i = 0; param_decl_i < fn_proto->params.length; param_decl_i += 1) { @@ -3914,7 +3920,7 @@ static void do_code_gen(CodeGen *g) { if (param_type->id == TypeTableEntryIdPointer && param_is_noalias) { LLVMAddAttribute(argument_val, LLVMNoAliasAttribute); } - if ((param_type->id == TypeTableEntryIdPointer && param_type->data.pointer.is_const) || + if ((param_type->id == TypeTableEntryIdPointer && (param_type->data.pointer.is_const || fn_table_entry->is_pure)) || is_byval) { LLVMAddAttribute(argument_val, LLVMReadOnlyAttribute); diff --git a/test/run_tests.cpp b/test/run_tests.cpp index c710c0325b..1d7e82e0b0 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1410,8 +1410,10 @@ fn baz(a: i32) {} )SOURCE"); add_debug_safety_case("integer addition overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - add(65530, 10); + const x = add(65530, 10); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn add(a: u16, b: u16) -> u16 { @@ -1420,8 +1422,10 @@ fn add(a: u16, b: u16) -> u16 { )SOURCE"); add_debug_safety_case("integer subtraction overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - sub(10, 20); + const x = sub(10, 20); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn sub(a: u16, b: u16) -> u16 { @@ -1430,8 +1434,10 @@ fn sub(a: u16, b: u16) -> u16 { )SOURCE"); add_debug_safety_case("integer multiplication overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - mul(300, 6000); + const x = mul(300, 6000); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn mul(a: u16, b: u16) -> u16 { @@ -1440,8 +1446,10 @@ fn mul(a: u16, b: u16) -> u16 { )SOURCE"); add_debug_safety_case("integer negation overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - neg(-32768); + const x = neg(-32768); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn neg(a: i16) -> i16 { @@ -1450,8 +1458,10 @@ fn neg(a: i16) -> i16 { )SOURCE"); add_debug_safety_case("signed shift left overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - shl(-16385, 1); + const x = shl(-16385, 1); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn shl(a: i16, b: i16) -> i16 { @@ -1460,8 +1470,10 @@ fn shl(a: i16, b: i16) -> i16 { )SOURCE"); add_debug_safety_case("unsigned shift left overflow", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - shl(0b0010111111111111, 3); + const x = shl(0b0010111111111111, 3); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn shl(a: u16, b: u16) -> u16 { @@ -1470,8 +1482,9 @@ fn shl(a: u16, b: u16) -> u16 { )SOURCE"); add_debug_safety_case("integer division by zero", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - div0(999, 0); + const x = div0(999, 0); } #static_eval_enable(false) fn div0(a: i32, b: i32) -> i32 { @@ -1480,8 +1493,10 @@ fn div0(a: i32, b: i32) -> i32 { )SOURCE"); add_debug_safety_case("exact division failure", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - div_exact(10, 3); + const x = div_exact(10, 3); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn div_exact(a: i32, b: i32) -> i32 { @@ -1490,8 +1505,10 @@ fn div_exact(a: i32, b: i32) -> i32 { )SOURCE"); add_debug_safety_case("cast []u8 to bigger slice of wrong size", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - widen_slice([]u8{1, 2, 3, 4, 5}); + const x = widen_slice([]u8{1, 2, 3, 4, 5}); + if (x.len == 0) return error.Whatever; } #static_eval_enable(false) fn widen_slice(slice: []u8) -> []i32 { @@ -1500,8 +1517,10 @@ fn widen_slice(slice: []u8) -> []i32 { )SOURCE"); add_debug_safety_case("value does not fit in shortening cast", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - shorten_cast(200); + const x = shorten_cast(200); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn shorten_cast(x: i32) -> i8 { @@ -1510,8 +1529,10 @@ fn shorten_cast(x: i32) -> i8 { )SOURCE"); add_debug_safety_case("signed integer not fitting in cast to unsigned integer", R"SOURCE( +error Whatever; pub fn main(args: [][]u8) -> %void { - unsigned_cast(-10); + const x = unsigned_cast(-10); + if (x == 0) return error.Whatever; } #static_eval_enable(false) fn unsigned_cast(x: i32) -> u32 {