zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

commit cbe468a787e93d57ca64799e715ea7493f4a6ebb (tree)
parent b6aebc41177dec5d7dddcc7594e5435d54f17963
Author: Andrew Kelley <andrew@ziglang.org>
Date:   Sat, 11 Apr 2026 16:44:41 +0200

Merge pull request 'Sema: allow @round, @floor, @ceil, and @trunc to coerce to integer types' (#30906) from adria/zig:sema-rounding-casts into master

Reviewed-on: https://codeberg.org/ziglang/zig/pulls/30906

Diffstat:
Mlib/std/zig/AstGen.zig | 45+++++++++++++++++++++++++++++++++++++++++----
Mlib/std/zig/Zir.zig | 17+++++++++++++++++
Msrc/Sema.zig | 184++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
Msrc/print_zir.zig | 12++++++++++++
Mtest/behavior/cast.zig | 64++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5 files changed, 307 insertions(+), 15 deletions(-)

diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig @@ -9266,10 +9266,10 @@ fn builtinCall( .log => return floatUnOp(gz, scope, ri, node, params[0], .log), .log2 => return floatUnOp(gz, scope, ri, node, params[0], .log2), .log10 => return floatUnOp(gz, scope, ri, node, params[0], .log10), - .floor => return floatUnOp(gz, scope, ri, node, params[0], .floor), - .ceil => return floatUnOp(gz, scope, ri, node, params[0], .ceil), - .trunc => return floatUnOp(gz, scope, ri, node, params[0], .trunc), - .round => return floatUnOp(gz, scope, ri, node, params[0], .round), + .floor => return floatRoundOp(gz, scope, ri, node, params[0], .floor), + .ceil => return floatRoundOp(gz, scope, ri, node, params[0], .ceil), + .trunc => return floatRoundOp(gz, scope, ri, node, params[0], .trunc), + .round => return floatRoundOp(gz, scope, ri, node, params[0], .round), .int_from_float => return typeCast(gz, scope, ri, node, params[0], .int_from_float, builtin_name), .float_from_int => return typeCast(gz, scope, ri, node, params[0], .float_from_int, builtin_name), @@ -9819,6 +9819,43 @@ fn simpleUnOp( return rvalue(gz, ri, result, node); } +fn floatRoundOp( + gz: *GenZir, + scope: *Scope, + ri: ResultInfo, + node: Ast.Node.Index, + operand_node: Ast.Node.Index, + float_tag: Zir.Inst.Tag, +) InnerError!Zir.Inst.Ref { + if (try ri.rl.resultType(gz, node)) |dest_type| { + const cursor = maybeAdvanceSourceCursorToMainToken(gz, node); + + const operand_ty_inst = try gz.addExtendedPayload(.round_op_ty, Zir.Inst.UnNode{ + .node = gz.nodeIndexToRelative(node), + .operand = dest_type, + }); + + const operand = try expr(gz, scope, .{ .rl = .{ .coerced_ty = operand_ty_inst } }, operand_node); + + try emitDbgStmt(gz, cursor); + const round_op: Zir.Inst.RoundOp = switch (float_tag) { + .round => .round, + .floor => .floor, + .ceil => .ceil, + .trunc => .trunc, + else => unreachable, + }; + const result = try gz.addExtendedPayloadSmall(.round_op, @intFromEnum(round_op), Zir.Inst.BinNode{ + .node = gz.nodeIndexToRelative(node), + .lhs = dest_type, + .rhs = operand, + }); + return rvalue(gz, ri, result, node); + } else { + return floatUnOp(gz, scope, ri, node, operand_node, float_tag); + } +} + fn floatUnOp( gz: *GenZir, scope: *Scope, diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig @@ -2009,6 +2009,14 @@ pub const Inst = struct { /// `operand` is payload index to `BinNode`. /// `small` is unused. shl_with_overflow, + /// `@round`, `@floor`, `@ceil`, or `@trunc`, with a result type. + /// `operand` is payload index to `BinNode`. + /// `small` is a `RoundOp` representing the specific operation being performed. + round_op, + /// Returns the type for the operand of a rounding op. + /// `operand` is `UnNode`. + /// `small` is unused. + round_op_ty, /// `operand` is payload index to `UnNode`. c_undef, /// `operand` is payload index to `UnNode`. @@ -3233,6 +3241,13 @@ pub const Inst = struct { string_to_union_field_attrs, }; + pub const RoundOp = enum(u16) { + round, + floor, + ceil, + trunc, + }; + pub const UnNode = struct { node: Ast.Node.Offset, operand: Ref, @@ -4344,6 +4359,7 @@ fn findTrackableInner( .sub_with_overflow, .mul_with_overflow, .shl_with_overflow, + .round_op, .c_undef, .c_include, .c_define, @@ -4385,6 +4401,7 @@ fn findTrackableInner( .dbg_empty_stmt, .astgen_error, .float_op_result_ty, + .round_op_ty, => return, // `@TypeOf` has a body. diff --git a/src/Sema.zig b/src/Sema.zig @@ -1405,6 +1405,8 @@ fn analyzeBodyInner( .@"asm" => try sema.zirAsm( block, extended, false), .asm_expr => try sema.zirAsm( block, extended, true), .typeof_peer => try sema.zirTypeofPeer( block, extended, inst), + .round_op => try sema.zirRoundCast( block, extended), + .round_op_ty => try sema.zirRoundOpType( block, extended), .compile_log => try sema.zirCompileLog( block, extended), .min_multi => try sema.zirMinMaxMulti( block, extended, .min), .max_multi => try sema.zirMinMaxMulti( block, extended, .max), @@ -19685,21 +19687,16 @@ fn maybeConstantUnaryMath( return null; } -fn zirUnaryMath( +fn unaryMath( sema: *Sema, block: *Block, - inst: Zir.Inst.Index, + operand_src: LazySrcLoc, + operand: Air.Inst.Ref, air_tag: Air.Inst.Tag, comptime eval: fn (Value, Type, Allocator, Zcu.PerThread) Allocator.Error!Value, ) CompileError!Air.Inst.Ref { - const tracy = trace(@src()); - defer tracy.end(); - const pt = sema.pt; const zcu = pt.zcu; - const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; - const operand = sema.resolveInst(inst_data.operand); - const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0); const operand_ty = sema.typeOf(operand); const scalar_ty = operand_ty.scalarType(zcu); @@ -19719,6 +19716,23 @@ fn zirUnaryMath( }; } +fn zirUnaryMath( + sema: *Sema, + block: *Block, + inst: Zir.Inst.Index, + air_tag: Air.Inst.Tag, + comptime eval: fn (Value, Type, Allocator, Zcu.PerThread) Allocator.Error!Value, +) CompileError!Air.Inst.Ref { + const tracy = trace(@src()); + defer tracy.end(); + + const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; + const operand = sema.resolveInst(inst_data.operand); + const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0); + + return sema.unaryMath(block, operand_src, operand, air_tag, eval); +} + fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0); @@ -20900,6 +20914,130 @@ fn zirIntFromFloat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro }, dest_ty, operand); } +fn zirRoundCast( + sema: *Sema, + block: *Block, + extended: Zir.Inst.Extended.InstData, +) CompileError!Air.Inst.Ref { + const pt = sema.pt; + const zcu = pt.zcu; + const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data; + const src = block.nodeOffset(extra.node); + const operand_src = block.builtinCallArgSrc(extra.node, 0); + + const operand = sema.resolveInst(extra.rhs); + + const round_op: Zir.Inst.RoundOp = @enumFromInt(extended.small); + const mode: IntFromFloatMode = switch (round_op) { + .round => .round, + .floor => .floor, + .ceil => .ceil, + .trunc => .truncate, + }; + + const dest_ty = (try sema.resolveTypeOrPoison(block, src, extra.lhs) orelse switch (mode) { + // zig fmt: off + .round => return sema.unaryMath(block, operand_src, operand, .round, Value.round), + .floor => return sema.unaryMath(block, operand_src, operand, .floor, Value.floor), + .ceil => return sema.unaryMath(block, operand_src, operand, .ceil, Value.ceil), + .truncate => return sema.unaryMath(block, operand_src, operand, .trunc_float, Value.trunc), + // zig fmt: on + .exact => unreachable, + }).optEuBaseType(zcu); + + const operand_ty = sema.typeOf(operand); + + try sema.checkVectorizableBinaryOperands(block, operand_src, dest_ty, operand_ty, src, operand_src); + + const dest_scalar_ty = dest_ty.scalarType(zcu); + const operand_scalar_ty = operand_ty.scalarType(zcu); + + switch (operand_scalar_ty.zigTypeTag(zcu)) { + .comptime_float, .float => {}, + else => return sema.fail( + block, + operand_src, + "expected float or vector type, found '{f}'", + .{operand_ty.fmt(pt)}, + ), + } + + switch (dest_scalar_ty.zigTypeTag(zcu)) { + .float, .comptime_float => { + const coerced_operand = try sema.coerce(block, dest_ty, operand, operand_src); + + const result_ref = switch (mode) { + .round => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.round), + .floor => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.floor), + .ceil => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.ceil), + .truncate => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.trunc), + .exact => unreachable, + }; + + if (result_ref) |ref| return ref; + + const air_tag: Air.Inst.Tag = switch (mode) { + .round => .round, + .floor => .floor, + .ceil => .ceil, + .truncate => .trunc_float, + .exact => unreachable, + }; + + try sema.requireRuntimeBlock(block, operand_src, null); + return block.addUnOp(air_tag, coerced_operand); + }, + .int, .comptime_int => {}, + else => return sema.fail( + block, + src, + "expected integer, float, or vector of either integers or floats, found '{f}'", + .{dest_ty.fmt(pt)}, + ), + } + + if (sema.resolveValue(operand)) |operand_val| { + const result_val = try sema.intFromFloat(block, operand_src, operand_val, operand_ty, dest_ty, mode); + return .fromValue(result_val); + } else if (dest_scalar_ty.zigTypeTag(zcu) == .comptime_int) { + return sema.failWithNeededComptime(block, operand_src, .{ .simple = .casted_to_comptime_int }); + } + + try sema.requireRuntimeBlock(block, src, operand_src); + + if (dest_scalar_ty.intInfo(zcu).bits == 0) { + if (block.wantSafety()) { + const abs_ref = try block.addTyOp(.abs, operand_ty, operand); + const is_vector = dest_ty.zigTypeTag(zcu) == .vector; + const max_abs_ref = if (is_vector) try block.addReduce(abs_ref, .Max) else abs_ref; + const one_ref = Air.internedToRef((try pt.floatValue(operand_scalar_ty, 1.0)).toIntern()); + const ok_ref = try block.addBinOp(.cmp_lt, max_abs_ref, one_ref); + try sema.addSafetyCheck(block, src, ok_ref, .integer_part_out_of_bounds); + } + const scalar_val = try pt.intValue(dest_scalar_ty, 0); + return Air.internedToRef((try sema.splat(dest_ty, scalar_val)).toIntern()); + } + + const safe = block.wantSafety(); + + if (safe) { + try sema.preparePanicId(src, .integer_part_out_of_bounds); + } + + const uncasted_result: Air.Inst.Ref = switch (mode) { + .truncate => operand, + .round => try block.addUnOp(.round, operand), + .floor => try block.addUnOp(.floor, operand), + .ceil => try block.addUnOp(.ceil, operand), + .exact => unreachable, + }; + const air_cast_tag: Air.Inst.Tag = switch (block.float_mode) { + .optimized => if (safe) .int_from_float_optimized_safe else .int_from_float_optimized, + .strict => if (safe) .int_from_float_safe else .int_from_float, + }; + return block.addTyOp(air_cast_tag, dest_ty, uncasted_result); +} + fn zirFloatFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref { const pt = sema.pt; const zcu = pt.zcu; @@ -25078,6 +25216,23 @@ fn zirFloatOpResultType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended. return .fromType(float_ty); } +fn zirRoundOpType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref { + const pt = sema.pt; + const zcu = pt.zcu; + const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data; + const operand_src = block.builtinCallArgSrc(extra.node, 0); + + const dest_ty = try sema.resolveTypeOrPoison(block, operand_src, extra.operand) orelse { + return .generic_poison_type; + }; + + const float_ty = dest_ty.optEuBaseType(zcu); + switch (float_ty.scalarType(zcu).zigTypeTag(zcu)) { + .float, .comptime_float => return .fromType(float_ty), + else => return .comptime_float_type, + } +} + fn requireRuntimeBlock(sema: *Sema, block: *Block, src: LazySrcLoc, runtime_src: ?LazySrcLoc) !void { if (block.isComptime()) { const msg, const fail_block = msg: { @@ -33467,7 +33622,7 @@ fn structFieldIndex( return sema.failWithBadStructFieldAccess(block, struct_ty, struct_type, field_src, field_name); } -const IntFromFloatMode = enum { exact, truncate }; +const IntFromFloatMode = enum { exact, truncate, round, floor, ceil }; fn intFromFloat( sema: *Sema, @@ -33505,7 +33660,14 @@ fn intFromFloatScalar( if (val.isUndef(zcu)) return sema.failWithUseOfUndef(block, src, vec_idx); - const float = val.toFloat(f128, zcu); + var float = val.toFloat(f128, zcu); + switch (mode) { + .round => float = @round(float), + .floor => float = @floor(float), + .ceil => float = @ceil(float), + .truncate, .exact => {}, + } + if (std.math.isNan(float)) { return sema.fail(block, src, "float value NaN cannot be stored in integer type '{f}'", .{ int_ty.fmt(pt), @@ -33530,7 +33692,7 @@ fn intFromFloatScalar( "fractional component prevents float value '{f}' from coercion to type '{f}'", .{ val.fmtValueSema(pt, sema), int_ty.fmt(pt) }, ), - .truncate => {}, + .truncate, .round, .floor, .ceil => {}, }, .exact => {}, } diff --git a/src/print_zir.zig b/src/print_zir.zig @@ -570,6 +570,7 @@ const Writer = struct { .float_op_result_ty, .reify_tuple, .reify_pointer_sentinel_ty, + .round_op_ty, => { const inst_data = self.code.extraData(Zir.Inst.UnNode, extended.operand).data; try self.writeInstRef(stream, inst_data.operand); @@ -593,6 +594,17 @@ const Writer = struct { try self.writeSrcNode(stream, inst_data.node); }, + .round_op => { + const round_op: Zir.Inst.RoundOp = @enumFromInt(extended.small); + const inst_data = self.code.extraData(Zir.Inst.BinNode, extended.operand).data; + try stream.print("{s}, ", .{@tagName(round_op)}); + try self.writeInstRef(stream, inst_data.lhs); + try stream.writeAll(", "); + try self.writeInstRef(stream, inst_data.rhs); + try stream.writeAll(")) "); + try self.writeSrcNode(stream, inst_data.node); + }, + .reify_slice_arg_ty => { const reify_slice_arg_info: Zir.Inst.ReifySliceArgInfo = @enumFromInt(extended.small); const extra = self.code.extraData(Zir.Inst.UnNode, extended.operand).data; diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig @@ -116,6 +116,10 @@ test "@floatFromInt" { const f = @as(f32, @floatFromInt(k)); const i = @as(i32, @intFromFloat(f)); try expect(i == k); + try expect(@as(i32, @round(f)) == k); + try expect(@as(i32, @floor(f)) == k); + try expect(@as(i32, @ceil(f)) == k); + try expect(@as(i32, @trunc(f)) == k); } }; try S.doTheTest(); @@ -139,6 +143,10 @@ test "@floatFromInt(f80)" { const f = @as(f80, @floatFromInt(k)); const i = @as(Int, @intFromFloat(f)); try expect(i == k); + try expect(@as(Int, @round(f)) == k); + try expect(@as(Int, @floor(f)) == k); + try expect(@as(Int, @ceil(f)) == k); + try expect(@as(Int, @trunc(f)) == k); } }; try S.doTheTest(i31); @@ -167,6 +175,10 @@ test "type coercion from int to float" { try std.testing.expectEqual(int, @as(Int, @intFromFloat(float))); try std.testing.expectEqual(int, @as(Int, @intFromFloat(@ceil(float)))); try std.testing.expectEqual(int, @as(Int, @intFromFloat(@floor(float)))); + try std.testing.expectEqual(int, @as(Int, @round(float))); + try std.testing.expectEqual(int, @as(Int, @ceil(float))); + try std.testing.expectEqual(int, @as(Int, @floor(float))); + try std.testing.expectEqual(int, @as(Int, @trunc(float))); } // Exhaustively check that all possible values of the integer type can @@ -230,12 +242,54 @@ fn testIntFromFloats() !void { try expectIntFromFloat(f32, 255.1, u8, 255); try expectIntFromFloat(f32, 127.2, i8, 127); try expectIntFromFloat(f32, -128.2, i8, -128); + + try expectRoundCast(f32, 255.1, u8, 255); + try expectFloorCast(f32, 255.1, u8, 255); + try expectTruncCast(f32, 255.1, u8, 255); + + try expectRoundCast(f32, 127.2, i8, 127); + try expectFloorCast(f32, 127.2, i8, 127); + try expectTruncCast(f32, 127.2, i8, 127); + + try expectRoundCast(f32, -128.2, i8, -128); + try expectCeilCast(f32, -128.2, i8, -128); + try expectTruncCast(f32, -128.2, i8, -128); +} + +test "rounding builtins with anytype and context propagation" { + const S = struct { + const x: i32 = 10; + fn check(expected: anytype, actual: anytype) !void { + try expectEqual(expected, actual); + } + }; + try expectEqual(@as(f32, 1.0), @round(@as(f32, 1.4))); + try S.check(@as(f32, 1.0), @round(@as(f32, 1.4))); + + const y: f64 = @floor(@floatFromInt(S.x)); + try expect(y == 10.0); + + try expectEqual(1.0, @round(1.4)); + try S.check(1.0, @round(1.4)); } fn expectIntFromFloat(comptime F: type, f: F, comptime I: type, i: I) !void { try expect(@as(I, @intFromFloat(f)) == i); } +fn expectRoundCast(comptime F: type, f: F, comptime I: type, i: I) !void { + try expect(@as(I, @round(f)) == i); +} +fn expectFloorCast(comptime F: type, f: F, comptime I: type, i: I) !void { + try expect(@as(I, @floor(f)) == i); +} +fn expectCeilCast(comptime F: type, f: F, comptime I: type, i: I) !void { + try expect(@as(I, @ceil(f)) == i); +} +fn expectTruncCast(comptime F: type, f: F, comptime I: type, i: I) !void { + try expect(@as(I, @trunc(f)) == i); +} + test "implicitly cast indirect pointer to maybe-indirect pointer" { if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest; @@ -1309,6 +1363,12 @@ test "comptime float casts" { try expectIntFromFloat(comptime_int, 1234, i16, 1234); try expectIntFromFloat(comptime_float, 12.3, comptime_int, 12); + + try expectRoundCast(comptime_float, 12.3, comptime_int, 12); + + try expectFloorCast(comptime_float, 12.3, comptime_int, 12); + try expectCeilCast(comptime_float, 12.3, comptime_int, 13); + try expectTruncCast(comptime_float, 12.3, comptime_int, 12); } test "pointer reinterpret const float to int" { @@ -1756,6 +1816,10 @@ test "intFromFloat to zero-bit int" { const a: f32 = 0.0; try comptime std.testing.expect(@as(u0, @intFromFloat(a)) == 0); + try comptime std.testing.expect(@as(u0, @round(a)) == 0); + try comptime std.testing.expect(@as(u0, @floor(a)) == 0); + try comptime std.testing.expect(@as(u0, @ceil(a)) == 0); + try comptime std.testing.expect(@as(u0, @trunc(a)) == 0); } test "peer type resolution of function pointer and function body" {