diff --git a/src/codegen.cpp b/src/codegen.cpp index c43611e8af..0cb1e93211 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3229,7 +3229,8 @@ static LLVMValueRef ir_render_br(CodeGen *g, IrExecutable *executable, IrInstruc static LLVMValueRef ir_render_un_op(CodeGen *g, IrExecutable *executable, IrInstructionUnOp *un_op_instruction) { IrUnOp op_id = un_op_instruction->op_id; LLVMValueRef expr = ir_llvm_value(g, un_op_instruction->value); - ZigType *expr_type = un_op_instruction->value->value.type; + ZigType *operand_type = un_op_instruction->value->value.type; + ZigType *scalar_type = (operand_type->id == ZigTypeIdVector) ? operand_type->data.vector.elem_type : operand_type; switch (op_id) { case IrUnOpInvalid: @@ -3239,16 +3240,16 @@ static LLVMValueRef ir_render_un_op(CodeGen *g, IrExecutable *executable, IrInst case IrUnOpNegation: case IrUnOpNegationWrap: { - if (expr_type->id == ZigTypeIdFloat) { + if (scalar_type->id == ZigTypeIdFloat) { ZigLLVMSetFastMath(g->builder, ir_want_fast_math(g, &un_op_instruction->base)); return LLVMBuildFNeg(g->builder, expr, ""); - } else if (expr_type->id == ZigTypeIdInt) { + } else if (scalar_type->id == ZigTypeIdInt) { if (op_id == IrUnOpNegationWrap) { return LLVMBuildNeg(g->builder, expr, ""); } else if (ir_want_runtime_safety(g, &un_op_instruction->base)) { LLVMValueRef zero = LLVMConstNull(LLVMTypeOf(expr)); - return gen_overflow_op(g, expr_type, AddSubMulSub, zero, expr); - } else if (expr_type->data.integral.is_signed) { + return gen_overflow_op(g, operand_type, AddSubMulSub, zero, expr); + } else if (scalar_type->data.integral.is_signed) { return LLVMBuildNSWNeg(g->builder, expr, ""); } else { return LLVMBuildNUWNeg(g->builder, expr, ""); diff --git a/src/ir.cpp b/src/ir.cpp index 36f11cb108..a3c15dbd0f 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -14620,6 +14620,41 @@ static IrInstruction *ir_analyze_maybe(IrAnalyze *ira, IrInstructionUnOp *un_op_ zig_unreachable(); } +static ErrorMsg *ir_eval_negation_scalar(IrAnalyze *ira, IrInstruction *source_instr, ZigType *scalar_type, + ConstExprValue *operand_val, ConstExprValue *scalar_out_val, bool is_wrap_op) +{ + bool is_float = (scalar_type->id == ZigTypeIdFloat || scalar_type->id == ZigTypeIdComptimeFloat); + + bool ok_type = ((scalar_type->id == ZigTypeIdInt && scalar_type->data.integral.is_signed) || + scalar_type->id == ZigTypeIdComptimeInt || (is_float && !is_wrap_op)); + + if (!ok_type) { + const char *fmt = is_wrap_op ? "invalid wrapping negation type: '%s'" : "invalid negation type: '%s'"; + return ir_add_error(ira, source_instr, buf_sprintf(fmt, buf_ptr(&scalar_type->name))); + } + + if (is_float) { + float_negate(scalar_out_val, operand_val); + } else if (is_wrap_op) { + bigint_negate_wrap(&scalar_out_val->data.x_bigint, &operand_val->data.x_bigint, + scalar_type->data.integral.bit_count); + } else { + bigint_negate(&scalar_out_val->data.x_bigint, &operand_val->data.x_bigint); + } + + scalar_out_val->type = scalar_type; + scalar_out_val->special = ConstValSpecialStatic; + + if (is_wrap_op || is_float || scalar_type->id == ZigTypeIdComptimeInt) { + return nullptr; + } + + if (!bigint_fits_in_bits(&scalar_out_val->data.x_bigint, scalar_type->data.integral.bit_count, true)) { + return ir_add_error(ira, source_instr, buf_sprintf("negation caused overflow")); + } + return nullptr; +} + static IrInstruction *ir_analyze_negation(IrAnalyze *ira, IrInstructionUnOp *instruction) { IrInstruction *value = instruction->value->child; ZigType *expr_type = value->value.type; @@ -14628,47 +14663,50 @@ static IrInstruction *ir_analyze_negation(IrAnalyze *ira, IrInstructionUnOp *ins bool is_wrap_op = (instruction->op_id == IrUnOpNegationWrap); - bool is_float = (expr_type->id == ZigTypeIdFloat || expr_type->id == ZigTypeIdComptimeFloat); + ZigType *scalar_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type; - if ((expr_type->id == ZigTypeIdInt && expr_type->data.integral.is_signed) || - expr_type->id == ZigTypeIdComptimeInt || (is_float && !is_wrap_op)) - { - if (instr_is_comptime(value)) { - ConstExprValue *target_const_val = ir_resolve_const(ira, value, UndefBad); - if (!target_const_val) - return ira->codegen->invalid_instruction; + if (instr_is_comptime(value)) { + ConstExprValue *operand_val = ir_resolve_const(ira, value, UndefBad); + if (!operand_val) + return ira->codegen->invalid_instruction; - IrInstruction *result = ir_const(ira, &instruction->base, expr_type); - ConstExprValue *out_val = &result->value; - if (is_float) { - float_negate(out_val, target_const_val); - } else if (is_wrap_op) { - bigint_negate_wrap(&out_val->data.x_bigint, &target_const_val->data.x_bigint, - expr_type->data.integral.bit_count); - } else { - bigint_negate(&out_val->data.x_bigint, &target_const_val->data.x_bigint); + IrInstruction *result_instruction = ir_const(ira, &instruction->base, expr_type); + ConstExprValue *out_val = &result_instruction->value; + if (expr_type->id == ZigTypeIdVector) { + expand_undef_array(ira->codegen, operand_val); + out_val->special = ConstValSpecialUndef; + expand_undef_array(ira->codegen, out_val); + size_t len = expr_type->data.vector.len; + for (size_t i = 0; i < len; i += 1) { + ConstExprValue *scalar_operand_val = &operand_val->data.x_array.data.s_none.elements[i]; + ConstExprValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i]; + assert(scalar_operand_val->type == scalar_type); + assert(scalar_out_val->type == scalar_type); + ErrorMsg *msg = ir_eval_negation_scalar(ira, &instruction->base, scalar_type, + scalar_operand_val, scalar_out_val, is_wrap_op); + if (msg != nullptr) { + add_error_note(ira->codegen, msg, instruction->base.source_node, + buf_sprintf("when computing vector element at index %" ZIG_PRI_usize, i)); + return ira->codegen->invalid_instruction; + } } - if (is_wrap_op || is_float || expr_type->id == ZigTypeIdComptimeInt) { - return result; - } - - if (!bigint_fits_in_bits(&out_val->data.x_bigint, expr_type->data.integral.bit_count, true)) { - ir_add_error(ira, &instruction->base, buf_sprintf("negation caused overflow")); + out_val->type = expr_type; + out_val->special = ConstValSpecialStatic; + } else { + if (ir_eval_negation_scalar(ira, &instruction->base, scalar_type, operand_val, out_val, + is_wrap_op) != nullptr) + { return ira->codegen->invalid_instruction; } - return result; } - - IrInstruction *result = ir_build_un_op(&ira->new_irb, - instruction->base.scope, instruction->base.source_node, - instruction->op_id, value); - result->value.type = expr_type; - return result; + return result_instruction; } - const char *fmt = is_wrap_op ? "invalid wrapping negation type: '%s'" : "invalid negation type: '%s'"; - ir_add_error(ira, &instruction->base, buf_sprintf(fmt, buf_ptr(&expr_type->name))); - return ira->codegen->invalid_instruction; + IrInstruction *result = ir_build_un_op(&ira->new_irb, + instruction->base.scope, instruction->base.source_node, + instruction->op_id, value); + result->value.type = expr_type; + return result; } static IrInstruction *ir_analyze_bin_not(IrAnalyze *ira, IrInstructionUnOp *instruction) { diff --git a/test/compile_errors.zig b/test/compile_errors.zig index c39d34b3e9..48eb7cd85d 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -1,6 +1,18 @@ const tests = @import("tests.zig"); pub fn addCases(cases: *tests.CompileErrorContext) void { + cases.addTest( + "comptime vector overflow shows the index", + \\comptime { + \\ var a: @Vector(4, u8) = []u8{ 1, 2, 255, 4 }; + \\ var b: @Vector(4, u8) = []u8{ 5, 6, 1, 8 }; + \\ var x = a + b; + \\} + , + ".tmp_source.zig:4:15: error: operation caused overflow", + ".tmp_source.zig:4:15: note: when computing vector element at index 2", + ); + cases.addTest( "packed struct with fields of not allowed types", \\const A = packed struct { diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig index 12cac64b3a..0427efabd5 100644 --- a/test/runtime_safety.zig +++ b/test/runtime_safety.zig @@ -118,6 +118,47 @@ pub fn addCases(cases: *tests.CompareOutputContext) void { \\} ); + cases.addRuntimeSafety("vector integer subtraction overflow", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var a: @Vector(4, u32) = []u32{ 1, 2, 8, 4 }; + \\ var b: @Vector(4, u32) = []u32{ 5, 6, 7, 8 }; + \\ const x = sub(b, a); + \\} + \\fn sub(a: @Vector(4, u32), b: @Vector(4, u32)) @Vector(4, u32) { + \\ return a - b; + \\} + ); + + cases.addRuntimeSafety("vector integer multiplication overflow", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var a: @Vector(4, u8) = []u8{ 1, 2, 200, 4 }; + \\ var b: @Vector(4, u8) = []u8{ 5, 6, 2, 8 }; + \\ const x = mul(b, a); + \\} + \\fn mul(a: @Vector(4, u8), b: @Vector(4, u8)) @Vector(4, u8) { + \\ return a * b; + \\} + ); + + cases.addRuntimeSafety("vector integer negation overflow", + \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { + \\ @import("std").os.exit(126); + \\} + \\pub fn main() void { + \\ var a: @Vector(4, i16) = []i16{ 1, -32768, 200, 4 }; + \\ const x = neg(a); + \\} + \\fn neg(a: @Vector(4, i16)) @Vector(4, i16) { + \\ return -a; + \\} + ); + cases.addRuntimeSafety("integer subtraction overflow", \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn { \\ @import("std").os.exit(126); diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index b0d2871454..21bbe3160d 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -5,11 +5,28 @@ const expect = std.testing.expect; test "vector wrap operators" { const S = struct { fn doTheTest() void { - const v: @Vector(4, i32) = [4]i32{ 10, 20, 30, 40 }; - const x: @Vector(4, i32) = [4]i32{ 1, 2, 3, 4 }; - expect(mem.eql(i32, ([4]i32)(v +% x), [4]i32{ 11, 22, 33, 44 })); - expect(mem.eql(i32, ([4]i32)(v -% x), [4]i32{ 9, 18, 27, 36 })); - expect(mem.eql(i32, ([4]i32)(v *% x), [4]i32{ 10, 40, 90, 160 })); + var v: @Vector(4, i32) = [4]i32{ 2147483647, -2, 30, 40 }; + var x: @Vector(4, i32) = [4]i32{ 1, 2147483647, 3, 4 }; + expect(mem.eql(i32, ([4]i32)(v +% x), [4]i32{ -2147483648, 2147483645, 33, 44 })); + expect(mem.eql(i32, ([4]i32)(v -% x), [4]i32{ 2147483646, 2147483647, 27, 36 })); + expect(mem.eql(i32, ([4]i32)(v *% x), [4]i32{ 2147483647, 2, 90, 160 })); + var z: @Vector(4, i32) = [4]i32{ 1, 2, 3, -2147483648 }; + expect(mem.eql(i32, ([4]i32)(-%z), [4]i32{ -1, -2, -3, -2147483648 })); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +} + +test "vector int operators" { + const S = struct { + fn doTheTest() void { + var v: @Vector(4, i32) = [4]i32{ 10, 20, 30, 40 }; + var x: @Vector(4, i32) = [4]i32{ 1, 2, 3, 4 }; + expect(mem.eql(i32, ([4]i32)(v + x), [4]i32{ 11, 22, 33, 44 })); + expect(mem.eql(i32, ([4]i32)(v - x), [4]i32{ 9, 18, 27, 36 })); + expect(mem.eql(i32, ([4]i32)(v * x), [4]i32{ 10, 40, 90, 160 })); + expect(mem.eql(i32, ([4]i32)(-v), [4]i32{ -10, -20, -30, -40 })); } }; S.doTheTest(); @@ -19,11 +36,12 @@ test "vector wrap operators" { test "vector float operators" { const S = struct { fn doTheTest() void { - const v: @Vector(4, f32) = [4]f32{ 10, 20, 30, 40 }; - const x: @Vector(4, f32) = [4]f32{ 1, 2, 3, 4 }; + var v: @Vector(4, f32) = [4]f32{ 10, 20, 30, 40 }; + var x: @Vector(4, f32) = [4]f32{ 1, 2, 3, 4 }; expect(mem.eql(f32, ([4]f32)(v + x), [4]f32{ 11, 22, 33, 44 })); expect(mem.eql(f32, ([4]f32)(v - x), [4]f32{ 9, 18, 27, 36 })); expect(mem.eql(f32, ([4]f32)(v * x), [4]f32{ 10, 40, 90, 160 })); + expect(mem.eql(f32, ([4]f32)(-x), [4]f32{ -1, -2, -3, -4 })); } }; S.doTheTest(); @@ -33,8 +51,8 @@ test "vector float operators" { test "vector bit operators" { const S = struct { fn doTheTest() void { - const v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 }; - const x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 }; + var v: @Vector(4, u8) = [4]u8{ 0b10101010, 0b10101010, 0b10101010, 0b10101010 }; + var x: @Vector(4, u8) = [4]u8{ 0b11110000, 0b00001111, 0b10101010, 0b01010101 }; expect(mem.eql(u8, ([4]u8)(v ^ x), [4]u8{ 0b01011010, 0b10100101, 0b00000000, 0b11111111 })); expect(mem.eql(u8, ([4]u8)(v | x), [4]u8{ 0b11111010, 0b10101111, 0b10101010, 0b11111111 })); expect(mem.eql(u8, ([4]u8)(v & x), [4]u8{ 0b10100000, 0b00001010, 0b10101010, 0b00000000 }));