commit 51c6bb92b1c0c02b214ae21986dce3f2e9960099 (tree)
parent 83f6f730cdd5bb9c2a12b30c0aac33d858a1eaa8
Author: Andrew Kelley <andrew@ziglang.org>
Date: Wed, 11 Mar 2020 14:22:40 -0400
Merge pull request #4709 from LemonBoy/implement-2096
Stricter shift left/right safety checks
Diffstat:
8 files changed, 172 insertions(+), 32 deletions(-)
diff --git a/lib/std/io.zig b/lib/std/io.zig
@@ -350,12 +350,18 @@ pub fn BitInStream(endian: builtin.Endian, comptime Error: type) type {
switch (endian) {
.Big => {
out_buffer = @as(Buf, self.bit_buffer >> shift);
- self.bit_buffer <<= n;
+ if (n >= u7_bit_count)
+ self.bit_buffer = 0
+ else
+ self.bit_buffer <<= n;
},
.Little => {
const value = (self.bit_buffer << shift) >> shift;
out_buffer = @as(Buf, value);
- self.bit_buffer >>= n;
+ if (n >= u7_bit_count)
+ self.bit_buffer = 0
+ else
+ self.bit_buffer >>= n;
},
}
self.bit_count -= n;
diff --git a/lib/std/mem.zig b/lib/std/mem.zig
@@ -935,6 +935,9 @@ pub fn writeInt(comptime T: type, buffer: *[@divExact(T.bit_count, 8)]u8, value:
pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void {
assert(buffer.len >= @divExact(T.bit_count, 8));
+ if (T.bit_count == 0)
+ return set(u8, buffer, 0);
+
// TODO I want to call writeIntLittle here but comptime eval facilities aren't good enough
const uint = std.meta.IntType(false, T.bit_count);
var bits = @truncate(uint, value);
@@ -952,6 +955,9 @@ pub fn writeIntSliceLittle(comptime T: type, buffer: []u8, value: T) void {
pub fn writeIntSliceBig(comptime T: type, buffer: []u8, value: T) void {
assert(buffer.len >= @divExact(T.bit_count, 8));
+ if (T.bit_count == 0)
+ return set(u8, buffer, 0);
+
// TODO I want to call writeIntBig here but comptime eval facilities aren't good enough
const uint = std.meta.IntType(false, T.bit_count);
var bits = @truncate(uint, value);
@@ -1821,7 +1827,7 @@ test "sliceAsBytes" {
}
test "sliceAsBytes with sentinel slice" {
- const empty_string:[:0]const u8 = "";
+ const empty_string: [:0]const u8 = "";
const bytes = sliceAsBytes(empty_string);
testing.expect(bytes.len == 0);
}
diff --git a/src/all_types.hpp b/src/all_types.hpp
@@ -1834,6 +1834,7 @@ enum PanicMsgId {
PanicMsgIdBadNoAsyncCall,
PanicMsgIdResumeNotSuspendedFn,
PanicMsgIdBadSentinel,
+ PanicMsgIdShxTooBigRhs,
PanicMsgIdCount,
};
diff --git a/src/codegen.cpp b/src/codegen.cpp
@@ -974,6 +974,8 @@ static Buf *panic_msg_buf(PanicMsgId msg_id) {
return buf_create_from_str("resumed a non-suspended function");
case PanicMsgIdBadSentinel:
return buf_create_from_str("sentinel mismatch");
+ case PanicMsgIdShxTooBigRhs:
+ return buf_create_from_str("shift amount is greater than the type size");
}
zig_unreachable();
}
@@ -2841,6 +2843,26 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}
+static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type, LLVMValueRef value) {
+ // We only check if the rhs value of the shift expression is greater or
+ // equal to the number of bits of the lhs if it's not a power of two,
+ // otherwise the check is useful as the allowed values are limited by the
+ // operand type itself
+ if (!is_power_of_2(lhs_type->data.integral.bit_count)) {
+ LLVMValueRef bit_count_value = LLVMConstInt(get_llvm_type(g, rhs_type),
+ lhs_type->data.integral.bit_count, false);
+ LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
+ LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckFail");
+ LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
+ LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
+
+ LLVMPositionBuilderAtEnd(g->builder, fail_block);
+ gen_safety_crash(g, PanicMsgIdShxTooBigRhs);
+
+ LLVMPositionBuilderAtEnd(g->builder, ok_block);
+ }
+}
+
static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
IrInstGenBinOp *bin_op_instruction)
{
@@ -2949,6 +2971,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
{
assert(scalar_type->id == ZigTypeIdInt);
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+
+ if (want_runtime_safety) {
+ gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
+ }
+
bool is_sloppy = (op_id == IrBinOpBitShiftLeftLossy);
if (is_sloppy) {
return LLVMBuildShl(g->builder, op1_value, op2_casted, "");
@@ -2965,6 +2992,11 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, IrExecutableGen *executable,
{
assert(scalar_type->id == ZigTypeIdInt);
LLVMValueRef op2_casted = gen_widen_or_shorten(g, false, op2->value->type, scalar_type, op2_value);
+
+ if (want_runtime_safety) {
+ gen_shift_rhs_check(g, scalar_type, op2->value->type, op2_value);
+ }
+
bool is_sloppy = (op_id == IrBinOpBitShiftRightLossy);
if (is_sloppy) {
if (scalar_type->data.integral.is_signed) {
diff --git a/src/ir.cpp b/src/ir.cpp
@@ -16635,49 +16635,69 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
IrInstGen *casted_op2;
IrBinOp op_id = bin_op_instruction->op_id;
if (op1->value->type->id == ZigTypeIdComptimeInt) {
+ // comptime_int has no finite bit width
casted_op2 = op2;
if (op_id == IrBinOpBitShiftLeftLossy) {
op_id = IrBinOpBitShiftLeftExact;
}
- if (casted_op2->value->data.x_bigint.is_negative) {
+ if (!instr_is_comptime(op2)) {
+ ir_add_error(ira, &bin_op_instruction->base.base,
+ buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
+ return ira->codegen->invalid_inst_gen;
+ }
+
+ ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
+ if (op2_val == nullptr)
+ return ira->codegen->invalid_inst_gen;
+
+ if (op2_val->data.x_bigint.is_negative) {
Buf *val_buf = buf_alloc();
- bigint_append_buf(val_buf, &casted_op2->value->data.x_bigint, 10);
- ir_add_error(ira, &casted_op2->base, buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
+ bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
+ ir_add_error(ira, &casted_op2->base,
+ buf_sprintf("shift by negative value %s", buf_ptr(val_buf)));
return ira->codegen->invalid_inst_gen;
}
} else {
+ const unsigned bit_count = op1->value->type->data.integral.bit_count;
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
- op1->value->type->data.integral.bit_count - 1);
- if (bin_op_instruction->op_id == IrBinOpBitShiftLeftLossy &&
- op2->value->type->id == ZigTypeIdComptimeInt) {
+ bit_count > 0 ? bit_count - 1 : 0);
- ZigValue *op2_val = ir_resolve_const(ira, op2, UndefBad);
+ casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
+ if (type_is_invalid(casted_op2->value->type))
+ return ira->codegen->invalid_inst_gen;
+
+ // This check is only valid iff op1 has at least one bit
+ if (bit_count > 0 && instr_is_comptime(casted_op2)) {
+ ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
- if (!bigint_fits_in_bits(&op2_val->data.x_bigint,
- shift_amt_type->data.integral.bit_count,
- op2_val->data.x_bigint.is_negative)) {
- Buf *val_buf = buf_alloc();
- bigint_append_buf(val_buf, &op2_val->data.x_bigint, 10);
+
+ BigInt bit_count_value = {0};
+ bigint_init_unsigned(&bit_count_value, bit_count);
+
+ if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
ErrorMsg* msg = ir_add_error(ira,
&bin_op_instruction->base.base,
buf_sprintf("RHS of shift is too large for LHS type"));
- add_error_note(
- ira->codegen,
- msg,
- op2->base.source_node,
- buf_sprintf("value %s cannot fit into type %s",
- buf_ptr(val_buf),
- buf_ptr(&shift_amt_type->name)));
+ add_error_note(ira->codegen, msg, op1->base.source_node,
+ buf_sprintf("type %s has only %u bits",
+ buf_ptr(&op1->value->type->name), bit_count));
+
return ira->codegen->invalid_inst_gen;
}
}
+ }
- casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
- if (type_is_invalid(casted_op2->value->type))
+ // Fast path for zero RHS
+ if (instr_is_comptime(casted_op2)) {
+ ZigValue *op2_val = ir_resolve_const(ira, casted_op2, UndefBad);
+ if (op2_val == nullptr)
return ira->codegen->invalid_inst_gen;
+
+ if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ)
+ return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1);
}
if (instr_is_comptime(op1) && instr_is_comptime(casted_op2)) {
@@ -16690,12 +16710,6 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
return ira->codegen->invalid_inst_gen;
return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val);
- } else if (op1->value->type->id == ZigTypeIdComptimeInt) {
- ir_add_error(ira, &bin_op_instruction->base.base,
- buf_sprintf("LHS of shift must be an integer type, or RHS must be compile-time known"));
- return ira->codegen->invalid_inst_gen;
- } else if (instr_is_comptime(casted_op2) && bigint_cmp_zero(&casted_op2->value->data.x_bigint) == CmpEQ) {
- return ir_build_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1, CastOpNoop);
}
return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type,
diff --git a/test/compile_errors.zig b/test/compile_errors.zig
@@ -2,6 +2,38 @@ const tests = @import("tests.zig");
const std = @import("std");
pub fn addCases(cases: *tests.CompileErrorContext) void {
+ cases.addTest("shift on type with non-power-of-two size",
+ \\export fn entry() void {
+ \\ const S = struct {
+ \\ fn a() void {
+ \\ var x: u24 = 42;
+ \\ _ = x >> 24;
+ \\ }
+ \\ fn b() void {
+ \\ var x: u24 = 42;
+ \\ _ = x << 24;
+ \\ }
+ \\ fn c() void {
+ \\ var x: u24 = 42;
+ \\ _ = @shlExact(x, 24);
+ \\ }
+ \\ fn d() void {
+ \\ var x: u24 = 42;
+ \\ _ = @shrExact(x, 24);
+ \\ }
+ \\ };
+ \\ S.a();
+ \\ S.b();
+ \\ S.c();
+ \\ S.d();
+ \\}
+ , &[_][]const u8{
+ "tmp.zig:5:19: error: RHS of shift is too large for LHS type",
+ "tmp.zig:9:19: error: RHS of shift is too large for LHS type",
+ "tmp.zig:13:17: error: RHS of shift is too large for LHS type",
+ "tmp.zig:17:17: error: RHS of shift is too large for LHS type",
+ });
+
cases.addTest("combination of noasync and async",
\\export fn entry() void {
\\ noasync {
@@ -4029,8 +4061,7 @@ pub fn addCases(cases: *tests.CompileErrorContext) void {
\\}
\\export fn entry() u16 { return f(); }
, &[_][]const u8{
- "tmp.zig:3:14: error: RHS of shift is too large for LHS type",
- "tmp.zig:3:17: note: value 8 cannot fit into type u3",
+ "tmp.zig:3:17: error: integer value 8 cannot be coerced to type 'u3'",
});
cases.add("missing function call param",
diff --git a/test/runtime_safety.zig b/test/runtime_safety.zig
@@ -1,6 +1,37 @@
const tests = @import("tests.zig");
pub fn addCases(cases: *tests.CompareOutputContext) void {
+ cases.addRuntimeSafety("shift left by huge amount",
+ \\const std = @import("std");
+ \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+ \\ std.debug.warn("{}\n", .{message});
+ \\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
+ \\ std.process.exit(126); // good
+ \\ }
+ \\ std.process.exit(0); // test failed
+ \\}
+ \\pub fn main() void {
+ \\ var x: u24 = 42;
+ \\ var y: u5 = 24;
+ \\ var z = x >> y;
+ \\}
+ );
+
+ cases.addRuntimeSafety("shift right by huge amount",
+ \\const std = @import("std");
+ \\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
+ \\ if (std.mem.eql(u8, message, "shift amount is greater than the type size")) {
+ \\ std.process.exit(126); // good
+ \\ }
+ \\ std.process.exit(0); // test failed
+ \\}
+ \\pub fn main() void {
+ \\ var x: u24 = 42;
+ \\ var y: u5 = 24;
+ \\ var z = x << y;
+ \\}
+ );
+
cases.addRuntimeSafety("slice sentinel mismatch - optional pointers",
\\const std = @import("std");
\\pub fn panic(message: []const u8, stack_trace: ?*@import("builtin").StackTrace) noreturn {
diff --git a/test/stage1/behavior/math.zig b/test/stage1/behavior/math.zig
@@ -453,6 +453,25 @@ fn testShrExact(x: u8) void {
expect(shifted == 0b00101101);
}
+test "shift left/right on u0 operand" {
+ const S = struct {
+ fn doTheTest() void {
+ var x: u0 = 0;
+ var y: u0 = 0;
+ expectEqual(@as(u0, 0), x << 0);
+ expectEqual(@as(u0, 0), x >> 0);
+ expectEqual(@as(u0, 0), x << y);
+ expectEqual(@as(u0, 0), x >> y);
+ expectEqual(@as(u0, 0), @shlExact(x, 0));
+ expectEqual(@as(u0, 0), @shrExact(x, 0));
+ expectEqual(@as(u0, 0), @shlExact(x, y));
+ expectEqual(@as(u0, 0), @shrExact(x, y));
+ }
+ };
+ S.doTheTest();
+ comptime S.doTheTest();
+}
+
test "comptime_int addition" {
comptime {
expect(35361831660712422535336160538497375248 + 101752735581729509668353361206450473702 == 137114567242441932203689521744947848950);