commit fad995834de6c8825b1c55463a810aaae7a7df3b (tree)
parent d1d9a92855bc0d596162b1f84f86614e7b2c8c19
Author: Motiejus Jakštys <motiejus@jakstys.lt>
Date: Tue, 24 Feb 2026 05:12:40 +0000
sema: cross-module generic monomorphization, AIR rollback (91 sema tests)
- Cross-module generic function body analysis with findStringInZirBytes
for name lookup across ZIR modules.
- Two-phase parameter mapping: comptime params mapped before return type
resolution, then runtime params create ARG instructions.
- call_arg_types: pass call-site types directly for generic parameters
to avoid evaluating cross-module ZIR type bodies.
- AIR rollback on comptime-returned inline calls (ported from Sema.zig
air_instructions.shrinkRetainingCapacity).
- Add sema tests: generic_fn_with_clz and generic_fn_with_shl_assign.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat:
4 files changed, 184 insertions(+), 82 deletions(-)
diff --git a/stage0/corpus.zig b/stage0/corpus.zig
@@ -203,7 +203,7 @@ pub const files = [_][]const u8{
"lib/std/math/expo2.zig", // 995
};
-pub const num_sema_passing: usize = 89;
+pub const num_sema_passing: usize = 91;
pub const sema_unit_tests = [_][]const u8{
"stage0/sema_tests/empty.zig",
@@ -295,5 +295,7 @@ pub const sema_unit_tests = [_][]const u8{
"stage0/sema_tests/max_float.zig",
"stage0/sema_tests/min_float.zig",
"stage0/sema_tests/f64_div.zig",
+ "stage0/sema_tests/generic_fn_with_clz.zig",
+ "stage0/sema_tests/generic_fn_with_shl_assign.zig",
};
diff --git a/stage0/sema.c b/stage0/sema.c
@@ -471,7 +471,7 @@ static uint8_t analyzeBodyRuntimeBreak(
static uint16_t floatBits(TypeIndex ty);
static void analyzeFuncBodyAndRecord(Sema* sema, SemaBlock* block,
uint32_t func_inst, uint32_t name_idx, AirInstRef* call_args,
- uint32_t call_args_len);
+ uint32_t call_args_len, TypeIndex* call_arg_types);
// getParamBody: extract param body from a param_block ZIR instruction.
// Ported from lib/std/zig/Zir.zig getParamBody.
@@ -2268,6 +2268,18 @@ static uint32_t findDeclInstByNameInZir(
return UINT32_MAX;
}
+// findStringInZirBytes: find a null-terminated string in ZIR string_bytes.
+// Returns the index, or 0 if not found.
+static uint32_t findStringInZirBytes(const Zir* zir, const char* target) {
+ size_t tlen = strlen(target);
+ for (uint32_t i = 1; i + tlen <= zir->string_bytes_len; i++) {
+ if (zir->string_bytes[i - 1] == 0
+ && memcmp(&zir->string_bytes[i], target, tlen + 1) == 0)
+ return i;
+ }
+ return 0;
+}
+
// findFuncInstInZir: find a func/func_fancy instruction by name in a ZIR.
// Scans the ZIR's struct_decl (inst 0) for a declaration matching the
@@ -4469,11 +4481,31 @@ static AirInstRef zirCall(
// Ported from Sema.zig analyzeCall: ensureFuncBodyAnalysisQueued.
// For generic functions, pass call args so comptime params are
// mapped to their resolved values (monomorphization).
- // Skip for cross-module calls: their function bodies belong to
- // the imported module's AIR, not the current file's.
- if (!is_cross_module)
- analyzeFuncBodyAndRecord(sema, block, func_inst, callee_name_idx,
- is_generic ? arg_refs : NULL, is_generic ? args_len : 0);
+ // For cross-module calls, only analyze generic (monomorphized)
+ // function bodies — non-generic cross-module functions belong to
+ // their own module's AIR, not the current file's.
+ if (!is_cross_module || is_generic) {
+ uint32_t body_name_idx = callee_name_idx;
+ if (is_cross_module && callee_name_idx != 0) {
+ // callee_name_idx is from the pre-swap ZIR's string table.
+ // Find the same name in the imported ZIR's string table.
+ const char* name
+ = (const char*)&saved_code.string_bytes[callee_name_idx];
+ body_name_idx = findStringInZirBytes(&sema->code, name);
+ }
+ // Compute call-site arg types before entering body analysis
+ // (which saves/resets AIR state, making semaTypeOf unavailable).
+ TypeIndex arg_type_buf[16];
+ TypeIndex* arg_types_ptr = NULL;
+ if (is_generic && args_len > 0) {
+ for (uint32_t ai = 0; ai < args_len && ai < 16; ai++)
+ arg_type_buf[ai] = semaTypeOf(sema, arg_refs[ai]);
+ arg_types_ptr = arg_type_buf;
+ }
+ analyzeFuncBodyAndRecord(sema, block, func_inst, body_name_idx,
+ is_generic ? arg_refs : NULL, is_generic ? args_len : 0,
+ arg_types_ptr);
+ }
// Clean up cross-module state.
if (is_cross_module) {
@@ -4550,6 +4582,12 @@ static AirInstRef zirCall(
AirInstTag block_tag
= need_debug_scope ? AIR_INST_DBG_INLINE_BLOCK : AIR_INST_BLOCK;
+ // Save AIR state for rollback if the inline call returns at comptime.
+ // Ported from Sema.zig:
+ // air_instructions.shrinkRetainingCapacity(block_inst)
+ uint32_t saved_air_inst_len = sema->air_inst_len;
+ uint32_t saved_block_inst_len = block->instructions_len;
+
// Reserve the block instruction (data filled later).
uint32_t block_inst_idx = semaAddInstAsIndex(sema, block_tag,
(AirInstData) { .ty_pl = { .ty_ref = 0, .payload = 0 } });
@@ -4688,9 +4726,13 @@ static AirInstRef zirCall(
// ComptimeReturn: the inline function returned at comptime.
// Skip block finalization — no AIR instructions emitted.
// Roll back AIR state to discard the reserved block and any body
- // instructions (they are dead). Ported from src/Sema.zig line 7872.
+ // instructions (they are dead). Ported from src/Sema.zig line 7872:
+ // sema.air_instructions.shrinkRetainingCapacity(block_inst)
+ // block.instructions.shrinkRetainingCapacity(block_index)
if (inlining.comptime_returned) {
AirInstRef ct_result = inlining.comptime_result;
+ sema->air_inst_len = saved_air_inst_len;
+ block->instructions_len = saved_block_inst_len;
// Cache comptime results for memoization.
if (all_comptime && AIR_REF_IS_IP(ct_result)) {
@@ -4890,7 +4932,7 @@ static AirInstRef zirCall(
// (from zirCall). name_idx is a string_bytes index for the function name.
static void analyzeFuncBodyAndRecord(Sema* sema, SemaBlock* block,
uint32_t func_inst, uint32_t name_idx, AirInstRef* call_args,
- uint32_t call_args_len) {
+ uint32_t call_args_len, TypeIndex* call_arg_types) {
if (!sema->func_air_list)
return;
FuncZirInfo fi = parseFuncZir(sema, func_inst);
@@ -4960,6 +5002,49 @@ static void analyzeFuncBodyAndRecord(Sema* sema, SemaBlock* block,
// Reserve extra[0] for main_block index.
semaAddExtra(sema, 0);
+ // --- Parse parameters ---
+ // For generic functions, comptime params must be mapped BEFORE return
+ // type resolution since the return type may reference them (e.g., T).
+ // Ported from src/Sema.zig finishFuncInstance which resolves comptime
+ // args before analyzing the return type.
+ uint32_t param_block_inst
+ = sema->code.extra[payload_index + fi.param_block_pi];
+ const uint32_t* param_body;
+ uint32_t param_body_len;
+ getParamBody(sema, param_block_inst, ¶m_body, ¶m_body_len);
+
+ uint32_t total_params = 0;
+ for (uint32_t p = 0; p < param_body_len; p++) {
+ ZirInstTag ptag = sema->code.inst_tags[param_body[p]];
+ if (ptag == ZIR_INST_PARAM || ptag == ZIR_INST_PARAM_COMPTIME
+ || ptag == ZIR_INST_PARAM_ANYTYPE
+ || ptag == ZIR_INST_PARAM_ANYTYPE_COMPTIME)
+ total_params++;
+ }
+ if (total_params > 0)
+ instMapEnsureSpaceForBody(&sema->inst_map, param_body, param_body_len);
+
+ // Phase 1: Map comptime params to call_args values.
+ // This enables return type resolution and runtime param type bodies
+ // that reference comptime params (e.g., *T where T is comptime).
+ if (call_args) {
+ uint32_t ct_arg_index = 0;
+ for (uint32_t p = 0; p < param_body_len; p++) {
+ uint32_t param_inst = param_body[p];
+ ZirInstTag ptag = sema->code.inst_tags[param_inst];
+ if (ptag != ZIR_INST_PARAM && ptag != ZIR_INST_PARAM_COMPTIME
+ && ptag != ZIR_INST_PARAM_ANYTYPE
+ && ptag != ZIR_INST_PARAM_ANYTYPE_COMPTIME)
+ continue;
+ bool is_ct = (ptag == ZIR_INST_PARAM_COMPTIME
+ || ptag == ZIR_INST_PARAM_ANYTYPE_COMPTIME);
+ if (is_ct && ct_arg_index < call_args_len)
+ instMapPut(
+ &sema->inst_map, param_inst, call_args[ct_arg_index]);
+ ct_arg_index++;
+ }
+ }
+
// Resolve the function return type.
if (fi.ret_ty_body_len == 0) {
sema->fn_ret_ty = IP_INDEX_VOID_TYPE;
@@ -5002,24 +5087,8 @@ static void analyzeFuncBodyAndRecord(Sema* sema, SemaBlock* block,
fn_block.want_safety = false;
fn_block.want_safety_set = true;
- // --- Process parameters ---
- uint32_t param_block_inst
- = sema->code.extra[payload_index + fi.param_block_pi];
- const uint32_t* param_body;
- uint32_t param_body_len;
- getParamBody(sema, param_block_inst, ¶m_body, ¶m_body_len);
-
- uint32_t total_params = 0;
- for (uint32_t p = 0; p < param_body_len; p++) {
- ZirInstTag ptag = sema->code.inst_tags[param_body[p]];
- if (ptag == ZIR_INST_PARAM || ptag == ZIR_INST_PARAM_COMPTIME
- || ptag == ZIR_INST_PARAM_ANYTYPE
- || ptag == ZIR_INST_PARAM_ANYTYPE_COMPTIME)
- total_params++;
- }
- if (total_params > 0)
- instMapEnsureSpaceForBody(&sema->inst_map, param_body, param_body_len);
-
+ // Phase 2: Process runtime parameters (create ARG instructions).
+ // Comptime params were already mapped in Phase 1.
uint32_t arg_index = 0; // index into call_args (all params)
uint32_t runtime_param_index = 0;
for (uint32_t p = 0; p < param_body_len; p++) {
@@ -5033,67 +5102,72 @@ static void analyzeFuncBodyAndRecord(Sema* sema, SemaBlock* block,
bool is_ct = (ptag == ZIR_INST_PARAM_COMPTIME
|| ptag == ZIR_INST_PARAM_ANYTYPE_COMPTIME);
- // For generic function monomorphization: map comptime params
- // to their resolved values from the call site instead of
- // creating ARG instructions.
- // Ported from src/Sema.zig finishFuncInstance (lines 7503-7570).
+ // Comptime params already mapped in Phase 1.
if (call_args && is_ct) {
- if (arg_index < call_args_len)
- instMapPut(&sema->inst_map, param_inst, call_args[arg_index]);
arg_index++;
continue;
}
- uint32_t param_payload
- = sema->code.inst_datas[param_inst].pl_tok.payload_index;
- uint32_t type_packed = sema->code.extra[param_payload + 1];
- uint32_t type_body_len_p = type_packed & 0x7FFFFFFF;
-
+ // Resolve parameter type.
+ // For generic monomorphization with call_arg_types, use the
+ // call-site type directly — the ZIR type body may reference
+ // cross-module comptime expressions we can't evaluate.
TypeIndex param_ty = IP_INDEX_VOID_TYPE;
- if (type_body_len_p == 1) {
- uint32_t type_inst = sema->code.extra[param_payload + 2];
- ZirInstTag type_tag = sema->code.inst_tags[type_inst];
- assert(type_tag == ZIR_INST_BREAK_INLINE);
- ZirInstRef type_ref
- = sema->code.inst_datas[type_inst].break_data.operand;
- AirInstRef type_air = resolveInst(sema, type_ref);
- param_ty = AIR_REF_TO_IP(type_air);
- } else if (type_body_len_p == 2) {
- uint32_t first_inst = sema->code.extra[param_payload + 2];
- ZirInstTag first_tag = sema->code.inst_tags[first_inst];
- if (first_tag == ZIR_INST_PTR_TYPE) {
- uint8_t zir_flags
- = sema->code.inst_datas[first_inst].ptr_type.flags;
- uint8_t zir_size
- = sema->code.inst_datas[first_inst].ptr_type.size;
- uint32_t pi
- = sema->code.inst_datas[first_inst].ptr_type.payload_index;
- ZirInstRef elem_ty_ref = sema->code.extra[pi];
- AirInstRef elem_air = resolveInst(sema, elem_ty_ref);
- TypeIndex elem_ty = AIR_REF_TO_IP(elem_air);
- uint32_t ip_flags = (uint32_t)zir_size & PTR_FLAGS_SIZE_MASK;
- if (!(zir_flags & 0x02))
- ip_flags |= PTR_FLAGS_IS_CONST;
- InternPoolKey key;
- memset(&key, 0, sizeof(key));
- key.tag = IP_KEY_PTR_TYPE;
- key.data.ptr_type.child = elem_ty;
- key.data.ptr_type.flags = ip_flags;
- param_ty = ipIntern(sema->ip, key);
+ if (call_arg_types && arg_index < call_args_len) {
+ param_ty = call_arg_types[arg_index];
+ } else {
+ uint32_t param_payload
+ = sema->code.inst_datas[param_inst].pl_tok.payload_index;
+ uint32_t type_packed = sema->code.extra[param_payload + 1];
+ uint32_t type_body_len_p = type_packed & 0x7FFFFFFF;
+
+ if (type_body_len_p == 1) {
+ uint32_t type_inst = sema->code.extra[param_payload + 2];
+ ZirInstTag type_tag = sema->code.inst_tags[type_inst];
+ assert(type_tag == ZIR_INST_BREAK_INLINE);
+ ZirInstRef type_ref
+ = sema->code.inst_datas[type_inst].break_data.operand;
+ AirInstRef type_air = resolveInst(sema, type_ref);
+ param_ty = AIR_REF_TO_IP(type_air);
+ } else if (type_body_len_p == 2) {
+ uint32_t first_inst = sema->code.extra[param_payload + 2];
+ ZirInstTag first_tag = sema->code.inst_tags[first_inst];
+ if (first_tag == ZIR_INST_PTR_TYPE) {
+ uint8_t zir_flags
+ = sema->code.inst_datas[first_inst].ptr_type.flags;
+ uint8_t zir_size
+ = sema->code.inst_datas[first_inst].ptr_type.size;
+ uint32_t pi = sema->code.inst_datas[first_inst]
+ .ptr_type.payload_index;
+ ZirInstRef elem_ty_ref = sema->code.extra[pi];
+ AirInstRef elem_air = resolveInst(sema, elem_ty_ref);
+ TypeIndex elem_ty = AIR_REF_TO_IP(elem_air);
+ uint32_t ip_flags
+ = (uint32_t)zir_size & PTR_FLAGS_SIZE_MASK;
+ if (!(zir_flags & 0x02))
+ ip_flags |= PTR_FLAGS_IS_CONST;
+ InternPoolKey key;
+ memset(&key, 0, sizeof(key));
+ key.tag = IP_KEY_PTR_TYPE;
+ key.data.ptr_type.child = elem_ty;
+ key.data.ptr_type.flags = ip_flags;
+ param_ty = ipIntern(sema->ip, key);
+ }
+ } else if (type_body_len_p > 2) {
+ const uint32_t* type_body
+ = &sema->code.extra[param_payload + 2];
+ instMapEnsureSpaceForBody(
+ &sema->inst_map, type_body, type_body_len_p);
+ fn_block.is_comptime = true;
+ (void)analyzeBodyInner(
+ sema, &fn_block, type_body, type_body_len_p);
+ fn_block.is_comptime = false;
+ ZirInstRef type_operand
+ = sema->code.inst_datas[sema->comptime_break_inst]
+ .break_data.operand;
+ AirInstRef type_air = resolveInst(sema, type_operand);
+ param_ty = AIR_REF_TO_IP(type_air);
}
- } else if (type_body_len_p > 2) {
- const uint32_t* type_body = &sema->code.extra[param_payload + 2];
- instMapEnsureSpaceForBody(
- &sema->inst_map, type_body, type_body_len_p);
- fn_block.is_comptime = true;
- (void)analyzeBodyInner(
- sema, &fn_block, type_body, type_body_len_p);
- fn_block.is_comptime = false;
- ZirInstRef type_operand
- = sema->code.inst_datas[sema->comptime_break_inst]
- .break_data.operand;
- AirInstRef type_air = resolveInst(sema, type_operand);
- param_ty = AIR_REF_TO_IP(type_air);
}
AirInstData arg_data;
@@ -5207,7 +5281,8 @@ static void zirFunc(Sema* sema, SemaBlock* block, uint32_t inst) {
if (!is_exported)
return;
- analyzeFuncBodyAndRecord(sema, block, inst, sema->cur_decl_name, NULL, 0);
+ analyzeFuncBodyAndRecord(
+ sema, block, inst, sema->cur_decl_name, NULL, 0, NULL);
}
// zirStructDecl: process struct_decl extended instruction.
diff --git a/stage0/sema_tests/generic_fn_with_clz.zig b/stage0/sema_tests/generic_fn_with_clz.zig
@@ -0,0 +1,11 @@
+fn doClz(comptime T: type, p: *T) i32 {
+ return @clz(p.*);
+}
+inline fn addf3(comptime T: type, a: T) T {
+ var x: T = a;
+ _ = doClz(T, &x);
+ return x;
+}
+export fn f(a: u32) u32 {
+ return addf3(u32, a);
+}
diff --git a/stage0/sema_tests/generic_fn_with_shl_assign.zig b/stage0/sema_tests/generic_fn_with_shl_assign.zig
@@ -0,0 +1,14 @@
+fn normalize(comptime T: type, sig: *T) i32 {
+ const integerBit: T = 1 << 10;
+ const shift: i32 = @clz(sig.*) - @clz(integerBit);
+ sig.* <<= @intCast(@as(u32, @bitCast(shift)));
+ return @as(i32, 1) - shift;
+}
+inline fn addf3(comptime T: type, a: T) T {
+ var x: T = a;
+ _ = normalize(T, &x);
+ return x;
+}
+export fn f(a: u16) u16 {
+ return addf3(u16, a);
+}