ir: Support shift left/right on vectors
This commit is contained in:
118
src/ir.cpp
118
src/ir.cpp
@@ -283,6 +283,8 @@ static IrInstGen *ir_analyze_union_init(IrAnalyze *ira, IrInst* source_instructi
|
||||
IrInstGen *result_loc);
|
||||
static IrInstGen *ir_analyze_struct_value_field_value(IrAnalyze *ira, IrInst* source_instr,
|
||||
IrInstGen *struct_operand, TypeStructField *field);
|
||||
static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right);
|
||||
static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right);
|
||||
|
||||
static void destroy_instruction_src(IrInstSrc *inst) {
|
||||
switch (inst->id) {
|
||||
@@ -16803,7 +16805,6 @@ static IrInstGen *ir_analyze_math_op(IrAnalyze *ira, IrInst* source_instr,
|
||||
ZigValue *scalar_op2_val = &op2_val->data.x_array.data.s_none.elements[i];
|
||||
ZigValue *scalar_out_val = &out_val->data.x_array.data.s_none.elements[i];
|
||||
assert(scalar_op1_val->type == scalar_type);
|
||||
assert(scalar_op2_val->type == scalar_type);
|
||||
assert(scalar_out_val->type == scalar_type);
|
||||
ErrorMsg *msg = ir_eval_math_op_scalar(ira, source_instr, scalar_type,
|
||||
scalar_op1_val, op_id, scalar_op2_val, scalar_out_val);
|
||||
@@ -16828,27 +16829,49 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||
if (type_is_invalid(op1->value->type))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
if (op1->value->type->id != ZigTypeIdInt && op1->value->type->id != ZigTypeIdComptimeInt) {
|
||||
ir_add_error(ira, &bin_op_instruction->op1->base,
|
||||
buf_sprintf("bit shifting operation expected integer type, found '%s'",
|
||||
buf_ptr(&op1->value->type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
IrInstGen *op2 = bin_op_instruction->op2->child;
|
||||
if (type_is_invalid(op2->value->type))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
if (op2->value->type->id != ZigTypeIdInt && op2->value->type->id != ZigTypeIdComptimeInt) {
|
||||
ZigType *op1_type = op1->value->type;
|
||||
ZigType *op2_type = op2->value->type;
|
||||
|
||||
if (op1_type->id == ZigTypeIdVector && op2_type->id != ZigTypeIdVector) {
|
||||
ir_add_error(ira, &bin_op_instruction->op1->base,
|
||||
buf_sprintf("bit shifting operation expected vector type, found '%s'",
|
||||
buf_ptr(&op2_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
if (op1_type->id != ZigTypeIdVector && op2_type->id == ZigTypeIdVector) {
|
||||
ir_add_error(ira, &bin_op_instruction->op1->base,
|
||||
buf_sprintf("bit shifting operation expected vector type, found '%s'",
|
||||
buf_ptr(&op1_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
ZigType *op1_scalar_type = (op1_type->id == ZigTypeIdVector) ?
|
||||
op1_type->data.vector.elem_type : op1_type;
|
||||
ZigType *op2_scalar_type = (op2_type->id == ZigTypeIdVector) ?
|
||||
op2_type->data.vector.elem_type : op2_type;
|
||||
|
||||
if (op1_scalar_type->id != ZigTypeIdInt && op1_scalar_type->id != ZigTypeIdComptimeInt) {
|
||||
ir_add_error(ira, &bin_op_instruction->op1->base,
|
||||
buf_sprintf("bit shifting operation expected integer type, found '%s'",
|
||||
buf_ptr(&op1_scalar_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
if (op2_scalar_type->id != ZigTypeIdInt && op2_scalar_type->id != ZigTypeIdComptimeInt) {
|
||||
ir_add_error(ira, &bin_op_instruction->op2->base,
|
||||
buf_sprintf("shift amount has to be an integer type, but found '%s'",
|
||||
buf_ptr(&op2->value->type->name)));
|
||||
buf_ptr(&op2_scalar_type->name)));
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
|
||||
IrInstGen *casted_op2;
|
||||
IrBinOp op_id = bin_op_instruction->op_id;
|
||||
if (op1->value->type->id == ZigTypeIdComptimeInt) {
|
||||
if (op1_scalar_type->id == ZigTypeIdComptimeInt) {
|
||||
// comptime_int has no finite bit width
|
||||
casted_op2 = op2;
|
||||
|
||||
@@ -16874,10 +16897,15 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
}
|
||||
} else {
|
||||
const unsigned bit_count = op1->value->type->data.integral.bit_count;
|
||||
const unsigned bit_count = op1_scalar_type->data.integral.bit_count;
|
||||
ZigType *shift_amt_type = get_smallest_unsigned_int_type(ira->codegen,
|
||||
bit_count > 0 ? bit_count - 1 : 0);
|
||||
|
||||
if (op1_type->id == ZigTypeIdVector) {
|
||||
shift_amt_type = get_vector_type(ira->codegen, op1_type->data.vector.len,
|
||||
shift_amt_type);
|
||||
}
|
||||
|
||||
casted_op2 = ir_implicit_cast(ira, op2, shift_amt_type);
|
||||
if (type_is_invalid(casted_op2->value->type))
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
@@ -16888,10 +16916,10 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||
if (op2_val == nullptr)
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
BigInt bit_count_value = {0};
|
||||
bigint_init_unsigned(&bit_count_value, bit_count);
|
||||
ZigValue bit_count_value;
|
||||
init_const_usize(ira->codegen, &bit_count_value, bit_count);
|
||||
|
||||
if (bigint_cmp(&op2_val->data.x_bigint, &bit_count_value) != CmpLT) {
|
||||
if (!value_cmp_numeric_val_all(op2_val, CmpLT, &bit_count_value)) {
|
||||
ErrorMsg* msg = ir_add_error(ira,
|
||||
&bin_op_instruction->base.base,
|
||||
buf_sprintf("RHS of shift is too large for LHS type"));
|
||||
@@ -16910,7 +16938,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||
if (op2_val == nullptr)
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
if (bigint_cmp_zero(&op2_val->data.x_bigint) == CmpEQ)
|
||||
if (value_cmp_numeric_val_all(op2_val, CmpEQ, nullptr))
|
||||
return ir_analyze_cast(ira, &bin_op_instruction->base.base, op1->value->type, op1);
|
||||
}
|
||||
|
||||
@@ -16923,7 +16951,7 @@ static IrInstGen *ir_analyze_bit_shift(IrAnalyze *ira, IrInstSrcBinOp *bin_op_in
|
||||
if (op2_val == nullptr)
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1->value->type, op1_val, op_id, op2_val);
|
||||
return ir_analyze_math_op(ira, &bin_op_instruction->base.base, op1_type, op1_val, op_id, op2_val);
|
||||
}
|
||||
|
||||
return ir_build_bin_op_gen(ira, &bin_op_instruction->base.base, op1->value->type,
|
||||
@@ -16991,31 +17019,53 @@ static bool is_pointer_arithmetic_allowed(ZigType *lhs_type, IrBinOp op) {
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
static bool value_cmp_zero_any(ZigValue *value, Cmp predicate) {
|
||||
assert(value->special == ConstValSpecialStatic);
|
||||
static bool value_cmp_numeric_val(ZigValue *left, Cmp predicate, ZigValue *right, bool any) {
|
||||
assert(left->special == ConstValSpecialStatic);
|
||||
assert(right == nullptr || right->special == ConstValSpecialStatic);
|
||||
|
||||
switch (value->type->id) {
|
||||
switch (left->type->id) {
|
||||
case ZigTypeIdComptimeInt:
|
||||
case ZigTypeIdInt:
|
||||
return bigint_cmp_zero(&value->data.x_bigint) == predicate;
|
||||
case ZigTypeIdInt: {
|
||||
const Cmp result = right ?
|
||||
bigint_cmp(&left->data.x_bigint, &right->data.x_bigint) :
|
||||
bigint_cmp_zero(&left->data.x_bigint);
|
||||
return result == predicate;
|
||||
}
|
||||
case ZigTypeIdComptimeFloat:
|
||||
case ZigTypeIdFloat:
|
||||
if (float_is_nan(value))
|
||||
case ZigTypeIdFloat: {
|
||||
if (float_is_nan(left))
|
||||
return false;
|
||||
return float_cmp_zero(value) == predicate;
|
||||
if (right != nullptr && float_is_nan(right))
|
||||
return false;
|
||||
|
||||
const Cmp result = right ? float_cmp(left, right) : float_cmp_zero(left);
|
||||
return result == predicate;
|
||||
}
|
||||
case ZigTypeIdVector: {
|
||||
for (size_t i = 0; i < value->type->data.vector.len; i++) {
|
||||
ZigValue *scalar_val = &value->data.x_array.data.s_none.elements[i];
|
||||
if (!value_cmp_zero_any(scalar_val, predicate))
|
||||
return true;
|
||||
for (size_t i = 0; i < left->type->data.vector.len; i++) {
|
||||
ZigValue *scalar_val = &left->data.x_array.data.s_none.elements[i];
|
||||
const bool result = value_cmp_numeric_val(scalar_val, predicate, right, any);
|
||||
|
||||
if (any && result)
|
||||
return true; // This element satisfies the predicate
|
||||
else if (!any && !result)
|
||||
return false; // This element doesn't satisfy the predicate
|
||||
}
|
||||
return false;
|
||||
return any ? false : true;
|
||||
}
|
||||
default:
|
||||
zig_unreachable();
|
||||
}
|
||||
}
|
||||
|
||||
static bool value_cmp_numeric_val_any(ZigValue *left, Cmp predicate, ZigValue *right) {
|
||||
return value_cmp_numeric_val(left, predicate, right, true);
|
||||
}
|
||||
|
||||
static bool value_cmp_numeric_val_all(ZigValue *left, Cmp predicate, ZigValue *right) {
|
||||
return value_cmp_numeric_val(left, predicate, right, false);
|
||||
}
|
||||
|
||||
static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruction) {
|
||||
Error err;
|
||||
|
||||
@@ -17165,8 +17215,8 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
|
||||
return ira->codegen->invalid_inst_gen;
|
||||
|
||||
// Promote division with negative numbers to signed
|
||||
bool is_signed_div = value_cmp_zero_any(op1_val, CmpLT) ||
|
||||
value_cmp_zero_any(op2_val, CmpLT);
|
||||
bool is_signed_div = value_cmp_numeric_val_any(op1_val, CmpLT, nullptr) ||
|
||||
value_cmp_numeric_val_any(op2_val, CmpLT, nullptr);
|
||||
|
||||
if (op_id == IrBinOpDivUnspecified && is_int) {
|
||||
// Default to truncating division and check if it's valid for the
|
||||
@@ -17176,7 +17226,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
|
||||
if (is_signed_div) {
|
||||
bool ok = false;
|
||||
|
||||
if (value_cmp_zero_any(op2_val, CmpEQ)) {
|
||||
if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) {
|
||||
// the division by zero error will be caught later, but we don't have a
|
||||
// division function ambiguity problem.
|
||||
ok = true;
|
||||
@@ -17215,7 +17265,7 @@ static IrInstGen *ir_analyze_bin_op_math(IrAnalyze *ira, IrInstSrcBinOp *instruc
|
||||
if (is_signed_div) {
|
||||
bool ok = false;
|
||||
|
||||
if (value_cmp_zero_any(op2_val, CmpEQ)) {
|
||||
if (value_cmp_numeric_val_any(op2_val, CmpEQ, nullptr)) {
|
||||
// the division by zero error will be caught later, but we don't have a
|
||||
// division function ambiguity problem.
|
||||
ok = true;
|
||||
|
||||
Reference in New Issue
Block a user