parse and codegen for math expressions

This commit is contained in:
Andrew Kelley
2015-11-28 00:40:54 -07:00
parent f6529341a2
commit e5d1f0eea5
7 changed files with 1398 additions and 185 deletions

View File

@@ -77,6 +77,7 @@ struct CodeGen {
ZigList<FnTableEntry *> fn_defs;
Buf *out_name;
OutType out_type;
LLVMValueRef cur_fn;
};
struct TypeNode {
@@ -301,12 +302,23 @@ static void find_declarations(CodeGen *g, AstNode *node) {
// we handled directives in the parent function
break;
case NodeTypeFnDecl:
case NodeTypeStatementReturn:
case NodeTypeReturnExpr:
case NodeTypeRoot:
case NodeTypeBlock:
case NodeTypeExpression:
case NodeTypeBoolOrExpr:
case NodeTypeFnCall:
case NodeTypeRootExportDecl:
case NodeTypeBoolAndExpr:
case NodeTypeComparisonExpr:
case NodeTypeBinOrExpr:
case NodeTypeBinXorExpr:
case NodeTypeBinAndExpr:
case NodeTypeBitShiftExpr:
case NodeTypeAddExpr:
case NodeTypeMultExpr:
case NodeTypeCastExpr:
case NodeTypePrimaryExpr:
case NodeTypeGroupedExpr:
zig_unreachable();
}
}
@@ -341,7 +353,7 @@ static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
bool prev_statement_return = false;
for (int i = 0; i < body_node->data.block.statements.length; i += 1) {
AstNode *statement_node = body_node->data.block.statements.at(i);
if (statement_node->type == NodeTypeStatementReturn) {
if (statement_node->type == NodeTypeReturnExpr) {
if (type_id == TypeIdUnreachable) {
add_node_error(g, statement_node,
buf_sprintf("return statement in function with unreachable return type"));
@@ -464,23 +476,15 @@ static void analyze_node(CodeGen *g, AstNode *node) {
analyze_node(g, child);
}
break;
case NodeTypeStatementReturn:
if (node->data.statement_return.expression) {
analyze_node(g, node->data.statement_return.expression);
case NodeTypeReturnExpr:
if (node->data.return_expr.expr) {
analyze_node(g, node->data.return_expr.expr);
}
break;
case NodeTypeExpression:
switch (node->data.expression.type) {
case AstNodeExpressionTypeNumber:
break;
case AstNodeExpressionTypeString:
break;
case AstNodeExpressionTypeFnCall:
analyze_node(g, node->data.expression.data.fn_call);
break;
case AstNodeExpressionTypeUnreachable:
break;
}
case NodeTypeBoolOrExpr:
analyze_node(g, node->data.bool_or_expr.op1);
if (node->data.bool_or_expr.op2)
analyze_node(g, node->data.bool_or_expr.op2);
break;
case NodeTypeFnCall:
{
@@ -511,6 +515,54 @@ static void analyze_node(CodeGen *g, AstNode *node) {
case NodeTypeDirective:
// we looked at directives in the parent node
break;
case NodeTypeBoolAndExpr:
zig_panic("TODO");
break;
case NodeTypeComparisonExpr:
zig_panic("TODO");
break;
case NodeTypeBinOrExpr:
zig_panic("TODO");
break;
case NodeTypeBinXorExpr:
zig_panic("TODO");
break;
case NodeTypeBinAndExpr:
zig_panic("TODO");
break;
case NodeTypeBitShiftExpr:
zig_panic("TODO");
break;
case NodeTypeAddExpr:
zig_panic("TODO");
break;
case NodeTypeMultExpr:
zig_panic("TODO");
break;
case NodeTypeCastExpr:
zig_panic("TODO");
break;
case NodeTypePrimaryExpr:
switch (node->data.primary_expr.type) {
case PrimaryExprTypeNumber:
case PrimaryExprTypeString:
case PrimaryExprTypeUnreachable:
// nothing to do
break;
case PrimaryExprTypeFnCall:
analyze_node(g, node->data.primary_expr.data.fn_call);
break;
case PrimaryExprTypeGroupedExpr:
analyze_node(g, node->data.primary_expr.data.grouped_expr);
break;
case PrimaryExprTypeBlock:
analyze_node(g, node->data.primary_expr.data.block);
break;
}
break;
case NodeTypeGroupedExpr:
zig_panic("TODO");
break;
}
}
@@ -649,34 +701,326 @@ static LLVMValueRef find_or_create_string(CodeGen *g, Buf *str) {
return global_value;
}
static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node) {
assert(expr_node->type == NodeTypeExpression);
switch (expr_node->data.expression.type) {
case AstNodeExpressionTypeNumber:
static LLVMValueRef gen_primary_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypePrimaryExpr);
AstNodePrimaryExpr *prim_expr = &node->data.primary_expr;
switch (node->data.primary_expr.type) {
case PrimaryExprTypeNumber:
{
Buf *number_str = &expr_node->data.expression.data.number;
Buf *number_str = &prim_expr->data.number;
LLVMTypeRef number_type = LLVMInt32Type();
LLVMValueRef number_val = LLVMConstIntOfStringAndSize(number_type,
buf_ptr(number_str), buf_len(number_str), 10);
return number_val;
}
case AstNodeExpressionTypeString:
case PrimaryExprTypeString:
{
Buf *str = &expr_node->data.expression.data.string;
Buf *str = &prim_expr->data.string;
LLVMValueRef str_val = find_or_create_string(g, str);
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32Type(), 0, false),
LLVMConstInt(LLVMInt32Type(), 0, false)
};
LLVMValueRef ptr_val = LLVMBuildInBoundsGEP(g->builder, str_val,
indices, 2, "");
LLVMValueRef ptr_val = LLVMBuildInBoundsGEP(g->builder, str_val, indices, 2, "");
return ptr_val;
}
case AstNodeExpressionTypeFnCall:
return gen_fn_call(g, expr_node->data.expression.data.fn_call);
case AstNodeExpressionTypeUnreachable:
case PrimaryExprTypeUnreachable:
add_debug_source_node(g, node);
return LLVMBuildUnreachable(g->builder);
case PrimaryExprTypeFnCall:
return gen_fn_call(g, prim_expr->data.fn_call);
case PrimaryExprTypeGroupedExpr:
return gen_expr(g, prim_expr->data.grouped_expr);
case PrimaryExprTypeBlock:
break;
}
zig_unreachable();
}
static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeCastExpr);
LLVMValueRef expr = gen_primary_expr(g, node->data.cast_expr.primary_expr);
if (!node->data.cast_expr.type)
return expr;
zig_panic("TODO cast expression");
}
static LLVMValueRef gen_mult_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeMultExpr);
LLVMValueRef val1 = gen_cast_expr(g, node->data.mult_expr.op1);
if (!node->data.mult_expr.op2)
return val1;
LLVMValueRef val2 = gen_cast_expr(g, node->data.mult_expr.op2);
switch (node->data.mult_expr.mult_op) {
case MultOpMult:
// TODO types so we know float vs int
add_debug_source_node(g, node);
return LLVMBuildMul(g->builder, val1, val2, "");
case MultOpDiv:
// TODO types so we know float vs int and signed vs unsigned
add_debug_source_node(g, node);
return LLVMBuildSDiv(g->builder, val1, val2, "");
case MultOpMod:
// TODO types so we know float vs int and signed vs unsigned
add_debug_source_node(g, node);
return LLVMBuildSRem(g->builder, val1, val2, "");
case MultOpInvalid:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_add_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeAddExpr);
LLVMValueRef val1 = gen_mult_expr(g, node->data.add_expr.op1);
if (!node->data.add_expr.op2)
return val1;
LLVMValueRef val2 = gen_mult_expr(g, node->data.add_expr.op2);
switch (node->data.add_expr.add_op) {
case AddOpAdd:
add_debug_source_node(g, node);
return LLVMBuildAdd(g->builder, val1, val2, "");
case AddOpSub:
add_debug_source_node(g, node);
return LLVMBuildSub(g->builder, val1, val2, "");
case AddOpInvalid:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_bit_shift_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBitShiftExpr);
LLVMValueRef val1 = gen_add_expr(g, node->data.bit_shift_expr.op1);
if (!node->data.bit_shift_expr.op2)
return val1;
LLVMValueRef val2 = gen_add_expr(g, node->data.bit_shift_expr.op2);
switch (node->data.bit_shift_expr.bit_shift_op) {
case BitShiftOpLeft:
add_debug_source_node(g, node);
return LLVMBuildShl(g->builder, val1, val2, "");
case BitShiftOpRight:
// TODO implement type system so that we know whether to do
// logical or arithmetic shifting here.
// signed -> arithmetic, unsigned -> logical
add_debug_source_node(g, node);
return LLVMBuildLShr(g->builder, val1, val2, "");
case BitShiftOpInvalid:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_bin_and_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinAndExpr);
LLVMValueRef val1 = gen_bit_shift_expr(g, node->data.bin_and_expr.op1);
if (!node->data.bin_and_expr.op2)
return val1;
LLVMValueRef val2 = gen_bit_shift_expr(g, node->data.bin_and_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildAnd(g->builder, val1, val2, "");
}
static LLVMValueRef gen_bin_xor_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinXorExpr);
LLVMValueRef val1 = gen_bin_and_expr(g, node->data.bin_xor_expr.op1);
if (!node->data.bin_xor_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_and_expr(g, node->data.bin_xor_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildXor(g->builder, val1, val2, "");
}
static LLVMValueRef gen_bin_or_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinOrExpr);
LLVMValueRef val1 = gen_bin_xor_expr(g, node->data.bin_or_expr.op1);
if (!node->data.bin_or_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_xor_expr(g, node->data.bin_or_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildOr(g->builder, val1, val2, "");
}
static LLVMIntPredicate cmp_op_to_int_predicate(CmpOp cmp_op, bool is_signed) {
switch (cmp_op) {
case CmpOpInvalid:
zig_unreachable();
case CmpOpEq:
return LLVMIntEQ;
case CmpOpNotEq:
return LLVMIntNE;
case CmpOpLessThan:
return is_signed ? LLVMIntSLT : LLVMIntULT;
case CmpOpGreaterThan:
return is_signed ? LLVMIntSGT : LLVMIntUGT;
case CmpOpLessOrEq:
return is_signed ? LLVMIntSLE : LLVMIntULE;
case CmpOpGreaterOrEq:
return is_signed ? LLVMIntSGE : LLVMIntUGE;
}
zig_unreachable();
}
static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeComparisonExpr);
LLVMValueRef val1 = gen_bin_or_expr(g, node->data.comparison_expr.op1);
if (!node->data.comparison_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_or_expr(g, node->data.comparison_expr.op2);
// TODO implement type system so that we know whether to do signed or unsigned comparison here
LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.comparison_expr.cmp_op, true);
add_debug_source_node(g, node);
return LLVMBuildICmp(g->builder, pred, val1, val2, "");
}
static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBoolAndExpr);
LLVMValueRef val1 = gen_cmp_expr(g, node->data.bool_and_expr.op1);
if (!node->data.bool_and_expr.op2)
return val1;
// block for when val1 == true
LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn, "BoolAndTrue");
// block for when val1 == false (don't even evaluate the second part)
LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn, "BoolAndFalse");
LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(val1));
add_debug_source_node(g, node);
LLVMValueRef val1_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, zero, "");
LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
LLVMPositionBuilderAtEnd(g->builder, true_block);
LLVMValueRef val2 = gen_cmp_expr(g, node->data.bool_and_expr.op2);
add_debug_source_node(g, node);
LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
LLVMPositionBuilderAtEnd(g->builder, false_block);
add_debug_source_node(g, node);
LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), "");
LLVMValueRef one_i1 = LLVMConstAllOnes(LLVMInt1Type());
LLVMValueRef incoming_values[2] = {one_i1, val2_i1};
LLVMBasicBlockRef incoming_blocks[2] = {LLVMGetInsertBlock(g->builder), true_block};
LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2);
return phi;
}
static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
assert(expr_node->type == NodeTypeBoolOrExpr);
LLVMValueRef val1 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op1);
if (!expr_node->data.bool_or_expr.op2)
return val1;
// block for when val1 == false
LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn, "BoolOrFalse");
// block for when val1 == true (don't even evaluate the second part)
LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn, "BoolOrTrue");
LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(val1));
add_debug_source_node(g, expr_node);
LLVMValueRef val1_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val1, zero, "");
LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
LLVMPositionBuilderAtEnd(g->builder, false_block);
LLVMValueRef val2 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op2);
add_debug_source_node(g, expr_node);
LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
LLVMPositionBuilderAtEnd(g->builder, true_block);
add_debug_source_node(g, expr_node);
LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), "");
LLVMValueRef one_i1 = LLVMConstAllOnes(LLVMInt1Type());
LLVMValueRef incoming_values[2] = {one_i1, val2_i1};
LLVMBasicBlockRef incoming_blocks[2] = {LLVMGetInsertBlock(g->builder), false_block};
LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2);
return phi;
}
static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeReturnExpr);
AstNode *param_node = node->data.return_expr.expr;
if (param_node) {
LLVMValueRef value = gen_expr(g, param_node);
add_debug_source_node(g, node);
return LLVMBuildRet(g->builder, value);
} else {
add_debug_source_node(g, node);
return LLVMBuildRetVoid(g->builder);
}
}
/*
Expression : BoolOrExpression | ReturnExpression
*/
static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeBoolOrExpr:
return gen_bool_or_expr(g, node);
case NodeTypeReturnExpr:
return gen_return_expr(g, node);
case NodeTypeRoot:
case NodeTypeRootExportDecl:
case NodeTypeFnProto:
case NodeTypeFnDef:
case NodeTypeFnDecl:
case NodeTypeParamDecl:
case NodeTypeType:
case NodeTypeBlock:
case NodeTypeFnCall:
case NodeTypeExternBlock:
case NodeTypeDirective:
case NodeTypeBoolAndExpr:
case NodeTypeComparisonExpr:
case NodeTypeBinOrExpr:
case NodeTypeBinXorExpr:
case NodeTypeBinAndExpr:
case NodeTypeBitShiftExpr:
case NodeTypeAddExpr:
case NodeTypeMultExpr:
case NodeTypeCastExpr:
case NodeTypePrimaryExpr:
return gen_primary_expr(g, node);
case NodeTypeGroupedExpr:
zig_unreachable();
}
zig_unreachable();
}
@@ -692,39 +1036,7 @@ static void gen_block(CodeGen *g, AstNode *block_node, bool add_implicit_return)
for (int i = 0; i < block_node->data.block.statements.length; i += 1) {
AstNode *statement_node = block_node->data.block.statements.at(i);
switch (statement_node->type) {
case NodeTypeStatementReturn:
{
AstNode *expr_node = statement_node->data.statement_return.expression;
if (expr_node) {
LLVMValueRef value = gen_expr(g, expr_node);
add_debug_source_node(g, statement_node);
LLVMBuildRet(g->builder, value);
} else {
add_debug_source_node(g, statement_node);
LLVMBuildRetVoid(g->builder);
}
break;
}
case NodeTypeExpression:
{
gen_expr(g, statement_node);
break;
}
case NodeTypeRoot:
case NodeTypeFnProto:
case NodeTypeFnDef:
case NodeTypeFnDecl:
case NodeTypeParamDecl:
case NodeTypeType:
case NodeTypeBlock:
case NodeTypeFnCall:
case NodeTypeExternBlock:
case NodeTypeDirective:
case NodeTypeRootExportDecl:
zig_unreachable();
}
gen_expr(g, statement_node);
}
if (add_implicit_return) {
@@ -810,6 +1122,7 @@ void code_gen(CodeGen *g) {
FnTableEntry *fn_table_entry = g->fn_defs.at(i);
AstNode *fn_def_node = fn_table_entry->fn_def_node;
LLVMValueRef fn = fn_table_entry->fn_value;
g->cur_fn = fn;
AstNode *proto_node = fn_table_entry->proto_node;
assert(proto_node->type == NodeTypeFnProto);