sema: extract analyzeArithmetic; split zirCmp/zirCmpEq

- Add analyzeArithmetic() worker matching upstream decomposition
- zirArithmetic() becomes a thin wrapper (extract operands + delegate)
- zirNegate/zirNegateWrap call analyzeArithmetic() directly
- zirNegate gains float branch (emit AIR_INST_NEG for float operands)
- Add zirCmp() for ordered comparisons (lt/lte/gt/gte)
- Add zirCmpEq() for equality comparisons (eq/neq)
- Update dispatch table to call zirCmp/zirCmpEq instead of zirArithmetic

Matches src/Sema.zig function decomposition: zirArithmetic → analyzeArithmetic,
zirCmp, zirCmpEq as separate entry points.

Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-07 07:25:36 +00:00
parent b90553615a
commit 3991c4bf41

View File

@@ -852,7 +852,9 @@ static AirInstRef zirLoad(Sema* sema, SemaBlock* block, uint32_t inst) {
return semaAddInst(block, AIR_INST_LOAD, data);
}
// Forward declarations for comptime helpers used by zirNegate.
// Forward declarations for comptime helpers used by analyzeArithmetic.
static void add128(uint64_t a_lo, uint64_t a_hi, uint64_t b_lo, uint64_t b_hi,
uint64_t* r_lo, uint64_t* r_hi);
static void sub128(uint64_t a_lo, uint64_t a_hi, uint64_t b_lo, uint64_t b_hi,
uint64_t* r_lo, uint64_t* r_hi);
static bool isComptimeIntWide(
@@ -860,25 +862,96 @@ static bool isComptimeIntWide(
static AirInstRef internComptimeInt(
Sema* sema, TypeIndex ty, uint64_t lo, uint64_t hi);
// analyzeArithmetic: inner worker for arithmetic binary operations.
// Ported from src/Sema.zig analyzeArithmetic (arithmetic subset).
// Handles add/sub/mul and their wrap/sat variants.
static AirInstRef analyzeArithmetic(Sema* sema, SemaBlock* block,
AirInstTag air_tag, AirInstRef lhs, AirInstRef rhs) {
// Comptime folding: if both operands are comptime integers,
// compute the result at comptime (128-bit).
uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
bool lhs_neg, rhs_neg;
bool lhs_ct = isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg);
bool rhs_ct = isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg);
if (lhs_ct && rhs_ct) {
uint64_t r_lo, r_hi;
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
TypeIndex result_ty
= (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
// Coerce both operands to result_ty, matching Zig's analyzeArithmetic
// (Sema.zig:16094-16095). Creates typed intermediates (e.g. u16(2)
// when coercing comptime_int(2) → u16) as IP side effects.
if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
lhs = semaCoerce(sema, block, result_ty, lhs);
rhs = semaCoerce(sema, block, result_ty, rhs);
}
switch (air_tag) {
case AIR_INST_ADD:
case AIR_INST_ADD_WRAP:
case AIR_INST_ADD_SAT:
add128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
break;
case AIR_INST_SUB:
case AIR_INST_SUB_WRAP:
case AIR_INST_SUB_SAT:
sub128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
break;
case AIR_INST_MUL:
case AIR_INST_MUL_WRAP:
case AIR_INST_MUL_SAT:
// Simple 128-bit mul: only lo*lo (sufficient for small values).
r_lo = lhs_lo * rhs_lo;
r_hi = lhs_hi * rhs_lo + lhs_lo * rhs_hi;
// Add the high part of lo*lo.
{
uint64_t a_lo = lhs_lo & 0xFFFFFFFFU;
uint64_t a_hi2 = lhs_lo >> 32;
uint64_t b_lo = rhs_lo & 0xFFFFFFFFU;
uint64_t b_hi2 = rhs_lo >> 32;
uint64_t cross = a_lo * b_hi2 + a_hi2 * b_lo;
r_hi += a_hi2 * b_hi2 + (cross >> 32);
uint64_t mid = (a_lo * b_lo >> 32) + (cross & 0xFFFFFFFFU);
(void)mid; // carry already in r_hi via full product
}
break;
default:
goto emit_runtime;
}
return internComptimeInt(sema, result_ty, r_lo, r_hi);
}
emit_runtime:;
TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
lhs = semaCoerce(sema, block, peer_ty, lhs);
rhs = semaCoerce(sema, block, peer_ty, rhs);
AirInstData data;
memset(&data, 0, sizeof(data));
data.bin_op.lhs = lhs;
data.bin_op.rhs = rhs;
return semaAddInst(block, air_tag, data);
}
// zirNegate: handle negate ZIR instruction (unary -).
// Ported from src/Sema.zig zirNegate.
// Lowers to sub(0, operand).
// For floats: emits AIR_INST_NEG. For ints: lowers to sub(0, operand).
static AirInstRef zirNegate(Sema* sema, SemaBlock* block, uint32_t inst) {
ZirInstRef operand_ref = sema->code.inst_datas[inst].un_node.operand;
AirInstRef operand = resolveInst(sema, operand_ref);
TypeIndex ty = semaTypeOf(sema, operand);
// Comptime folding: negate comptime integer via sub(0, value).
uint64_t val_lo, val_hi;
bool val_neg;
if (isComptimeIntWide(sema, operand, &val_lo, &val_hi, &val_neg)) {
TypeIndex ty = semaTypeOf(sema, operand);
uint64_t r_lo, r_hi;
sub128(0, 0, val_lo, val_hi, &r_lo, &r_hi);
return internComptimeInt(sema, ty, r_lo, r_hi);
// Float negation: emit AIR_INST_NEG directly.
// Ported from src/Sema.zig zirNegate float branch.
if (floatBits(ty) > 0) {
AirInstData data;
memset(&data, 0, sizeof(data));
data.ty_op.ty_ref = AIR_REF_FROM_IP(ty);
data.ty_op.operand = operand;
return semaAddInst(block, AIR_INST_NEG, data);
}
TypeIndex ty = semaTypeOf(sema, operand);
// Create a zero value of the operand type.
// Int negation: splat(rhs_ty, 0) then sub(0, operand).
// Ported from src/Sema.zig zirNegate int branch.
InternPoolKey key;
memset(&key, 0, sizeof(key));
key.tag = IP_KEY_INT;
@@ -886,11 +959,7 @@ static AirInstRef zirNegate(Sema* sema, SemaBlock* block, uint32_t inst) {
key.data.int_val.value_lo = 0;
key.data.int_val.is_negative = false;
AirInstRef zero = AIR_REF_FROM_IP(ipIntern(sema->ip, key));
AirInstData data;
memset(&data, 0, sizeof(data));
data.bin_op.lhs = zero;
data.bin_op.rhs = operand;
return semaAddInst(block, AIR_INST_SUB, data);
return analyzeArithmetic(sema, block, AIR_INST_SUB, zero, operand);
}
// zirNegateWrap: handle negate_wrap ZIR instruction (wrapping unary -).
@@ -907,11 +976,7 @@ static AirInstRef zirNegateWrap(Sema* sema, SemaBlock* block, uint32_t inst) {
key.data.int_val.value_lo = 0;
key.data.int_val.is_negative = false;
AirInstRef zero = AIR_REF_FROM_IP(ipIntern(sema->ip, key));
AirInstData data;
memset(&data, 0, sizeof(data));
data.bin_op.lhs = zero;
data.bin_op.rhs = operand;
return semaAddInst(block, AIR_INST_SUB_WRAP, data);
return analyzeArithmetic(sema, block, AIR_INST_SUB_WRAP, zero, operand);
}
// analyzeBitNot: inner worker for bitwise NOT.
@@ -1221,69 +1286,79 @@ static AirInstRef zirByteSwap(Sema* sema, SemaBlock* block, uint32_t inst) {
return semaAddInst(block, AIR_INST_BYTE_SWAP, data);
}
// zirArithmetic: handle add/sub ZIR instructions.
// Ported from src/Sema.zig zirArithmetic.
static AirInstRef zirArithmetic(
// zirCmpEq: handle cmp_eq/cmp_neq ZIR instructions.
// Ported from src/Sema.zig zirCmpEq.
static AirInstRef zirCmpEq(
Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
ZirInstRef zir_lhs = sema->code.extra[payload_index];
ZirInstRef zir_rhs = sema->code.extra[payload_index + 1];
AirInstRef lhs = resolveInst(sema, zir_lhs);
AirInstRef rhs = resolveInst(sema, zir_rhs);
AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
// Comptime folding: if both operands are comptime integers,
// compute the result at comptime (128-bit).
// Comptime folding for integer operands.
uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
bool lhs_neg, rhs_neg;
bool lhs_ct = isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg);
bool rhs_ct = isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg);
if (lhs_ct && rhs_ct) {
uint64_t r_lo, r_hi;
if (isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg)
&& isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg)) {
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
TypeIndex result_ty
= (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
lhs = semaCoerce(sema, block, result_ty, lhs);
rhs = semaCoerce(sema, block, result_ty, rhs);
}
bool eq = (lhs_lo == rhs_lo && lhs_hi == rhs_hi);
return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
}
// Comptime equality for non-integer IP values (e.g. enum_tag,
// enum_literal). Coerce through peer types first so that enum_literal
// → enum coercion creates the enum_tag IP entry (matching the Zig
// compiler's analyzeCmp → resolvePeerTypes → coerce path).
if (AIR_REF_IS_IP(lhs) && AIR_REF_IS_IP(rhs)) {
TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
lhs = semaCoerce(sema, block, peer_ty, lhs);
rhs = semaCoerce(sema, block, peer_ty, rhs);
bool eq = (AIR_REF_TO_IP(lhs) == AIR_REF_TO_IP(rhs));
return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
}
TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
lhs = semaCoerce(sema, block, peer_ty, lhs);
rhs = semaCoerce(sema, block, peer_ty, rhs);
AirInstData data;
memset(&data, 0, sizeof(data));
data.bin_op.lhs = lhs;
data.bin_op.rhs = rhs;
return semaAddInst(block, air_tag, data);
}
// zirCmp: handle cmp_lt/cmp_lte/cmp_gte/cmp_gt ZIR instructions.
// Ported from src/Sema.zig zirCmp → analyzeCmp → cmpNumeric.
static AirInstRef zirCmp(
Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
// Comptime folding for integer operands.
uint64_t lhs_lo, lhs_hi, rhs_lo, rhs_hi;
bool lhs_neg, rhs_neg;
if (isComptimeIntWide(sema, lhs, &lhs_lo, &lhs_hi, &lhs_neg)
&& isComptimeIntWide(sema, rhs, &rhs_lo, &rhs_hi, &rhs_neg)) {
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
TypeIndex result_ty
= (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE) ? rhs_ty : lhs_ty;
// Coerce both operands to result_ty, matching Zig's analyzeArithmetic
// (Sema.zig:16094-16095). Creates typed intermediates (e.g. u16(2)
// when coercing comptime_int(2) → u16) as IP side effects.
if (result_ty != IP_INDEX_COMPTIME_INT_TYPE) {
lhs = semaCoerce(sema, block, result_ty, lhs);
rhs = semaCoerce(sema, block, result_ty, rhs);
}
switch (air_tag) {
case AIR_INST_ADD:
case AIR_INST_ADD_WRAP:
add128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
break;
case AIR_INST_SUB:
case AIR_INST_SUB_WRAP:
sub128(lhs_lo, lhs_hi, rhs_lo, rhs_hi, &r_lo, &r_hi);
break;
case AIR_INST_MUL:
case AIR_INST_MUL_WRAP:
// Simple 128-bit mul: only lo*lo (sufficient for small values).
r_lo = lhs_lo * rhs_lo;
r_hi = lhs_hi * rhs_lo + lhs_lo * rhs_hi;
// Add the high part of lo*lo.
{
uint64_t a_lo = lhs_lo & 0xFFFFFFFFU;
uint64_t a_hi2 = lhs_lo >> 32;
uint64_t b_lo = rhs_lo & 0xFFFFFFFFU;
uint64_t b_hi2 = rhs_lo >> 32;
uint64_t cross = a_lo * b_hi2 + a_hi2 * b_lo;
r_hi += a_hi2 * b_hi2 + (cross >> 32);
uint64_t mid = (a_lo * b_lo >> 32) + (cross & 0xFFFFFFFFU);
(void)mid; // carry already in r_hi via full product
}
break;
case AIR_INST_CMP_EQ:
return AIR_REF_FROM_IP((lhs_lo == rhs_lo && lhs_hi == rhs_hi)
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
case AIR_INST_CMP_NEQ:
return AIR_REF_FROM_IP((lhs_lo != rhs_lo || lhs_hi != rhs_hi)
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
case AIR_INST_CMP_LT:
return AIR_REF_FROM_IP(
(lhs_hi < rhs_hi || (lhs_hi == rhs_hi && lhs_lo < rhs_lo))
@@ -1305,42 +1380,24 @@ static AirInstRef zirArithmetic(
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
default:
goto emit_runtime;
break;
}
return internComptimeInt(sema, result_ty, r_lo, r_hi);
}
// Comptime equality for non-integer IP values (e.g. enum_tag,
// enum_literal). Coerce through peer types first so that enum_literal
// → enum coercion creates the enum_tag IP entry (matching the Zig
// compiler's analyzeCmp → resolvePeerTypes → coerce path).
if (AIR_REF_IS_IP(lhs) && AIR_REF_IS_IP(rhs)
&& (air_tag == AIR_INST_CMP_EQ || air_tag == AIR_INST_CMP_NEQ)) {
TypeIndex peer_ty = semaResolvePeerTypes(sema, lhs, rhs);
lhs = semaCoerce(sema, block, peer_ty, lhs);
rhs = semaCoerce(sema, block, peer_ty, rhs);
bool eq = (AIR_REF_TO_IP(lhs) == AIR_REF_TO_IP(rhs));
return AIR_REF_FROM_IP((eq == (air_tag == AIR_INST_CMP_EQ))
? IP_INDEX_BOOL_TRUE
: IP_INDEX_BOOL_FALSE);
}
emit_runtime:;
// Ported from Sema.zig cmpNumeric → compareIntsOnlyPossibleResult
// (line 32299-32325): when one operand is comptime-known and the
// other is a runtime integer, intern the type's min/max bounds.
// These entries are side effects of the bounds check; even if the
// comparison can't be folded, the boundary values must be interned
// to match the Zig compiler's IP entry sequence.
if (air_tag >= AIR_INST_CMP_LT && air_tag <= AIR_INST_CMP_GT) {
{
TypeIndex lhs_ty = semaTypeOf(sema, lhs);
TypeIndex rhs_ty = semaTypeOf(sema, rhs);
bool lhs_is_ct = (lhs_ty == IP_INDEX_COMPTIME_INT_TYPE);
bool rhs_is_ct = (rhs_ty == IP_INDEX_COMPTIME_INT_TYPE);
if ((lhs_is_ct != rhs_is_ct) && !lhs_is_ct
&& sema->ip->items[lhs_ty].tag == IP_KEY_INT_TYPE) {
// RHS is comptime, LHS is runtime int → intern LHS
// type bounds.
// RHS is comptime, LHS is runtime int → intern LHS type bounds.
uint16_t bits = sema->ip->items[lhs_ty].data.int_type.bits;
bool is_signed = sema->ip->items[lhs_ty].data.int_type.signedness;
uint64_t min_val = 0;
@@ -1362,8 +1419,7 @@ emit_runtime:;
(void)ipIntern(sema->ip, mk);
} else if ((lhs_is_ct != rhs_is_ct) && !rhs_is_ct
&& sema->ip->items[rhs_ty].tag == IP_KEY_INT_TYPE) {
// LHS is comptime, RHS is runtime int → intern RHS
// type bounds.
// LHS is comptime, RHS is runtime int → intern RHS type bounds.
uint16_t bits = sema->ip->items[rhs_ty].data.int_type.bits;
bool is_signed = sema->ip->items[rhs_ty].data.int_type.signedness;
uint64_t min_val = 0;
@@ -1396,6 +1452,16 @@ emit_runtime:;
return semaAddInst(block, air_tag, data);
}
// zirArithmetic: handle add/sub/mul/sat ZIR instructions.
// Ported from src/Sema.zig zirArithmetic.
static AirInstRef zirArithmetic(
Sema* sema, SemaBlock* block, uint32_t inst, AirInstTag air_tag) {
uint32_t payload_index = sema->code.inst_datas[inst].pl_node.payload_index;
AirInstRef lhs = resolveInst(sema, sema->code.extra[payload_index]);
AirInstRef rhs = resolveInst(sema, sema->code.extra[payload_index + 1]);
return analyzeArithmetic(sema, block, air_tag, lhs, rhs);
}
// zirDiv: handle div ZIR instruction (/ operator).
// Ported from src/Sema.zig zirDiv.
// For floats: emits AIR_INST_DIV_FLOAT (strict mode).
@@ -11224,35 +11290,35 @@ bool analyzeBodyInner(
i++;
continue;
// Comparisons: same binary pattern as arithmetic.
// Comparisons: zirCmp for ordered, zirCmpEq for equality.
case ZIR_INST_CMP_LT:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_LT));
zirCmp(sema, block, inst, AIR_INST_CMP_LT));
i++;
continue;
case ZIR_INST_CMP_LTE:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_LTE));
zirCmp(sema, block, inst, AIR_INST_CMP_LTE));
i++;
continue;
case ZIR_INST_CMP_EQ:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_EQ));
zirCmpEq(sema, block, inst, AIR_INST_CMP_EQ));
i++;
continue;
case ZIR_INST_CMP_GTE:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_GTE));
zirCmp(sema, block, inst, AIR_INST_CMP_GTE));
i++;
continue;
case ZIR_INST_CMP_GT:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_GT));
zirCmp(sema, block, inst, AIR_INST_CMP_GT));
i++;
continue;
case ZIR_INST_CMP_NEQ:
instMapPut(&sema->inst_map, inst,
zirArithmetic(sema, block, inst, AIR_INST_CMP_NEQ));
zirCmpEq(sema, block, inst, AIR_INST_CMP_NEQ));
i++;
continue;