zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit a1359ac3abbb83f30719f596180f1e923da3a8f2 (tree)
parent 9daf0140e5a78802fd294bce8a9019f59bd89b61
Author: Andrew Kelley <andrew@ziglang.org>
Date:   Wed,  3 Jul 2019 13:55:50 -0400

Merge branch 'rbscott-comptime-union-init'

Diffstat:
Mdoc/langref.html.in | 17+++++++++++++++++
Msrc/all_types.hpp | 11+++++++++++
Msrc/codegen.cpp | 2++
Msrc/ir.cpp | 154+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
Msrc/ir_print.cpp | 15+++++++++++++++
Mtest/stage1/behavior/union.zig | 37++++++++++++++++++++++++++++++++++++-
6 files changed, 208 insertions(+), 28 deletions(-)

diff --git a/doc/langref.html.in b/doc/langref.html.in @@ -5065,6 +5065,12 @@ test "@intToPtr for pointer to zero bit type" { {#header_close#} {#header_close#} + {#header_open|Result Location Semantics#} + <p> + <a href="https://github.com/ziglang/zig/issues/2809">TODO add documentation for this</a> + </p> + {#header_close#} + {#header_open|comptime#} <p> Zig places importance on the concept of whether an expression is known at compile-time. @@ -7809,6 +7815,17 @@ pub const TypeInfo = union(TypeId) { {#header_close#} + {#header_open|@unionInit#} + <pre>{#syntax#}@unionInit(comptime Union: type, comptime active_field_name: []const u8, init_expr) Union{#endsyntax#}</pre> + <p> + This is the same thing as {#link|union#} initialization syntax, except that the field name is a + {#link|comptime#}-known value rather than an identifier token. + </p> + <p> + {#syntax#}@unionInit{#endsyntax#} forwards its {#link|result location|Result Location Semantics#} to {#syntax#}init_expr{#endsyntax#}. + </p> + {#header_close#} + {#header_open|@Vector#} <pre>{#syntax#}@Vector(comptime len: u32, comptime ElemType: type) type{#endsyntax#}</pre> <p> diff --git a/src/all_types.hpp b/src/all_types.hpp @@ -1509,6 +1509,7 @@ enum BuiltinFnId { BuiltinFnIdAtomicRmw, BuiltinFnIdAtomicLoad, BuiltinFnIdHasDecl, + BuiltinFnIdUnionInit, }; struct BuiltinFnEntry { @@ -2359,6 +2360,7 @@ enum IrInstructionId { IrInstructionIdAllocaGen, IrInstructionIdEndExpr, IrInstructionIdPtrOfArrayToSlice, + IrInstructionIdUnionInitNamedField, }; struct IrInstruction { @@ -3603,6 +3605,15 @@ struct IrInstructionAssertNonNull { IrInstruction *target; }; +struct IrInstructionUnionInitNamedField { + IrInstruction base; + + IrInstruction *union_type; + IrInstruction *field_name; + IrInstruction *field_result_loc; + IrInstruction *result_loc; +}; + struct IrInstructionHasDecl { IrInstruction base; diff --git a/src/codegen.cpp b/src/codegen.cpp @@ -5635,6 +5635,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, case IrInstructionIdRef: case IrInstructionIdBitCastSrc: case IrInstructionIdTestErrSrc: + case IrInstructionIdUnionInitNamedField: zig_unreachable(); case IrInstructionIdDeclVarGen: @@ -7419,6 +7420,7 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdFromBytes, "bytesToSlice", 2); create_builtin_fn(g, BuiltinFnIdThis, "This", 0); create_builtin_fn(g, BuiltinFnIdHasDecl, "hasDecl", 2); + create_builtin_fn(g, BuiltinFnIdUnionInit, "unionInit", 3); } static const char *bool_to_str(bool b) { diff --git a/src/ir.cpp b/src/ir.cpp @@ -198,6 +198,9 @@ static IrInstruction *ir_analyze_unwrap_err_code(IrAnalyze *ira, IrInstruction * IrInstruction *base_ptr, bool initializing); static IrInstruction *ir_analyze_store_ptr(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *ptr, IrInstruction *uncasted_value); +static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_type, IrInstruction *field_name, AstNode *expr_node, + LVal lval, ResultLoc *parent_result_loc); static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *const_val) { assert(get_src_ptr_type(const_val->type) != nullptr); @@ -1089,6 +1092,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEndExpr *) { return IrInstructionIdEndExpr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInitNamedField *) { + return IrInstructionIdUnionInitNamedField; +} + template<typename T> static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) { T *special_instruction = allocate<T>(1); @@ -1352,12 +1359,13 @@ static IrInstruction *ir_build_elem_ptr(IrBuilder *irb, Scope *scope, AstNode *s } static IrInstruction *ir_build_field_ptr_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node, - IrInstruction *container_ptr, IrInstruction *field_name_expr) + IrInstruction *container_ptr, IrInstruction *field_name_expr, bool initializing) { IrInstructionFieldPtr *instruction = ir_build_instruction<IrInstructionFieldPtr>(irb, scope, source_node); instruction->container_ptr = container_ptr; instruction->field_name_buffer = nullptr; instruction->field_name_expr = field_name_expr; + instruction->initializing = initializing; ir_ref_instruction(container_ptr, irb->current_basic_block); ir_ref_instruction(field_name_expr, irb->current_basic_block); @@ -3324,6 +3332,24 @@ static IrInstruction *ir_build_check_runtime_scope(IrBuilder *irb, Scope *scope, return &instruction->base; } +static IrInstruction *ir_build_union_init_named_field(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_type, IrInstruction *field_name, IrInstruction *field_result_loc, IrInstruction *result_loc) +{ + IrInstructionUnionInitNamedField *instruction = ir_build_instruction<IrInstructionUnionInitNamedField>(irb, scope, source_node); + instruction->union_type = union_type; + instruction->field_name = field_name; + instruction->field_result_loc = field_result_loc; + instruction->result_loc = result_loc; + + ir_ref_instruction(union_type, irb->current_basic_block); + ir_ref_instruction(field_name, irb->current_basic_block); + ir_ref_instruction(field_result_loc, irb->current_basic_block); + if (result_loc != nullptr) ir_ref_instruction(result_loc, irb->current_basic_block); + + return &instruction->base; +} + + static IrInstruction *ir_build_vector_to_array(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *result_type, IrInstruction *vector, IrInstruction *result_loc) { @@ -5110,7 +5136,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo if (arg1_value == irb->codegen->invalid_instruction) return arg1_value; - IrInstruction *ptr_instruction = ir_build_field_ptr_instruction(irb, scope, node, arg0_value, arg1_value); + IrInstruction *ptr_instruction = ir_build_field_ptr_instruction(irb, scope, node, + arg0_value, arg1_value, false); if (lval == LValPtr) return ptr_instruction; @@ -5651,6 +5678,23 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo IrInstruction *has_decl = ir_build_has_decl(irb, scope, node, arg0_value, arg1_value); return ir_lval_wrap(irb, scope, has_decl, lval, result_loc); } + case BuiltinFnIdUnionInit: + { + AstNode *union_type_node = node->data.fn_call_expr.params.at(0); + IrInstruction *union_type_inst = ir_gen_node(irb, union_type_node, scope); + if (union_type_inst == irb->codegen->invalid_instruction) + return union_type_inst; + + AstNode *name_node = node->data.fn_call_expr.params.at(1); + IrInstruction *name_inst = ir_gen_node(irb, name_node, scope); + if (name_inst == irb->codegen->invalid_instruction) + return name_inst; + + AstNode *init_node = node->data.fn_call_expr.params.at(2); + + return ir_gen_union_init_expr(irb, scope, node, union_type_inst, name_inst, init_node, + lval, result_loc); + } } zig_unreachable(); } @@ -5929,6 +5973,31 @@ static IrInstruction *ir_gen_prefix_op_expr(IrBuilder *irb, Scope *scope, AstNod zig_unreachable(); } +static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *union_type, IrInstruction *field_name, AstNode *expr_node, + LVal lval, ResultLoc *parent_result_loc) +{ + IrInstruction *container_ptr = ir_build_resolve_result(irb, scope, source_node, parent_result_loc, union_type); + IrInstruction *field_ptr = ir_build_field_ptr_instruction(irb, scope, source_node, container_ptr, + field_name, true); + + ResultLocInstruction *result_loc_inst = allocate<ResultLocInstruction>(1); + result_loc_inst->base.id = ResultLocIdInstruction; + result_loc_inst->base.source_instruction = field_ptr; + ir_ref_instruction(field_ptr, irb->current_basic_block); + ir_build_reset_result(irb, scope, expr_node, &result_loc_inst->base); + + IrInstruction *expr_value = ir_gen_node_extra(irb, expr_node, scope, LValNone, + &result_loc_inst->base); + if (expr_value == irb->codegen->invalid_instruction) + return expr_value; + + IrInstruction *init_union = ir_build_union_init_named_field(irb, scope, source_node, union_type, + field_name, field_ptr, container_ptr); + + return ir_lval_wrap(irb, scope, init_union, lval, parent_result_loc); +} + static IrInstruction *ir_gen_container_init_expr(IrBuilder *irb, Scope *scope, AstNode *node, LVal lval, ResultLoc *parent_result_loc) { @@ -19408,32 +19477,21 @@ static IrInstruction *ir_analyze_instruction_ref(IrAnalyze *ira, IrInstructionRe return ir_get_ref(ira, &ref_instruction->base, value, ref_instruction->is_const, ref_instruction->is_volatile); } -static IrInstruction *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrInstruction *instruction, - ZigType *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields, - IrInstruction *result_loc) +static IrInstruction *ir_analyze_union_init(IrAnalyze *ira, IrInstruction *source_instruction, + AstNode *field_source_node, ZigType *union_type, Buf *field_name, IrInstruction *field_result_loc, + IrInstruction *result_loc) { Error err; - assert(container_type->id == ZigTypeIdUnion); - - if ((err = type_resolve(ira->codegen, container_type, ResolveStatusSizeKnown))) - return ira->codegen->invalid_instruction; - - if (instr_field_count != 1) { - ir_add_error(ira, instruction, - buf_sprintf("union initialization expects exactly one field")); - return ira->codegen->invalid_instruction; - } + assert(union_type->id == ZigTypeIdUnion); - IrInstructionContainerInitFieldsField *field = &fields[0]; - IrInstruction *field_result_loc = field->result_loc->child; - if (type_is_invalid(field_result_loc->value.type)) + if ((err = type_resolve(ira->codegen, union_type, ResolveStatusSizeKnown))) return ira->codegen->invalid_instruction; - TypeUnionField *type_field = find_union_type_field(container_type, field->name); + TypeUnionField *type_field = find_union_type_field(union_type, field_name); if (type_field == nullptr) { - ir_add_error_node(ira, field->source_node, + ir_add_error_node(ira, field_source_node, buf_sprintf("no member named '%s' in union '%s'", - buf_ptr(field->name), buf_ptr(&container_type->name))); + buf_ptr(field_name), buf_ptr(&union_type->name))); return ira->codegen->invalid_instruction; } @@ -19450,12 +19508,12 @@ static IrInstruction *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrI } } - bool is_comptime = ir_should_inline(ira->new_irb.exec, instruction->scope) - || type_requires_comptime(ira->codegen, container_type) == ReqCompTimeYes; + bool is_comptime = ir_should_inline(ira->new_irb.exec, source_instruction->scope) + || type_requires_comptime(ira->codegen, union_type) == ReqCompTimeYes; - IrInstruction *result = ir_get_deref(ira, instruction, result_loc, nullptr); + IrInstruction *result = ir_get_deref(ira, source_instruction, result_loc, nullptr); if (is_comptime && !instr_is_comptime(result)) { - ir_add_error(ira, field->result_loc, + ir_add_error(ira, field_result_loc, buf_sprintf("unable to evaluate constant expression")); return ira->codegen->invalid_instruction; } @@ -19468,8 +19526,18 @@ static IrInstruction *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstruc { Error err; if (container_type->id == ZigTypeIdUnion) { - return ir_analyze_container_init_fields_union(ira, instruction, container_type, instr_field_count, - fields, result_loc); + if (instr_field_count != 1) { + ir_add_error(ira, instruction, + buf_sprintf("union initialization expects exactly one field")); + return ira->codegen->invalid_instruction; + } + IrInstructionContainerInitFieldsField *field = &fields[0]; + IrInstruction *field_result_loc = field->result_loc->child; + if (type_is_invalid(field_result_loc->value.type)) + return ira->codegen->invalid_instruction; + + return ir_analyze_union_init(ira, instruction, field->source_node, container_type, field->name, + field_result_loc, result_loc); } if (container_type->id != ZigTypeIdStruct || is_slice(container_type)) { ir_add_error(ira, instruction, @@ -25326,6 +25394,35 @@ static IrInstruction *ir_analyze_instruction_bit_cast_src(IrAnalyze *ira, IrInst return instruction->result_loc_bit_cast->parent->gen_instruction; } +static IrInstruction *ir_analyze_instruction_union_init_named_field(IrAnalyze *ira, + IrInstructionUnionInitNamedField *instruction) +{ + ZigType *union_type = ir_resolve_type(ira, instruction->union_type->child); + if (type_is_invalid(union_type)) + return ira->codegen->invalid_instruction; + + if (union_type->id != ZigTypeIdUnion) { + ir_add_error(ira, instruction->union_type, + buf_sprintf("non-union type '%s' passed to @unionInit", buf_ptr(&union_type->name))); + return ira->codegen->invalid_instruction; + } + + Buf *field_name = ir_resolve_str(ira, instruction->field_name->child); + if (field_name == nullptr) + return ira->codegen->invalid_instruction; + + IrInstruction *field_result_loc = instruction->field_result_loc->child; + if (type_is_invalid(field_result_loc->value.type)) + return ira->codegen->invalid_instruction; + + IrInstruction *result_loc = instruction->result_loc->child; + if (type_is_invalid(result_loc->value.type)) + return ira->codegen->invalid_instruction; + + return ir_analyze_union_init(ira, &instruction->base, instruction->base.source_node, + union_type, field_name, field_result_loc, result_loc); +} + static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) { switch (instruction->id) { case IrInstructionIdInvalid: @@ -25641,6 +25738,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_end_expr(ira, (IrInstructionEndExpr *)instruction); case IrInstructionIdBitCastSrc: return ir_analyze_instruction_bit_cast_src(ira, (IrInstructionBitCastSrc *)instruction); + case IrInstructionIdUnionInitNamedField: + return ir_analyze_instruction_union_init_named_field(ira, (IrInstructionUnionInitNamedField *)instruction); } zig_unreachable(); } @@ -25794,6 +25893,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdCast: case IrInstructionIdContainerInitList: case IrInstructionIdContainerInitFields: + case IrInstructionIdUnionInitNamedField: case IrInstructionIdFieldPtr: case IrInstructionIdElemPtr: case IrInstructionIdVarPtr: diff --git a/src/ir_print.cpp b/src/ir_print.cpp @@ -1626,6 +1626,18 @@ static void ir_print_undeclared_ident(IrPrint *irp, IrInstructionUndeclaredIdent fprintf(irp->f, "@undeclaredIdent(%s)", buf_ptr(instruction->name)); } +static void ir_print_union_init_named_field(IrPrint *irp, IrInstructionUnionInitNamedField *instruction) { + fprintf(irp->f, "@unionInit("); + ir_print_other_instruction(irp, instruction->union_type); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->field_name); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->field_result_loc); + fprintf(irp->f, ", "); + ir_print_other_instruction(irp, instruction->result_loc); + fprintf(irp->f, ")"); +} + static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { ir_print_prefix(irp, instruction); switch (instruction->id) { @@ -2132,6 +2144,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) { case IrInstructionIdEndExpr: ir_print_end_expr(irp, (IrInstructionEndExpr *)instruction); break; + case IrInstructionIdUnionInitNamedField: + ir_print_union_init_named_field(irp, (IrInstructionUnionInitNamedField *)instruction); + break; } fprintf(irp->f, "\n"); } diff --git a/test/stage1/behavior/union.zig b/test/stage1/behavior/union.zig @@ -416,9 +416,44 @@ test "return union init with void payload" { two: u32, }; fn func() Outer { - return Outer{ .state = State{ .one = {} }}; + return Outer{ .state = State{ .one = {} } }; } }; S.entry(); comptime S.entry(); } + +test "@unionInit can modify a union type" { + const UnionInitEnum = union(enum) { + Boolean: bool, + Byte: u8, + }; + + var value: UnionInitEnum = undefined; + + value = @unionInit(UnionInitEnum, "Boolean", true); + expect(value.Boolean == true); + value.Boolean = false; + expect(value.Boolean == false); + + value = @unionInit(UnionInitEnum, "Byte", 2); + expect(value.Byte == 2); + value.Byte = 3; + expect(value.Byte == 3); +} + +test "@unionInit can modify a pointer value" { + const UnionInitEnum = union(enum) { + Boolean: bool, + Byte: u8, + }; + + var value: UnionInitEnum = undefined; + var value_ptr = &value; + + value_ptr.* = @unionInit(UnionInitEnum, "Boolean", true); + expect(value.Boolean == true); + + value_ptr.* = @unionInit(UnionInitEnum, "Byte", 2); + expect(value.Byte == 2); +}