stage1: Implement @reduce builtin for vector types
The builtin folds a Vector(N,T) into a scalar T using a specified operator. Closes #2698
This commit is contained in:
@@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
enum class ScalarizePredicate {
|
||||
// Returns true iff all the elements in the vector are 1.
|
||||
// Equivalent to folding all the bits with `and`.
|
||||
All,
|
||||
// Returns true iff there's at least one element in the vector that is 1.
|
||||
// Equivalent to folding all the bits with `or`.
|
||||
Any,
|
||||
};
|
||||
|
||||
// Collapses a <N x i1> vector into a single i1 according to the given predicate
|
||||
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
|
||||
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
|
||||
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
|
||||
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");
|
||||
|
||||
switch (predicate) {
|
||||
case ScalarizePredicate::Any: {
|
||||
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
|
||||
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
|
||||
}
|
||||
case ScalarizePredicate::All: {
|
||||
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
|
||||
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
|
||||
}
|
||||
}
|
||||
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
|
||||
static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMValueRef val1, LLVMValueRef val2)
|
||||
{
|
||||
@@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
}
|
||||
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
|
||||
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
|
||||
}
|
||||
|
||||
LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
|
||||
@@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
|
||||
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
|
||||
overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);
|
||||
|
||||
@@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
|
||||
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
|
||||
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
|
||||
ltz = ZigLLVMBuildOrReduce(g->builder, ltz);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);
|
||||
|
||||
@@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
|
||||
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
|
||||
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);
|
||||
|
||||
@@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
|
||||
}
|
||||
|
||||
if (operand_type->id == ZigTypeIdVector) {
|
||||
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
|
||||
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
|
||||
}
|
||||
|
||||
LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
|
||||
@@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
|
||||
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
|
||||
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
|
||||
if (rhs_type->id == ZigTypeIdVector) {
|
||||
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
|
||||
less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit);
|
||||
}
|
||||
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);
|
||||
|
||||
@@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
|
||||
return result_loc;
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) {
|
||||
LLVMValueRef value = ir_llvm_value(g, instruction->value);
|
||||
|
||||
ZigType *value_type = instruction->value->value->type;
|
||||
assert(value_type->id == ZigTypeIdVector);
|
||||
ZigType *scalar_type = value_type->data.vector.elem_type;
|
||||
|
||||
LLVMValueRef result_val;
|
||||
switch (instruction->op) {
|
||||
case ReduceOp_and:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildAndReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_or:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildOrReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_xor:
|
||||
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
|
||||
result_val = ZigLLVMBuildXorReduce(g->builder, value);
|
||||
break;
|
||||
case ReduceOp_min: {
|
||||
if (scalar_type->id == ZigTypeIdInt) {
|
||||
const bool is_signed = scalar_type->data.integral.is_signed;
|
||||
result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed);
|
||||
} else if (scalar_type->id == ZigTypeIdFloat) {
|
||||
result_val = ZigLLVMBuildFPMinReduce(g->builder, value);
|
||||
} else zig_unreachable();
|
||||
} break;
|
||||
case ReduceOp_max: {
|
||||
if (scalar_type->id == ZigTypeIdInt) {
|
||||
const bool is_signed = scalar_type->data.integral.is_signed;
|
||||
result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed);
|
||||
} else if (scalar_type->id == ZigTypeIdFloat) {
|
||||
result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
|
||||
} else zig_unreachable();
|
||||
} break;
|
||||
default:
|
||||
zig_unreachable();
|
||||
}
|
||||
|
||||
return result_val;
|
||||
}
|
||||
|
||||
static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) {
|
||||
LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order);
|
||||
LLVMBuildFence(g->builder, atomic_order, false, "");
|
||||
@@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl
|
||||
return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction);
|
||||
case IrInstGenIdFence:
|
||||
return ir_render_fence(g, executable, (IrInstGenFence *)instruction);
|
||||
case IrInstGenIdReduce:
|
||||
return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction);
|
||||
case IrInstGenIdTruncate:
|
||||
return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction);
|
||||
case IrInstGenIdBoolNot:
|
||||
@@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) {
|
||||
create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1);
|
||||
create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2);
|
||||
create_builtin_fn(g, BuiltinFnIdSrc, "src", 0);
|
||||
create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2);
|
||||
}
|
||||
|
||||
static const char *bool_to_str(bool b) {
|
||||
|
||||
Reference in New Issue
Block a user