stage1: add @sin @cos @exp @exp2 @ln @log2 @log10 @fabs @floor @ceil @trunc @round

and expand @sqrt

This revealed that the accuracy of ln is not as good as the current algorithm in
musl and glibc, and should be ported again.

v2: actually include tests
v3: fix reversal of in and out arguments on f128M_sqrt()
    add test for @sqrt on comptime_float
    do not include @nearbyInt() until it works on all targets.
This commit is contained in:
Shawn Landden
2019-06-21 16:18:59 -05:00
parent ebde2ff899
commit 71e014caec
11 changed files with 724 additions and 136 deletions

View File

@@ -991,8 +991,8 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionMarkErrRetTraceP
return IrInstructionIdMarkErrRetTracePtr;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionSqrt *) {
return IrInstructionIdSqrt;
static constexpr IrInstructionId ir_instruction_id(IrInstructionFloatOp *) {
return IrInstructionIdFloatOp;
}
static constexpr IrInstructionId ir_instruction_id(IrInstructionCheckRuntimeScope *) {
@@ -2312,6 +2312,59 @@ static IrInstruction *ir_build_overflow_op(IrBuilder *irb, Scope *scope, AstNode
return &instruction->base;
}
//TODO Powi, Pow, minnum, maxnum, maximum, minimum, copysign,
// lround, llround, lrint, llrint
// So far this is only non-complicated type functions.
const char *float_op_to_name(BuiltinFnId op, bool llvm_name) {
const bool b = llvm_name;
switch (op) {
case BuiltinFnIdSqrt:
return "sqrt";
case BuiltinFnIdSin:
return "sin";
case BuiltinFnIdCos:
return "cos";
case BuiltinFnIdExp:
return "exp";
case BuiltinFnIdExp2:
return "exp2";
case BuiltinFnIdLn:
return b ? "log" : "ln";
case BuiltinFnIdLog10:
return "log10";
case BuiltinFnIdLog2:
return "log2";
case BuiltinFnIdFabs:
return "fabs";
case BuiltinFnIdFloor:
return "floor";
case BuiltinFnIdCeil:
return "ceil";
case BuiltinFnIdTrunc:
return "trunc";
case BuiltinFnIdNearbyInt:
return b ? "nearbyint" : "nearbyInt";
case BuiltinFnIdRound:
return "round";
default:
zig_unreachable();
}
}
static IrInstruction *ir_build_float_op(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op1, BuiltinFnId op) {
IrInstructionFloatOp *instruction = ir_build_instruction<IrInstructionFloatOp>(irb, scope, source_node);
instruction->type = type;
instruction->op1 = op1;
instruction->op = op;
if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block);
ir_ref_instruction(op1, irb->current_basic_block);
return &instruction->base;
}
static IrInstruction *ir_build_mul_add(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *type_value, IrInstruction *op1, IrInstruction *op2, IrInstruction *op3) {
IrInstructionMulAdd *instruction = ir_build_instruction<IrInstructionMulAdd>(irb, scope, source_node);
@@ -3033,17 +3086,6 @@ static IrInstruction *ir_build_mark_err_ret_trace_ptr(IrBuilder *irb, Scope *sco
return &instruction->base;
}
static IrInstruction *ir_build_sqrt(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type, IrInstruction *op) {
IrInstructionSqrt *instruction = ir_build_instruction<IrInstructionSqrt>(irb, scope, source_node);
instruction->type = type;
instruction->op = op;
if (type != nullptr) ir_ref_instruction(type, irb->current_basic_block);
ir_ref_instruction(op, irb->current_basic_block);
return &instruction->base;
}
static IrInstruction *ir_build_has_decl(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *container, IrInstruction *name)
{
@@ -4400,6 +4442,19 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
return ir_lval_wrap(irb, scope, bin_op, lval);
}
case BuiltinFnIdSqrt:
case BuiltinFnIdSin:
case BuiltinFnIdCos:
case BuiltinFnIdExp:
case BuiltinFnIdExp2:
case BuiltinFnIdLn:
case BuiltinFnIdLog2:
case BuiltinFnIdLog10:
case BuiltinFnIdFabs:
case BuiltinFnIdFloor:
case BuiltinFnIdCeil:
case BuiltinFnIdTrunc:
case BuiltinFnIdNearbyInt:
case BuiltinFnIdRound:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
@@ -4411,7 +4466,7 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
if (arg1_value == irb->codegen->invalid_instruction)
return arg1_value;
IrInstruction *ir_sqrt = ir_build_sqrt(irb, scope, node, arg0_value, arg1_value);
IrInstruction *ir_sqrt = ir_build_float_op(irb, scope, node, arg0_value, arg1_value, builtin_fn->id);
return ir_lval_wrap(irb, scope, ir_sqrt, lval);
}
case BuiltinFnIdTruncate:
@@ -23214,70 +23269,248 @@ static IrInstruction *ir_analyze_instruction_mark_err_ret_trace_ptr(IrAnalyze *i
return result;
}
static IrInstruction *ir_analyze_instruction_sqrt(IrAnalyze *ira, IrInstructionSqrt *instruction) {
ZigType *float_type = ir_resolve_type(ira, instruction->type->child);
if (type_is_invalid(float_type))
static void ir_eval_float_op(IrAnalyze *ira, IrInstructionFloatOp *source_instr, ZigType *float_type,
ConstExprValue *op, ConstExprValue *out_val) {
assert(ira && source_instr && float_type && out_val && op);
assert(float_type->id == ZigTypeIdFloat ||
float_type->id == ZigTypeIdComptimeFloat);
BuiltinFnId fop = source_instr->op;
unsigned bits;
if (float_type->id == ZigTypeIdComptimeFloat) {
bits = 128;
} else if (float_type->id == ZigTypeIdFloat)
bits = float_type->data.floating.bit_count;
switch (bits) {
case 16: {
switch (fop) {
case BuiltinFnIdSqrt:
out_val->data.x_f16 = f16_sqrt(op->data.x_f16);
break;
case BuiltinFnIdSin:
case BuiltinFnIdCos:
case BuiltinFnIdExp:
case BuiltinFnIdExp2:
case BuiltinFnIdLn:
case BuiltinFnIdLog10:
case BuiltinFnIdLog2:
case BuiltinFnIdFabs:
case BuiltinFnIdFloor:
case BuiltinFnIdCeil:
case BuiltinFnIdTrunc:
case BuiltinFnIdNearbyInt:
case BuiltinFnIdRound:
zig_panic("unimplemented f16 builtin");
default:
zig_unreachable();
};
break;
};
case 32: {
switch (fop) {
case BuiltinFnIdSqrt:
out_val->data.x_f32 = sqrtf(op->data.x_f32);
break;
case BuiltinFnIdSin:
out_val->data.x_f32 = sinf(op->data.x_f32);
break;
case BuiltinFnIdCos:
out_val->data.x_f32 = cosf(op->data.x_f32);
break;
case BuiltinFnIdExp:
out_val->data.x_f32 = expf(op->data.x_f32);
break;
case BuiltinFnIdExp2:
out_val->data.x_f32 = exp2f(op->data.x_f32);
break;
case BuiltinFnIdLn:
out_val->data.x_f32 = logf(op->data.x_f32);
break;
case BuiltinFnIdLog10:
out_val->data.x_f32 = log10f(op->data.x_f32);
break;
case BuiltinFnIdLog2:
out_val->data.x_f32 = log2f(op->data.x_f32);
break;
case BuiltinFnIdFabs:
out_val->data.x_f32 = fabsf(op->data.x_f32);
break;
case BuiltinFnIdFloor:
out_val->data.x_f32 = floorf(op->data.x_f32);
break;
case BuiltinFnIdCeil:
out_val->data.x_f32 = ceilf(op->data.x_f32);
break;
case BuiltinFnIdTrunc:
out_val->data.x_f32 = truncf(op->data.x_f32);
break;
case BuiltinFnIdNearbyInt:
out_val->data.x_f32 = nearbyintf(op->data.x_f32);
break;
case BuiltinFnIdRound:
out_val->data.x_f32 = roundf(op->data.x_f32);
break;
default:
zig_unreachable();
};
break;
};
case 64: {
switch (fop) {
case BuiltinFnIdSqrt:
out_val->data.x_f64 = sqrt(op->data.x_f64);
break;
case BuiltinFnIdSin:
out_val->data.x_f64 = sin(op->data.x_f64);
break;
case BuiltinFnIdCos:
out_val->data.x_f64 = cos(op->data.x_f64);
break;
case BuiltinFnIdExp:
out_val->data.x_f64 = exp(op->data.x_f64);
break;
case BuiltinFnIdExp2:
out_val->data.x_f64 = exp2(op->data.x_f64);
break;
case BuiltinFnIdLn:
out_val->data.x_f64 = log(op->data.x_f64);
break;
case BuiltinFnIdLog10:
out_val->data.x_f64 = log10(op->data.x_f64);
break;
case BuiltinFnIdLog2:
out_val->data.x_f64 = log2(op->data.x_f64);
break;
case BuiltinFnIdFabs:
out_val->data.x_f64 = fabs(op->data.x_f64);
break;
case BuiltinFnIdFloor:
out_val->data.x_f64 = floor(op->data.x_f64);
break;
case BuiltinFnIdCeil:
out_val->data.x_f64 = ceil(op->data.x_f64);
break;
case BuiltinFnIdTrunc:
out_val->data.x_f64 = trunc(op->data.x_f64);
break;
case BuiltinFnIdNearbyInt:
out_val->data.x_f64 = nearbyint(op->data.x_f64);
break;
case BuiltinFnIdRound:
out_val->data.x_f64 = round(op->data.x_f64);
break;
default:
zig_unreachable();
}
break;
};
case 128: {
float128_t *out, *in;
if (float_type->id == ZigTypeIdComptimeFloat) {
out = &out_val->data.x_bigfloat.value;
in = &op->data.x_bigfloat.value;
} else {
out = &out_val->data.x_f128;
in = &op->data.x_f128;
}
switch (fop) {
case BuiltinFnIdSqrt:
f128M_sqrt(in, out);
break;
case BuiltinFnIdNearbyInt:
case BuiltinFnIdSin:
case BuiltinFnIdCos:
case BuiltinFnIdExp:
case BuiltinFnIdExp2:
case BuiltinFnIdLn:
case BuiltinFnIdLog10:
case BuiltinFnIdLog2:
case BuiltinFnIdFabs:
case BuiltinFnIdFloor:
case BuiltinFnIdCeil:
case BuiltinFnIdTrunc:
case BuiltinFnIdRound:
zig_panic("unimplemented f128 builtin");
default:
zig_unreachable();
}
break;
};
default:
zig_unreachable();
}
}
static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstructionFloatOp *instruction) {
IrInstruction *type = instruction->type->child;
if (type_is_invalid(type->value.type))
return ira->codegen->invalid_instruction;
ZigType *expr_type = ir_resolve_type(ira, type);
if (type_is_invalid(expr_type))
return ira->codegen->invalid_instruction;
IrInstruction *op = instruction->op->child;
if (type_is_invalid(op->value.type))
return ira->codegen->invalid_instruction;
bool ok_type = float_type->id == ZigTypeIdComptimeFloat || float_type->id == ZigTypeIdFloat;
if (!ok_type) {
ir_add_error(ira, instruction->type, buf_sprintf("@sqrt does not support type '%s'", buf_ptr(&float_type->name)));
// Only allow float types, and vectors of floats.
ZigType *float_type = (expr_type->id == ZigTypeIdVector) ? expr_type->data.vector.elem_type : expr_type;
if (float_type->id != ZigTypeIdFloat && float_type->id != ZigTypeIdComptimeFloat) {
ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name)));
return ira->codegen->invalid_instruction;
}
IrInstruction *casted_op = ir_implicit_cast(ira, op, float_type);
if (type_is_invalid(casted_op->value.type))
IrInstruction *op1 = instruction->op1->child;
if (type_is_invalid(op1->value.type))
return ira->codegen->invalid_instruction;
if (instr_is_comptime(casted_op)) {
ConstExprValue *val = ir_resolve_const(ira, casted_op, UndefBad);
if (!val)
IrInstruction *casted_op1 = ir_implicit_cast(ira, op1, float_type);
if (type_is_invalid(casted_op1->value.type))
return ira->codegen->invalid_instruction;
if (instr_is_comptime(casted_op1)) {
// Our comptime 16-bit and 128-bit support is quite limited.
if ((float_type->id == ZigTypeIdComptimeFloat ||
float_type->data.floating.bit_count == 16 ||
float_type->data.floating.bit_count == 128) &&
instruction->op != BuiltinFnIdSqrt) {
ir_add_error(ira, instruction->type, buf_sprintf("@%s does not support type '%s'", float_op_to_name(instruction->op, false), buf_ptr(&float_type->name)));
return ira->codegen->invalid_instruction;
IrInstruction *result = ir_const(ira, &instruction->base, float_type);
ConstExprValue *out_val = &result->value;
if (float_type->id == ZigTypeIdComptimeFloat) {
bigfloat_sqrt(&out_val->data.x_bigfloat, &val->data.x_bigfloat);
} else if (float_type->id == ZigTypeIdFloat) {
switch (float_type->data.floating.bit_count) {
case 16:
out_val->data.x_f16 = f16_sqrt(val->data.x_f16);
break;
case 32:
out_val->data.x_f32 = sqrtf(val->data.x_f32);
break;
case 64:
out_val->data.x_f64 = sqrt(val->data.x_f64);
break;
case 128:
f128M_sqrt(&val->data.x_f128, &out_val->data.x_f128);
break;
default:
zig_unreachable();
}
} else {
zig_unreachable();
}
ConstExprValue *op1_const = ir_resolve_const(ira, casted_op1, UndefBad);
if (!op1_const)
return ira->codegen->invalid_instruction;
IrInstruction *result = ir_const(ira, &instruction->base, expr_type);
ConstExprValue *out_val = &result->value;
if (expr_type->id == ZigTypeIdVector) {
expand_undef_array(ira->codegen, op1_const);
out_val->special = ConstValSpecialUndef;
expand_undef_array(ira->codegen, out_val);
size_t len = expr_type->data.vector.len;
for (size_t i = 0; i < len; i += 1) {
ConstExprValue *float_operand_op1 = &op1_const->data.x_array.data.s_none.elements[i];
ConstExprValue *float_out_val = &out_val->data.x_array.data.s_none.elements[i];
assert(float_operand_op1->type == float_type);
assert(float_out_val->type == float_type);
ir_eval_float_op(ira, instruction, float_type,
op1_const, float_out_val);
float_out_val->type = float_type;
}
out_val->type = expr_type;
out_val->special = ConstValSpecialStatic;
} else {
ir_eval_float_op(ira, instruction, float_type, op1_const, out_val);
}
return result;
}
ir_assert(float_type->id == ZigTypeIdFloat, &instruction->base);
if (float_type->data.floating.bit_count != 16 &&
float_type->data.floating.bit_count != 32 &&
float_type->data.floating.bit_count != 64) {
ir_add_error(ira, instruction->type, buf_sprintf("compiler TODO: add implementation of sqrt for '%s'", buf_ptr(&float_type->name)));
return ira->codegen->invalid_instruction;
}
IrInstruction *result = ir_build_sqrt(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, nullptr, casted_op);
result->value.type = float_type;
IrInstruction *result = ir_build_float_op(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, nullptr, casted_op1, instruction->op);
result->value.type = expr_type;
return result;
}
@@ -23762,8 +23995,8 @@ static IrInstruction *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructio
return ir_analyze_instruction_merge_err_ret_traces(ira, (IrInstructionMergeErrRetTraces *)instruction);
case IrInstructionIdMarkErrRetTracePtr:
return ir_analyze_instruction_mark_err_ret_trace_ptr(ira, (IrInstructionMarkErrRetTracePtr *)instruction);
case IrInstructionIdSqrt:
return ir_analyze_instruction_sqrt(ira, (IrInstructionSqrt *)instruction);
case IrInstructionIdFloatOp:
return ir_analyze_instruction_float_op(ira, (IrInstructionFloatOp *)instruction);
case IrInstructionIdMulAdd:
return ir_analyze_instruction_mul_add(ira, (IrInstructionMulAdd *)instruction);
case IrInstructionIdIntToErr:
@@ -24004,7 +24237,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdCoroFree:
case IrInstructionIdCoroPromise:
case IrInstructionIdPromiseResultType:
case IrInstructionIdSqrt:
case IrInstructionIdFloatOp:
case IrInstructionIdMulAdd:
case IrInstructionIdAtomicLoad:
case IrInstructionIdIntCast: