diff --git a/src/analyze.cpp b/src/analyze.cpp index 29729070af..3bfa599061 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -1944,8 +1944,11 @@ static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import, if (resolved_type->id != TypeTableEntryIdInvalid) { assert(resolved_type->id == TypeTableEntryIdBool); bool constant_cond_value = number_literal.data.x_uint; - if (constant_cond_value && !node->codegen_node->data.while_node.contains_break) { - expr_return_type = g->builtin_types.entry_unreachable; + if (constant_cond_value) { + node->codegen_node->data.while_node.condition_always_true = true; + if (!node->codegen_node->data.while_node.contains_break) { + expr_return_type = g->builtin_types.entry_unreachable; + } } } } @@ -2085,13 +2088,74 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry builtin_fn->param_count, actual_param_count)); } - for (int i = 0; i < actual_param_count; i += 1) { - AstNode *child = node->data.fn_call_expr.params.at(i); - TypeTableEntry *expected_param_type = builtin_fn->param_types[i]; - analyze_expression(g, import, context, expected_param_type, child); - } + switch (builtin_fn->id) { + case BuiltinFnIdInvalid: + zig_unreachable(); + case BuiltinFnIdArithmeticWithOverflow: + for (int i = 0; i < actual_param_count; i += 1) { + AstNode *child = node->data.fn_call_expr.params.at(i); + TypeTableEntry *expected_param_type = builtin_fn->param_types[i]; + analyze_expression(g, import, context, expected_param_type, child); + } + return builtin_fn->return_type; + case BuiltinFnIdMemcpy: + { + AstNode *dest_node = node->data.fn_call_expr.params.at(0); + AstNode *src_node = node->data.fn_call_expr.params.at(1); + AstNode *len_node = node->data.fn_call_expr.params.at(2); + TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node); + TypeTableEntry *src_type = analyze_expression(g, import, context, nullptr, src_node); + analyze_expression(g, import, context, builtin_fn->param_types[2], len_node); - return builtin_fn->return_type; + if (dest_type->id != TypeTableEntryIdInvalid && + dest_type->id != TypeTableEntryIdPointer) + { + add_node_error(g, dest_node, + buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name))); + } + + if (src_type->id != TypeTableEntryIdInvalid && + src_type->id != TypeTableEntryIdPointer) + { + add_node_error(g, src_node, + buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&src_type->name))); + } + + if (dest_type->id == TypeTableEntryIdPointer && + src_type->id == TypeTableEntryIdPointer) + { + uint64_t dest_align_bits = dest_type->data.pointer.child_type->align_in_bits; + uint64_t src_align_bits = src_type->data.pointer.child_type->align_in_bits; + if (dest_align_bits != src_align_bits) { + add_node_error(g, dest_node, buf_sprintf( + "misaligned memcpy, '%s' has alignment '%" PRIu64 ", '%s' has alignment %" PRIu64, + buf_ptr(&dest_type->name), dest_align_bits / 8, + buf_ptr(&src_type->name), src_align_bits / 8)); + } + } + + return builtin_fn->return_type; + } + case BuiltinFnIdMemset: + { + AstNode *dest_node = node->data.fn_call_expr.params.at(0); + AstNode *char_node = node->data.fn_call_expr.params.at(1); + AstNode *len_node = node->data.fn_call_expr.params.at(2); + TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node); + analyze_expression(g, import, context, builtin_fn->param_types[1], char_node); + analyze_expression(g, import, context, builtin_fn->param_types[2], len_node); + + if (dest_type->id != TypeTableEntryIdInvalid && + dest_type->id != TypeTableEntryIdPointer) + { + add_node_error(g, dest_node, + buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name))); + } + + return builtin_fn->return_type; + } + } + zig_unreachable(); } else { add_node_error(g, node, buf_sprintf("invalid builtin function: '%s'", buf_ptr(name))); diff --git a/src/analyze.hpp b/src/analyze.hpp index 37d1a49d3f..8e86ef00af 100644 --- a/src/analyze.hpp +++ b/src/analyze.hpp @@ -151,6 +151,8 @@ struct FnTableEntry { enum BuiltinFnId { BuiltinFnIdInvalid, BuiltinFnIdArithmeticWithOverflow, + BuiltinFnIdMemcpy, + BuiltinFnIdMemset, }; struct BuiltinFnEntry { @@ -354,6 +356,7 @@ struct ImportNode { }; struct WhileNode { + bool condition_always_true; bool contains_break; }; diff --git a/src/codegen.cpp b/src/codegen.cpp index 6c7aaca341..66ab62bfca 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -171,6 +171,67 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) { return overflow_bit; } + case BuiltinFnIdMemcpy: + { + int fn_call_param_count = node->data.fn_call_expr.params.length; + assert(fn_call_param_count == 3); + + AstNode *dest_node = node->data.fn_call_expr.params.at(0); + TypeTableEntry *dest_type = get_expr_type(dest_node); + + LLVMValueRef dest_ptr = gen_expr(g, dest_node); + LLVMValueRef src_ptr = gen_expr(g, node->data.fn_call_expr.params.at(1)); + LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2)); + + LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0); + + add_debug_source_node(g, node); + LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, ""); + LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, src_ptr, ptr_u8, ""); + + uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8; + + LLVMValueRef params[] = { + dest_ptr_casted, // dest pointer + src_ptr_casted, // source pointer + len_val, // byte count + LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes + LLVMConstNull(LLVMInt1Type()), // is volatile + }; + + LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, ""); + return nullptr; + } + case BuiltinFnIdMemset: + { + int fn_call_param_count = node->data.fn_call_expr.params.length; + assert(fn_call_param_count == 3); + + AstNode *dest_node = node->data.fn_call_expr.params.at(0); + TypeTableEntry *dest_type = get_expr_type(dest_node); + + LLVMValueRef dest_ptr = gen_expr(g, dest_node); + LLVMValueRef char_val = gen_expr(g, node->data.fn_call_expr.params.at(1)); + LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2)); + + LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0); + + add_debug_source_node(g, node); + LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, ""); + + uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8; + + LLVMValueRef params[] = { + dest_ptr_casted, // dest pointer + char_val, // source pointer + len_val, // byte count + LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes + LLVMConstNull(LLVMInt1Type()), // is volatile + }; + + LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, ""); + return nullptr; + } } zig_unreachable(); } @@ -1376,23 +1437,35 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) { assert(node->data.while_expr.condition); assert(node->data.while_expr.body); - if (get_expr_type(node)->id == TypeTableEntryIdUnreachable) { - // generate a forever loop. guarantees no break statements + bool condition_always_true = node->codegen_node->data.while_node.condition_always_true; + bool contains_break = node->codegen_node->data.while_node.contains_break; + if (condition_always_true) { + // generate a forever loop LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody"); + LLVMBasicBlockRef end_block = nullptr; + if (contains_break) { + end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd"); + } add_debug_source_node(g, node); LLVMBuildBr(g->builder, body_block); LLVMPositionBuilderAtEnd(g->builder, body_block); + g->break_block_stack.append(end_block); g->continue_block_stack.append(body_block); gen_expr(g, node->data.while_expr.body); + g->break_block_stack.pop(); g->continue_block_stack.pop(); if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) { add_debug_source_node(g, node); LLVMBuildBr(g->builder, body_block); } + + if (contains_break) { + LLVMPositionBuilderAtEnd(g->builder, end_block); + } } else { // generate a normal while loop @@ -1755,20 +1828,6 @@ static LLVMAttribute to_llvm_fn_attr(FnAttrId attr_id) { static void do_code_gen(CodeGen *g) { assert(!g->errors.length); - { - LLVMTypeRef param_types[] = { - LLVMPointerType(LLVMInt8Type(), 0), - LLVMPointerType(LLVMInt8Type(), 0), - LLVMIntType(g->pointer_size_bytes * 8), - LLVMInt32Type(), - LLVMInt1Type(), - }; - LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false); - Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8); - g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); - assert(LLVMGetIntrinsicID(g->memcpy_fn_val)); - } - // Generate module level variables for (int i = 0; i < g->global_vars.length; i += 1) { VariableTableEntry *var = g->global_vars.at(i); @@ -2267,6 +2326,57 @@ static void define_builtin_fns(CodeGen *g) { define_builtin_fns_int(g, g->builtin_types.entry_i16); define_builtin_fns_int(g, g->builtin_types.entry_i32); define_builtin_fns_int(g, g->builtin_types.entry_i64); + { + BuiltinFnEntry *builtin_fn = allocate(1); + buf_init_from_str(&builtin_fn->name, "memcpy"); + builtin_fn->id = BuiltinFnIdMemcpy; + builtin_fn->return_type = g->builtin_types.entry_void; + builtin_fn->param_count = 3; + builtin_fn->param_types = allocate(builtin_fn->param_count); + builtin_fn->param_types[0] = nullptr; // manually checked later + builtin_fn->param_types[1] = nullptr; // manually checked later + builtin_fn->param_types[2] = g->builtin_types.entry_usize; + + LLVMTypeRef param_types[] = { + LLVMPointerType(LLVMInt8Type(), 0), + LLVMPointerType(LLVMInt8Type(), 0), + LLVMIntType(g->pointer_size_bytes * 8), + LLVMInt32Type(), + LLVMInt1Type(), + }; + LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false); + Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8); + g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); + builtin_fn->fn_val = g->memcpy_fn_val; + assert(LLVMGetIntrinsicID(g->memcpy_fn_val)); + + g->builtin_fn_table.put(&builtin_fn->name, builtin_fn); + } + { + BuiltinFnEntry *builtin_fn = allocate(1); + buf_init_from_str(&builtin_fn->name, "memset"); + builtin_fn->id = BuiltinFnIdMemset; + builtin_fn->return_type = g->builtin_types.entry_void; + builtin_fn->param_count = 3; + builtin_fn->param_types = allocate(builtin_fn->param_count); + builtin_fn->param_types[0] = nullptr; // manually checked later + builtin_fn->param_types[1] = g->builtin_types.entry_u8; + builtin_fn->param_types[2] = g->builtin_types.entry_usize; + + LLVMTypeRef param_types[] = { + LLVMPointerType(LLVMInt8Type(), 0), + LLVMInt8Type(), + LLVMIntType(g->pointer_size_bytes * 8), + LLVMInt32Type(), + LLVMInt1Type(), + }; + LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false); + Buf *name = buf_sprintf("llvm.memset.p0i8.i%d", g->pointer_size_bytes * 8); + builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type); + assert(LLVMGetIntrinsicID(builtin_fn->fn_val)); + + g->builtin_fn_table.put(&builtin_fn->name, builtin_fn); + } } diff --git a/std/std.zig b/std/std.zig index e1d669624f..d42074a340 100644 --- a/std/std.zig +++ b/std/std.zig @@ -118,13 +118,7 @@ fn buf_print_u64(out_buf: []u8, x: u64) -> usize { const len = buf.len - index; - // TODO memcpy intrinsic - // @memcpy(out_buf, buf, len); - var i: usize = 0; - while (i < len) { - out_buf[i] = buf[index + i]; - i += 1; - } + @memcpy(out_buf.ptr, &buf[index], len); return len; } diff --git a/test/run_tests.cpp b/test/run_tests.cpp index b72e23d23b..168fc3566f 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -973,6 +973,24 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { return 0; } )SOURCE", "OK\n"); + + add_simple_case("memcpy and memset intrinsics", R"SOURCE( +use "std.zig"; +pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 { + var foo : [20]u8; + var bar : [20]u8; + + @memset(foo.ptr, 'A', foo.len); + @memcpy(bar.ptr, foo.ptr, bar.len); + + if (bar[11] != 'A') { + print_str("BAD\n"); + } + + print_str("OK\n"); + return 0; +} + )SOURCE", "OK\n"); }