diff --git a/src/all_types.hpp b/src/all_types.hpp index 464a1d6ba4..695f22ac90 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2432,7 +2432,8 @@ enum IrInstructionId { IrInstructionIdIntType, IrInstructionIdVectorType, IrInstructionIdShuffleVector, - IrInstructionIdSplat, + IrInstructionIdSplatSrc, + IrInstructionIdSplatGen, IrInstructionIdBoolNot, IrInstructionIdMemset, IrInstructionIdMemcpy, @@ -3683,13 +3684,19 @@ struct IrInstructionShuffleVector { IrInstruction *mask; // This is in zig-format, not llvm format }; -struct IrInstructionSplat { +struct IrInstructionSplatSrc { IrInstruction base; IrInstruction *len; IrInstruction *scalar; }; +struct IrInstructionSplatGen { + IrInstruction base; + + IrInstruction *scalar; +}; + struct IrInstructionAssertZero { IrInstruction base; diff --git a/src/codegen.cpp b/src/codegen.cpp index 49681c20c1..b0817e8eb8 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -4619,18 +4619,16 @@ static LLVMValueRef ir_render_shuffle_vector(CodeGen *g, IrExecutable *executabl llvm_mask_value, ""); } -static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplat *instruction) { - uint64_t len = bigint_as_u64(&instruction->len->value.data.x_bigint); - LLVMValueRef wrapped_scalar_undef = LLVMGetUndef(instruction->base.value.type->llvm_type); - LLVMValueRef wrapped_scalar = LLVMBuildInsertElement(g->builder, wrapped_scalar_undef, - ir_llvm_value(g, instruction->scalar), - LLVMConstInt(LLVMInt32Type(), 0, false), - ""); - return LLVMBuildShuffleVector(g->builder, - wrapped_scalar, - wrapped_scalar_undef, - LLVMConstNull(LLVMVectorType(g->builtin_types.entry_u32->llvm_type, (uint32_t)len)), - ""); +static LLVMValueRef ir_render_splat(CodeGen *g, IrExecutable *executable, IrInstructionSplatGen *instruction) { + ZigType *result_type = instruction->base.value.type; + src_assert(result_type->id == ZigTypeIdVector, instruction->base.source_node); + uint32_t len = result_type->data.vector.len; + LLVMTypeRef op_llvm_type = LLVMVectorType(get_llvm_type(g, instruction->scalar->value.type), 1); + LLVMTypeRef mask_llvm_type = LLVMVectorType(LLVMInt32Type(), len); + LLVMValueRef undef_vector = LLVMGetUndef(op_llvm_type); + LLVMValueRef op_vector = LLVMBuildInsertElement(g->builder, undef_vector, + ir_llvm_value(g, instruction->scalar), LLVMConstInt(LLVMInt32Type(), 0, false), ""); + return LLVMBuildShuffleVector(g->builder, op_vector, undef_vector, LLVMConstNull(mask_llvm_type), ""); } static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, IrInstructionPopCount *instruction) { @@ -6000,6 +5998,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, case IrInstructionIdFrameSizeSrc: case IrInstructionIdAllocaGen: case IrInstructionIdAwaitSrc: + case IrInstructionIdSplatSrc: zig_unreachable(); case IrInstructionIdDeclVarGen: @@ -6160,8 +6159,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_spill_end(g, executable, (IrInstructionSpillEnd *)instruction); case IrInstructionIdShuffleVector: return ir_render_shuffle_vector(g, executable, (IrInstructionShuffleVector *) instruction); - case IrInstructionIdSplat: - return ir_render_splat(g, executable, (IrInstructionSplat *) instruction); + case IrInstructionIdSplatGen: + return ir_render_splat(g, executable, (IrInstructionSplatGen *) instruction); } zig_unreachable(); } diff --git a/src/ir.cpp b/src/ir.cpp index 8fca50c6f7..0c48a2f982 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -721,8 +721,12 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionShuffleVector *) return IrInstructionIdShuffleVector; } -static constexpr IrInstructionId ir_instruction_id(IrInstructionSplat *) { - return IrInstructionIdSplat; +static constexpr IrInstructionId ir_instruction_id(IrInstructionSplatSrc *) { + return IrInstructionIdSplatSrc; +} + +static constexpr IrInstructionId ir_instruction_id(IrInstructionSplatGen *) { + return IrInstructionIdSplatGen; } static constexpr IrInstructionId ir_instruction_id(IrInstructionBoolNot *) { @@ -2304,10 +2308,10 @@ static IrInstruction *ir_build_shuffle_vector(IrBuilder *irb, Scope *scope, AstN return &instruction->base; } -static IrInstruction *ir_build_splat(IrBuilder *irb, Scope *scope, AstNode *source_node, +static IrInstruction *ir_build_splat_src(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *len, IrInstruction *scalar) { - IrInstructionSplat *instruction = ir_build_instruction(irb, scope, source_node); + IrInstructionSplatSrc *instruction = ir_build_instruction(irb, scope, source_node); instruction->len = len; instruction->scalar = scalar; @@ -2373,6 +2377,19 @@ static IrInstruction *ir_build_slice_src(IrBuilder *irb, Scope *scope, AstNode * return &instruction->base; } +static IrInstruction *ir_build_splat_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *result_type, + IrInstruction *scalar) +{ + IrInstructionSplatGen *instruction = ir_build_instruction( + &ira->new_irb, source_instruction->scope, source_instruction->source_node); + instruction->base.value.type = result_type; + instruction->scalar = scalar; + + ir_ref_instruction(scalar, ira->new_irb.current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_slice_gen(IrAnalyze *ira, IrInstruction *source_instruction, ZigType *slice_type, IrInstruction *ptr, IrInstruction *start, IrInstruction *end, bool safety_check_on, IrInstruction *result_loc) { @@ -5014,7 +5031,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo if (arg1_value == irb->codegen->invalid_instruction) return arg1_value; - IrInstruction *splat = ir_build_splat(irb, scope, node, + IrInstruction *splat = ir_build_splat_src(irb, scope, node, arg0_value, arg1_value); return ir_lval_wrap(irb, scope, splat, lval, result_loc); } @@ -11082,16 +11099,23 @@ static ZigType *ir_resolve_type(IrAnalyze *ira, IrInstruction *type_value) { return ir_resolve_const_type(ira->codegen, ira->new_irb.exec, type_value->source_node, val); } +static Error ir_validate_vector_elem_type(IrAnalyze *ira, IrInstruction *source_instr, ZigType *elem_type) { + if (!is_valid_vector_elem_type(elem_type)) { + ir_add_error(ira, source_instr, + buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid", + buf_ptr(&elem_type->name))); + return ErrorSemanticAnalyzeFail; + } + return ErrorNone; +} + static ZigType *ir_resolve_vector_elem_type(IrAnalyze *ira, IrInstruction *elem_type_value) { + Error err; ZigType *elem_type = ir_resolve_type(ira, elem_type_value); if (type_is_invalid(elem_type)) return ira->codegen->builtin_types.entry_invalid; - if (!is_valid_vector_elem_type(elem_type)) { - ir_add_error(ira, elem_type_value, - buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid", - buf_ptr(&elem_type->name))); + if ((err = ir_validate_vector_elem_type(ira, elem_type_value, elem_type))) return ira->codegen->builtin_types.entry_invalid; - } return elem_type; } @@ -22357,7 +22381,9 @@ static IrInstruction *ir_analyze_instruction_shuffle_vector(IrAnalyze *ira, IrIn return ir_analyze_shuffle_vector(ira, &instruction->base, scalar_type, a, b, mask); } -static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplat *instruction) { +static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstructionSplatSrc *instruction) { + Error err; + IrInstruction *len = instruction->len->child; if (type_is_invalid(len->value.type)) return ira->codegen->invalid_instruction; @@ -22366,41 +22392,32 @@ static IrInstruction *ir_analyze_instruction_splat(IrAnalyze *ira, IrInstruction if (type_is_invalid(scalar->value.type)) return ira->codegen->invalid_instruction; - uint64_t len_int; - if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_int)) { - ir_add_error(ira, len, - buf_sprintf("splat length must be comptime")); + uint64_t len_u64; + if (!ir_resolve_unsigned(ira, len, ira->codegen->builtin_types.entry_u32, &len_u64)) return ira->codegen->invalid_instruction; - } + uint32_t len_int = len_u64; - if (!is_valid_vector_elem_type(scalar->value.type)) { - ir_add_error(ira, len, - buf_sprintf("vector element type must be integer, float, bool, or pointer; '%s' is invalid", - buf_ptr(&scalar->value.type->name))); + if ((err = ir_validate_vector_elem_type(ira, scalar, scalar->value.type))) return ira->codegen->invalid_instruction; - } ZigType *return_type = get_vector_type(ira->codegen, len_int, scalar->value.type); if (instr_is_comptime(scalar)) { - IrInstruction *result = ir_const_undef(ira, scalar, return_type); - result->value.data.x_array.data.s_none.elements = - allocate(len_int); - for (uint32_t i = 0; i < len_int; i++) { - result->value.data.x_array.data.s_none.elements[i] = - scalar->value; + ConstExprValue *scalar_val = ir_resolve_const(ira, scalar, UndefOk); + if (scalar_val == nullptr) + return ira->codegen->invalid_instruction; + if (scalar_val->special == ConstValSpecialUndef) + return ir_const_undef(ira, &instruction->base, return_type); + + IrInstruction *result = ir_const(ira, &instruction->base, return_type); + result->value.data.x_array.data.s_none.elements = create_const_vals(len_int); + for (uint32_t i = 0; i < len_int; i += 1) { + copy_const_val(&result->value.data.x_array.data.s_none.elements[i], scalar_val, false); } - result->value.type = return_type; - result->value.special = ConstValSpecialStatic; return result; } - IrInstruction *result = ir_build_splat(&ira->new_irb, - instruction->base.scope, instruction->base.source_node, - instruction->len->child, instruction->scalar->child); - result->value.type = return_type; - result->value.special = ConstValSpecialRuntime; - return result; + return ir_build_splat_gen(ira, &instruction->base, return_type, scalar); } static IrInstruction *ir_analyze_instruction_bool_not(IrAnalyze *ira, IrInstructionBoolNot *instruction) { @@ -25857,6 +25874,7 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction case IrInstructionIdTestErrGen: case IrInstructionIdFrameSizeGen: case IrInstructionIdAwaitGen: + case IrInstructionIdSplatGen: zig_unreachable(); case IrInstructionIdReturn: @@ -25987,8 +26005,8 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_vector_type(ira, (IrInstructionVectorType *)instruction); case IrInstructionIdShuffleVector: return ir_analyze_instruction_shuffle_vector(ira, (IrInstructionShuffleVector *)instruction); - case IrInstructionIdSplat: - return ir_analyze_instruction_splat(ira, (IrInstructionSplat *)instruction); + case IrInstructionIdSplatSrc: + return ir_analyze_instruction_splat(ira, (IrInstructionSplatSrc *)instruction); case IrInstructionIdBoolNot: return ir_analyze_instruction_bool_not(ira, (IrInstructionBoolNot *)instruction); case IrInstructionIdMemset: @@ -26325,7 +26343,8 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdIntType: case IrInstructionIdVectorType: case IrInstructionIdShuffleVector: - case IrInstructionIdSplat: + case IrInstructionIdSplatSrc: + case IrInstructionIdSplatGen: case IrInstructionIdBoolNot: case IrInstructionIdSliceSrc: case IrInstructionIdMemberCount: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index 0dee7d342a..aae65d50a9 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -44,8 +44,10 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) { return "Invalid"; case IrInstructionIdShuffleVector: return "Shuffle"; - case IrInstructionIdSplat: - return "Splat"; + case IrInstructionIdSplatSrc: + return "SplatSrc"; + case IrInstructionIdSplatGen: + return "SplatGen"; case IrInstructionIdDeclVarSrc: return "DeclVarSrc"; case IrInstructionIdDeclVarGen: @@ -1224,7 +1226,7 @@ static void ir_print_shuffle_vector(IrPrint *irp, IrInstructionShuffleVector *in fprintf(irp->f, ")"); } -static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) { +static void ir_print_splat_src(IrPrint *irp, IrInstructionSplatSrc *instruction) { fprintf(irp->f, "@splat("); ir_print_other_instruction(irp, instruction->len); fprintf(irp->f, ", "); @@ -1232,6 +1234,12 @@ static void ir_print_splat(IrPrint *irp, IrInstructionSplat *instruction) { fprintf(irp->f, ")"); } +static void ir_print_splat_gen(IrPrint *irp, IrInstructionSplatGen *instruction) { + fprintf(irp->f, "@splat("); + ir_print_other_instruction(irp, instruction->scalar); + fprintf(irp->f, ")"); +} + static void ir_print_bool_not(IrPrint *irp, IrInstructionBoolNot *instruction) { fprintf(irp->f, "! "); ir_print_other_instruction(irp, instruction->value); @@ -2170,8 +2178,11 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool case IrInstructionIdShuffleVector: ir_print_shuffle_vector(irp, (IrInstructionShuffleVector *)instruction); break; - case IrInstructionIdSplat: - ir_print_splat(irp, (IrInstructionSplat *)instruction); + case IrInstructionIdSplatSrc: + ir_print_splat_src(irp, (IrInstructionSplatSrc *)instruction); + break; + case IrInstructionIdSplatGen: + ir_print_splat_gen(irp, (IrInstructionSplatGen *)instruction); break; case IrInstructionIdBoolNot: ir_print_bool_not(irp, (IrInstructionBoolNot *)instruction); diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 2909bffc3b..034800fd4c 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -6514,7 +6514,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { \\ var v = @splat(4, c); \\} , - "tmp.zig:3:20: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid", + "tmp.zig:3:23: error: vector element type must be integer, float, bool, or pointer; 'comptime_int' is invalid", ); cases.add("compileLog of tagged enum doesn't crash the compiler", diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index 88a332d87b..d3a771fca8 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -145,10 +145,11 @@ test "vector @splat" { var v: u32 = 5; var x = @splat(4, v); expect(@typeOf(x) == @Vector(4, u32)); - expect(x[0] == 5); - expect(x[1] == 5); - expect(x[2] == 5); - expect(x[3] == 5); + var array_x: [4]u32 = x; + expect(array_x[0] == 5); + expect(array_x[1] == 5); + expect(array_x[2] == 5); + expect(array_x[3] == 5); } }; S.doTheTest();