flatten expression ast to hide operator precedence

This commit is contained in:
Josh Wolfe
2015-11-29 14:46:05 -07:00
parent 4466a4533c
commit 9a014b52cc
3 changed files with 230 additions and 421 deletions

View File

@@ -305,17 +305,9 @@ static void find_declarations(CodeGen *g, AstNode *node) {
case NodeTypeReturnExpr:
case NodeTypeRoot:
case NodeTypeBlock:
case NodeTypeBoolOrExpr:
case NodeTypeBinOpExpr:
case NodeTypeFnCall:
case NodeTypeRootExportDecl:
case NodeTypeBoolAndExpr:
case NodeTypeComparisonExpr:
case NodeTypeBinOrExpr:
case NodeTypeBinXorExpr:
case NodeTypeBinAndExpr:
case NodeTypeBitShiftExpr:
case NodeTypeAddExpr:
case NodeTypeMultExpr:
case NodeTypeCastExpr:
case NodeTypePrimaryExpr:
case NodeTypeGroupedExpr:
@@ -481,10 +473,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
analyze_node(g, node->data.return_expr.expr);
}
break;
case NodeTypeBoolOrExpr:
analyze_node(g, node->data.bool_or_expr.op1);
if (node->data.bool_or_expr.op2)
analyze_node(g, node->data.bool_or_expr.op2);
case NodeTypeBinOpExpr:
analyze_node(g, node->data.bin_op_expr.op1);
analyze_node(g, node->data.bin_op_expr.op2);
break;
case NodeTypeFnCall:
{
@@ -515,30 +506,6 @@ static void analyze_node(CodeGen *g, AstNode *node) {
case NodeTypeDirective:
// we looked at directives in the parent node
break;
case NodeTypeBoolAndExpr:
zig_panic("TODO");
break;
case NodeTypeComparisonExpr:
zig_panic("TODO");
break;
case NodeTypeBinOrExpr:
zig_panic("TODO");
break;
case NodeTypeBinXorExpr:
zig_panic("TODO");
break;
case NodeTypeBinAndExpr:
zig_panic("TODO");
break;
case NodeTypeBitShiftExpr:
zig_panic("TODO");
break;
case NodeTypeAddExpr:
zig_panic("TODO");
break;
case NodeTypeMultExpr:
zig_panic("TODO");
break;
case NodeTypeCastExpr:
zig_panic("TODO");
break;
@@ -752,168 +719,138 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) {
}
static LLVMValueRef gen_mult_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeMultExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_cast_expr(g, node->data.mult_expr.op1);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
if (!node->data.mult_expr.op2)
return val1;
LLVMValueRef val2 = gen_cast_expr(g, node->data.mult_expr.op2);
switch (node->data.mult_expr.mult_op) {
case MultOpMult:
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeMult:
// TODO types so we know float vs int
add_debug_source_node(g, node);
return LLVMBuildMul(g->builder, val1, val2, "");
case MultOpDiv:
case BinOpTypeDiv:
// TODO types so we know float vs int and signed vs unsigned
add_debug_source_node(g, node);
return LLVMBuildSDiv(g->builder, val1, val2, "");
case MultOpMod:
case BinOpTypeMod:
// TODO types so we know float vs int and signed vs unsigned
add_debug_source_node(g, node);
return LLVMBuildSRem(g->builder, val1, val2, "");
case MultOpInvalid:
default:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_add_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeAddExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_mult_expr(g, node->data.add_expr.op1);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
if (!node->data.add_expr.op2)
return val1;
LLVMValueRef val2 = gen_mult_expr(g, node->data.add_expr.op2);
switch (node->data.add_expr.add_op) {
case AddOpAdd:
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeAdd:
add_debug_source_node(g, node);
return LLVMBuildAdd(g->builder, val1, val2, "");
case AddOpSub:
case BinOpTypeSub:
add_debug_source_node(g, node);
return LLVMBuildSub(g->builder, val1, val2, "");
case AddOpInvalid:
default:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_bit_shift_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBitShiftExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_add_expr(g, node->data.bit_shift_expr.op1);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
if (!node->data.bit_shift_expr.op2)
return val1;
LLVMValueRef val2 = gen_add_expr(g, node->data.bit_shift_expr.op2);
switch (node->data.bit_shift_expr.bit_shift_op) {
case BitShiftOpLeft:
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeBitShiftLeft:
add_debug_source_node(g, node);
return LLVMBuildShl(g->builder, val1, val2, "");
case BitShiftOpRight:
case BinOpTypeBitShiftRight:
// TODO implement type system so that we know whether to do
// logical or arithmetic shifting here.
// signed -> arithmetic, unsigned -> logical
add_debug_source_node(g, node);
return LLVMBuildLShr(g->builder, val1, val2, "");
case BitShiftOpInvalid:
default:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_bin_and_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinAndExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_bit_shift_expr(g, node->data.bin_and_expr.op1);
if (!node->data.bin_and_expr.op2)
return val1;
LLVMValueRef val2 = gen_bit_shift_expr(g, node->data.bin_and_expr.op2);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildAnd(g->builder, val1, val2, "");
}
static LLVMValueRef gen_bin_xor_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinXorExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_bin_and_expr(g, node->data.bin_xor_expr.op1);
if (!node->data.bin_xor_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_and_expr(g, node->data.bin_xor_expr.op2);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildXor(g->builder, val1, val2, "");
}
static LLVMValueRef gen_bin_or_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBinOrExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_bin_xor_expr(g, node->data.bin_or_expr.op1);
if (!node->data.bin_or_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_xor_expr(g, node->data.bin_or_expr.op2);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
return LLVMBuildOr(g->builder, val1, val2, "");
}
static LLVMIntPredicate cmp_op_to_int_predicate(CmpOp cmp_op, bool is_signed) {
static LLVMIntPredicate cmp_op_to_int_predicate(BinOpType cmp_op, bool is_signed) {
switch (cmp_op) {
case CmpOpInvalid:
case BinOpTypeInvalid:
zig_unreachable();
case CmpOpEq:
case BinOpTypeCmpEq:
return LLVMIntEQ;
case CmpOpNotEq:
case BinOpTypeCmpNotEq:
return LLVMIntNE;
case CmpOpLessThan:
case BinOpTypeCmpLessThan:
return is_signed ? LLVMIntSLT : LLVMIntULT;
case CmpOpGreaterThan:
case BinOpTypeCmpGreaterThan:
return is_signed ? LLVMIntSGT : LLVMIntUGT;
case CmpOpLessOrEq:
case BinOpTypeCmpLessOrEq:
return is_signed ? LLVMIntSLE : LLVMIntULE;
case CmpOpGreaterOrEq:
case BinOpTypeCmpGreaterOrEq:
return is_signed ? LLVMIntSGE : LLVMIntUGE;
default:
zig_unreachable();
}
zig_unreachable();
}
static LLVMValueRef gen_cmp_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeComparisonExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_bin_or_expr(g, node->data.comparison_expr.op1);
if (!node->data.comparison_expr.op2)
return val1;
LLVMValueRef val2 = gen_bin_or_expr(g, node->data.comparison_expr.op2);
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
// TODO implement type system so that we know whether to do signed or unsigned comparison here
LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.comparison_expr.cmp_op, true);
LLVMIntPredicate pred = cmp_op_to_int_predicate(node->data.bin_op_expr.bin_op, true);
add_debug_source_node(g, node);
return LLVMBuildICmp(g->builder, pred, val1, val2, "");
}
static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeBoolAndExpr);
assert(node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_cmp_expr(g, node->data.bool_and_expr.op1);
if (!node->data.bool_and_expr.op2)
return val1;
LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1);
// block for when val1 == true
LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn, "BoolAndTrue");
@@ -926,7 +863,7 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
LLVMPositionBuilderAtEnd(g->builder, true_block);
LLVMValueRef val2 = gen_cmp_expr(g, node->data.bool_and_expr.op2);
LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2);
add_debug_source_node(g, node);
LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
@@ -942,12 +879,9 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) {
}
static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
assert(expr_node->type == NodeTypeBoolOrExpr);
assert(expr_node->type == NodeTypeBinOpExpr);
LLVMValueRef val1 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op1);
if (!expr_node->data.bool_or_expr.op2)
return val1;
LLVMValueRef val1 = gen_expr(g, expr_node->data.bin_op_expr.op1);
// block for when val1 == false
LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn, "BoolOrFalse");
@@ -960,7 +894,7 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
LLVMBuildCondBr(g->builder, val1_i1, false_block, true_block);
LLVMPositionBuilderAtEnd(g->builder, false_block);
LLVMValueRef val2 = gen_bool_and_expr(g, expr_node->data.bool_or_expr.op2);
LLVMValueRef val2 = gen_expr(g, expr_node->data.bin_op_expr.op2);
add_debug_source_node(g, expr_node);
LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, "");
@@ -975,6 +909,41 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) {
return phi;
}
static LLVMValueRef gen_bin_op_expr(CodeGen *g, AstNode *node) {
switch (node->data.bin_op_expr.bin_op) {
case BinOpTypeInvalid:
zig_unreachable();
case BinOpTypeBoolOr:
return gen_bool_or_expr(g, node);
case BinOpTypeBoolAnd:
return gen_bool_and_expr(g, node);
case BinOpTypeCmpEq:
case BinOpTypeCmpNotEq:
case BinOpTypeCmpLessThan:
case BinOpTypeCmpGreaterThan:
case BinOpTypeCmpLessOrEq:
case BinOpTypeCmpGreaterOrEq:
return gen_cmp_expr(g, node);
case BinOpTypeBinOr:
return gen_bin_or_expr(g, node);
case BinOpTypeBinXor:
return gen_bin_xor_expr(g, node);
case BinOpTypeBinAnd:
return gen_bin_and_expr(g, node);
case BinOpTypeBitShiftLeft:
case BinOpTypeBitShiftRight:
return gen_bit_shift_expr(g, node);
case BinOpTypeAdd:
case BinOpTypeSub:
return gen_add_expr(g, node);
case BinOpTypeMult:
case BinOpTypeDiv:
case BinOpTypeMod:
return gen_mult_expr(g, node);
}
zig_unreachable();
}
static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) {
assert(node->type == NodeTypeReturnExpr);
AstNode *param_node = node->data.return_expr.expr;
@@ -993,10 +962,12 @@ Expression : BoolOrExpression | ReturnExpression
*/
static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeBoolOrExpr:
return gen_bool_or_expr(g, node);
case NodeTypeBinOpExpr:
return gen_bin_op_expr(g, node);
case NodeTypeReturnExpr:
return gen_return_expr(g, node);
case NodeTypeCastExpr:
return gen_cast_expr(g, node);
case NodeTypeRoot:
case NodeTypeRootExportDecl:
case NodeTypeFnProto:
@@ -1008,15 +979,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) {
case NodeTypeFnCall:
case NodeTypeExternBlock:
case NodeTypeDirective:
case NodeTypeBoolAndExpr:
case NodeTypeComparisonExpr:
case NodeTypeBinOrExpr:
case NodeTypeBinXorExpr:
case NodeTypeBinAndExpr:
case NodeTypeBitShiftExpr:
case NodeTypeAddExpr:
case NodeTypeMultExpr:
case NodeTypeCastExpr:
case NodeTypePrimaryExpr:
return gen_primary_expr(g, node);
case NodeTypeGroupedExpr: