From 08a2311efd8b388cd431feb6000741f4a62da613 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Tue, 1 Dec 2015 21:19:38 -0700 Subject: [PATCH] support if conditionals --- README.md | 22 ++++++- doc/vim/syntax/zig.vim | 4 +- src/analyze.cpp | 112 ++++++++++++++++++++++++++++++-- src/codegen.cpp | 143 ++++++++++++++++++++++++++++++++--------- src/parser.cpp | 137 +++++++++++++++++++++++++++++++++------ src/parser.hpp | 8 +++ src/semantic_info.hpp | 6 ++ src/tokenizer.cpp | 6 ++ src/tokenizer.hpp | 2 + test/run_tests.cpp | 24 +++++++ 10 files changed, 400 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index 25d9ec10d6..7cfc9e7e41 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,11 @@ make * variable declarations and assignment expressions * Type checking + * loops + * labels and goto * inline assembly and syscalls + * conditional compilation and ability to check target platform and architecture + * main function with command line arguments * running code at compile time * print! macro that takes var args * panic! macro that prints a stack trace to stderr in debug mode and calls @@ -104,14 +108,26 @@ Type : token(Symbol) | PointerType | token(Unreachable) PointerType : token(Star) token(Const) Type | token(Star) token(Mut) Type -Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace) +Block : token(LBrace) list(option(Statement), token(Semicolon)) token(RBrace) -Expression : BoolOrExpression | ReturnExpression +Statement : NonBlockExpression token(Semicolon) | BlockExpression + +Expression : BlockExpression | NonBlockExpression + +NonBlockExpression : BoolOrExpression | ReturnExpression + +BlockExpression : IfExpression | Block BoolOrExpression : BoolAndExpression token(BoolOr) BoolAndExpression | BoolAndExpression ReturnExpression : token(Return) option(Expression) +IfExpression : token(If) Expression Block option(Else | ElseIf) + +ElseIf : token(Else) IfExpression + +Else : token(Else) Block + BoolAndExpression : ComparisonExpression token(BoolAnd) ComparisonExpression | ComparisonExpression ComparisonExpression : BinaryOrExpression ComparisonOperator BinaryOrExpression | BinaryOrExpression @@ -144,7 +160,7 @@ FnCallExpression : PrimaryExpression token(LParen) list(Expression, token(Comma) PrefixOp : token(Not) | token(Dash) | token(Tilde) -PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | Block | token(Symbol) +PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | token(Symbol) GroupedExpression : token(LParen) Expression token(RParen) ``` diff --git a/doc/vim/syntax/zig.vim b/doc/vim/syntax/zig.vim index 1527cb8049..3804421849 100644 --- a/doc/vim/syntax/zig.vim +++ b/doc/vim/syntax/zig.vim @@ -7,8 +7,8 @@ if exists("b:current_syntax") finish endif -syn keyword zigKeyword fn return mut const extern unreachable export pub as use -syn keyword zigType bool i8 u8 i16 u16 i32 u32 i64 u64 isize usize f32 f64 f128 void +syn keyword zigKeyword fn return mut const extern unreachable export pub as use if else let void +syn keyword zigType bool i8 u8 i16 u16 i32 u32 i64 u64 isize usize f32 f64 f128 syn region zigCommentLine start="//" end="$" contains=zigTodo,@Spell syn region zigCommentLineDoc start="//\%(//\@!\|!\)" end="$" contains=zigTodo,@Spell diff --git a/src/analyze.cpp b/src/analyze.cpp index 9ecb4eb8d3..c70b2d6695 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -274,6 +274,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, case NodeTypeSymbol: case NodeTypeCastExpr: case NodeTypePrefixOpExpr: + case NodeTypeIfExpr: zig_unreachable(); } } @@ -302,7 +303,9 @@ static void check_type_compatibility(CodeGen *g, AstNode *node, TypeTableEntry * add_node_error(g, node, buf_sprintf("type mismatch. expected %s. got %s", buf_ptr(&expected_type->name), buf_ptr(&actual_type->name))); } -static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context, TypeTableEntry *expected_type, AstNode *node) { +static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, BlockContext *context, + TypeTableEntry *expected_type, AstNode *node) +{ TypeTableEntry *return_type = nullptr; switch (node->type) { case NodeTypeBlock: @@ -348,10 +351,64 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeBinOpExpr: { - // TODO: think about expected types - analyze_expression(g, import, context, expected_type, node->data.bin_op_expr.op1); - analyze_expression(g, import, context, expected_type, node->data.bin_op_expr.op2); - return_type = expected_type; + switch (node->data.bin_op_expr.bin_op) { + case BinOpTypeBoolOr: + case BinOpTypeBoolAnd: + analyze_expression(g, import, context, g->builtin_types.entry_bool, + node->data.bin_op_expr.op1); + analyze_expression(g, import, context, g->builtin_types.entry_bool, + node->data.bin_op_expr.op2); + return_type = g->builtin_types.entry_bool; + break; + case BinOpTypeCmpEq: + case BinOpTypeCmpNotEq: + case BinOpTypeCmpLessThan: + case BinOpTypeCmpGreaterThan: + case BinOpTypeCmpLessOrEq: + case BinOpTypeCmpGreaterOrEq: + // TODO think how should type checking for these work? + analyze_expression(g, import, context, g->builtin_types.entry_i32, + node->data.bin_op_expr.op1); + analyze_expression(g, import, context, g->builtin_types.entry_i32, + node->data.bin_op_expr.op2); + return_type = g->builtin_types.entry_bool; + break; + case BinOpTypeBinOr: + zig_panic("TODO bin or type"); + break; + case BinOpTypeBinXor: + zig_panic("TODO bin xor type"); + break; + case BinOpTypeBinAnd: + zig_panic("TODO bin and type"); + break; + case BinOpTypeBitShiftLeft: + zig_panic("TODO bit shift left type"); + break; + case BinOpTypeBitShiftRight: + zig_panic("TODO bit shift right type"); + break; + case BinOpTypeAdd: + case BinOpTypeSub: + // TODO think how should type checking for these work? + analyze_expression(g, import, context, g->builtin_types.entry_i32, + node->data.bin_op_expr.op1); + analyze_expression(g, import, context, g->builtin_types.entry_i32, + node->data.bin_op_expr.op2); + return_type = g->builtin_types.entry_i32; + break; + case BinOpTypeMult: + zig_panic("TODO mult type"); + break; + case BinOpTypeDiv: + zig_panic("TODO div type"); + break; + case BinOpTypeMod: + zig_panic("TODO modulus type"); + break; + case BinOpTypeInvalid: + zig_unreachable(); + } break; } @@ -426,11 +483,46 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, case NodeTypeSymbol: // look up symbol in symbol table - zig_panic("TODO"); + zig_panic("TODO analyze_expression symbol"); case NodeTypeCastExpr: + zig_panic("TODO analyze_expression cast expr"); + break; + case NodeTypePrefixOpExpr: - zig_panic("TODO"); + switch (node->data.prefix_op_expr.prefix_op) { + case PrefixOpBoolNot: + analyze_expression(g, import, context, g->builtin_types.entry_bool, + node->data.prefix_op_expr.primary_expr); + return_type = g->builtin_types.entry_bool; + break; + case PrefixOpBinNot: + zig_panic("TODO type check bin not"); + break; + case PrefixOpNegation: + zig_panic("TODO type check negation"); + break; + case PrefixOpInvalid: + zig_unreachable(); + } + break; + case NodeTypeIfExpr: + { + analyze_expression(g, import, context, g->builtin_types.entry_bool, node->data.if_expr.condition); + + TypeTableEntry *else_type; + if (node->data.if_expr.else_node) { + else_type = analyze_expression(g, import, context, expected_type, node->data.if_expr.else_node); + } else { + else_type = g->builtin_types.entry_void; + } + TypeTableEntry *then_type = analyze_expression(g, import, context, expected_type, + node->data.if_expr.then_block); + + check_type_compatibility(g, node, expected_type, else_type); + return_type = then_type; + break; + } case NodeTypeDirective: case NodeTypeFnDecl: case NodeTypeFnProto: @@ -445,6 +537,11 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, } assert(return_type); check_type_compatibility(g, node, expected_type, return_type); + + assert(!node->codegen_node); + node->codegen_node = allocate(1); + node->codegen_node->data.expr_node.type_entry = return_type; + return return_type; } @@ -509,6 +606,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, case NodeTypeSymbol: case NodeTypeCastExpr: case NodeTypePrefixOpExpr: + case NodeTypeIfExpr: zig_unreachable(); } } diff --git a/src/codegen.cpp b/src/codegen.cpp index 18a5c4c5a1..d442856692 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -120,6 +120,10 @@ static LLVMValueRef get_variable_value(CodeGen *g, Buf *name) { zig_unreachable(); } +static TypeTableEntry *get_expr_type(AstNode *node) { + return node->codegen_node->data.expr_node.type_entry; +} + static LLVMValueRef gen_fn_call_expr(CodeGen *g, AstNode *node) { assert(node->type == NodeTypeFnCallExpr); @@ -283,6 +287,7 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { LLVMValueRef val1 = gen_expr(g, node->data.bin_op_expr.op1); + LLVMBasicBlockRef orig_block = LLVMGetInsertBlock(g->builder); // block for when val1 == true LLVMBasicBlockRef true_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoolAndTrue"); // block for when val1 == false (don't even evaluate the second part) @@ -297,13 +302,14 @@ static LLVMValueRef gen_bool_and_expr(CodeGen *g, AstNode *node) { LLVMValueRef val2 = gen_expr(g, node->data.bin_op_expr.op2); add_debug_source_node(g, node); LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); + LLVMBuildBr(g->builder, false_block); LLVMPositionBuilderAtEnd(g->builder, false_block); add_debug_source_node(g, node); LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), ""); LLVMValueRef one_i1 = LLVMConstAllOnes(LLVMInt1Type()); LLVMValueRef incoming_values[2] = {one_i1, val2_i1}; - LLVMBasicBlockRef incoming_blocks[2] = {LLVMGetInsertBlock(g->builder), true_block}; + LLVMBasicBlockRef incoming_blocks[2] = {orig_block, true_block}; LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); return phi; @@ -314,6 +320,8 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { LLVMValueRef val1 = gen_expr(g, expr_node->data.bin_op_expr.op1); + LLVMBasicBlockRef orig_block = LLVMGetInsertBlock(g->builder); + // block for when val1 == false LLVMBasicBlockRef false_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "BoolOrFalse"); // block for when val1 == true (don't even evaluate the second part) @@ -328,13 +336,14 @@ static LLVMValueRef gen_bool_or_expr(CodeGen *g, AstNode *expr_node) { LLVMValueRef val2 = gen_expr(g, expr_node->data.bin_op_expr.op2); add_debug_source_node(g, expr_node); LLVMValueRef val2_i1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, zero, ""); + LLVMBuildBr(g->builder, true_block); LLVMPositionBuilderAtEnd(g->builder, true_block); add_debug_source_node(g, expr_node); LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMInt1Type(), ""); LLVMValueRef one_i1 = LLVMConstAllOnes(LLVMInt1Type()); LLVMValueRef incoming_values[2] = {one_i1, val2_i1}; - LLVMBasicBlockRef incoming_blocks[2] = {LLVMGetInsertBlock(g->builder), false_block}; + LLVMBasicBlockRef incoming_blocks[2] = {orig_block, false_block}; LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); return phi; @@ -383,9 +392,91 @@ static LLVMValueRef gen_return_expr(CodeGen *g, AstNode *node) { return LLVMBuildRetVoid(g->builder); } } -/* -Expression : BoolOrExpression | ReturnExpression -*/ + +static LLVMValueRef gen_if_expr(CodeGen *g, AstNode *node) { + assert(node->type == NodeTypeIfExpr); + assert(node->data.if_expr.condition); + assert(node->data.if_expr.then_block); + + LLVMValueRef cond_value = gen_expr(g, node->data.if_expr.condition); + + TypeTableEntry *then_type = get_expr_type(node->data.if_expr.then_block); + bool use_expr_value = (then_type != g->builtin_types.entry_unreachable && + then_type != g->builtin_types.entry_void); + + if (node->data.if_expr.else_node) { + LLVMBasicBlockRef then_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "Then"); + LLVMBasicBlockRef else_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "Else"); + LLVMBasicBlockRef endif_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "EndIf"); + + LLVMBuildCondBr(g->builder, cond_value, then_block, else_block); + + LLVMPositionBuilderAtEnd(g->builder, then_block); + LLVMValueRef then_expr_result = gen_expr(g, node->data.if_expr.then_block); + LLVMBuildBr(g->builder, endif_block); + + LLVMPositionBuilderAtEnd(g->builder, else_block); + LLVMValueRef else_expr_result = gen_expr(g, node->data.if_expr.else_node); + LLVMBuildBr(g->builder, endif_block); + + LLVMPositionBuilderAtEnd(g->builder, endif_block); + if (use_expr_value) { + LLVMValueRef phi = LLVMBuildPhi(g->builder, LLVMTypeOf(then_expr_result), ""); + LLVMValueRef incoming_values[2] = {then_expr_result, else_expr_result}; + LLVMBasicBlockRef incoming_blocks[2] = {then_block, else_block}; + LLVMAddIncoming(phi, incoming_values, incoming_blocks, 2); + + return phi; + } + + return nullptr; + } + + assert(!use_expr_value); + + LLVMBasicBlockRef then_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "Then"); + LLVMBasicBlockRef endif_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "EndIf"); + + LLVMBuildCondBr(g->builder, cond_value, then_block, endif_block); + + LLVMPositionBuilderAtEnd(g->builder, then_block); + gen_expr(g, node->data.if_expr.then_block); + LLVMBuildBr(g->builder, endif_block); + + LLVMPositionBuilderAtEnd(g->builder, endif_block); + return nullptr; +} + +static LLVMValueRef gen_block(CodeGen *g, AstNode *block_node, TypeTableEntry *implicit_return_type) { + assert(block_node->type == NodeTypeBlock); + + ImportTableEntry *import = g->cur_fn->import_entry; + + LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder, g->block_scopes.last(), + import->di_file, block_node->line + 1, block_node->column + 1); + g->block_scopes.append(LLVMZigLexicalBlockToScope(di_block)); + + add_debug_source_node(g, block_node); + + LLVMValueRef return_value; + for (int i = 0; i < block_node->data.block.statements.length; i += 1) { + AstNode *statement_node = block_node->data.block.statements.at(i); + return_value = gen_expr(g, statement_node); + } + + if (implicit_return_type) { + if (implicit_return_type == g->builtin_types.entry_void) { + LLVMBuildRetVoid(g->builder); + } else if (implicit_return_type != g->builtin_types.entry_unreachable) { + LLVMBuildRet(g->builder, return_value); + } + } + + g->block_scopes.pop(); + + return return_value; +} + static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { switch (node->type) { case NodeTypeBinOpExpr: @@ -403,6 +494,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { return LLVMBuildUnreachable(g->builder); case NodeTypeVoid: return nullptr; + case NodeTypeIfExpr: + return gen_if_expr(g, node); case NodeTypeNumberLiteral: { Buf *number_str = &node->data.number; @@ -427,6 +520,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { Buf *name = &node->data.symbol; return get_variable_value(g, name); } + case NodeTypeBlock: + return gen_block(g, node, nullptr); case NodeTypeRoot: case NodeTypeRootExportDecl: case NodeTypeFnProto: @@ -434,7 +529,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { case NodeTypeFnDecl: case NodeTypeParamDecl: case NodeTypeType: - case NodeTypeBlock: case NodeTypeExternBlock: case NodeTypeDirective: case NodeTypeUse: @@ -443,30 +537,6 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } -static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, TypeTableEntry *implicit_return_type) { - assert(block_node->type == NodeTypeBlock); - - LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder, g->block_scopes.last(), - import->di_file, block_node->line + 1, block_node->column + 1); - g->block_scopes.append(LLVMZigLexicalBlockToScope(di_block)); - - add_debug_source_node(g, block_node); - - LLVMValueRef return_value; - for (int i = 0; i < block_node->data.block.statements.length; i += 1) { - AstNode *statement_node = block_node->data.block.statements.at(i); - return_value = gen_expr(g, statement_node); - } - - if (implicit_return_type == g->builtin_types.entry_void) { - LLVMBuildRetVoid(g->builder); - } else if (implicit_return_type != g->builtin_types.entry_unreachable) { - LLVMBuildRet(g->builder, return_value); - } - - g->block_scopes.pop(); -} - static LLVMZigDISubroutineType *create_di_function_type(CodeGen *g, AstNodeFnProto *fn_proto, LLVMZigDIFile *di_file) { @@ -558,7 +628,7 @@ static void do_code_gen(CodeGen *g) { LLVMGetParams(fn, codegen_fn_def->params); TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type; - gen_block(g, import, fn_def_node->data.fn_def.body, implicit_return_type); + gen_block(g, fn_def_node->data.fn_def.body, implicit_return_type); g->block_scopes.pop(); } @@ -585,6 +655,15 @@ static void define_primitive_types(CodeGen *g) { buf_init_from_str(&entry->name, "(invalid)"); g->builtin_types.entry_invalid = entry; } + { + TypeTableEntry *entry = allocate(1); + entry->type_ref = LLVMInt1Type(); + buf_init_from_str(&entry->name, "bool"); + entry->di_type = LLVMZigCreateDebugBasicType(g->dbuilder, buf_ptr(&entry->name), 1, 8, + LLVMZigEncoding_DW_ATE_unsigned()); + g->type_table.put(&entry->name, entry); + g->builtin_types.entry_bool = entry; + } { TypeTableEntry *entry = allocate(1); entry->type_ref = LLVMInt8Type(); @@ -803,7 +882,7 @@ static Buf *to_c_type(CodeGen *g, AstNode *type_node) { g->c_stdint_used = true; return buf_create_from_str("int32_t"); } else { - zig_panic("TODO"); + zig_panic("TODO to_c_type"); } } diff --git a/src/parser.cpp b/src/parser.cpp index 562ade0f21..b1b9e122a8 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -91,6 +91,8 @@ const char *node_type_str(NodeType node_type) { return "Use"; case NodeTypeVoid: return "Void"; + case NodeTypeIfExpr: + return "IfExpr"; } zig_unreachable(); } @@ -236,7 +238,15 @@ void ast_print(AstNode *node, int indent) { fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.use.path)); break; case NodeTypeVoid: - fprintf(stderr, "Void\n"); + fprintf(stderr, "%s\n", node_type_str(node->type)); + break; + case NodeTypeIfExpr: + fprintf(stderr, "%s\n", node_type_str(node->type)); + if (node->data.if_expr.condition) + ast_print(node->data.if_expr.condition, indent + 2); + ast_print(node->data.if_expr.then_block, indent + 2); + if (node->data.if_expr.else_node) + ast_print(node->data.if_expr.else_node, indent + 2); break; } } @@ -353,6 +363,7 @@ static void ast_invalid_token_error(ParseContext *pc, Token *token) { static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory); static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandatory); +static AstNode *ast_parse_if_expr(ParseContext *pc, int *token_index, bool mandatory); static void ast_expect_token(ParseContext *pc, Token *token, TokenId token_id) { @@ -558,7 +569,7 @@ static AstNode *ast_parse_grouped_expr(ParseContext *pc, int *token_index, bool } /* -PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | Block | token(Symbol) +PrimaryExpression : token(Number) | token(String) | token(Unreachable) | GroupedExpression | token(Symbol) */ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); @@ -588,11 +599,6 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool return node; } - AstNode *block_node = ast_parse_block(pc, token_index, false); - if (block_node) { - return block_node; - } - AstNode *grouped_expr_node = ast_parse_grouped_expr(pc, token_index, false); if (grouped_expr_node) { return grouped_expr_node; @@ -975,6 +981,50 @@ static AstNode *ast_parse_bool_and_expr(ParseContext *pc, int *token_index, bool return node; } +/* +ElseIf : token(Else) IfExpression +Else : token(Else) Block +*/ +static AstNode *ast_parse_else_or_else_if(ParseContext *pc, int *token_index, bool mandatory) { + Token *else_token = &pc->tokens->at(*token_index); + + if (else_token->id != TokenIdKeywordElse) { + if (mandatory) { + ast_invalid_token_error(pc, else_token); + } else { + return nullptr; + } + } + *token_index += 1; + + AstNode *if_expr = ast_parse_if_expr(pc, token_index, false); + if (if_expr) + return if_expr; + + return ast_parse_block(pc, token_index, true); +} + +/* +IfExpression : token(If) Expression Block option(Else | ElseIf) +*/ +static AstNode *ast_parse_if_expr(ParseContext *pc, int *token_index, bool mandatory) { + Token *if_tok = &pc->tokens->at(*token_index); + if (if_tok->id != TokenIdKeywordIf) { + if (mandatory) { + ast_invalid_token_error(pc, if_tok); + } else { + return nullptr; + } + } + *token_index += 1; + + AstNode *node = ast_create_node(pc, NodeTypeIfExpr, if_tok); + node->data.if_expr.condition = ast_parse_expression(pc, token_index, true); + node->data.if_expr.then_block = ast_parse_block(pc, token_index, true); + node->data.if_expr.else_node = ast_parse_else_or_else_if(pc, token_index, false); + return node; +} + /* ReturnExpression : token(Return) option(Expression) */ @@ -1016,27 +1066,68 @@ static AstNode *ast_parse_bool_or_expr(ParseContext *pc, int *token_index, bool } /* -Expression : BoolOrExpression | ReturnExpression +BlockExpression : IfExpression | Block */ -static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory) { +static AstNode *ast_parse_block_expr(ParseContext *pc, int *token_index, bool mandatory) { Token *token = &pc->tokens->at(*token_index); - AstNode *return_expr = ast_parse_return_expr(pc, token_index, false); - if (return_expr) - return return_expr; + AstNode *if_expr = ast_parse_if_expr(pc, token_index, false); + if (if_expr) + return if_expr; + + AstNode *block = ast_parse_block(pc, token_index, false); + if (block) + return block; + + if (mandatory) + ast_invalid_token_error(pc, token); + + return nullptr; +} + +/* +NonBlockExpression : BoolOrExpression | ReturnExpression +*/ +static AstNode *ast_parse_non_block_expr(ParseContext *pc, int *token_index, bool mandatory) { + Token *token = &pc->tokens->at(*token_index); AstNode *bool_or_expr = ast_parse_bool_or_expr(pc, token_index, false); if (bool_or_expr) return bool_or_expr; - if (!mandatory) - return nullptr; + AstNode *return_expr = ast_parse_return_expr(pc, token_index, false); + if (return_expr) + return return_expr; - ast_invalid_token_error(pc, token); + if (mandatory) + ast_invalid_token_error(pc, token); + + return nullptr; } /* -Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace) +Expression : BlockExpression | NonBlockExpression +*/ +static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool mandatory) { + Token *token = &pc->tokens->at(*token_index); + + AstNode *block_expr = ast_parse_block_expr(pc, token_index, false); + if (block_expr) + return block_expr; + + AstNode *non_block_expr = ast_parse_non_block_expr(pc, token_index, false); + if (non_block_expr) + return non_block_expr; + + if (mandatory) + ast_invalid_token_error(pc, token); + + return nullptr; +} + +/* +Statement : NonBlockExpression token(Semicolon) | BlockExpression +Block : token(LBrace) list(option(Statement), token(Semicolon)) token(RBrace) */ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandatory) { Token *last_token = &pc->tokens->at(*token_index); @@ -1058,16 +1149,22 @@ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandato // {2;} -> {2;void} // {;2} -> {void;2} for (;;) { - AstNode *expression_node = ast_parse_expression(pc, token_index, false); - if (!expression_node) { - expression_node = ast_create_node(pc, NodeTypeVoid, last_token); + AstNode *statement_node = ast_parse_block_expr(pc, token_index, false); + bool semicolon_expected = !statement_node; + if (!statement_node) { + statement_node = ast_parse_non_block_expr(pc, token_index, false); + if (!statement_node) { + statement_node = ast_create_node(pc, NodeTypeVoid, last_token); + } } - node->data.block.statements.append(expression_node); + node->data.block.statements.append(statement_node); last_token = &pc->tokens->at(*token_index); if (last_token->id == TokenIdRBrace) { *token_index += 1; return node; + } else if (!semicolon_expected) { + continue; } else if (last_token->id == TokenIdSemicolon) { *token_index += 1; } else { diff --git a/src/parser.hpp b/src/parser.hpp index 2a7d43c07f..3bfafd75ba 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -39,6 +39,7 @@ enum NodeType { NodeTypeFnCallExpr, NodeTypeUse, NodeTypeVoid, + NodeTypeIfExpr, }; struct AstNodeRoot { @@ -167,6 +168,12 @@ struct AstNodeUse { ZigList *directives; }; +struct AstNodeIfExpr { + AstNode *condition; + AstNode *then_block; + AstNode *else_node; // null, block node, or other if expr node +}; + struct AstNode { enum NodeType type; int line; @@ -190,6 +197,7 @@ struct AstNode { AstNodePrefixOpExpr prefix_op_expr; AstNodeFnCallExpr fn_call_expr; AstNodeUse use; + AstNodeIfExpr if_expr; Buf number; Buf string; Buf symbol; diff --git a/src/semantic_info.hpp b/src/semantic_info.hpp index 8636a3d394..85da09cb16 100644 --- a/src/semantic_info.hpp +++ b/src/semantic_info.hpp @@ -63,6 +63,7 @@ struct CodeGen { HashMap import_table; struct { + TypeTableEntry *entry_bool; TypeTableEntry *entry_u8; TypeTableEntry *entry_i32; TypeTableEntry *entry_string_literal; @@ -111,10 +112,15 @@ struct FnDefNode { LLVMValueRef *params; }; +struct ExprNode { + TypeTableEntry *type_entry; +}; + struct CodeGenNode { union { TypeNode type_node; // for NodeTypeType FnDefNode fn_def_node; // for NodeTypeFnDef + ExprNode expr_node; // for all the expression nodes } data; }; diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 8e8219bf6b..68e072fd5e 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -183,6 +183,10 @@ static void end_token(Tokenize *t) { t->cur_tok->id = TokenIdKeywordUse; } else if (mem_eql_str(token_mem, token_len, "void")) { t->cur_tok->id = TokenIdKeywordVoid; + } else if (mem_eql_str(token_mem, token_len, "if")) { + t->cur_tok->id = TokenIdKeywordIf; + } else if (mem_eql_str(token_mem, token_len, "else")) { + t->cur_tok->id = TokenIdKeywordElse; } t->cur_tok = nullptr; @@ -577,6 +581,8 @@ static const char * token_name(Token *token) { case TokenIdKeywordAs: return "As"; case TokenIdKeywordUse: return "Use"; case TokenIdKeywordVoid: return "Void"; + case TokenIdKeywordIf: return "If"; + case TokenIdKeywordElse: return "Else"; case TokenIdLParen: return "LParen"; case TokenIdRParen: return "RParen"; case TokenIdComma: return "Comma"; diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index 9d5739089d..598176b365 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -24,6 +24,8 @@ enum TokenId { TokenIdKeywordAs, TokenIdKeywordUse, TokenIdKeywordVoid, + TokenIdKeywordIf, + TokenIdKeywordElse, TokenIdLParen, TokenIdRParen, TokenIdComma, diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 334c35220b..c9855e8606 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -189,6 +189,30 @@ static void add_compiling_test_cases(void) { )SOURCE"); } + add_simple_case("if statements", R"SOURCE( + #link("c") + extern { + fn puts(s: *const u8) -> i32; + fn exit(code: i32) -> unreachable; + } + + export fn _start() -> unreachable { + if 1 != 0 { + puts("1 is true"); + } else { + puts("1 is false"); + } + if 0 != 0 { + puts("0 is true"); + } else if 1 - 1 != 0 { + puts("1 - 1 is true"); + } + if !(0 != 0) { + puts("!0 is true"); + } + exit(0); + } + )SOURCE", "1 is true\n!0 is true\n"); } static void add_compile_failure_test_cases(void) {