suport checked arithmetic operations via intrinsics

closes #32
This commit is contained in:
Andrew Kelley
2016-01-08 23:41:40 -07:00
parent 14b9cbd43c
commit b7dd88ad68
10 changed files with 205 additions and 19 deletions

View File

@@ -22,6 +22,7 @@ CodeGen *codegen_create(Buf *root_source_dir) {
g->str_table.init(32);
g->link_table.init(32);
g->import_table.init(32);
g->builtin_fn_table.init(32);
g->build_type = CodeGenBuildTypeDebug;
g->root_source_dir = root_source_dir;
@@ -139,6 +140,41 @@ static TypeTableEntry *get_expr_type(AstNode *node) {
return cast_type ? cast_type : node->codegen_node->expr_node.type_entry;
}
static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
AstNode *fn_ref_expr = node->data.fn_call_expr.fn_ref_expr;
assert(fn_ref_expr->type == NodeTypeSymbol);
BuiltinFnEntry *builtin_fn = node->codegen_node->data.fn_call_node.builtin_fn;
switch (builtin_fn->id) {
case BuiltinFnIdInvalid:
zig_unreachable();
case BuiltinFnIdArithmeticWithOverflow:
{
int fn_call_param_count = node->data.fn_call_expr.params.length;
assert(fn_call_param_count == 3);
LLVMValueRef op1 = gen_expr(g, node->data.fn_call_expr.params.at(0));
LLVMValueRef op2 = gen_expr(g, node->data.fn_call_expr.params.at(1));
LLVMValueRef ptr_result = gen_expr(g, node->data.fn_call_expr.params.at(2));
LLVMValueRef params[] = {
op1,
op2,
};
add_debug_source_node(g, node);
LLVMValueRef result_struct = LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 2, "");
LLVMValueRef result = LLVMBuildExtractValue(g->builder, result_struct, 0, "");
LLVMValueRef overflow_bit = LLVMBuildExtractValue(g->builder, result_struct, 1, "");
LLVMBuildStore(g->builder, result, ptr_result);
return overflow_bit;
}
}
zig_unreachable();
}
static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeFnCallExpr);
@@ -159,7 +195,15 @@ static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) {
zig_unreachable();
}
} else if (fn_ref_expr->type == NodeTypeSymbol) {
Buf *name = hack_get_fn_call_name(g, fn_ref_expr);
if (node->data.fn_call_expr.is_builtin) {
return gen_builtin_fn_call_expr(g, node);
}
// Assume that the expression evaluates to a simple name and return the buf
// TODO after we support function pointers we can make this generic
assert(fn_ref_expr->type == NodeTypeSymbol);
Buf *name = &fn_ref_expr->data.symbol;
struct_type = nullptr;
first_param_expr = nullptr;
fn_table_entry = g->cur_fn->import_entry->fn_table.get(name);
@@ -2167,6 +2211,64 @@ static void define_builtin_types(CodeGen *g) {
}
}
static void define_builtin_fns_int(CodeGen *g, TypeTableEntry *type_entry) {
assert(type_entry->id == TypeTableEntryIdInt);
struct OverflowFn {
const char *bare_name;
const char *signed_name;
const char *unsigned_name;
};
OverflowFn overflow_fns[] = {
{"add", "sadd", "uadd"},
{"sub", "ssub", "usub"},
{"mul", "smul", "umul"},
};
for (int i = 0; i < sizeof(overflow_fns)/sizeof(overflow_fns[0]); i += 1) {
OverflowFn *overflow_fn = &overflow_fns[i];
BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
buf_resize(&builtin_fn->name, 0);
buf_appendf(&builtin_fn->name, "%s_with_overflow_%s", overflow_fn->bare_name, buf_ptr(&type_entry->name));
builtin_fn->id = BuiltinFnIdArithmeticWithOverflow;
builtin_fn->return_type = g->builtin_types.entry_bool;
builtin_fn->param_count = 3;
builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
builtin_fn->param_types[0] = type_entry;
builtin_fn->param_types[1] = type_entry;
builtin_fn->param_types[2] = get_pointer_to_type(g, type_entry, false, false);
const char *signed_str = type_entry->data.integral.is_signed ?
overflow_fn->signed_name : overflow_fn->unsigned_name;
Buf *llvm_name = buf_sprintf("llvm.%s.with.overflow.i%" PRIu64, signed_str, type_entry->size_in_bits);
LLVMTypeRef return_elem_types[] = {
type_entry->type_ref,
LLVMInt1Type(),
};
LLVMTypeRef param_types[] = {
type_entry->type_ref,
type_entry->type_ref,
};
LLVMTypeRef return_struct_type = LLVMStructType(return_elem_types, 2, false);
LLVMTypeRef fn_type = LLVMFunctionType(return_struct_type, param_types, 2, false);
builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(llvm_name), fn_type);
assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
}
}
static void define_builtin_fns(CodeGen *g) {
define_builtin_fns_int(g, g->builtin_types.entry_u8);
define_builtin_fns_int(g, g->builtin_types.entry_u16);
define_builtin_fns_int(g, g->builtin_types.entry_u32);
define_builtin_fns_int(g, g->builtin_types.entry_u64);
define_builtin_fns_int(g, g->builtin_types.entry_i8);
define_builtin_fns_int(g, g->builtin_types.entry_i16);
define_builtin_fns_int(g, g->builtin_types.entry_i32);
define_builtin_fns_int(g, g->builtin_types.entry_i64);
}
static void init(CodeGen *g, Buf *source_path) {
@@ -2228,9 +2330,10 @@ static void init(CodeGen *g, Buf *source_path) {
"", 0, !g->strip_debug_symbols);
// This is for debug stuff that doesn't have a real file.
g->dummy_di_file = nullptr; //LLVMZigCreateFile(g->dbuilder, "", "");
g->dummy_di_file = nullptr;
define_builtin_types(g);
define_builtin_fns(g);
}