diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index c395c26437..f2710574d7 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1432,8 +1432,10 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .const_ty => unreachable, .add => self.airBinOp(inst, .add), + .add_sat => self.airSatBinOp(inst, .add), .addwrap => self.airWrapBinOp(inst, .add), .sub => self.airBinOp(inst, .sub), + .sub_sat => self.airSatBinOp(inst, .sub), .subwrap => self.airWrapBinOp(inst, .sub), .mul => self.airBinOp(inst, .mul), .mulwrap => self.airWrapBinOp(inst, .mul), @@ -1452,6 +1454,7 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .rem => self.airBinOp(inst, .rem), .shl => self.airWrapBinOp(inst, .shl), .shl_exact => self.airBinOp(inst, .shl), + .shl_sat => self.airShlSat(inst), .shr, .shr_exact => self.airBinOp(inst, .shr), .xor => self.airBinOp(inst, .xor), .max => self.airMaxMin(inst, .max), @@ -1583,12 +1586,9 @@ fn genInst(self: *Self, inst: Air.Inst.Index) !WValue { .memcpy => self.airMemcpy(inst), - .add_sat, - .sub_sat, .mul_sat, .mod, .assembly, - .shl_sat, .ret_addr, .frame_addr, .bit_reverse, @@ -4878,3 +4878,216 @@ fn airCeilFloorTrunc(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValu try self.addLabel(.local_set, result.local); return result; } + +fn airSatBinOp(self: *Self, inst: Air.Inst.Index, op: Op) InnerError!WValue { + assert(op == .add or op == .sub); + if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; + + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const ty = self.air.typeOfIndex(inst); + const lhs = try self.resolveInst(bin_op.lhs); + const rhs = try self.resolveInst(bin_op.rhs); + + const int_info = ty.intInfo(self.target); + const is_signed = int_info.signedness == .signed; + + if (int_info.bits > 64) { + return self.fail("TODO: saturating arithmetic for integers with bitsize '{d}'", .{int_info.bits}); + } + + if (is_signed) { + return signedSat(self, lhs, rhs, ty, op); + } + + const wasm_bits = toWasmBits(int_info.bits).?; + const bin_result = try self.binOp(lhs, rhs, ty, op); + if (wasm_bits != int_info.bits and op == .add) { + const val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits)) - 1); + const imm_val = switch (wasm_bits) { + 32 => WValue{ .imm32 = @intCast(u32, val) }, + 64 => WValue{ .imm64 = val }, + else => unreachable, + }; + + const cmp_result = try self.cmp(bin_result, imm_val, ty, .lt); + try self.emitWValue(bin_result); + try self.emitWValue(imm_val); + try self.emitWValue(cmp_result); + } else { + const cmp_result = try self.cmp(bin_result, lhs, ty, if (op == .add) .lt else .gt); + switch (wasm_bits) { + 32 => try self.addImm32(if (op == .add) @as(i32, -1) else 0), + 64 => try self.addImm64(if (op == .add) @bitCast(u64, @as(i64, -1)) else 0), + else => unreachable, + } + try self.emitWValue(bin_result); + try self.emitWValue(cmp_result); + } + + try self.addTag(.select); + const result = try self.allocLocal(ty); + try self.addLabel(.local_set, result.local); + return result; +} + +fn signedSat(self: *Self, lhs_operand: WValue, rhs_operand: WValue, ty: Type, op: Op) InnerError!WValue { + const int_info = ty.intInfo(self.target); + const wasm_bits = toWasmBits(int_info.bits).?; + const is_wasm_bits = wasm_bits == int_info.bits; + + const lhs = if (!is_wasm_bits) try self.signAbsValue(lhs_operand, ty) else lhs_operand; + const rhs = if (!is_wasm_bits) try self.signAbsValue(rhs_operand, ty) else rhs_operand; + + const max_val: u64 = @intCast(u64, (@as(u65, 1) << @intCast(u7, int_info.bits - 1)) - 1); + const min_val: i64 = (-@intCast(i64, @intCast(u63, max_val))) - 1; + const max_wvalue = switch (wasm_bits) { + 32 => WValue{ .imm32 = @truncate(u32, max_val) }, + 64 => WValue{ .imm64 = max_val }, + else => unreachable, + }; + const min_wvalue = switch (wasm_bits) { + 32 => WValue{ .imm32 = @bitCast(u32, @truncate(i32, min_val)) }, + 64 => WValue{ .imm64 = @bitCast(u64, min_val) }, + else => unreachable, + }; + + const bin_result = try self.binOp(lhs, rhs, ty, op); + if (!is_wasm_bits) { + const cmp_result_lt = try self.cmp(bin_result, max_wvalue, ty, .lt); + try self.emitWValue(bin_result); + try self.emitWValue(max_wvalue); + try self.emitWValue(cmp_result_lt); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + + const cmp_result_gt = try self.cmp(bin_result, min_wvalue, ty, .gt); + try self.emitWValue(bin_result); + try self.emitWValue(min_wvalue); + try self.emitWValue(cmp_result_gt); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + return self.wrapOperand(bin_result, ty); + } else { + const zero = switch (wasm_bits) { + 32 => WValue{ .imm32 = 0 }, + 64 => WValue{ .imm64 = 0 }, + else => unreachable, + }; + const cmp_bin_result = try self.cmp(bin_result, lhs, ty, .lt); + const cmp_zero_result = try self.cmp(rhs, zero, ty, if (op == .add) .lt else .gt); + const xor = try self.binOp(cmp_zero_result, cmp_bin_result, Type.u32, .xor); // comparisons always return i32, so provide u32 as type to xor. + const cmp_bin_zero_result = try self.cmp(bin_result, zero, ty, .lt); + try self.emitWValue(max_wvalue); + try self.emitWValue(min_wvalue); + try self.emitWValue(cmp_bin_zero_result); + try self.addTag(.select); + try self.emitWValue(bin_result); + try self.emitWValue(xor); + try self.addTag(.select); + try self.addLabel(.local_set, bin_result.local); // re-use local + return bin_result; + } +} + +fn airShlSat(self: *Self, inst: Air.Inst.Index) InnerError!WValue { + if (self.liveness.isUnused(inst)) return WValue{ .none = {} }; + + const bin_op = self.air.instructions.items(.data)[inst].bin_op; + const ty = self.air.typeOfIndex(inst); + const int_info = ty.intInfo(self.target); + const is_signed = int_info.signedness == .signed; + if (int_info.bits > 64) { + return self.fail("TODO: Saturating shifting left for integers with bitsize '{d}'", .{int_info.bits}); + } + + const lhs = try self.resolveInst(bin_op.lhs); + const rhs = try self.resolveInst(bin_op.rhs); + const wasm_bits = toWasmBits(int_info.bits).?; + const result = try self.allocLocal(ty); + + if (wasm_bits == int_info.bits) { + const shl = try self.binOp(lhs, rhs, ty, .shl); + const shr = try self.binOp(shl, rhs, ty, .shr); + const cmp_result = try self.cmp(lhs, shr, ty, .neq); + + switch (wasm_bits) { + 32 => blk: { + if (!is_signed) { + try self.addImm32(-1); + break :blk; + } + const less_than_zero = try self.cmp(lhs, .{ .imm32 = 0 }, ty, .lt); + try self.addImm32(std.math.minInt(i32)); + try self.addImm32(std.math.maxInt(i32)); + try self.emitWValue(less_than_zero); + try self.addTag(.select); + }, + 64 => blk: { + if (!is_signed) { + try self.addImm64(@bitCast(u64, @as(i64, -1))); + break :blk; + } + const less_than_zero = try self.cmp(lhs, .{ .imm64 = 0 }, ty, .lt); + try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64)))); + try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64)))); + try self.emitWValue(less_than_zero); + try self.addTag(.select); + }, + else => unreachable, + } + try self.emitWValue(shl); + try self.emitWValue(cmp_result); + try self.addTag(.select); + try self.addLabel(.local_set, result.local); + return result; + } else { + const shift_size = wasm_bits - int_info.bits; + const shift_value = switch (wasm_bits) { + 32 => WValue{ .imm32 = shift_size }, + 64 => WValue{ .imm64 = shift_size }, + else => unreachable, + }; + + const shl_res = try self.binOp(lhs, shift_value, ty, .shl); + const shl = try self.binOp(shl_res, rhs, ty, .shl); + const shr = try self.binOp(shl, rhs, ty, .shr); + const cmp_result = try self.cmp(shl_res, shr, ty, .neq); + + switch (wasm_bits) { + 32 => blk: { + if (!is_signed) { + try self.addImm32(-1); + break :blk; + } + + const less_than_zero = try self.cmp(shl_res, .{ .imm32 = 0 }, ty, .lt); + try self.addImm32(std.math.minInt(i32)); + try self.addImm32(std.math.maxInt(i32)); + try self.emitWValue(less_than_zero); + try self.addTag(.select); + }, + 64 => blk: { + if (!is_signed) { + try self.addImm64(@bitCast(u64, @as(i64, -1))); + break :blk; + } + + const less_than_zero = try self.cmp(shl_res, .{ .imm64 = 0 }, ty, .lt); + try self.addImm64(@bitCast(u64, @as(i64, std.math.minInt(i64)))); + try self.addImm64(@bitCast(u64, @as(i64, std.math.maxInt(i64)))); + try self.emitWValue(less_than_zero); + try self.addTag(.select); + }, + else => unreachable, + } + try self.emitWValue(shl); + try self.emitWValue(cmp_result); + try self.addTag(.select); + try self.addLabel(.local_set, result.local); + const shift_result = try self.binOp(result, shift_value, ty, .shr); + if (is_signed) { + return self.wrapOperand(shift_result, ty); + } + return shift_result; + } +} diff --git a/test/behavior/saturating_arithmetic.zig b/test/behavior/saturating_arithmetic.zig index 33266f711e..1790fe4505 100644 --- a/test/behavior/saturating_arithmetic.zig +++ b/test/behavior/saturating_arithmetic.zig @@ -5,7 +5,6 @@ const maxInt = std.math.maxInt; const expect = std.testing.expect; test "saturating add" { - if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO @@ -20,8 +19,6 @@ test "saturating add" { try testSatAdd(i2, 1, -1, 0); try testSatAdd(i2, -1, -1, -2); try testSatAdd(i64, maxInt(i64), 1, maxInt(i64)); - try testSatAdd(i128, maxInt(i128), -maxInt(i128), 0); - try testSatAdd(i128, minInt(i128), maxInt(i128), -1); try testSatAdd(i8, 127, 127, 127); try testSatAdd(u2, 0, 0, 0); try testSatAdd(u2, 0, 1, 1); @@ -29,7 +26,6 @@ test "saturating add" { try testSatAdd(u8, 255, 255, 255); try testSatAdd(u2, 3, 2, 3); try testSatAdd(u3, 7, 1, 7); - try testSatAdd(u128, maxInt(u128), 1, maxInt(u128)); } fn testSatAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void { @@ -54,12 +50,36 @@ test "saturating add" { comptime try S.testSatAdd(comptime_int, 7, -593423721213448152027139550640105366508, -593423721213448152027139550640105366501); } -test "saturating subtraction" { +test "saturating add 128bit" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + const S = struct { + fn doTheTest() !void { + try testSatAdd(i128, maxInt(i128), -maxInt(i128), 0); + try testSatAdd(i128, minInt(i128), maxInt(i128), -1); + try testSatAdd(u128, maxInt(u128), 1, maxInt(u128)); + } + fn testSatAdd(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs +| rhs) == expected); + + var x = lhs; + x +|= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); +} + +test "saturating subtraction" { + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO const S = struct { fn doTheTest() !void { @@ -71,14 +91,11 @@ test "saturating subtraction" { try testSatSub(i2, 1, -1, 1); try testSatSub(i2, -2, -2, 0); try testSatSub(i64, minInt(i64), 1, minInt(i64)); - try testSatSub(i128, maxInt(i128), -1, maxInt(i128)); - try testSatSub(i128, minInt(i128), -maxInt(i128), -1); try testSatSub(u2, 0, 0, 0); try testSatSub(u2, 0, 1, 0); try testSatSub(u5, 0, 31, 0); try testSatSub(u8, 10, 3, 7); try testSatSub(u8, 0, 255, 0); - try testSatSub(u128, 0, maxInt(u128), 0); } fn testSatSub(comptime T: type, lhs: T, rhs: T, expected: T) !void { @@ -103,6 +120,33 @@ test "saturating subtraction" { comptime try S.testSatSub(comptime_int, 7, -593423721213448152027139550640105366508, 593423721213448152027139550640105366515); } +test "saturating subtraction 128bit" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO + + const S = struct { + fn doTheTest() !void { + try testSatSub(i128, maxInt(i128), -1, maxInt(i128)); + try testSatSub(i128, minInt(i128), -maxInt(i128), -1); + try testSatSub(u128, 0, maxInt(u128), 0); + } + + fn testSatSub(comptime T: type, lhs: T, rhs: T, expected: T) !void { + try expect((lhs -| rhs) == expected); + + var x = lhs; + x -|= rhs; + try expect(x == expected); + } + }; + + try S.doTheTest(); + comptime try S.doTheTest(); +} + test "saturating multiplication" { if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO @@ -153,7 +197,6 @@ test "saturating multiplication" { } test "saturating shift-left" { - if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO @@ -193,7 +236,6 @@ test "saturating shift-left" { } test "saturating shl uses the LHS type" { - if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO