diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 2f82f1f694..462393a150 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -215,8 +215,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode { 16 => switch (args.valtype1.?) { .i32 => if (args.signedness.? == .signed) return .i32_load16_s else return .i32_load16_u, .i64 => if (args.signedness.? == .signed) return .i64_load16_s else return .i64_load16_u, - .f32 => return .f32_load, - .f64 => unreachable, + .f32, .f64 => unreachable, }, 32 => switch (args.valtype1.?) { .i64 => if (args.signedness.? == .signed) return .i64_load32_s else return .i64_load32_u, @@ -246,8 +245,7 @@ fn buildOpcode(args: OpcodeBuildArguments) wasm.Opcode { 16 => switch (args.valtype1.?) { .i32 => return .i32_store16, .i64 => return .i64_store16, - .f32 => return .f32_store, - .f64 => unreachable, + .f32, .f64 => unreachable, }, 32 => switch (args.valtype1.?) { .i64 => return .i64_store32, @@ -725,7 +723,8 @@ fn typeToValtype(ty: Type, target: std.Target) wasm.Valtype { return switch (ty.zigTypeTag()) { .Float => blk: { const bits = ty.floatBits(target); - if (bits == 16 or bits == 32) break :blk wasm.Valtype.f32; + if (bits == 16) return wasm.Valtype.i32; // stored/loaded as u16 + if (bits == 32) break :blk wasm.Valtype.f32; if (bits == 64) break :blk wasm.Valtype.f64; if (bits == 128) break :blk wasm.Valtype.i64; return wasm.Valtype.i32; // represented as pointer to stack @@ -2013,6 +2012,10 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa } } + if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) { + return self.binOpFloat16(lhs, rhs, op); + } + const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = typeToValtype(ty, self.target), @@ -2029,6 +2032,20 @@ fn binOp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WVa return bin_local; } +fn binOpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: Op) InnerError!WValue { + const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32); + const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32); + + const opcode: wasm.Opcode = buildOpcode(.{ .op = op, .valtype1 = .f32, .signedness = .unsigned }); + try self.emitWValue(ext_lhs); + try self.emitWValue(ext_rhs); + try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + + // re-use temporary local + try self.addLabel(.local_set, ext_lhs.local); + return self.fptrunc(ext_lhs, Type.f32, Type.f16); +} + fn binOpBigInt(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: Op) InnerError!WValue { if (ty.intInfo(self.target).bits > 128) { return self.fail("TODO: Implement binary operation for big integer", .{}); @@ -2310,8 +2327,9 @@ fn lowerConstant(self: *Self, val: Value, ty: Type) InnerError!WValue { }, .Bool => return WValue{ .imm32 = @intCast(u32, val.toUnsignedInt(target)) }, .Float => switch (ty.floatBits(self.target)) { - 0...32 => return WValue{ .float32 = val.toFloat(f32) }, - 33...64 => return WValue{ .float64 = val.toFloat(f64) }, + 16 => return WValue{ .imm32 = @bitCast(u16, val.toFloat(f16)) }, + 32 => return WValue{ .float32 = val.toFloat(f32) }, + 64 => return WValue{ .float64 = val.toFloat(f64) }, else => unreachable, }, .Pointer => switch (val.tag()) { @@ -2389,8 +2407,9 @@ fn emitUndefined(self: *Self, ty: Type) InnerError!WValue { else => unreachable, }, .Float => switch (ty.floatBits(self.target)) { - 0...32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) }, - 33...64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) }, + 16 => return WValue{ .imm32 = 0xaaaaaaaa }, + 32 => return WValue{ .float32 = @bitCast(f32, @as(u32, 0xaaaaaaaa)) }, + 64 => return WValue{ .float64 = @bitCast(f64, @as(u64, 0xaaaaaaaaaaaaaaaa)) }, else => unreachable, }, .Pointer => switch (self.arch()) { @@ -2562,6 +2581,8 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper } } else if (isByRef(ty, self.target)) { return self.cmpBigInt(lhs, rhs, ty, op); + } else if (ty.isAnyFloat() and ty.floatBits(self.target) == 16) { + return self.cmpFloat16(lhs, rhs, op); } // ensure that when we compare pointers, we emit @@ -2595,6 +2616,31 @@ fn cmp(self: *Self, lhs: WValue, rhs: WValue, ty: Type, op: std.math.CompareOper return cmp_tmp; } +fn cmpFloat16(self: *Self, lhs: WValue, rhs: WValue, op: std.math.CompareOperator) InnerError!WValue { + const ext_lhs = try self.fpext(lhs, Type.f16, Type.f32); + const ext_rhs = try self.fpext(rhs, Type.f16, Type.f32); + + const opcode: wasm.Opcode = buildOpcode(.{ + .op = switch (op) { + .lt => .lt, + .lte => .le, + .eq => .eq, + .neq => .ne, + .gte => .ge, + .gt => .gt, + }, + .valtype1 = .f32, + .signedness = .unsigned, + }); + try self.emitWValue(ext_lhs); + try self.emitWValue(ext_rhs); + try self.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + + const result = try self.allocLocal(Type.initTag(.i32)); // bool is always i32 + try self.addLabel(.local_set, result.local); + return result; +} + fn airCmpVector(self: *Self, inst: Air.Inst.Index) InnerError!WValue { _ = inst; return self.fail("TODO implement airCmpVector for wasm", .{}); @@ -3934,19 +3980,44 @@ fn airFpext(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const dest_ty = self.air.typeOfIndex(inst); - const dest_bits = dest_ty.floatBits(self.target); - const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target); const operand = try self.resolveInst(ty_op.operand); - if (dest_bits == 64 and src_bits == 32) { - const result = try self.allocLocal(dest_ty); + return self.fpext(operand, self.air.typeOf(ty_op.operand), dest_ty); +} + +fn fpext(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue { + const given_bits = given.floatBits(self.target); + const wanted_bits = wanted.floatBits(self.target); + + if (wanted_bits == 64 and given_bits == 32) { + const result = try self.allocLocal(wanted); try self.emitWValue(operand); try self.addTag(.f64_promote_f32); try self.addLabel(.local_set, result.local); return result; + } else if (given_bits == 16) { + // call __extendhfsf2(f16) f32 + const f32_result = try self.callIntrinsic( + "__extendhfsf2", + &.{Type.f16}, + Type.f32, + &.{operand}, + ); + + if (wanted_bits == 32) { + return f32_result; + } + if (wanted_bits == 64) { + const result = try self.allocLocal(wanted); + try self.emitWValue(f32_result); + try self.addTag(.f64_promote_f32); + try self.addLabel(.local_set, result.local); + return result; + } + return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits}); } else { // TODO: Emit a call to compiler-rt to extend the float. e.g. __extendhfsf2 - return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{dest_bits}); + return self.fail("TODO: Implement 'fpext' for floats with bitsize: {d}", .{wanted_bits}); } } @@ -3955,19 +4026,34 @@ fn airFptrunc(self: *Self, inst: Air.Inst.Index) InnerError!WValue { const ty_op = self.air.instructions.items(.data)[inst].ty_op; const dest_ty = self.air.typeOfIndex(inst); - const dest_bits = dest_ty.floatBits(self.target); - const src_bits = self.air.typeOf(ty_op.operand).floatBits(self.target); const operand = try self.resolveInst(ty_op.operand); + return self.fptrunc(operand, self.air.typeOf(ty_op.operand), dest_ty); +} - if (dest_bits == 32 and src_bits == 64) { - const result = try self.allocLocal(dest_ty); +fn fptrunc(self: *Self, operand: WValue, given: Type, wanted: Type) InnerError!WValue { + const given_bits = given.floatBits(self.target); + const wanted_bits = wanted.floatBits(self.target); + + if (wanted_bits == 32 and given_bits == 64) { + const result = try self.allocLocal(wanted); try self.emitWValue(operand); try self.addTag(.f32_demote_f64); try self.addLabel(.local_set, result.local); return result; + } else if (wanted_bits == 16) { + const op: WValue = if (given_bits == 64) blk: { + const tmp = try self.allocLocal(Type.f32); + try self.emitWValue(operand); + try self.addTag(.f32_demote_f64); + try self.addLabel(.local_set, tmp.local); + break :blk tmp; + } else operand; + + // call __truncsfhf2(f32) f16 + return self.callIntrinsic("__truncsfhf2", &.{Type.f32}, Type.f16, &.{op}); } else { // TODO: Emit a call to compiler-rt to trunc the float. e.g. __truncdfhf2 - return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{dest_bits}); + return self.fail("TODO: Implement 'fptrunc' for floats with bitsize: {d}", .{wanted_bits}); } }