diff --git a/doc/langref.html.in b/doc/langref.html.in index b6496c6979..cad6dee59f 100644 --- a/doc/langref.html.in +++ b/doc/langref.html.in @@ -7031,6 +7031,16 @@ fn readFile(allocator: *Allocator, filename: []const u8) ![]u8 { If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.

{#header_close#} + {#header_open|@addWithSaturation#} +
{#syntax#}@addWithSaturation(a: T, b: T) T{#endsyntax#}
+

+ Returns {#syntax#}a + b{#endsyntax#}. The result will be clamped between the type maximum and minimum. +

+

+ Once Saturating arithmetic. + is completed, the syntax {#syntax#}a +| b{#endsyntax#} will be equivalent to calling {#syntax#}@addWithSaturation(a, b){#endsyntax#}. +

+ {#header_close#} {#header_open|@alignCast#}
{#syntax#}@alignCast(comptime alignment: u29, ptr: anytype) anytype{#endsyntax#}

@@ -8143,6 +8153,22 @@ test "@wasmMemoryGrow" { If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.

{#header_close#} + + {#header_open|@mulWithSaturation#} +
{#syntax#}@mulWithSaturation(a: T, b: T) T{#endsyntax#}
+

+ Returns {#syntax#}a * b{#endsyntax#}. The result will be clamped between the type maximum and minimum. +

+

+ Once Saturating arithmetic. + is completed, the syntax {#syntax#}a *| b{#endsyntax#} will be equivalent to calling {#syntax#}@mulWithSaturation(a, b){#endsyntax#}. +

+

+ NOTE: Currently there is a bug in the llvm.smul.fix.sat intrinsic which affects {#syntax#}@mulWithSaturation{#endsyntax#} of signed integers. + This may result in an incorrect sign bit when there is overflow. This will be fixed in zig's 0.9.0 release. + Check this issue for more information. +

+ {#header_close#} {#header_open|@panic#}
{#syntax#}@panic(message: []const u8) noreturn{#endsyntax#}
@@ -8368,7 +8394,7 @@ test "@setRuntimeSafety" { The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits. This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.

- {#see_also|@shrExact|@shlWithOverflow#} + {#see_also|@shrExact|@shlWithOverflow|@shlWithSaturation#} {#header_close#} {#header_open|@shlWithOverflow#} @@ -8382,7 +8408,22 @@ test "@setRuntimeSafety" { The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits. This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.

- {#see_also|@shlExact|@shrExact#} + {#see_also|@shlExact|@shrExact|@shlWithSaturation#} + {#header_close#} + + {#header_open|@shlWithSaturation#} +
{#syntax#}@shlWithSaturation(a: T, shift_amt: T) T{#endsyntax#}
+

+ Returns {#syntax#}a << b{#endsyntax#}. The result will be clamped between type minimum and maximum. +

+

+ Once Saturating arithmetic. + is completed, the syntax {#syntax#}a <<| b{#endsyntax#} will be equivalent to calling {#syntax#}@shlWithSaturation(a, b){#endsyntax#}. +

+

+ Unlike other @shl builtins, shift_amt doesn't need to be a Log2T as saturated overshifting is well defined. +

+ {#see_also|@shlExact|@shrExact|@shlWithOverflow#} {#header_close#} {#header_open|@shrExact#} @@ -8395,7 +8436,7 @@ test "@setRuntimeSafety" { The type of {#syntax#}shift_amt{#endsyntax#} is an unsigned integer with {#syntax#}log2(T.bit_count){#endsyntax#} bits. This is because {#syntax#}shift_amt >= T.bit_count{#endsyntax#} is undefined behavior.

- {#see_also|@shlExact|@shlWithOverflow#} + {#see_also|@shlExact|@shlWithOverflow|@shlWithSaturation#} {#header_close#} {#header_open|@shuffle#} @@ -8694,6 +8735,17 @@ fn doTheTest() !void { If no overflow or underflow occurs, returns {#syntax#}false{#endsyntax#}.

{#header_close#} + + {#header_open|@subWithSaturation#} +
{#syntax#}@subWithSaturation(a: T, b: T) T{#endsyntax#}
+

+ Returns {#syntax#}a - b{#endsyntax#}. The result will be clamped between the type maximum and minimum. +

+

+ Once Saturating arithmetic. + is completed, the syntax {#syntax#}a -| b{#endsyntax#} will be equivalent to calling {#syntax#}@subWithSaturation(a, b){#endsyntax#}. +

+ {#header_close#} {#header_open|@tagName#}
{#syntax#}@tagName(value: anytype) [:0]const u8{#endsyntax#}
diff --git a/src/AstGen.zig b/src/AstGen.zig index a51dd38f8c..75f506096f 100644 --- a/src/AstGen.zig +++ b/src/AstGen.zig @@ -7301,6 +7301,11 @@ fn builtinCall( return rvalue(gz, rl, result, node); }, + .add_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .add_with_saturation), + .sub_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .sub_with_saturation), + .mul_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .mul_with_saturation), + .shl_with_saturation => return saturatingArithmetic(gz, scope, rl, node, params, .shl_with_saturation), + .atomic_load => { const int_type = try typeExpr(gz, scope, params[0]); const ptr_type = try gz.add(.{ .tag = .ptr_type_simple, .data = .{ @@ -7693,6 +7698,24 @@ fn overflowArithmetic( return rvalue(gz, rl, result, node); } +fn saturatingArithmetic( + gz: *GenZir, + scope: *Scope, + rl: ResultLoc, + node: ast.Node.Index, + params: []const ast.Node.Index, + tag: Zir.Inst.Extended, +) InnerError!Zir.Inst.Ref { + const lhs = try expr(gz, scope, .none, params[0]); + const rhs = try expr(gz, scope, .none, params[1]); + const result = try gz.addExtendedPayload(tag, Zir.Inst.SaturatingArithmetic{ + .node = gz.nodeIndexToRelative(node), + .lhs = lhs, + .rhs = rhs, + }); + return rvalue(gz, rl, result, node); +} + fn callExpr( gz: *GenZir, scope: *Scope, diff --git a/src/BuiltinFn.zig b/src/BuiltinFn.zig index 8f23ec86d7..e415d27a3a 100644 --- a/src/BuiltinFn.zig +++ b/src/BuiltinFn.zig @@ -2,6 +2,7 @@ const std = @import("std"); pub const Tag = enum { add_with_overflow, + add_with_saturation, align_cast, align_of, as, @@ -65,6 +66,7 @@ pub const Tag = enum { wasm_memory_grow, mod, mul_with_overflow, + mul_with_saturation, panic, pop_count, ptr_cast, @@ -79,10 +81,12 @@ pub const Tag = enum { set_runtime_safety, shl_exact, shl_with_overflow, + shl_with_saturation, shr_exact, shuffle, size_of, splat, + sub_with_saturation, reduce, src, sqrt, @@ -527,6 +531,34 @@ pub const list = list: { .param_count = 2, }, }, + .{ + "@addWithSaturation", + .{ + .tag = .add_with_saturation, + .param_count = 2, + }, + }, + .{ + "@subWithSaturation", + .{ + .tag = .sub_with_saturation, + .param_count = 2, + }, + }, + .{ + "@mulWithSaturation", + .{ + .tag = .mul_with_saturation, + .param_count = 2, + }, + }, + .{ + "@shlWithSaturation", + .{ + .tag = .shl_with_saturation, + .param_count = 2, + }, + }, .{ "@memcpy", .{ diff --git a/src/Sema.zig b/src/Sema.zig index b05eba18f9..108470ec79 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -570,6 +570,10 @@ fn zirExtended(sema: *Sema, block: *Scope.Block, inst: Zir.Inst.Index) CompileEr .c_define => return sema.zirCDefine( block, extended), .wasm_memory_size => return sema.zirWasmMemorySize( block, extended), .wasm_memory_grow => return sema.zirWasmMemoryGrow( block, extended), + .add_with_saturation=> return sema.zirSatArithmetic( block, extended), + .sub_with_saturation=> return sema.zirSatArithmetic( block, extended), + .mul_with_saturation=> return sema.zirSatArithmetic( block, extended), + .shl_with_saturation=> return sema.zirSatArithmetic( block, extended), // zig fmt: on } } @@ -5691,6 +5695,19 @@ fn zirOverflowArithmetic( return sema.mod.fail(&block.base, src, "TODO implement Sema.zirOverflowArithmetic", .{}); } +fn zirSatArithmetic( + sema: *Sema, + block: *Scope.Block, + extended: Zir.Inst.Extended.InstData, +) CompileError!Air.Inst.Ref { + const tracy = trace(@src()); + defer tracy.end(); + + const extra = sema.code.extraData(Zir.Inst.SaturatingArithmetic, extended.operand).data; + const src: LazySrcLoc = .{ .node_offset = extra.node }; + return sema.mod.fail(&block.base, src, "TODO implement Sema.zirSatArithmetic", .{}); +} + fn analyzeArithmetic( sema: *Sema, block: *Scope.Block, diff --git a/src/Zir.zig b/src/Zir.zig index 2092a7b5e4..e8e79fe1b5 100644 --- a/src/Zir.zig +++ b/src/Zir.zig @@ -1629,6 +1629,22 @@ pub const Inst = struct { wasm_memory_size, /// `operand` is payload index to `BinNode`. wasm_memory_grow, + /// Implements the `@addWithSaturation` builtin. + /// `operand` is payload index to `SaturatingArithmetic`. + /// `small` is unused. + add_with_saturation, + /// Implements the `@subWithSaturation` builtin. + /// `operand` is payload index to `SaturatingArithmetic`. + /// `small` is unused. + sub_with_saturation, + /// Implements the `@mulWithSaturation` builtin. + /// `operand` is payload index to `SaturatingArithmetic`. + /// `small` is unused. + mul_with_saturation, + /// Implements the `@shlWithSaturation` builtin. + /// `operand` is payload index to `SaturatingArithmetic`. + /// `small` is unused. + shl_with_saturation, pub const InstData = struct { opcode: Extended, @@ -2751,6 +2767,12 @@ pub const Inst = struct { ptr: Ref, }; + pub const SaturatingArithmetic = struct { + node: i32, + lhs: Ref, + rhs: Ref, + }; + pub const Cmpxchg = struct { ptr: Ref, expected_value: Ref, @@ -3231,6 +3253,11 @@ const Writer = struct { .shl_with_overflow, => try self.writeOverflowArithmetic(stream, extended), + .add_with_saturation, + .sub_with_saturation, + .mul_with_saturation, + .shl_with_saturation, + => try self.writeSaturatingArithmetic(stream, extended), .struct_decl => try self.writeStructDecl(stream, extended), .union_decl => try self.writeUnionDecl(stream, extended), .enum_decl => try self.writeEnumDecl(stream, extended), @@ -3584,6 +3611,18 @@ const Writer = struct { try self.writeSrc(stream, src); } + fn writeSaturatingArithmetic(self: *Writer, stream: anytype, extended: Inst.Extended.InstData) !void { + const extra = self.code.extraData(Zir.Inst.SaturatingArithmetic, extended.operand).data; + const src: LazySrcLoc = .{ .node_offset = extra.node }; + + try self.writeInstRef(stream, extra.lhs); + try stream.writeAll(", "); + try self.writeInstRef(stream, extra.rhs); + try stream.writeAll(", "); + try stream.writeAll(") "); + try self.writeSrc(stream, src); + } + fn writePlNodeCall(self: *Writer, stream: anytype, inst: Inst.Index) !void { const inst_data = self.code.instructions.items(.data)[inst].pl_node; const extra = self.code.extraData(Inst.Call, inst_data.payload_index); diff --git a/src/stage1/all_types.hpp b/src/stage1/all_types.hpp index fc3b00d6db..4004199eb6 100644 --- a/src/stage1/all_types.hpp +++ b/src/stage1/all_types.hpp @@ -1802,6 +1802,10 @@ enum BuiltinFnId { BuiltinFnIdReduce, BuiltinFnIdMaximum, BuiltinFnIdMinimum, + BuiltinFnIdSatAdd, + BuiltinFnIdSatSub, + BuiltinFnIdSatMul, + BuiltinFnIdSatShl, }; struct BuiltinFnEntry { @@ -2946,6 +2950,10 @@ enum IrBinOp { IrBinOpArrayMult, IrBinOpMaximum, IrBinOpMinimum, + IrBinOpSatAdd, + IrBinOpSatSub, + IrBinOpSatMul, + IrBinOpSatShl, }; struct Stage1ZirInstBinOp { diff --git a/src/stage1/astgen.cpp b/src/stage1/astgen.cpp index 367fd8ac08..9e5d9da9ee 100644 --- a/src/stage1/astgen.cpp +++ b/src/stage1/astgen.cpp @@ -4704,6 +4704,66 @@ static Stage1ZirInst *astgen_builtin_fn_call(Stage1AstGen *ag, Scope *scope, Ast Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpMaximum, arg0_value, arg1_value, true); return ir_lval_wrap(ag, scope, bin_op, lval, result_loc); } + case BuiltinFnIdSatAdd: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope); + if (arg1_value == ag->codegen->invalid_inst_src) + return arg1_value; + + Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatAdd, arg0_value, arg1_value, true); + return ir_lval_wrap(ag, scope, bin_op, lval, result_loc); + } + case BuiltinFnIdSatSub: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope); + if (arg1_value == ag->codegen->invalid_inst_src) + return arg1_value; + + Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatSub, arg0_value, arg1_value, true); + return ir_lval_wrap(ag, scope, bin_op, lval, result_loc); + } + case BuiltinFnIdSatMul: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope); + if (arg1_value == ag->codegen->invalid_inst_src) + return arg1_value; + + Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatMul, arg0_value, arg1_value, true); + return ir_lval_wrap(ag, scope, bin_op, lval, result_loc); + } + case BuiltinFnIdSatShl: + { + AstNode *arg0_node = node->data.fn_call_expr.params.at(0); + Stage1ZirInst *arg0_value = astgen_node(ag, arg0_node, scope); + if (arg0_value == ag->codegen->invalid_inst_src) + return arg0_value; + + AstNode *arg1_node = node->data.fn_call_expr.params.at(1); + Stage1ZirInst *arg1_value = astgen_node(ag, arg1_node, scope); + if (arg1_value == ag->codegen->invalid_inst_src) + return arg1_value; + + Stage1ZirInst *bin_op = ir_build_bin_op(ag, scope, node, IrBinOpSatShl, arg0_value, arg1_value, true); + return ir_lval_wrap(ag, scope, bin_op, lval, result_loc); + } case BuiltinFnIdMemcpy: { AstNode *arg0_node = node->data.fn_call_expr.params.at(0); diff --git a/src/stage1/bigint.cpp b/src/stage1/bigint.cpp index 5c8efad698..3a7f10e699 100644 --- a/src/stage1/bigint.cpp +++ b/src/stage1/bigint.cpp @@ -468,6 +468,84 @@ void bigint_min(BigInt* dest, const BigInt *op1, const BigInt *op2) { } } +/// clamps op within bit_count/signedness boundaries +/// signed bounds are [-2^(bit_count-1)..2^(bit_count-1)-1] +/// unsigned bounds are [0..2^bit_count-1] +void bigint_clamp_by_bitcount(BigInt* dest, uint32_t bit_count, bool is_signed) { + // compute the number of bits required to store the value, and use that + // to decide whether to clamp the result + bool is_negative = dest->is_negative; + // to workaround the fact this bits_needed calculation would yield 65 or more for + // all negative numbers, set is_negative to false. this is a cheap way to find + // bits_needed(abs(dest)). + dest->is_negative = false; + // because we've set is_negative to false, we have to account for the extra bit here + // by adding 1 additional bit_needed when (is_negative && !is_signed). + size_t full_bits = dest->digit_count * 64; + size_t leading_zero_count = bigint_clz(dest, full_bits); + size_t bits_needed = full_bits - leading_zero_count + (is_negative && !is_signed); + + bit_count -= is_signed; + if(bits_needed > bit_count) { + BigInt one; + bigint_init_unsigned(&one, 1); + BigInt bit_count_big; + bigint_init_unsigned(&bit_count_big, bit_count); + + if(is_signed) { + if(is_negative) { + BigInt bound; + bigint_shl(&bound, &one, &bit_count_big); + bigint_deinit(dest); + *dest = bound; + } else { + BigInt bound; + bigint_shl(&bound, &one, &bit_count_big); + BigInt bound_sub_one; + bigint_sub(&bound_sub_one, &bound, &one); + bigint_deinit(&bound); + bigint_deinit(dest); + *dest = bound_sub_one; + } + } else { + if(is_negative) { + bigint_deinit(dest); + bigint_init_unsigned(dest, 0); + return; // skips setting is_negative which would be invalid + } else { + BigInt bound; + bigint_shl(&bound, &one, &bit_count_big); + BigInt bound_sub_one; + bigint_sub(&bound_sub_one, &bound, &one); + bigint_deinit(&bound); + bigint_deinit(dest); + *dest = bound_sub_one; + } + } + } + dest->is_negative = is_negative; +} + +void bigint_add_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed) { + bigint_add(dest, op1, op2); + bigint_clamp_by_bitcount(dest, bit_count, is_signed); +} + +void bigint_sub_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed) { + bigint_sub(dest, op1, op2); + bigint_clamp_by_bitcount(dest, bit_count, is_signed); +} + +void bigint_mul_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed) { + bigint_mul(dest, op1, op2); + bigint_clamp_by_bitcount(dest, bit_count, is_signed); +} + +void bigint_shl_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed) { + bigint_shl(dest, op1, op2); + bigint_clamp_by_bitcount(dest, bit_count, is_signed); +} + void bigint_add(BigInt *dest, const BigInt *op1, const BigInt *op2) { if (op1->digit_count == 0) { return bigint_init_bigint(dest, op2); diff --git a/src/stage1/bigint.hpp b/src/stage1/bigint.hpp index 53e07f9284..7d30fb1689 100644 --- a/src/stage1/bigint.hpp +++ b/src/stage1/bigint.hpp @@ -105,4 +105,8 @@ bool mul_u64_overflow(uint64_t op1, uint64_t op2, uint64_t *result); uint32_t bigint_hash(BigInt const *x); bool bigint_eql(BigInt const *a, BigInt const *b); +void bigint_add_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed); +void bigint_sub_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed); +void bigint_mul_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed); +void bigint_shl_sat(BigInt* dest, const BigInt *op1, const BigInt *op2, uint32_t bit_count, bool is_signed); #endif diff --git a/src/stage1/codegen.cpp b/src/stage1/codegen.cpp index c44081c770..614ed8e26c 100644 --- a/src/stage1/codegen.cpp +++ b/src/stage1/codegen.cpp @@ -3335,6 +3335,46 @@ static LLVMValueRef ir_render_bin_op(CodeGen *g, Stage1Air *executable, } else { zig_unreachable(); } + case IrBinOpSatAdd: + if (scalar_type->id == ZigTypeIdInt) { + if (scalar_type->data.integral.is_signed) { + return ZigLLVMBuildSAddSat(g->builder, op1_value, op2_value, ""); + } else { + return ZigLLVMBuildUAddSat(g->builder, op1_value, op2_value, ""); + } + } else { + zig_unreachable(); + } + case IrBinOpSatSub: + if (scalar_type->id == ZigTypeIdInt) { + if (scalar_type->data.integral.is_signed) { + return ZigLLVMBuildSSubSat(g->builder, op1_value, op2_value, ""); + } else { + return ZigLLVMBuildUSubSat(g->builder, op1_value, op2_value, ""); + } + } else { + zig_unreachable(); + } + case IrBinOpSatMul: + if (scalar_type->id == ZigTypeIdInt) { + if (scalar_type->data.integral.is_signed) { + return ZigLLVMBuildSMulFixSat(g->builder, op1_value, op2_value, ""); + } else { + return ZigLLVMBuildUMulFixSat(g->builder, op1_value, op2_value, ""); + } + } else { + zig_unreachable(); + } + case IrBinOpSatShl: + if (scalar_type->id == ZigTypeIdInt) { + if (scalar_type->data.integral.is_signed) { + return ZigLLVMBuildSShlSat(g->builder, op1_value, op2_value, ""); + } else { + return ZigLLVMBuildUShlSat(g->builder, op1_value, op2_value, ""); + } + } else { + zig_unreachable(); + } } zig_unreachable(); } @@ -9096,6 +9136,10 @@ static void define_builtin_fns(CodeGen *g) { create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2); create_builtin_fn(g, BuiltinFnIdMaximum, "maximum", 2); create_builtin_fn(g, BuiltinFnIdMinimum, "minimum", 2); + create_builtin_fn(g, BuiltinFnIdSatAdd, "addWithSaturation", 2); + create_builtin_fn(g, BuiltinFnIdSatSub, "subWithSaturation", 2); + create_builtin_fn(g, BuiltinFnIdSatMul, "mulWithSaturation", 2); + create_builtin_fn(g, BuiltinFnIdSatShl, "shlWithSaturation", 2); } static const char *bool_to_str(bool b) { diff --git a/src/stage1/ir.cpp b/src/stage1/ir.cpp index 830ce76708..0604c05c46 100644 --- a/src/stage1/ir.cpp +++ b/src/stage1/ir.cpp @@ -9820,6 +9820,34 @@ static ErrorMsg *ir_eval_math_op_scalar(IrAnalyze *ira, Scope *scope, AstNode *s float_min(out_val, op1_val, op2_val); } break; + case IrBinOpSatAdd: + if (is_int) { + bigint_add_sat(&out_val->data.x_bigint, &op1_val->data.x_bigint, &op2_val->data.x_bigint, type_entry->data.integral.bit_count, type_entry->data.integral.is_signed); + } else { + zig_unreachable(); + } + break; + case IrBinOpSatSub: + if (is_int) { + bigint_sub_sat(&out_val->data.x_bigint, &op1_val->data.x_bigint, &op2_val->data.x_bigint, type_entry->data.integral.bit_count, type_entry->data.integral.is_signed); + } else { + zig_unreachable(); + } + break; + case IrBinOpSatMul: + if (is_int) { + bigint_mul_sat(&out_val->data.x_bigint, &op1_val->data.x_bigint, &op2_val->data.x_bigint, type_entry->data.integral.bit_count, type_entry->data.integral.is_signed); + } else { + zig_unreachable(); + } + break; + case IrBinOpSatShl: + if (is_int) { + bigint_shl_sat(&out_val->data.x_bigint, &op1_val->data.x_bigint, &op2_val->data.x_bigint, type_entry->data.integral.bit_count, type_entry->data.integral.is_signed); + } else { + zig_unreachable(); + } + break; } if (type_entry->id == ZigTypeIdInt) { @@ -10041,6 +10069,10 @@ static bool ok_float_op(IrBinOp op) { case IrBinOpBitShiftRightExact: case IrBinOpAddWrap: case IrBinOpSubWrap: + case IrBinOpSatAdd: + case IrBinOpSatSub: + case IrBinOpSatMul: + case IrBinOpSatShl: case IrBinOpMultWrap: case IrBinOpArrayCat: case IrBinOpArrayMult: @@ -11014,6 +11046,10 @@ static Stage1AirInst *ir_analyze_instruction_bin_op(IrAnalyze *ira, Stage1ZirIns case IrBinOpRemMod: case IrBinOpMaximum: case IrBinOpMinimum: + case IrBinOpSatAdd: + case IrBinOpSatSub: + case IrBinOpSatMul: + case IrBinOpSatShl: return ir_analyze_bin_op_math(ira, bin_op_instruction); case IrBinOpArrayCat: return ir_analyze_array_cat(ira, bin_op_instruction); diff --git a/src/stage1/ir_print.cpp b/src/stage1/ir_print.cpp index 96e924b768..152221926d 100644 --- a/src/stage1/ir_print.cpp +++ b/src/stage1/ir_print.cpp @@ -737,6 +737,14 @@ static const char *ir_bin_op_id_str(IrBinOp op_id) { return "@maximum"; case IrBinOpMinimum: return "@minimum"; + case IrBinOpSatAdd: + return "@addWithSaturation"; + case IrBinOpSatSub: + return "@subWithSaturation"; + case IrBinOpSatMul: + return "@mulWithSaturation"; + case IrBinOpSatShl: + return "@shlWithSaturation"; } zig_unreachable(); } diff --git a/src/zig_llvm.cpp b/src/zig_llvm.cpp index d0bc24ed1b..2089092c7c 100644 --- a/src/zig_llvm.cpp +++ b/src/zig_llvm.cpp @@ -488,6 +488,58 @@ LLVMValueRef ZigLLVMBuildSMin(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef R return wrap(call_inst); } +LLVMValueRef ZigLLVMBuildSAddSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::sadd_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildUAddSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::uadd_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildSSubSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::ssub_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildUSubSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::usub_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildSMulFixSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + llvm::Type* types[1] = { + unwrap(LHS)->getType(), + }; + // pass scale = 0 as third argument + llvm::Value* values[3] = {unwrap(LHS), unwrap(RHS), unwrap(B)->getInt32(0)}; + + CallInst *call_inst = unwrap(B)->CreateIntrinsic(Intrinsic::smul_fix_sat, types, values, nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildUMulFixSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + llvm::Type* types[1] = { + unwrap(LHS)->getType(), + }; + // pass scale = 0 as third argument + llvm::Value* values[3] = {unwrap(LHS), unwrap(RHS), unwrap(B)->getInt32(0)}; + + CallInst *call_inst = unwrap(B)->CreateIntrinsic(Intrinsic::umul_fix_sat, types, values, nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildSShlSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::sshl_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + +LLVMValueRef ZigLLVMBuildUShlSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name) { + CallInst *call_inst = unwrap(B)->CreateBinaryIntrinsic(Intrinsic::ushl_sat, unwrap(LHS), unwrap(RHS), nullptr, name); + return wrap(call_inst); +} + void ZigLLVMFnSetSubprogram(LLVMValueRef fn, ZigLLVMDISubprogram *subprogram) { assert( isa(unwrap(fn)) ); Function *unwrapped_function = reinterpret_cast(unwrap(fn)); diff --git a/src/zig_llvm.h b/src/zig_llvm.h index f49c2662c6..91407b7f12 100644 --- a/src/zig_llvm.h +++ b/src/zig_llvm.h @@ -136,6 +136,15 @@ ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUMax(LLVMBuilderRef builder, LLVMValueRef ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUMin(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSMax(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSMin(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUAddSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSAddSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUSubSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSSubSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSMulFixSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUMulFixSat(LLVMBuilderRef B, LLVMValueRef LHS, LLVMValueRef RHS, const char *name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildUShlSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); +ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildSShlSat(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS, const char* name); + ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp, LLVMValueRef new_val, LLVMAtomicOrdering success_ordering, diff --git a/test/behavior.zig b/test/behavior.zig index f459e23f7b..7028d6cbf6 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -125,6 +125,7 @@ test { _ = @import("behavior/pub_enum.zig"); _ = @import("behavior/ref_var_in_if_after_if_2nd_switch_prong.zig"); _ = @import("behavior/reflection.zig"); + _ = @import("behavior/saturating_arithmetic.zig"); _ = @import("behavior/shuffle.zig"); _ = @import("behavior/select.zig"); _ = @import("behavior/sizeof_and_typeof.zig"); diff --git a/test/behavior/saturating_arithmetic.zig b/test/behavior/saturating_arithmetic.zig new file mode 100644 index 0000000000..553e9ff21a --- /dev/null +++ b/test/behavior/saturating_arithmetic.zig @@ -0,0 +1,139 @@ +const std = @import("std"); +const builtin = @import("builtin"); +const mem = std.mem; +const expectEqual = std.testing.expectEqual; +const Vector = std.meta.Vector; +const minInt = std.math.minInt; +const maxInt = std.math.maxInt; + +const Op = enum { add, sub, mul, shl }; +fn testSaturatingOp(comptime op: Op, comptime T: type, test_data: [3]T) !void { + const a = test_data[0]; + const b = test_data[1]; + const expected = test_data[2]; + const actual = switch (op) { + .add => @addWithSaturation(a, b), + .sub => @subWithSaturation(a, b), + .mul => @mulWithSaturation(a, b), + .shl => @shlWithSaturation(a, b), + }; + try expectEqual(expected, actual); +} + +test "@addWithSaturation" { + const S = struct { + fn doTheTest() !void { + // .{a, b, expected a+b} + try testSaturatingOp(.add, i8, .{ -3, 10, 7 }); + try testSaturatingOp(.add, i8, .{ -128, -128, -128 }); + try testSaturatingOp(.add, i2, .{ 1, 1, 1 }); + try testSaturatingOp(.add, i64, .{ maxInt(i64), 1, maxInt(i64) }); + try testSaturatingOp(.add, i128, .{ maxInt(i128), -maxInt(i128), 0 }); + try testSaturatingOp(.add, i128, .{ minInt(i128), maxInt(i128), -1 }); + try testSaturatingOp(.add, i8, .{ 127, 127, 127 }); + try testSaturatingOp(.add, u8, .{ 3, 10, 13 }); + try testSaturatingOp(.add, u8, .{ 255, 255, 255 }); + try testSaturatingOp(.add, u2, .{ 3, 2, 3 }); + try testSaturatingOp(.add, u3, .{ 7, 1, 7 }); + try testSaturatingOp(.add, u128, .{ maxInt(u128), 1, maxInt(u128) }); + + const u8x3 = std.meta.Vector(3, u8); + try expectEqual(u8x3{ 255, 255, 255 }, @addWithSaturation( + u8x3{ 255, 254, 1 }, + u8x3{ 1, 2, 255 }, + )); + const i8x3 = std.meta.Vector(3, i8); + try expectEqual(i8x3{ 127, 127, 127 }, @addWithSaturation( + i8x3{ 127, 126, 1 }, + i8x3{ 1, 2, 127 }, + )); + } + }; + try S.doTheTest(); + comptime try S.doTheTest(); +} + +test "@subWithSaturation" { + const S = struct { + fn doTheTest() !void { + // .{a, b, expected a-b} + try testSaturatingOp(.sub, i8, .{ -3, 10, -13 }); + try testSaturatingOp(.sub, i8, .{ -128, -128, 0 }); + try testSaturatingOp(.sub, i8, .{ -1, 127, -128 }); + try testSaturatingOp(.sub, i64, .{ minInt(i64), 1, minInt(i64) }); + try testSaturatingOp(.sub, i128, .{ maxInt(i128), -1, maxInt(i128) }); + try testSaturatingOp(.sub, i128, .{ minInt(i128), -maxInt(i128), -1 }); + try testSaturatingOp(.sub, u8, .{ 10, 3, 7 }); + try testSaturatingOp(.sub, u8, .{ 0, 255, 0 }); + try testSaturatingOp(.sub, u5, .{ 0, 31, 0 }); + try testSaturatingOp(.sub, u128, .{ 0, maxInt(u128), 0 }); + + const u8x3 = std.meta.Vector(3, u8); + try expectEqual(u8x3{ 0, 0, 0 }, @subWithSaturation( + u8x3{ 0, 0, 0 }, + u8x3{ 255, 255, 255 }, + )); + } + }; + try S.doTheTest(); + comptime try S.doTheTest(); +} + +test "@mulWithSaturation" { + // TODO: once #9660 has been solved, remove this line + if (std.builtin.target.cpu.arch == .wasm32) return error.SkipZigTest; + + const S = struct { + fn doTheTest() !void { + // .{a, b, expected a*b} + try testSaturatingOp(.mul, i8, .{ -3, 10, -30 }); + try testSaturatingOp(.mul, i4, .{ 2, 4, 7 }); + try testSaturatingOp(.mul, i8, .{ 2, 127, 127 }); + // TODO: uncomment these after #9643 has been solved - this should happen at 0.9.0/llvm-13 release + // try testSaturatingOp(.mul, i8, .{ -128, -128, 127 }); + // try testSaturatingOp(.mul, i8, .{ maxInt(i8), maxInt(i8), maxInt(i8) }); + try testSaturatingOp(.mul, i16, .{ maxInt(i16), -1, minInt(i16) + 1 }); + try testSaturatingOp(.mul, i128, .{ maxInt(i128), -1, minInt(i128) + 1 }); + try testSaturatingOp(.mul, i128, .{ minInt(i128), -1, maxInt(i128) }); + try testSaturatingOp(.mul, u8, .{ 10, 3, 30 }); + try testSaturatingOp(.mul, u8, .{ 2, 255, 255 }); + try testSaturatingOp(.mul, u128, .{ maxInt(u128), maxInt(u128), maxInt(u128) }); + + const u8x3 = std.meta.Vector(3, u8); + try expectEqual(u8x3{ 255, 255, 255 }, @mulWithSaturation( + u8x3{ 2, 2, 2 }, + u8x3{ 255, 255, 255 }, + )); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); +} + +test "@shlWithSaturation" { + const S = struct { + fn doTheTest() !void { + // .{a, b, expected a< 64 bits on wasm due to miscompilation / wasmtime ci error + try testSaturatingOp(.shl, i128, .{ maxInt(i128), 64, maxInt(i128) }); + try testSaturatingOp(.shl, u128, .{ maxInt(u128), 64, maxInt(u128) }); + } + try testSaturatingOp(.shl, u8, .{ 1, 2, 4 }); + try testSaturatingOp(.shl, u8, .{ 255, 1, 255 }); + + const u8x3 = std.meta.Vector(3, u8); + try expectEqual(u8x3{ 255, 255, 255 }, @shlWithSaturation( + u8x3{ 255, 255, 255 }, + u8x3{ 1, 1, 1 }, + )); + } + }; + try S.doTheTest(); + comptime try S.doTheTest(); +} diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 530ef6dcd6..4c55e39ac7 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -8838,4 +8838,12 @@ pub fn addCases(ctx: *TestContext) !void { "tmp.zig:2:9: note: declared mutable here", "tmp.zig:3:12: note: crosses namespace boundary here", }); + + ctx.objErrStage1("Issue #9619: saturating arithmetic builtins should fail to compile when given floats", + \\pub fn main() !void { + \\ _ = @addWithSaturation(@as(f32, 1.0), @as(f32, 1.0)); + \\} + , &[_][]const u8{ + "error: invalid operands to binary expression: 'f32' and 'f32'", + }); }