Sema: implement vector coercions

These used to be lowered elementwise in air, and now are a single air
instruction that can be lowered elementwise in the backend if necessary.
This commit is contained in:
Jacob Young
2024-02-15 10:37:52 +01:00
parent 2fdc9e6ae8
commit 2fcb2f5975
11 changed files with 552 additions and 254 deletions

View File

@@ -6109,41 +6109,48 @@ fn airFloatCast(f: *Function, inst: Air.Inst.Index) !CValue {
const ty_op = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_op;
const inst_ty = f.typeOfIndex(inst);
const inst_scalar_ty = inst_ty.scalarType(mod);
const operand = try f.resolveInst(ty_op.operand);
try reap(f, inst, &.{ty_op.operand});
const operand_ty = f.typeOf(ty_op.operand);
const scalar_ty = operand_ty.scalarType(mod);
const target = f.object.dg.module.getTarget();
const operation = if (inst_ty.isRuntimeFloat() and operand_ty.isRuntimeFloat())
if (inst_ty.floatBits(target) < operand_ty.floatBits(target)) "trunc" else "extend"
else if (inst_ty.isInt(mod) and operand_ty.isRuntimeFloat())
if (inst_ty.isSignedInt(mod)) "fix" else "fixuns"
else if (inst_ty.isRuntimeFloat() and operand_ty.isInt(mod))
if (operand_ty.isSignedInt(mod)) "float" else "floatun"
const operation = if (inst_scalar_ty.isRuntimeFloat() and scalar_ty.isRuntimeFloat())
if (inst_scalar_ty.floatBits(target) < scalar_ty.floatBits(target)) "trunc" else "extend"
else if (inst_scalar_ty.isInt(mod) and scalar_ty.isRuntimeFloat())
if (inst_scalar_ty.isSignedInt(mod)) "fix" else "fixuns"
else if (inst_scalar_ty.isRuntimeFloat() and scalar_ty.isInt(mod))
if (scalar_ty.isSignedInt(mod)) "float" else "floatun"
else
unreachable;
const writer = f.object.writer();
const local = try f.allocLocal(inst, inst_ty);
const v = try Vectorize.start(f, inst, writer, operand_ty);
const a = try Assignment.start(f, writer, scalar_ty);
try f.writeCValue(writer, local, .Other);
try writer.writeAll(" = ");
if (inst_ty.isInt(mod) and operand_ty.isRuntimeFloat()) {
try v.elem(f, writer);
try a.assign(f, writer);
if (inst_scalar_ty.isInt(mod) and scalar_ty.isRuntimeFloat()) {
try writer.writeAll("zig_wrap_");
try f.object.dg.renderTypeForBuiltinFnName(writer, inst_ty);
try f.object.dg.renderTypeForBuiltinFnName(writer, inst_scalar_ty);
try writer.writeByte('(');
}
try writer.writeAll("zig_");
try writer.writeAll(operation);
try writer.writeAll(compilerRtAbbrev(operand_ty, mod));
try writer.writeAll(compilerRtAbbrev(inst_ty, mod));
try writer.writeAll(compilerRtAbbrev(scalar_ty, mod));
try writer.writeAll(compilerRtAbbrev(inst_scalar_ty, mod));
try writer.writeByte('(');
try f.writeCValue(writer, operand, .FunctionArgument);
try v.elem(f, writer);
try writer.writeByte(')');
if (inst_ty.isInt(mod) and operand_ty.isRuntimeFloat()) {
try f.object.dg.renderBuiltinInfo(writer, inst_ty, .bits);
if (inst_scalar_ty.isInt(mod) and scalar_ty.isRuntimeFloat()) {
try f.object.dg.renderBuiltinInfo(writer, inst_scalar_ty, .bits);
try writer.writeByte(')');
}
try writer.writeAll(";\n");
try a.end(f, writer);
try v.end(f, inst, writer);
return local;
}

View File

@@ -8648,8 +8648,6 @@ pub const FuncGen = struct {
const operand_ty = self.typeOf(ty_op.operand);
const dest_ty = self.typeOfIndex(inst);
const target = mod.getTarget();
const dest_bits = dest_ty.floatBits(target);
const src_bits = operand_ty.floatBits(target);
if (intrinsicsAllowed(dest_ty, target) and intrinsicsAllowed(operand_ty, target)) {
return self.wip.cast(.fpext, operand, try o.lowerType(dest_ty), "");
@@ -8657,11 +8655,19 @@ pub const FuncGen = struct {
const operand_llvm_ty = try o.lowerType(operand_ty);
const dest_llvm_ty = try o.lowerType(dest_ty);
const dest_bits = dest_ty.scalarType(mod).floatBits(target);
const src_bits = operand_ty.scalarType(mod).floatBits(target);
const fn_name = try o.builder.fmt("__extend{s}f{s}f2", .{
compilerRtFloatAbbrev(src_bits), compilerRtFloatAbbrev(dest_bits),
});
const libc_fn = try self.getLibcFunction(fn_name, &.{operand_llvm_ty}, dest_llvm_ty);
if (dest_ty.isVector(mod)) return self.buildElementwiseCall(
libc_fn,
&.{operand},
try o.builder.poisonValue(dest_llvm_ty),
dest_ty.vectorLen(mod),
);
return self.wip.call(
.normal,
.ccc,