allow empty function and return with no expression

This commit is contained in:
Andrew Kelley
2015-11-27 10:52:31 -07:00
parent 821907317e
commit 9ca9a2c554
4 changed files with 107 additions and 23 deletions

View File

@@ -80,9 +80,14 @@ struct TypeNode {
TypeTableEntry *entry;
};
struct FnDefNode {
bool add_implicit_return;
};
struct CodeGenNode {
union {
TypeNode type_node; // for NodeTypeType
FnDefNode fn_def_node; // for NodeTypeFnDef
} data;
};
@@ -275,6 +280,60 @@ static void find_declarations(CodeGen *g, AstNode *node) {
}
}
static void check_fn_def_control_flow(CodeGen *g, AstNode *node) {
// Follow the execution flow and make sure the code returns appropriately.
// * A `return` statement in an unreachable type function should be an error.
// * Control flow should not be able to reach the end of an unreachable type function.
// * Functions that have a type other than void should not return without a value.
// * void functions without explicit return statements at the end need the
// add_implicit_return flag set on the codegen node.
assert(node->type == NodeTypeFnDef);
AstNode *proto_node = node->data.fn_def.fn_proto;
assert(proto_node->type == NodeTypeFnProto);
AstNode *return_type_node = proto_node->data.fn_proto.return_type;
assert(return_type_node->type == NodeTypeType);
node->codegen_node = allocate<CodeGenNode>(1);
FnDefNode *codegen_fn_def = &node->codegen_node->data.fn_def_node;
assert(return_type_node->codegen_node);
TypeTableEntry *type_entry = return_type_node->codegen_node->data.type_node.entry;
assert(type_entry);
TypeId type_id = type_entry->id;
AstNode *body_node = node->data.fn_def.body;
assert(body_node->type == NodeTypeBlock);
// TODO once we understand types, do this pass after type checking, and
// if an expression has an unreachable value then stop looking at statements after
// it. then we can remove the check to `unreachable` in the end of this function.
bool prev_statement_return = false;
for (int i = 0; i < body_node->data.block.statements.length; i += 1) {
AstNode *statement_node = body_node->data.block.statements.at(i);
if (statement_node->type == NodeTypeStatementReturn) {
if (type_id == TypeIdUnreachable) {
add_node_error(g, statement_node,
buf_sprintf("return statement in function with unreachable return type"));
return;
} else {
prev_statement_return = true;
}
} else if (prev_statement_return) {
add_node_error(g, statement_node,
buf_sprintf("unreachable code"));
}
}
if (!prev_statement_return) {
if (type_id == TypeIdVoid) {
codegen_fn_def->add_implicit_return = true;
} else if (type_id != TypeIdUnreachable) {
add_node_error(g, node,
buf_sprintf("control reaches end of non-void function"));
}
}
}
static void analyze_node(CodeGen *g, AstNode *node) {
switch (node->type) {
case NodeTypeRoot:
@@ -299,6 +358,8 @@ static void analyze_node(CodeGen *g, AstNode *node) {
AstNode *proto_node = node->data.fn_def.fn_proto;
assert(proto_node->type == NodeTypeFnProto);
analyze_node(g, proto_node);
check_fn_def_control_flow(g, node);
break;
}
case NodeTypeFnDecl:
@@ -331,7 +392,9 @@ static void analyze_node(CodeGen *g, AstNode *node) {
}
break;
case NodeTypeStatementReturn:
analyze_node(g, node->data.statement_return.expression);
if (node->data.statement_return.expression) {
analyze_node(g, node->data.statement_return.expression);
}
break;
case NodeTypeExpression:
switch (node->data.expression.type) {
@@ -545,7 +608,7 @@ static LLVMValueRef gen_expr(CodeGen *g, AstNode *expr_node) {
zig_unreachable();
}
static void gen_block(CodeGen *g, AstNode *block_node) {
static void gen_block(CodeGen *g, AstNode *block_node, bool add_implicit_return) {
assert(block_node->type == NodeTypeBlock);
llvm::DILexicalBlock *di_block = g->dbuilder->createLexicalBlock(g->block_scopes.last(),
@@ -558,10 +621,15 @@ static void gen_block(CodeGen *g, AstNode *block_node) {
case NodeTypeStatementReturn:
{
AstNode *expr_node = statement_node->data.statement_return.expression;
LLVMValueRef value = gen_expr(g, expr_node);
if (expr_node) {
LLVMValueRef value = gen_expr(g, expr_node);
add_debug_source_node(g, statement_node);
LLVMBuildRet(g->builder, value);
add_debug_source_node(g, statement_node);
LLVMBuildRet(g->builder, value);
} else {
add_debug_source_node(g, statement_node);
LLVMBuildRetVoid(g->builder);
}
break;
}
case NodeTypeExpression:
@@ -583,6 +651,10 @@ static void gen_block(CodeGen *g, AstNode *block_node) {
}
}
if (add_implicit_return) {
LLVMBuildRetVoid(g->builder);
}
g->block_scopes.pop();
}
@@ -685,7 +757,10 @@ void code_gen(CodeGen *g) {
LLVMBasicBlockRef entry_block = LLVMAppendBasicBlock(fn, "entry");
LLVMPositionBuilderAtEnd(g->builder, entry_block);
gen_block(g, fn_def_node->data.fn_def.body);
CodeGenNode *codegen_node = fn_def_node->codegen_node;
assert(codegen_node);
bool add_implicit_return = codegen_node->data.fn_def_node.add_implicit_return;
gen_block(g, fn_def_node->data.fn_def.body, add_implicit_return);
g->block_scopes.pop();
}