commit a1359ac3abbb83f30719f596180f1e923da3a8f2 (tree)
parent 9daf0140e5a78802fd294bce8a9019f59bd89b61
Author: Andrew Kelley <andrew@ziglang.org>
Date: Wed, 3 Jul 2019 13:55:50 -0400
Merge branch 'rbscott-comptime-union-init'
Diffstat:
6 files changed, 208 insertions(+), 28 deletions(-)
diff --git a/doc/langref.html.in b/doc/langref.html.in
@@ -5065,6 +5065,12 @@ test "@intToPtr for pointer to zero bit type" {
{#header_close#}
{#header_close#}
+ {#header_open|Result Location Semantics#}
+ <p>
+ <a href="https://github.com/ziglang/zig/issues/2809">TODO add documentation for this</a>
+ </p>
+ {#header_close#}
+
{#header_open|comptime#}
<p>
Zig places importance on the concept of whether an expression is known at compile-time.
@@ -7809,6 +7815,17 @@ pub const TypeInfo = union(TypeId) {
{#header_close#}
+ {#header_open|@unionInit#}
+ <pre>{#syntax#}@unionInit(comptime Union: type, comptime active_field_name: []const u8, init_expr) Union{#endsyntax#}</pre>
+ <p>
+ This is the same thing as {#link|union#} initialization syntax, except that the field name is a
+ {#link|comptime#}-known value rather than an identifier token.
+ </p>
+ <p>
+ {#syntax#}@unionInit{#endsyntax#} forwards its {#link|result location|Result Location Semantics#} to {#syntax#}init_expr{#endsyntax#}.
+ </p>
+ {#header_close#}
+
{#header_open|@Vector#}
<pre>{#syntax#}@Vector(comptime len: u32, comptime ElemType: type) type{#endsyntax#}</pre>
<p>
diff --git a/src/all_types.hpp b/src/all_types.hpp
@@ -1509,6 +1509,7 @@ enum BuiltinFnId {
BuiltinFnIdAtomicRmw,
BuiltinFnIdAtomicLoad,
BuiltinFnIdHasDecl,
+ BuiltinFnIdUnionInit,
};
struct BuiltinFnEntry {
@@ -2359,6 +2360,7 @@ enum IrInstructionId {
IrInstructionIdAllocaGen,
IrInstructionIdEndExpr,
IrInstructionIdPtrOfArrayToSlice,
+ IrInstructionIdUnionInitNamedField,
};
struct IrInstruction {
@@ -3603,6 +3605,15 @@ struct IrInstructionAssertNonNull {
IrInstruction *target;
};
+struct IrInstructionUnionInitNamedField {
+ IrInstruction base;
+
+ IrInstruction *union_type;
+ IrInstruction *field_name;
+ IrInstruction *field_result_loc;
+ IrInstruction *result_loc;
+};
+
struct IrInstructionHasDecl {
IrInstruction base;
diff --git a/src/codegen.cpp b/src/codegen.cpp
@@ -5635,6 +5635,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
case IrInstructionIdRef:
case IrInstructionIdBitCastSrc:
case IrInstructionIdTestErrSrc:
+ case IrInstructionIdUnionInitNamedField:
zig_unreachable();
case IrInstructionIdDeclVarGen:
@@ -7419,6 +7420,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdFromBytes, "bytesToSlice", 2);
create_builtin_fn(g, BuiltinFnIdThis, "This", 0);
create_builtin_fn(g, BuiltinFnIdHasDecl, "hasDecl", 2);
+ create_builtin_fn(g, BuiltinFnIdUnionInit, "unionInit", 3);
}
static const char *bool_to_str(bool b) {
diff --git a/src/ir.cpp b/src/ir.cpp
@@ -198,6 +198,9 @@ static IrInstruction *ir_analyze_unwrap_err_code(IrAnalyze *ira, IrInstruction *
IrInstruction *base_ptr, bool initializing);
static IrInstruction *ir_analyze_store_ptr(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *ptr, IrInstruction *uncasted_value);
+static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNode *source_node,
+ IrInstruction *union_type, IrInstruction *field_name, AstNode *expr_node,
+ LVal lval, ResultLoc *parent_result_loc);
static ConstExprValue *const_ptr_pointee_unchecked(CodeGen *g, ConstExprValue *const_val) {
assert(get_src_ptr_type(const_val->type) != nullptr);
@@ -1089,6 +1092,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionEndExpr *) {
return IrInstructionIdEndExpr;
}
+static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionInitNamedField *) {
+ return IrInstructionIdUnionInitNamedField;
+}
+
template<typename T>
static T *ir_create_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node) {
T *special_instruction = allocate<T>(1);
@@ -1352,12 +1359,13 @@ static IrInstruction *ir_build_elem_ptr(IrBuilder *irb, Scope *scope, AstNode *s
}
static IrInstruction *ir_build_field_ptr_instruction(IrBuilder *irb, Scope *scope, AstNode *source_node,
- IrInstruction *container_ptr, IrInstruction *field_name_expr)
+ IrInstruction *container_ptr, IrInstruction *field_name_expr, bool initializing)
{
IrInstructionFieldPtr *instruction = ir_build_instruction<IrInstructionFieldPtr>(irb, scope, source_node);
instruction->container_ptr = container_ptr;
instruction->field_name_buffer = nullptr;
instruction->field_name_expr = field_name_expr;
+ instruction->initializing = initializing;
ir_ref_instruction(container_ptr, irb->current_basic_block);
ir_ref_instruction(field_name_expr, irb->current_basic_block);
@@ -3324,6 +3332,24 @@ static IrInstruction *ir_build_check_runtime_scope(IrBuilder *irb, Scope *scope,
return &instruction->base;
}
+static IrInstruction *ir_build_union_init_named_field(IrBuilder *irb, Scope *scope, AstNode *source_node,
+ IrInstruction *union_type, IrInstruction *field_name, IrInstruction *field_result_loc, IrInstruction *result_loc)
+{
+ IrInstructionUnionInitNamedField *instruction = ir_build_instruction<IrInstructionUnionInitNamedField>(irb, scope, source_node);
+ instruction->union_type = union_type;
+ instruction->field_name = field_name;
+ instruction->field_result_loc = field_result_loc;
+ instruction->result_loc = result_loc;
+
+ ir_ref_instruction(union_type, irb->current_basic_block);
+ ir_ref_instruction(field_name, irb->current_basic_block);
+ ir_ref_instruction(field_result_loc, irb->current_basic_block);
+ if (result_loc != nullptr) ir_ref_instruction(result_loc, irb->current_basic_block);
+
+ return &instruction->base;
+}
+
+
static IrInstruction *ir_build_vector_to_array(IrAnalyze *ira, IrInstruction *source_instruction,
ZigType *result_type, IrInstruction *vector, IrInstruction *result_loc)
{
@@ -5110,7 +5136,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
if (arg1_value == irb->codegen->invalid_instruction)
return arg1_value;
- IrInstruction *ptr_instruction = ir_build_field_ptr_instruction(irb, scope, node, arg0_value, arg1_value);
+ IrInstruction *ptr_instruction = ir_build_field_ptr_instruction(irb, scope, node,
+ arg0_value, arg1_value, false);
if (lval == LValPtr)
return ptr_instruction;
@@ -5651,6 +5678,23 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
IrInstruction *has_decl = ir_build_has_decl(irb, scope, node, arg0_value, arg1_value);
return ir_lval_wrap(irb, scope, has_decl, lval, result_loc);
}
+ case BuiltinFnIdUnionInit:
+ {
+ AstNode *union_type_node = node->data.fn_call_expr.params.at(0);
+ IrInstruction *union_type_inst = ir_gen_node(irb, union_type_node, scope);
+ if (union_type_inst == irb->codegen->invalid_instruction)
+ return union_type_inst;
+
+ AstNode *name_node = node->data.fn_call_expr.params.at(1);
+ IrInstruction *name_inst = ir_gen_node(irb, name_node, scope);
+ if (name_inst == irb->codegen->invalid_instruction)
+ return name_inst;
+
+ AstNode *init_node = node->data.fn_call_expr.params.at(2);
+
+ return ir_gen_union_init_expr(irb, scope, node, union_type_inst, name_inst, init_node,
+ lval, result_loc);
+ }
}
zig_unreachable();
}
@@ -5929,6 +5973,31 @@ static IrInstruction *ir_gen_prefix_op_expr(IrBuilder *irb, Scope *scope, AstNod
zig_unreachable();
}
+static IrInstruction *ir_gen_union_init_expr(IrBuilder *irb, Scope *scope, AstNode *source_node,
+ IrInstruction *union_type, IrInstruction *field_name, AstNode *expr_node,
+ LVal lval, ResultLoc *parent_result_loc)
+{
+ IrInstruction *container_ptr = ir_build_resolve_result(irb, scope, source_node, parent_result_loc, union_type);
+ IrInstruction *field_ptr = ir_build_field_ptr_instruction(irb, scope, source_node, container_ptr,
+ field_name, true);
+
+ ResultLocInstruction *result_loc_inst = allocate<ResultLocInstruction>(1);
+ result_loc_inst->base.id = ResultLocIdInstruction;
+ result_loc_inst->base.source_instruction = field_ptr;
+ ir_ref_instruction(field_ptr, irb->current_basic_block);
+ ir_build_reset_result(irb, scope, expr_node, &result_loc_inst->base);
+
+ IrInstruction *expr_value = ir_gen_node_extra(irb, expr_node, scope, LValNone,
+ &result_loc_inst->base);
+ if (expr_value == irb->codegen->invalid_instruction)
+ return expr_value;
+
+ IrInstruction *init_union = ir_build_union_init_named_field(irb, scope, source_node, union_type,
+ field_name, field_ptr, container_ptr);
+
+ return ir_lval_wrap(irb, scope, init_union, lval, parent_result_loc);
+}
+
static IrInstruction *ir_gen_container_init_expr(IrBuilder *irb, Scope *scope, AstNode *node, LVal lval,
ResultLoc *parent_result_loc)
{
@@ -19408,32 +19477,21 @@ static IrInstruction *ir_analyze_instruction_ref(IrAnalyze *ira, IrInstructionRe
return ir_get_ref(ira, &ref_instruction->base, value, ref_instruction->is_const, ref_instruction->is_volatile);
}
-static IrInstruction *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrInstruction *instruction,
- ZigType *container_type, size_t instr_field_count, IrInstructionContainerInitFieldsField *fields,
- IrInstruction *result_loc)
+static IrInstruction *ir_analyze_union_init(IrAnalyze *ira, IrInstruction *source_instruction,
+ AstNode *field_source_node, ZigType *union_type, Buf *field_name, IrInstruction *field_result_loc,
+ IrInstruction *result_loc)
{
Error err;
- assert(container_type->id == ZigTypeIdUnion);
-
- if ((err = type_resolve(ira->codegen, container_type, ResolveStatusSizeKnown)))
- return ira->codegen->invalid_instruction;
-
- if (instr_field_count != 1) {
- ir_add_error(ira, instruction,
- buf_sprintf("union initialization expects exactly one field"));
- return ira->codegen->invalid_instruction;
- }
+ assert(union_type->id == ZigTypeIdUnion);
- IrInstructionContainerInitFieldsField *field = &fields[0];
- IrInstruction *field_result_loc = field->result_loc->child;
- if (type_is_invalid(field_result_loc->value.type))
+ if ((err = type_resolve(ira->codegen, union_type, ResolveStatusSizeKnown)))
return ira->codegen->invalid_instruction;
- TypeUnionField *type_field = find_union_type_field(container_type, field->name);
+ TypeUnionField *type_field = find_union_type_field(union_type, field_name);
if (type_field == nullptr) {
- ir_add_error_node(ira, field->source_node,
+ ir_add_error_node(ira, field_source_node,
buf_sprintf("no member named '%s' in union '%s'",
- buf_ptr(field->name), buf_ptr(&container_type->name)));
+ buf_ptr(field_name), buf_ptr(&union_type->name)));
return ira->codegen->invalid_instruction;
}
@@ -19450,12 +19508,12 @@ static IrInstruction *ir_analyze_container_init_fields_union(IrAnalyze *ira, IrI
}
}
- bool is_comptime = ir_should_inline(ira->new_irb.exec, instruction->scope)
- || type_requires_comptime(ira->codegen, container_type) == ReqCompTimeYes;
+ bool is_comptime = ir_should_inline(ira->new_irb.exec, source_instruction->scope)
+ || type_requires_comptime(ira->codegen, union_type) == ReqCompTimeYes;
- IrInstruction *result = ir_get_deref(ira, instruction, result_loc, nullptr);
+ IrInstruction *result = ir_get_deref(ira, source_instruction, result_loc, nullptr);
if (is_comptime && !instr_is_comptime(result)) {
- ir_add_error(ira, field->result_loc,
+ ir_add_error(ira, field_result_loc,
buf_sprintf("unable to evaluate constant expression"));
return ira->codegen->invalid_instruction;
}
@@ -19468,8 +19526,18 @@ static IrInstruction *ir_analyze_container_init_fields(IrAnalyze *ira, IrInstruc
{
Error err;
if (container_type->id == ZigTypeIdUnion) {
- return ir_analyze_container_init_fields_union(ira, instruction, container_type, instr_field_count,
- fields, result_loc);
+ if (instr_field_count != 1) {
+ ir_add_error(ira, instruction,
+ buf_sprintf("union initialization expects exactly one field"));
+ return ira->codegen->invalid_instruction;
+ }
+ IrInstructionContainerInitFieldsField *field = &fields[0];
+ IrInstruction *field_result_loc = field->result_loc->child;
+ if (type_is_invalid(field_result_loc->value.type))
+ return ira->codegen->invalid_instruction;
+
+ return ir_analyze_union_init(ira, instruction, field->source_node, container_type, field->name,
+ field_result_loc, result_loc);
}
if (container_type->id != ZigTypeIdStruct || is_slice(container_type)) {
ir_add_error(ira, instruction,
@@ -25326,6 +25394,35 @@ static IrInstruction *ir_analyze_instruction_bit_cast_src(IrAnalyze *ira, IrInst
return instruction->result_loc_bit_cast->parent->gen_instruction;
}
+static IrInstruction *ir_analyze_instruction_union_init_named_field(IrAnalyze *ira,
+ IrInstructionUnionInitNamedField *instruction)
+{
+ ZigType *union_type = ir_resolve_type(ira, instruction->union_type->child);
+ if (type_is_invalid(union_type))
+ return ira->codegen->invalid_instruction;
+
+ if (union_type->id != ZigTypeIdUnion) {
+ ir_add_error(ira, instruction->union_type,
+ buf_sprintf("non-union type '%s' passed to @unionInit", buf_ptr(&union_type->name)));
+ return ira->codegen->invalid_instruction;
+ }
+
+ Buf *field_name = ir_resolve_str(ira, instruction->field_name->child);
+ if (field_name == nullptr)
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *field_result_loc = instruction->field_result_loc->child;
+ if (type_is_invalid(field_result_loc->value.type))
+ return ira->codegen->invalid_instruction;
+
+ IrInstruction *result_loc = instruction->result_loc->child;
+ if (type_is_invalid(result_loc->value.type))
+ return ira->codegen->invalid_instruction;
+
+ return ir_analyze_union_init(ira, &instruction->base, instruction->base.source_node,
+ union_type, field_name, field_result_loc, result_loc);
+}
+
static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction *instruction) {
switch (instruction->id) {
case IrInstructionIdInvalid:
@@ -25641,6 +25738,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
return ir_analyze_instruction_end_expr(ira, (IrInstructionEndExpr *)instruction);
case IrInstructionIdBitCastSrc:
return ir_analyze_instruction_bit_cast_src(ira, (IrInstructionBitCastSrc *)instruction);
+ case IrInstructionIdUnionInitNamedField:
+ return ir_analyze_instruction_union_init_named_field(ira, (IrInstructionUnionInitNamedField *)instruction);
}
zig_unreachable();
}
@@ -25794,6 +25893,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdCast:
case IrInstructionIdContainerInitList:
case IrInstructionIdContainerInitFields:
+ case IrInstructionIdUnionInitNamedField:
case IrInstructionIdFieldPtr:
case IrInstructionIdElemPtr:
case IrInstructionIdVarPtr:
diff --git a/src/ir_print.cpp b/src/ir_print.cpp
@@ -1626,6 +1626,18 @@ static void ir_print_undeclared_ident(IrPrint *irp, IrInstructionUndeclaredIdent
fprintf(irp->f, "@undeclaredIdent(%s)", buf_ptr(instruction->name));
}
+static void ir_print_union_init_named_field(IrPrint *irp, IrInstructionUnionInitNamedField *instruction) {
+ fprintf(irp->f, "@unionInit(");
+ ir_print_other_instruction(irp, instruction->union_type);
+ fprintf(irp->f, ", ");
+ ir_print_other_instruction(irp, instruction->field_name);
+ fprintf(irp->f, ", ");
+ ir_print_other_instruction(irp, instruction->field_result_loc);
+ fprintf(irp->f, ", ");
+ ir_print_other_instruction(irp, instruction->result_loc);
+ fprintf(irp->f, ")");
+}
+
static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
ir_print_prefix(irp, instruction);
switch (instruction->id) {
@@ -2132,6 +2144,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
case IrInstructionIdEndExpr:
ir_print_end_expr(irp, (IrInstructionEndExpr *)instruction);
break;
+ case IrInstructionIdUnionInitNamedField:
+ ir_print_union_init_named_field(irp, (IrInstructionUnionInitNamedField *)instruction);
+ break;
}
fprintf(irp->f, "\n");
}
diff --git a/test/stage1/behavior/union.zig b/test/stage1/behavior/union.zig
@@ -416,9 +416,44 @@ test "return union init with void payload" {
two: u32,
};
fn func() Outer {
- return Outer{ .state = State{ .one = {} }};
+ return Outer{ .state = State{ .one = {} } };
}
};
S.entry();
comptime S.entry();
}
+
+test "@unionInit can modify a union type" {
+ const UnionInitEnum = union(enum) {
+ Boolean: bool,
+ Byte: u8,
+ };
+
+ var value: UnionInitEnum = undefined;
+
+ value = @unionInit(UnionInitEnum, "Boolean", true);
+ expect(value.Boolean == true);
+ value.Boolean = false;
+ expect(value.Boolean == false);
+
+ value = @unionInit(UnionInitEnum, "Byte", 2);
+ expect(value.Byte == 2);
+ value.Byte = 3;
+ expect(value.Byte == 3);
+}
+
+test "@unionInit can modify a pointer value" {
+ const UnionInitEnum = union(enum) {
+ Boolean: bool,
+ Byte: u8,
+ };
+
+ var value: UnionInitEnum = undefined;
+ var value_ptr = &value;
+
+ value_ptr.* = @unionInit(UnionInitEnum, "Boolean", true);
+ expect(value.Boolean == true);
+
+ value_ptr.* = @unionInit(UnionInitEnum, "Byte", 2);
+ expect(value.Byte == 2);
+}