commit cbe468a787e93d57ca64799e715ea7493f4a6ebb (tree)
parent b6aebc41177dec5d7dddcc7594e5435d54f17963
Author: Andrew Kelley <andrew@ziglang.org>
Date: Sat, 11 Apr 2026 16:44:41 +0200
Merge pull request 'Sema: allow @round, @floor, @ceil, and @trunc to coerce to integer types' (#30906) from adria/zig:sema-rounding-casts into master
Reviewed-on: https://codeberg.org/ziglang/zig/pulls/30906
Diffstat:
5 files changed, 307 insertions(+), 15 deletions(-)
diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig
@@ -9266,10 +9266,10 @@ fn builtinCall(
.log => return floatUnOp(gz, scope, ri, node, params[0], .log),
.log2 => return floatUnOp(gz, scope, ri, node, params[0], .log2),
.log10 => return floatUnOp(gz, scope, ri, node, params[0], .log10),
- .floor => return floatUnOp(gz, scope, ri, node, params[0], .floor),
- .ceil => return floatUnOp(gz, scope, ri, node, params[0], .ceil),
- .trunc => return floatUnOp(gz, scope, ri, node, params[0], .trunc),
- .round => return floatUnOp(gz, scope, ri, node, params[0], .round),
+ .floor => return floatRoundOp(gz, scope, ri, node, params[0], .floor),
+ .ceil => return floatRoundOp(gz, scope, ri, node, params[0], .ceil),
+ .trunc => return floatRoundOp(gz, scope, ri, node, params[0], .trunc),
+ .round => return floatRoundOp(gz, scope, ri, node, params[0], .round),
.int_from_float => return typeCast(gz, scope, ri, node, params[0], .int_from_float, builtin_name),
.float_from_int => return typeCast(gz, scope, ri, node, params[0], .float_from_int, builtin_name),
@@ -9819,6 +9819,43 @@ fn simpleUnOp(
return rvalue(gz, ri, result, node);
}
+fn floatRoundOp(
+ gz: *GenZir,
+ scope: *Scope,
+ ri: ResultInfo,
+ node: Ast.Node.Index,
+ operand_node: Ast.Node.Index,
+ float_tag: Zir.Inst.Tag,
+) InnerError!Zir.Inst.Ref {
+ if (try ri.rl.resultType(gz, node)) |dest_type| {
+ const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);
+
+ const operand_ty_inst = try gz.addExtendedPayload(.round_op_ty, Zir.Inst.UnNode{
+ .node = gz.nodeIndexToRelative(node),
+ .operand = dest_type,
+ });
+
+ const operand = try expr(gz, scope, .{ .rl = .{ .coerced_ty = operand_ty_inst } }, operand_node);
+
+ try emitDbgStmt(gz, cursor);
+ const round_op: Zir.Inst.RoundOp = switch (float_tag) {
+ .round => .round,
+ .floor => .floor,
+ .ceil => .ceil,
+ .trunc => .trunc,
+ else => unreachable,
+ };
+ const result = try gz.addExtendedPayloadSmall(.round_op, @intFromEnum(round_op), Zir.Inst.BinNode{
+ .node = gz.nodeIndexToRelative(node),
+ .lhs = dest_type,
+ .rhs = operand,
+ });
+ return rvalue(gz, ri, result, node);
+ } else {
+ return floatUnOp(gz, scope, ri, node, operand_node, float_tag);
+ }
+}
+
fn floatUnOp(
gz: *GenZir,
scope: *Scope,
diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig
@@ -2009,6 +2009,14 @@ pub const Inst = struct {
/// `operand` is payload index to `BinNode`.
/// `small` is unused.
shl_with_overflow,
+ /// `@round`, `@floor`, `@ceil`, or `@trunc`, with a result type.
+ /// `operand` is payload index to `BinNode`.
+ /// `small` is a `RoundOp` representing the specific operation being performed.
+ round_op,
+ /// Returns the type for the operand of a rounding op.
+ /// `operand` is `UnNode`.
+ /// `small` is unused.
+ round_op_ty,
/// `operand` is payload index to `UnNode`.
c_undef,
/// `operand` is payload index to `UnNode`.
@@ -3233,6 +3241,13 @@ pub const Inst = struct {
string_to_union_field_attrs,
};
+ pub const RoundOp = enum(u16) {
+ round,
+ floor,
+ ceil,
+ trunc,
+ };
+
pub const UnNode = struct {
node: Ast.Node.Offset,
operand: Ref,
@@ -4344,6 +4359,7 @@ fn findTrackableInner(
.sub_with_overflow,
.mul_with_overflow,
.shl_with_overflow,
+ .round_op,
.c_undef,
.c_include,
.c_define,
@@ -4385,6 +4401,7 @@ fn findTrackableInner(
.dbg_empty_stmt,
.astgen_error,
.float_op_result_ty,
+ .round_op_ty,
=> return,
// `@TypeOf` has a body.
diff --git a/src/Sema.zig b/src/Sema.zig
@@ -1405,6 +1405,8 @@ fn analyzeBodyInner(
.@"asm" => try sema.zirAsm( block, extended, false),
.asm_expr => try sema.zirAsm( block, extended, true),
.typeof_peer => try sema.zirTypeofPeer( block, extended, inst),
+ .round_op => try sema.zirRoundCast( block, extended),
+ .round_op_ty => try sema.zirRoundOpType( block, extended),
.compile_log => try sema.zirCompileLog( block, extended),
.min_multi => try sema.zirMinMaxMulti( block, extended, .min),
.max_multi => try sema.zirMinMaxMulti( block, extended, .max),
@@ -19685,21 +19687,16 @@ fn maybeConstantUnaryMath(
return null;
}
-fn zirUnaryMath(
+fn unaryMath(
sema: *Sema,
block: *Block,
- inst: Zir.Inst.Index,
+ operand_src: LazySrcLoc,
+ operand: Air.Inst.Ref,
air_tag: Air.Inst.Tag,
comptime eval: fn (Value, Type, Allocator, Zcu.PerThread) Allocator.Error!Value,
) CompileError!Air.Inst.Ref {
- const tracy = trace(@src());
- defer tracy.end();
-
const pt = sema.pt;
const zcu = pt.zcu;
- const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
- const operand = sema.resolveInst(inst_data.operand);
- const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0);
const operand_ty = sema.typeOf(operand);
const scalar_ty = operand_ty.scalarType(zcu);
@@ -19719,6 +19716,23 @@ fn zirUnaryMath(
};
}
+fn zirUnaryMath(
+ sema: *Sema,
+ block: *Block,
+ inst: Zir.Inst.Index,
+ air_tag: Air.Inst.Tag,
+ comptime eval: fn (Value, Type, Allocator, Zcu.PerThread) Allocator.Error!Value,
+) CompileError!Air.Inst.Ref {
+ const tracy = trace(@src());
+ defer tracy.end();
+
+ const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
+ const operand = sema.resolveInst(inst_data.operand);
+ const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0);
+
+ return sema.unaryMath(block, operand_src, operand, air_tag, eval);
+}
+
fn zirTagName(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node;
const operand_src = block.builtinCallArgSrc(inst_data.src_node, 0);
@@ -20900,6 +20914,130 @@ fn zirIntFromFloat(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileErro
}, dest_ty, operand);
}
+fn zirRoundCast(
+ sema: *Sema,
+ block: *Block,
+ extended: Zir.Inst.Extended.InstData,
+) CompileError!Air.Inst.Ref {
+ const pt = sema.pt;
+ const zcu = pt.zcu;
+ const extra = sema.code.extraData(Zir.Inst.BinNode, extended.operand).data;
+ const src = block.nodeOffset(extra.node);
+ const operand_src = block.builtinCallArgSrc(extra.node, 0);
+
+ const operand = sema.resolveInst(extra.rhs);
+
+ const round_op: Zir.Inst.RoundOp = @enumFromInt(extended.small);
+ const mode: IntFromFloatMode = switch (round_op) {
+ .round => .round,
+ .floor => .floor,
+ .ceil => .ceil,
+ .trunc => .truncate,
+ };
+
+ const dest_ty = (try sema.resolveTypeOrPoison(block, src, extra.lhs) orelse switch (mode) {
+ // zig fmt: off
+ .round => return sema.unaryMath(block, operand_src, operand, .round, Value.round),
+ .floor => return sema.unaryMath(block, operand_src, operand, .floor, Value.floor),
+ .ceil => return sema.unaryMath(block, operand_src, operand, .ceil, Value.ceil),
+ .truncate => return sema.unaryMath(block, operand_src, operand, .trunc_float, Value.trunc),
+ // zig fmt: on
+ .exact => unreachable,
+ }).optEuBaseType(zcu);
+
+ const operand_ty = sema.typeOf(operand);
+
+ try sema.checkVectorizableBinaryOperands(block, operand_src, dest_ty, operand_ty, src, operand_src);
+
+ const dest_scalar_ty = dest_ty.scalarType(zcu);
+ const operand_scalar_ty = operand_ty.scalarType(zcu);
+
+ switch (operand_scalar_ty.zigTypeTag(zcu)) {
+ .comptime_float, .float => {},
+ else => return sema.fail(
+ block,
+ operand_src,
+ "expected float or vector type, found '{f}'",
+ .{operand_ty.fmt(pt)},
+ ),
+ }
+
+ switch (dest_scalar_ty.zigTypeTag(zcu)) {
+ .float, .comptime_float => {
+ const coerced_operand = try sema.coerce(block, dest_ty, operand, operand_src);
+
+ const result_ref = switch (mode) {
+ .round => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.round),
+ .floor => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.floor),
+ .ceil => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.ceil),
+ .truncate => try sema.maybeConstantUnaryMath(coerced_operand, dest_ty, Value.trunc),
+ .exact => unreachable,
+ };
+
+ if (result_ref) |ref| return ref;
+
+ const air_tag: Air.Inst.Tag = switch (mode) {
+ .round => .round,
+ .floor => .floor,
+ .ceil => .ceil,
+ .truncate => .trunc_float,
+ .exact => unreachable,
+ };
+
+ try sema.requireRuntimeBlock(block, operand_src, null);
+ return block.addUnOp(air_tag, coerced_operand);
+ },
+ .int, .comptime_int => {},
+ else => return sema.fail(
+ block,
+ src,
+ "expected integer, float, or vector of either integers or floats, found '{f}'",
+ .{dest_ty.fmt(pt)},
+ ),
+ }
+
+ if (sema.resolveValue(operand)) |operand_val| {
+ const result_val = try sema.intFromFloat(block, operand_src, operand_val, operand_ty, dest_ty, mode);
+ return .fromValue(result_val);
+ } else if (dest_scalar_ty.zigTypeTag(zcu) == .comptime_int) {
+ return sema.failWithNeededComptime(block, operand_src, .{ .simple = .casted_to_comptime_int });
+ }
+
+ try sema.requireRuntimeBlock(block, src, operand_src);
+
+ if (dest_scalar_ty.intInfo(zcu).bits == 0) {
+ if (block.wantSafety()) {
+ const abs_ref = try block.addTyOp(.abs, operand_ty, operand);
+ const is_vector = dest_ty.zigTypeTag(zcu) == .vector;
+ const max_abs_ref = if (is_vector) try block.addReduce(abs_ref, .Max) else abs_ref;
+ const one_ref = Air.internedToRef((try pt.floatValue(operand_scalar_ty, 1.0)).toIntern());
+ const ok_ref = try block.addBinOp(.cmp_lt, max_abs_ref, one_ref);
+ try sema.addSafetyCheck(block, src, ok_ref, .integer_part_out_of_bounds);
+ }
+ const scalar_val = try pt.intValue(dest_scalar_ty, 0);
+ return Air.internedToRef((try sema.splat(dest_ty, scalar_val)).toIntern());
+ }
+
+ const safe = block.wantSafety();
+
+ if (safe) {
+ try sema.preparePanicId(src, .integer_part_out_of_bounds);
+ }
+
+ const uncasted_result: Air.Inst.Ref = switch (mode) {
+ .truncate => operand,
+ .round => try block.addUnOp(.round, operand),
+ .floor => try block.addUnOp(.floor, operand),
+ .ceil => try block.addUnOp(.ceil, operand),
+ .exact => unreachable,
+ };
+ const air_cast_tag: Air.Inst.Tag = switch (block.float_mode) {
+ .optimized => if (safe) .int_from_float_optimized_safe else .int_from_float_optimized,
+ .strict => if (safe) .int_from_float_safe else .int_from_float,
+ };
+ return block.addTyOp(air_cast_tag, dest_ty, uncasted_result);
+}
+
fn zirFloatFromInt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!Air.Inst.Ref {
const pt = sema.pt;
const zcu = pt.zcu;
@@ -25078,6 +25216,23 @@ fn zirFloatOpResultType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.
return .fromType(float_ty);
}
+fn zirRoundOpType(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData) CompileError!Air.Inst.Ref {
+ const pt = sema.pt;
+ const zcu = pt.zcu;
+ const extra = sema.code.extraData(Zir.Inst.UnNode, extended.operand).data;
+ const operand_src = block.builtinCallArgSrc(extra.node, 0);
+
+ const dest_ty = try sema.resolveTypeOrPoison(block, operand_src, extra.operand) orelse {
+ return .generic_poison_type;
+ };
+
+ const float_ty = dest_ty.optEuBaseType(zcu);
+ switch (float_ty.scalarType(zcu).zigTypeTag(zcu)) {
+ .float, .comptime_float => return .fromType(float_ty),
+ else => return .comptime_float_type,
+ }
+}
+
fn requireRuntimeBlock(sema: *Sema, block: *Block, src: LazySrcLoc, runtime_src: ?LazySrcLoc) !void {
if (block.isComptime()) {
const msg, const fail_block = msg: {
@@ -33467,7 +33622,7 @@ fn structFieldIndex(
return sema.failWithBadStructFieldAccess(block, struct_ty, struct_type, field_src, field_name);
}
-const IntFromFloatMode = enum { exact, truncate };
+const IntFromFloatMode = enum { exact, truncate, round, floor, ceil };
fn intFromFloat(
sema: *Sema,
@@ -33505,7 +33660,14 @@ fn intFromFloatScalar(
if (val.isUndef(zcu)) return sema.failWithUseOfUndef(block, src, vec_idx);
- const float = val.toFloat(f128, zcu);
+ var float = val.toFloat(f128, zcu);
+ switch (mode) {
+ .round => float = @round(float),
+ .floor => float = @floor(float),
+ .ceil => float = @ceil(float),
+ .truncate, .exact => {},
+ }
+
if (std.math.isNan(float)) {
return sema.fail(block, src, "float value NaN cannot be stored in integer type '{f}'", .{
int_ty.fmt(pt),
@@ -33530,7 +33692,7 @@ fn intFromFloatScalar(
"fractional component prevents float value '{f}' from coercion to type '{f}'",
.{ val.fmtValueSema(pt, sema), int_ty.fmt(pt) },
),
- .truncate => {},
+ .truncate, .round, .floor, .ceil => {},
},
.exact => {},
}
diff --git a/src/print_zir.zig b/src/print_zir.zig
@@ -570,6 +570,7 @@ const Writer = struct {
.float_op_result_ty,
.reify_tuple,
.reify_pointer_sentinel_ty,
+ .round_op_ty,
=> {
const inst_data = self.code.extraData(Zir.Inst.UnNode, extended.operand).data;
try self.writeInstRef(stream, inst_data.operand);
@@ -593,6 +594,17 @@ const Writer = struct {
try self.writeSrcNode(stream, inst_data.node);
},
+ .round_op => {
+ const round_op: Zir.Inst.RoundOp = @enumFromInt(extended.small);
+ const inst_data = self.code.extraData(Zir.Inst.BinNode, extended.operand).data;
+ try stream.print("{s}, ", .{@tagName(round_op)});
+ try self.writeInstRef(stream, inst_data.lhs);
+ try stream.writeAll(", ");
+ try self.writeInstRef(stream, inst_data.rhs);
+ try stream.writeAll(")) ");
+ try self.writeSrcNode(stream, inst_data.node);
+ },
+
.reify_slice_arg_ty => {
const reify_slice_arg_info: Zir.Inst.ReifySliceArgInfo = @enumFromInt(extended.small);
const extra = self.code.extraData(Zir.Inst.UnNode, extended.operand).data;
diff --git a/test/behavior/cast.zig b/test/behavior/cast.zig
@@ -116,6 +116,10 @@ test "@floatFromInt" {
const f = @as(f32, @floatFromInt(k));
const i = @as(i32, @intFromFloat(f));
try expect(i == k);
+ try expect(@as(i32, @round(f)) == k);
+ try expect(@as(i32, @floor(f)) == k);
+ try expect(@as(i32, @ceil(f)) == k);
+ try expect(@as(i32, @trunc(f)) == k);
}
};
try S.doTheTest();
@@ -139,6 +143,10 @@ test "@floatFromInt(f80)" {
const f = @as(f80, @floatFromInt(k));
const i = @as(Int, @intFromFloat(f));
try expect(i == k);
+ try expect(@as(Int, @round(f)) == k);
+ try expect(@as(Int, @floor(f)) == k);
+ try expect(@as(Int, @ceil(f)) == k);
+ try expect(@as(Int, @trunc(f)) == k);
}
};
try S.doTheTest(i31);
@@ -167,6 +175,10 @@ test "type coercion from int to float" {
try std.testing.expectEqual(int, @as(Int, @intFromFloat(float)));
try std.testing.expectEqual(int, @as(Int, @intFromFloat(@ceil(float))));
try std.testing.expectEqual(int, @as(Int, @intFromFloat(@floor(float))));
+ try std.testing.expectEqual(int, @as(Int, @round(float)));
+ try std.testing.expectEqual(int, @as(Int, @ceil(float)));
+ try std.testing.expectEqual(int, @as(Int, @floor(float)));
+ try std.testing.expectEqual(int, @as(Int, @trunc(float)));
}
// Exhaustively check that all possible values of the integer type can
@@ -230,12 +242,54 @@ fn testIntFromFloats() !void {
try expectIntFromFloat(f32, 255.1, u8, 255);
try expectIntFromFloat(f32, 127.2, i8, 127);
try expectIntFromFloat(f32, -128.2, i8, -128);
+
+ try expectRoundCast(f32, 255.1, u8, 255);
+ try expectFloorCast(f32, 255.1, u8, 255);
+ try expectTruncCast(f32, 255.1, u8, 255);
+
+ try expectRoundCast(f32, 127.2, i8, 127);
+ try expectFloorCast(f32, 127.2, i8, 127);
+ try expectTruncCast(f32, 127.2, i8, 127);
+
+ try expectRoundCast(f32, -128.2, i8, -128);
+ try expectCeilCast(f32, -128.2, i8, -128);
+ try expectTruncCast(f32, -128.2, i8, -128);
+}
+
+test "rounding builtins with anytype and context propagation" {
+ const S = struct {
+ const x: i32 = 10;
+ fn check(expected: anytype, actual: anytype) !void {
+ try expectEqual(expected, actual);
+ }
+ };
+ try expectEqual(@as(f32, 1.0), @round(@as(f32, 1.4)));
+ try S.check(@as(f32, 1.0), @round(@as(f32, 1.4)));
+
+ const y: f64 = @floor(@floatFromInt(S.x));
+ try expect(y == 10.0);
+
+ try expectEqual(1.0, @round(1.4));
+ try S.check(1.0, @round(1.4));
}
fn expectIntFromFloat(comptime F: type, f: F, comptime I: type, i: I) !void {
try expect(@as(I, @intFromFloat(f)) == i);
}
+fn expectRoundCast(comptime F: type, f: F, comptime I: type, i: I) !void {
+ try expect(@as(I, @round(f)) == i);
+}
+fn expectFloorCast(comptime F: type, f: F, comptime I: type, i: I) !void {
+ try expect(@as(I, @floor(f)) == i);
+}
+fn expectCeilCast(comptime F: type, f: F, comptime I: type, i: I) !void {
+ try expect(@as(I, @ceil(f)) == i);
+}
+fn expectTruncCast(comptime F: type, f: F, comptime I: type, i: I) !void {
+ try expect(@as(I, @trunc(f)) == i);
+}
+
test "implicitly cast indirect pointer to maybe-indirect pointer" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_spirv) return error.SkipZigTest;
@@ -1309,6 +1363,12 @@ test "comptime float casts" {
try expectIntFromFloat(comptime_int, 1234, i16, 1234);
try expectIntFromFloat(comptime_float, 12.3, comptime_int, 12);
+
+ try expectRoundCast(comptime_float, 12.3, comptime_int, 12);
+
+ try expectFloorCast(comptime_float, 12.3, comptime_int, 12);
+ try expectCeilCast(comptime_float, 12.3, comptime_int, 13);
+ try expectTruncCast(comptime_float, 12.3, comptime_int, 12);
}
test "pointer reinterpret const float to int" {
@@ -1756,6 +1816,10 @@ test "intFromFloat to zero-bit int" {
const a: f32 = 0.0;
try comptime std.testing.expect(@as(u0, @intFromFloat(a)) == 0);
+ try comptime std.testing.expect(@as(u0, @round(a)) == 0);
+ try comptime std.testing.expect(@as(u0, @floor(a)) == 0);
+ try comptime std.testing.expect(@as(u0, @ceil(a)) == 0);
+ try comptime std.testing.expect(@as(u0, @trunc(a)) == 0);
}
test "peer type resolution of function pointer and function body" {