commit 3991c4bf41d9aeb07d034f3f24c1bbb2b0c9f3a5 (tree)
parent b90553615a73c856f8aabe96c73fe7764f54462f
Author: Motiejus <motiejus@jakstys.lt>
Date: Sat, 7 Mar 2026 07:25:36 +0000
sema: extract analyzeArithmetic; split zirCmp/zirCmpEq
- Add analyzeArithmetic() worker matching upstream decomposition
- zirArithmetic() becomes a thin wrapper (extract operands + delegate)
- zirNegate/zirNegateWrap call analyzeArithmetic() directly
- zirNegate gains float branch (emit AIR_INST_NEG for float operands)
- Add zirCmp() for ordered comparisons (lt/lte/gt/gte)
- Add zirCmpEq() for equality comparisons (eq/neq)
- Update dispatch table to call zirCmp/zirCmpEq instead of zirArithmetic
Matches src/Sema.zig function decomposition: zirArithmetic → analyzeArithmetic,
zirCmp, zirCmpEq as separate entry points.
Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
Diffstat:
| M | stage0/sema.c | | | 268 | +++++++++++++++++++++++++++++++++++++++++++++++++------------------------------ |
1 file changed, 167 insertions(+), 101 deletions(-)
diff --git a/stage0/sema.c b/stage0/sema.c
@@ -852,7 +852,9 @@ static AirInstRef zirLoad(Sema* sema, SemaBlock* block, uint32_t inst) {
return semaAddInst(block, AIR_INST_LOAD, data);
}
-// Forward declarations for comptime helpers used by zirNegate.
+// Forward declarations for comptime helpers used by analyzeArithmetic.
+static void add128(uint64_t a_lo, uint64_t a_hi, uint64_t b_lo, uint64_t b_hi,
+ uint64_t* r_lo, uint64_t* r_hi);
static void sub128(uint64_t a_lo, uint64_t a_hi, uint64_t b_lo, uint64_t b_hi,
uint64_t* r_lo, uint64_t* r_hi);
static bool isComptimeIntWide(
@@ -860,25 +862,96 @@ static bool isComptimeIntWide(
static AirInstRef internComptimeInt(
Sema* sema, TypeIndex ty, uint64_t lo, uint64_t hi);
+// analyzeArithmetic: inner worker for arithmetic binary operations.
+// Ported from src/Sema.zig analyzeArithmetic (arithmetic subset).
+// Handles add/sub/mul and their wrap/sat variants.
+static AirInstRef analyzeArithmetic(Sema* sema, SemaBlock* block,
+ AirInstTag air_tag, AirInstRef lhs, AirInstRef rhs) {
+ // Comptime folding: if both operands are comptime integers,
+ // compute the result at comptime (128-bit).
+ uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
+ bool lhs_neg, rhs_neg;
+ bool lhs_ct = isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg);
+ bool rhs_ct = isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg);
+ if (lhs_ct && rhs_ct) {
+ uint64_t r_lo, r_hi;
+ TypeIndex lhs_ty = semaTypeOf(sema, lhs);
+ TypeIndex rhs_ty = semaTypeOf(sema, rhs);
+ TypeIndex result_ty
+ = (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
+ // Coerce both operands to result_ty, matching Zig's analyzeArithmetic
+ // (Sema.zig:16094-16095). Creates typed intermediates (e.g. u16(2)
+ // when coercing comptime_int(2) → u16) as IP side effects.
+ if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
+ lhs = semaCoerce(sema, block, result_ty, lhs);
+ rhs = semaCoerce(sema, block, result_ty, rhs);
+ }
+ switch (air_tag) {
+ case AIR_INST_ADD:
+ case AIR_INST_ADD_WRAP:
+ case AIR_INST_ADD_SAT:
+ add128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
+ break;
+ case AIR_INST_SUB:
+ case AIR_INST_SUB_WRAP:
+ case AIR_INST_SUB_SAT:
+ sub128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
+ break;
+ case AIR_INST_MUL:
+ case AIR_INST_MUL_WRAP:
+ case AIR_INST_MUL_SAT:
+ // Simple 128-bit mul: only lo*lo (sufficient for small values).
+ r_lo = lhs_lo * rhs_lo;
+ r_hi = lhs_hi * rhs_lo + lhs_lo * rhs_hi;
+ // Add the high part of lo*lo.
+ {
+ uint64_t a_lo = lhs_lo & 0xFFFFFFFFU;
+ uint64_t a_hi2 = lhs_lo >> 32;
+ uint64_t b_lo = rhs_lo & 0xFFFFFFFFU;
+ uint64_t b_hi2 = rhs_lo >> 32;
+ uint64_t cross = a_lo * b_hi2 + a_hi2 * b_lo;
+ r_hi += a_hi2 * b_hi2 + (cross >> 32);
+ uint64_t mid = (a_lo * b_lo >> 32) + (cross & 0xFFFFFFFFU);
+ (void)mid; // carry already in r_hi via full product
+ }
+ break;
+ default:
+ goto emit_runtime;
+ }
+ return internComptimeInt(sema, result_ty, r_lo, r_hi);
+ }
+
+emit_runtime:;
+ TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
+ lhs = semaCoerce(sema, block, peer_ty, lhs);
+ rhs = semaCoerce(sema, block, peer_ty, rhs);
+ AirInstData data;
+ memset(&data, 0, sizeof(data));
+ data.bin_op.lhs = lhs;
+ data.bin_op.rhs = rhs;
+ return semaAddInst(block, air_tag, data);
+}
+
// zirNegate: handle negate ZIR instruction (unary -).
// Ported from src/Sema.zig zirNegate.
-// Lowers to sub(0, operand).
+// For floats: emits AIR_INST_NEG. For ints: lowers to sub(0, operand).
static AirInstRef zirNegate(Sema* sema, SemaBlock* block, uint32_t inst) {
ZirInstRef operand_ref = sema->code.inst_datas[inst].un_node.operand;
AirInstRef operand = resolveInst(sema, operand_ref);
+ TypeIndex ty = semaTypeOf(sema, operand);
- // Comptime folding: negate comptime integer via sub(0, value).
- uint64_t val_lo, val_hi;
- bool val_neg;
- if (isComptimeIntWide(sema, operand, &val_lo, &val_hi, &val_neg)) {
- TypeIndex ty = semaTypeOf(sema, operand);
- uint64_t r_lo, r_hi;
- sub128(0, 0, val_lo, val_hi, &r_lo, &r_hi);
- return internComptimeInt(sema, ty, r_lo, r_hi);
+ // Float negation: emit AIR_INST_NEG directly.
+ // Ported from src/Sema.zig zirNegate float branch.
+ if (floatBits(ty) > 0) {
+ AirInstData data;
+ memset(&data, 0, sizeof(data));
+ data.ty_op.ty_ref = AIR_REF_FROM_IP(ty);
+ data.ty_op.operand = operand;
+ return semaAddInst(block, AIR_INST_NEG, data);
}
- TypeIndex ty = semaTypeOf(sema, operand);
- // Create a zero value of the operand type.
+ // Int negation: splat(rhs_ty, 0) then sub(0, operand).
+ // Ported from src/Sema.zig zirNegate int branch.
InternPoolKey key;
memset(&key, 0, sizeof(key));
key.tag = IP_KEY_INT;
@@ -886,11 +959,7 @@ static AirInstRef zirNegate(Sema* sema, SemaBlock* block, uint32_t inst) {
key.data.int_val.value_lo = 0;
key.data.int_val.is_negative = false;
AirInstRef zero = AIR_REF_FROM_IP(ipIntern(sema->ip, key));
- AirInstData data;
- memset(&data, 0, sizeof(data));
- data.bin_op.lhs = zero;
- data.bin_op.rhs = operand;
- return semaAddInst(block, AIR_INST_SUB, data);
+ return analyzeArithmetic(sema, block, AIR_INST_SUB, zero, operand);
}
// zirNegateWrap: handle negate_wrap ZIR instruction (wrapping unary -).
@@ -907,11 +976,7 @@ static AirInstRef zirNegateWrap(Sema* sema, SemaBlock* block, uint32_t inst) {
key.data.int_val.value_lo = 0;
key.data.int_val.is_negative = false;
AirInstRef zero = AIR_REF_FROM_IP(ipIntern(sema->ip, key));
- AirInstData data;
- memset(&data, 0, sizeof(data));
- data.bin_op.lhs = zero;
- data.bin_op.rhs = operand;
- return semaAddInst(block, AIR_INST_SUB_WRAP, data);
+ return analyzeArithmetic(sema, block, AIR_INST_SUB_WRAP, zero, operand);
}
// analyzeBitNot: inner worker for bitwise NOT.
@@ -1221,69 +1286,79 @@ static AirInstRef zirByteSwap(Sema* sema, SemaBlock* block, uint32_t inst) {
return semaAddInst(block, AIR_INST_BYTE_SWAP, data);
}
-// zirArithmetic: handle add/sub ZIR instructions.
-// Ported from src/Sema.zig zirArithmetic.
-static AirInstRef zirArithmetic(
+// zirCmpEq: handle cmp_eq/cmp_neq ZIR instructions.
+// Ported from src/Sema.zig zirCmpEq.
+static AirInstRef zirCmpEq(
Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
- ZirInstRef zir_lhs = sema->code.extra[payload_index];
- ZirInstRef zir_rhs = sema->code.extra[payload_index + 1];
- AirInstRef lhs = resolveInst(sema, zir_lhs);
- AirInstRef rhs = resolveInst(sema, zir_rhs);
+ AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
+ AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
- // Comptime folding: if both operands are comptime integers,
- // compute the result at comptime (128-bit).
+ // Comptime folding for integer operands.
uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
bool lhs_neg, rhs_neg;
- bool lhs_ct = isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg);
- bool rhs_ct = isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg);
- if (lhs_ct && rhs_ct) {
- uint64_t r_lo, r_hi;
+ if (isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg)
+ && isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg)) {
+ TypeIndex lhs_ty = semaTypeOf(sema, lhs);
+ TypeIndex rhs_ty = semaTypeOf(sema, rhs);
+ TypeIndex result_ty
+ = (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
+ if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
+ lhs = semaCoerce(sema, block, result_ty, lhs);
+ rhs = semaCoerce(sema, block, result_ty, rhs);
+ }
+ bool eq = (lhs_lo == rhs_lo && lhs_hi == rhs_hi);
+ return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
+ ? IP_INDEX_BOOL_TRUE
+ : IP_INDEX_BOOL_FALSE);
+ }
+
+ // Comptime equality for non-integer IP values (e.g. enum_tag,
+ // enum_literal). Coerce through peer types first so that enum_literal
+ // → enum coercion creates the enum_tag IP entry (matching the Zig
+ // compiler's analyzeCmp → resolvePeerTypes → coerce path).
+ if (AIR_REF_IS_IP(lhs) && AIR_REF_IS_IP(rhs)) {
+ TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
+ lhs = semaCoerce(sema, block, peer_ty, lhs);
+ rhs = semaCoerce(sema, block, peer_ty, rhs);
+ bool eq = (AIR_REF_TO_IP(lhs) == AIR_REF_TO_IP(rhs));
+ return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
+ ? IP_INDEX_BOOL_TRUE
+ : IP_INDEX_BOOL_FALSE);
+ }
+
+ TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
+ lhs = semaCoerce(sema, block, peer_ty, lhs);
+ rhs = semaCoerce(sema, block, peer_ty, rhs);
+ AirInstData data;
+ memset(&data, 0, sizeof(data));
+ data.bin_op.lhs = lhs;
+ data.bin_op.rhs = rhs;
+ return semaAddInst(block, air_tag, data);
+}
+
+// zirCmp: handle cmp_lt/cmp_lte/cmp_gte/cmp_gt ZIR instructions.
+// Ported from src/Sema.zig zirCmp → analyzeCmp → cmpNumeric.
+static AirInstRef zirCmp(
+ Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
+ uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
+ AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
+ AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
+
+ // Comptime folding for integer operands.
+ uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
+ bool lhs_neg, rhs_neg;
+ if (isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg)
+ && isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg)) {
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
TypeIndex result_ty
= (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
- // Coerce both operands to result_ty, matching Zig's analyzeArithmetic
- // (Sema.zig:16094-16095). Creates typed intermediates (e.g. u16(2)
- // when coercing comptime_int(2) → u16) as IP side effects.
if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
lhs = semaCoerce(sema, block, result_ty, lhs);
rhs = semaCoerce(sema, block, result_ty, rhs);
}
switch (air_tag) {
- case AIR_INST_ADD:
- case AIR_INST_ADD_WRAP:
- add128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
- break;
- case AIR_INST_SUB:
- case AIR_INST_SUB_WRAP:
- sub128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
- break;
- case AIR_INST_MUL:
- case AIR_INST_MUL_WRAP:
- // Simple 128-bit mul: only lo*lo (sufficient for small values).
- r_lo = lhs_lo * rhs_lo;
- r_hi = lhs_hi * rhs_lo + lhs_lo * rhs_hi;
- // Add the high part of lo*lo.
- {
- uint64_t a_lo = lhs_lo & 0xFFFFFFFFU;
- uint64_t a_hi2 = lhs_lo >> 32;
- uint64_t b_lo = rhs_lo & 0xFFFFFFFFU;
- uint64_t b_hi2 = rhs_lo >> 32;
- uint64_t cross = a_lo * b_hi2 + a_hi2 * b_lo;
- r_hi += a_hi2 * b_hi2 + (cross >> 32);
- uint64_t mid = (a_lo * b_lo >> 32) + (cross & 0xFFFFFFFFU);
- (void)mid; // carry already in r_hi via full product
- }
- break;
- case AIR_INST_CMP_EQ:
- return AIR_REF_FROM_IP((lhs_lo == rhs_lo && lhs_hi == rhs_hi)
- ? IP_INDEX_BOOL_TRUE
- : IP_INDEX_BOOL_FALSE);
- case AIR_INST_CMP_NEQ:
- return AIR_REF_FROM_IP((lhs_lo != rhs_lo || lhs_hi != rhs_hi)
- ? IP_INDEX_BOOL_TRUE
- : IP_INDEX_BOOL_FALSE);
case AIR_INST_CMP_LT:
return AIR_REF_FROM_IP(
(lhs_hi < rhs_hi || (lhs_hi == rhs_hi && lhs_lo < rhs_lo))
@@ -1305,42 +1380,24 @@ static AirInstRef zirArithmetic(
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
default:
- goto emit_runtime;
+ break;
}
- return internComptimeInt(sema, result_ty, r_lo, r_hi);
}
- // Comptime equality for non-integer IP values (e.g. enum_tag,
- // enum_literal). Coerce through peer types first so that enum_literal
- // → enum coercion creates the enum_tag IP entry (matching the Zig
- // compiler's analyzeCmp → resolvePeerTypes → coerce path).
- if (AIR_REF_IS_IP(lhs) && AIR_REF_IS_IP(rhs)
- && (air_tag == AIR_INST_CMP_EQ || air_tag == AIR_INST_CMP_NEQ)) {
- TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
- lhs = semaCoerce(sema, block, peer_ty, lhs);
- rhs = semaCoerce(sema, block, peer_ty, rhs);
- bool eq = (AIR_REF_TO_IP(lhs) == AIR_REF_TO_IP(rhs));
- return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
- ? IP_INDEX_BOOL_TRUE
- : IP_INDEX_BOOL_FALSE);
- }
-
-emit_runtime:;
// Ported from Sema.zig cmpNumeric → compareIntsOnlyPossibleResult
// (line 32299-32325): when one operand is comptime-known and the
// other is a runtime integer, intern the type's min/max bounds.
// These entries are side effects of the bounds check; even if the
// comparison can't be folded, the boundary values must be interned
// to match the Zig compiler's IP entry sequence.
- if (air_tag >= AIR_INST_CMP_LT && air_tag <= AIR_INST_CMP_GT) {
+ {
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
bool lhs_is_ct = (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE);
bool rhs_is_ct = (rhs_ty == IP_INDEX_COMPTIME_INT_TYPE);
if ((lhs_is_ct != rhs_is_ct) && !lhs_is_ct
&& sema->ip->items[lhs_ty].tag == IP_KEY_INT_TYPE) {
- // RHS is comptime, LHS is runtime int → intern LHS
- // type bounds.
+ // RHS is comptime, LHS is runtime int → intern LHS type bounds.
uint16_t bits = sema->ip->items[lhs_ty].data.int_type.bits;
bool is_signed = sema->ip->items[lhs_ty].data.int_type.signedness;
uint64_t min_val = 0;
@@ -1362,8 +1419,7 @@ emit_runtime:;
(void)ipIntern(sema->ip, mk);
} else if ((lhs_is_ct != rhs_is_ct) && !rhs_is_ct
&& sema->ip->items[rhs_ty].tag == IP_KEY_INT_TYPE) {
- // LHS is comptime, RHS is runtime int → intern RHS
- // type bounds.
+ // LHS is comptime, RHS is runtime int → intern RHS type bounds.
uint16_t bits = sema->ip->items[rhs_ty].data.int_type.bits;
bool is_signed = sema->ip->items[rhs_ty].data.int_type.signedness;
uint64_t min_val = 0;
@@ -1396,6 +1452,16 @@ emit_runtime:;
return semaAddInst(block, air_tag, data);
}
+// zirArithmetic: handle add/sub/mul/sat ZIR instructions.
+// Ported from src/Sema.zig zirArithmetic.
+static AirInstRef zirArithmetic(
+ Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
+ uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
+ AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
+ AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
+ return analyzeArithmetic(sema, block, air_tag, lhs, rhs);
+}
+
// zirDiv: handle div ZIR instruction (/ operator).
// Ported from src/Sema.zig zirDiv.
// For floats: emits AIR_INST_DIV_FLOAT (strict mode).
@@ -11224,35 +11290,35 @@ bool analyzeBodyInner(
i++;
continue;
- // Comparisons: same binary pattern as arithmetic.
+ // Comparisons: zirCmp for ordered, zirCmpEq for equality.
case ZIR_INST_CMP_LT:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_LT));
+ zirCmp(sema, block, inst, AIR_INST_CMP_LT));
i++;
continue;
case ZIR_INST_CMP_LTE:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_LTE));
+ zirCmp(sema, block, inst, AIR_INST_CMP_LTE));
i++;
continue;
case ZIR_INST_CMP_EQ:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_EQ));
+ zirCmpEq(sema, block, inst, AIR_INST_CMP_EQ));
i++;
continue;
case ZIR_INST_CMP_GTE:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_GTE));
+ zirCmp(sema, block, inst, AIR_INST_CMP_GTE));
i++;
continue;
case ZIR_INST_CMP_GT:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_GT));
+ zirCmp(sema, block, inst, AIR_INST_CMP_GT));
i++;
continue;
case ZIR_INST_CMP_NEQ:
instMapPut(&sema->inst_map, inst,
- zirArithmetic(sema, block, inst, AIR_INST_CMP_NEQ));
+ zirCmpEq(sema, block, inst, AIR_INST_CMP_NEQ));
i++;
continue;