diff --git a/README.md b/README.md index fd3396620a..25d9ec10d6 100644 --- a/README.md +++ b/README.md @@ -104,11 +104,7 @@ Type : token(Symbol) | PointerType | token(Unreachable) PointerType : token(Star) token(Const) Type | token(Star) token(Mut) Type -Block : token(LBrace) many(Statement) token(RBrace) - -Statement : ExpressionStatement - -ExpressionStatement : Expression token(Semicolon) +Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace) Expression : BoolOrExpression | ReturnExpression diff --git a/src/analyze.cpp b/src/analyze.cpp index de0d5b1cdb..9ecb4eb8d3 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -270,6 +270,7 @@ static void preview_function_declarations(CodeGen *g, ImportTableEntry *import, case NodeTypeNumberLiteral: case NodeTypeStringLiteral: case NodeTypeUnreachable: + case NodeTypeVoid: case NodeTypeSymbol: case NodeTypeCastExpr: case NodeTypePrefixOpExpr: @@ -311,8 +312,12 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, for (int i = 0; i < node->data.block.statements.length; i += 1) { AstNode *child = node->data.block.statements.at(i); if (return_type == g->builtin_types.entry_unreachable) { - add_node_error(g, child, - buf_sprintf("unreachable code")); + if (child->type == NodeTypeVoid) { + // {unreachable;void;void} is allowed. + // ignore void statements once we enter unreachable land. + continue; + } + add_node_error(g, child, buf_sprintf("unreachable code")); break; } return_type = analyze_expression(g, import, context, nullptr, child); @@ -415,6 +420,10 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, return_type = g->builtin_types.entry_unreachable; break; + case NodeTypeVoid: + return_type = g->builtin_types.entry_void; + break; + case NodeTypeSymbol: // look up symbol in symbol table zig_panic("TODO"); @@ -439,59 +448,6 @@ static TypeTableEntry * analyze_expression(CodeGen *g, ImportTableEntry *import, return return_type; } -static void check_fn_def_control_flow(CodeGen *g, AstNode *node) { - // Follow the execution flow and make sure the code returns appropriately. - // * A `return` statement in an unreachable type function should be an error. - // * Control flow should not be able to reach the end of an unreachable type function. - // * Functions that have a type other than void should not return without a value. - // * void functions without explicit return statements at the end need the - // add_implicit_return flag set on the codegen node. - assert(node->type == NodeTypeFnDef); - AstNode *proto_node = node->data.fn_def.fn_proto; - assert(proto_node->type == NodeTypeFnProto); - AstNode *return_type_node = proto_node->data.fn_proto.return_type; - assert(return_type_node->type == NodeTypeType); - - node->codegen_node = allocate(1); - FnDefNode *codegen_fn_def = &node->codegen_node->data.fn_def_node; - - assert(return_type_node->codegen_node); - TypeTableEntry *type_entry = return_type_node->codegen_node->data.type_node.entry; - assert(type_entry); - - AstNode *body_node = node->data.fn_def.body; - assert(body_node->type == NodeTypeBlock); - - // TODO once we understand types, do this pass after type checking, and - // if an expression has an unreachable value then stop looking at statements after - // it. then we can remove the check to `unreachable` in the end of this function. - bool prev_statement_return = false; - for (int i = 0; i < body_node->data.block.statements.length; i += 1) { - AstNode *statement_node = body_node->data.block.statements.at(i); - if (statement_node->type == NodeTypeReturnExpr) { - if (type_entry == g->builtin_types.entry_unreachable) { - add_node_error(g, statement_node, - buf_sprintf("return statement in function with unreachable return type")); - return; - } else { - prev_statement_return = true; - } - } else if (prev_statement_return) { - add_node_error(g, statement_node, - buf_sprintf("unreachable code")); - } - } - - if (!prev_statement_return) { - if (type_entry == g->builtin_types.entry_void) { - codegen_fn_def->add_implicit_return = true; - } else if (type_entry != g->builtin_types.entry_unreachable) { - add_node_error(g, node, - buf_sprintf("control reaches end of non-void function")); - } - } -} - static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, AstNode *node) { switch (node->type) { case NodeTypeFnDef: @@ -512,14 +468,15 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, // TODO: define local variables for parameters } - check_fn_def_control_flow(g, node); - BlockContext context; context.node = node; context.root = &context; context.parent = nullptr; TypeTableEntry *expected_type = fn_proto->return_type->codegen_node->data.type_node.entry; - analyze_expression(g, import, &context, expected_type, node->data.fn_def.body); + TypeTableEntry *block_return_type = analyze_expression(g, import, &context, expected_type, node->data.fn_def.body); + + node->codegen_node = allocate(1); + node->codegen_node->data.fn_def_node.implicit_return_type = block_return_type; } break; @@ -548,6 +505,7 @@ static void analyze_top_level_declaration(CodeGen *g, ImportTableEntry *import, case NodeTypeNumberLiteral: case NodeTypeStringLiteral: case NodeTypeUnreachable: + case NodeTypeVoid: case NodeTypeSymbol: case NodeTypeCastExpr: case NodeTypePrefixOpExpr: diff --git a/src/codegen.cpp b/src/codegen.cpp index ab74ae19ce..18a5c4c5a1 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -401,6 +401,8 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { case NodeTypeUnreachable: add_debug_source_node(g, node); return LLVMBuildUnreachable(g->builder); + case NodeTypeVoid: + return nullptr; case NodeTypeNumberLiteral: { Buf *number_str = &node->data.number; @@ -441,7 +443,7 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *node) { zig_unreachable(); } -static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, bool add_implicit_return) { +static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, TypeTableEntry *implicit_return_type) { assert(block_node->type == NodeTypeBlock); LLVMZigDILexicalBlock *di_block = LLVMZigCreateLexicalBlock(g->dbuilder, g->block_scopes.last(), @@ -450,13 +452,16 @@ static void gen_block(CodeGen *g, ImportTableEntry *import, AstNode *block_node, add_debug_source_node(g, block_node); + LLVMValueRef return_value; for (int i = 0; i < block_node->data.block.statements.length; i += 1) { AstNode *statement_node = block_node->data.block.statements.at(i); - gen_expr(g, statement_node); + return_value = gen_expr(g, statement_node); } - if (add_implicit_return) { + if (implicit_return_type == g->builtin_types.entry_void) { LLVMBuildRetVoid(g->builder); + } else if (implicit_return_type != g->builtin_types.entry_unreachable) { + LLVMBuildRet(g->builder, return_value); } g->block_scopes.pop(); @@ -552,8 +557,8 @@ static void do_code_gen(CodeGen *g) { codegen_fn_def->params = allocate(LLVMCountParams(fn)); LLVMGetParams(fn, codegen_fn_def->params); - bool add_implicit_return = codegen_fn_def->add_implicit_return; - gen_block(g, import, fn_def_node->data.fn_def.body, add_implicit_return); + TypeTableEntry *implicit_return_type = codegen_fn_def->implicit_return_type; + gen_block(g, import, fn_def_node->data.fn_def.body, implicit_return_type); g->block_scopes.pop(); } diff --git a/src/parser.cpp b/src/parser.cpp index e17e3c83fe..562ade0f21 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -89,6 +89,8 @@ const char *node_type_str(NodeType node_type) { return "PrefixOpExpr"; case NodeTypeUse: return "Use"; + case NodeTypeVoid: + return "Void"; } zig_unreachable(); } @@ -233,6 +235,9 @@ void ast_print(AstNode *node, int indent) { case NodeTypeUse: fprintf(stderr, "%s '%s'\n", node_type_str(node->type), buf_ptr(&node->data.use.path)); break; + case NodeTypeVoid: + fprintf(stderr, "Void\n"); + break; } } @@ -416,6 +421,9 @@ static AstNode *ast_parse_type(ParseContext *pc, int token_index, int *new_token if (token->id == TokenIdKeywordUnreachable) { node->data.type.type = AstNodeTypeTypePrimitive; buf_init_from_str(&node->data.type.primitive_name, "unreachable"); + } else if (token->id == TokenIdKeywordVoid) { + node->data.type.type = AstNodeTypeTypePrimitive; + buf_init_from_str(&node->data.type.primitive_name, "void"); } else if (token->id == TokenIdSymbol) { node->data.type.type = AstNodeTypeTypePrimitive; ast_buf_from_token(pc, token, &node->data.type.primitive_name); @@ -569,6 +577,10 @@ static AstNode *ast_parse_primary_expr(ParseContext *pc, int *token_index, bool AstNode *node = ast_create_node(pc, NodeTypeUnreachable, token); *token_index += 1; return node; + } else if (token->id == TokenIdKeywordVoid) { + AstNode *node = ast_create_node(pc, NodeTypeVoid, token); + *token_index += 1; + return node; } else if (token->id == TokenIdSymbol) { AstNode *node = ast_create_node(pc, NodeTypeSymbol, token); ast_buf_from_token(pc, token, &node->data.symbol); @@ -1024,50 +1036,42 @@ static AstNode *ast_parse_expression(ParseContext *pc, int *token_index, bool ma } /* -ExpressionStatement : Expression token(Semicolon) -*/ -static AstNode *ast_parse_expression_statement(ParseContext *pc, int *token_index) { - AstNode *expr_node = ast_parse_expression(pc, token_index, true); - - Token *semicolon = &pc->tokens->at(*token_index); - *token_index += 1; - ast_expect_token(pc, semicolon, TokenIdSemicolon); - - return expr_node; -} - -/* -Statement : ExpressionStatement -*/ -static AstNode *ast_parse_statement(ParseContext *pc, int *token_index) { - return ast_parse_expression_statement(pc, token_index); -} - -/* -Block : token(LBrace) many(Statement) token(RBrace); +Block : token(LBrace) list(option(Expression), token(Semicolon)) token(RBrace) */ static AstNode *ast_parse_block(ParseContext *pc, int *token_index, bool mandatory) { - Token *l_brace = &pc->tokens->at(*token_index); + Token *last_token = &pc->tokens->at(*token_index); - if (l_brace->id != TokenIdLBrace) { + if (last_token->id != TokenIdLBrace) { if (mandatory) { - ast_invalid_token_error(pc, l_brace); + ast_invalid_token_error(pc, last_token); } else { return nullptr; } } *token_index += 1; - AstNode *node = ast_create_node(pc, NodeTypeBlock, l_brace); + AstNode *node = ast_create_node(pc, NodeTypeBlock, last_token); + // {} -> {void} + // {;} -> {void;void} + // {2} -> {2} + // {2;} -> {2;void} + // {;2} -> {void;2} for (;;) { - Token *token = &pc->tokens->at(*token_index); - if (token->id == TokenIdRBrace) { + AstNode *expression_node = ast_parse_expression(pc, token_index, false); + if (!expression_node) { + expression_node = ast_create_node(pc, NodeTypeVoid, last_token); + } + node->data.block.statements.append(expression_node); + + last_token = &pc->tokens->at(*token_index); + if (last_token->id == TokenIdRBrace) { *token_index += 1; return node; + } else if (last_token->id == TokenIdSemicolon) { + *token_index += 1; } else { - AstNode *statement_node = ast_parse_statement(pc, token_index); - node->data.block.statements.append(statement_node); + ast_invalid_token_error(pc, last_token); } } zig_unreachable(); diff --git a/src/parser.hpp b/src/parser.hpp index 4f433fe8eb..2a7d43c07f 100644 --- a/src/parser.hpp +++ b/src/parser.hpp @@ -38,6 +38,7 @@ enum NodeType { NodeTypePrefixOpExpr, NodeTypeFnCallExpr, NodeTypeUse, + NodeTypeVoid, }; struct AstNodeRoot { diff --git a/src/semantic_info.hpp b/src/semantic_info.hpp index 66b57f5a55..8636a3d394 100644 --- a/src/semantic_info.hpp +++ b/src/semantic_info.hpp @@ -106,7 +106,7 @@ struct TypeNode { }; struct FnDefNode { - bool add_implicit_return; + TypeTableEntry *implicit_return_type; bool skip; LLVMValueRef *params; }; diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp index 0a61bb9168..8e8219bf6b 100644 --- a/src/tokenizer.cpp +++ b/src/tokenizer.cpp @@ -181,6 +181,8 @@ static void end_token(Tokenize *t) { t->cur_tok->id = TokenIdKeywordAs; } else if (mem_eql_str(token_mem, token_len, "use")) { t->cur_tok->id = TokenIdKeywordUse; + } else if (mem_eql_str(token_mem, token_len, "void")) { + t->cur_tok->id = TokenIdKeywordVoid; } t->cur_tok = nullptr; @@ -574,6 +576,7 @@ static const char * token_name(Token *token) { case TokenIdKeywordExport: return "Export"; case TokenIdKeywordAs: return "As"; case TokenIdKeywordUse: return "Use"; + case TokenIdKeywordVoid: return "Void"; case TokenIdLParen: return "LParen"; case TokenIdRParen: return "RParen"; case TokenIdComma: return "Comma"; diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp index 45bc011e03..9d5739089d 100644 --- a/src/tokenizer.hpp +++ b/src/tokenizer.hpp @@ -23,6 +23,7 @@ enum TokenId { TokenIdKeywordExport, TokenIdKeywordAs, TokenIdKeywordUse, + TokenIdKeywordVoid, TokenIdLParen, TokenIdRParen, TokenIdComma, diff --git a/test/run_tests.cpp b/test/run_tests.cpp index 6b6d039f4c..334c35220b 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -209,11 +209,11 @@ fn a() {} add_compile_fail_case("unreachable with return", R"SOURCE( fn a() -> unreachable {return;} - )SOURCE", 1, ".tmp_source.zig:2:24: error: return statement in function with unreachable return type"); + )SOURCE", 1, ".tmp_source.zig:2:24: error: type mismatch. expected unreachable. got void"); add_compile_fail_case("control reaches end of non-void function", R"SOURCE( fn a() -> i32 {} - )SOURCE", 1, ".tmp_source.zig:2:1: error: control reaches end of non-void function"); + )SOURCE", 1, ".tmp_source.zig:2:15: error: type mismatch. expected i32. got void"); add_compile_fail_case("undefined function call", R"SOURCE( fn a() {