zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

fma.zig (11575B) - Raw


      1 //! Ported from musl, which is MIT licensed:
      2 //! https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
      3 //!
      4 //! https://git.musl-libc.org/cgit/musl/tree/src/math/fmal.c
      5 //! https://git.musl-libc.org/cgit/musl/tree/src/math/fmaf.c
      6 //! https://git.musl-libc.org/cgit/musl/tree/src/math/fma.c
      7 
      8 const std = @import("std");
      9 const math = std.math;
     10 const expect = std.testing.expect;
     11 const common = @import("common.zig");
     12 
     13 pub const panic = common.panic;
     14 
     15 comptime {
     16     @export(&__fmah, .{ .name = "__fmah", .linkage = common.linkage, .visibility = common.visibility });
     17     @export(&fmaf, .{ .name = "fmaf", .linkage = common.linkage, .visibility = common.visibility });
     18     @export(&fma, .{ .name = "fma", .linkage = common.linkage, .visibility = common.visibility });
     19     @export(&__fmax, .{ .name = "__fmax", .linkage = common.linkage, .visibility = common.visibility });
     20     if (common.want_ppc_abi) {
     21         @export(&fmaq, .{ .name = "fmaf128", .linkage = common.linkage, .visibility = common.visibility });
     22     }
     23     @export(&fmaq, .{ .name = "fmaq", .linkage = common.linkage, .visibility = common.visibility });
     24     @export(&fmal, .{ .name = "fmal", .linkage = common.linkage, .visibility = common.visibility });
     25 }
     26 
     27 pub fn __fmah(x: f16, y: f16, z: f16) callconv(.c) f16 {
     28     // TODO: more efficient implementation
     29     return @floatCast(fmaf(x, y, z));
     30 }
     31 
     32 pub fn fmaf(x: f32, y: f32, z: f32) callconv(.c) f32 {
     33     const xy = @as(f64, x) * y;
     34     const xy_z = xy + z;
     35     const u = @as(u64, @bitCast(xy_z));
     36     const e = (u >> 52) & 0x7FF;
     37 
     38     if ((u & 0x1FFFFFFF) != 0x10000000 or e == 0x7FF or (xy_z - xy == z and xy_z - z == xy)) {
     39         return @floatCast(xy_z);
     40     } else {
     41         // TODO: Handle inexact case with double-rounding
     42         return @floatCast(xy_z);
     43     }
     44 }
     45 
     46 /// NOTE: Upstream fma.c has been rewritten completely to raise fp exceptions more accurately.
     47 pub fn fma(x: f64, y: f64, z: f64) callconv(.c) f64 {
     48     if (!math.isFinite(x) or !math.isFinite(y)) {
     49         return x * y + z;
     50     }
     51     if (!math.isFinite(z)) {
     52         return z;
     53     }
     54     if (x == 0.0 or y == 0.0) {
     55         return x * y + z;
     56     }
     57     if (z == 0.0) {
     58         return x * y;
     59     }
     60 
     61     const x1 = math.frexp(x);
     62     const ex = x1.exponent;
     63     const xs = x1.significand;
     64     const x2 = math.frexp(y);
     65     const ey = x2.exponent;
     66     const ys = x2.significand;
     67     const x3 = math.frexp(z);
     68     const ez = x3.exponent;
     69     var zs = x3.significand;
     70 
     71     var spread = ex + ey - ez;
     72     if (spread <= 53 * 2) {
     73         zs = math.scalbn(zs, -spread);
     74     } else {
     75         zs = math.copysign(math.floatMin(f64), zs);
     76     }
     77 
     78     const xy = dd_mul(xs, ys);
     79     const r = dd_add(xy.hi, zs);
     80     spread = ex + ey;
     81 
     82     if (r.hi == 0.0) {
     83         return xy.hi + zs + math.scalbn(xy.lo, spread);
     84     }
     85 
     86     const adj = add_adjusted(r.lo, xy.lo);
     87     if (spread + math.ilogb(r.hi) > -1023) {
     88         return math.scalbn(r.hi + adj, spread);
     89     } else {
     90         return add_and_denorm(r.hi, adj, spread);
     91     }
     92 }
     93 
     94 pub fn __fmax(a: f80, b: f80, c: f80) callconv(.c) f80 {
     95     // TODO: more efficient implementation
     96     return @floatCast(fmaq(a, b, c));
     97 }
     98 
     99 /// Fused multiply-add: Compute x * y + z with a single rounding error.
    100 ///
    101 /// We use scaling to avoid overflow/underflow, along with the
    102 /// canonical precision-doubling technique adapted from:
    103 ///
    104 ///      Dekker, T.  A Floating-Point Technique for Extending the
    105 ///      Available Precision.  Numer. Math. 18, 224-242 (1971).
    106 pub fn fmaq(x: f128, y: f128, z: f128) callconv(.c) f128 {
    107     if (!math.isFinite(x) or !math.isFinite(y)) {
    108         return x * y + z;
    109     }
    110     if (!math.isFinite(z)) {
    111         return z;
    112     }
    113     if (x == 0.0 or y == 0.0) {
    114         return x * y + z;
    115     }
    116     if (z == 0.0) {
    117         return x * y;
    118     }
    119 
    120     const x1 = math.frexp(x);
    121     const ex = x1.exponent;
    122     const xs = x1.significand;
    123     const x2 = math.frexp(y);
    124     const ey = x2.exponent;
    125     const ys = x2.significand;
    126     const x3 = math.frexp(z);
    127     const ez = x3.exponent;
    128     var zs = x3.significand;
    129 
    130     var spread = ex + ey - ez;
    131     if (spread <= 113 * 2) {
    132         zs = math.scalbn(zs, -spread);
    133     } else {
    134         zs = math.copysign(math.floatMin(f128), zs);
    135     }
    136 
    137     const xy = dd_mul128(xs, ys);
    138     const r = dd_add128(xy.hi, zs);
    139     spread = ex + ey;
    140 
    141     if (r.hi == 0.0) {
    142         return xy.hi + zs + math.scalbn(xy.lo, spread);
    143     }
    144 
    145     const adj = add_adjusted128(r.lo, xy.lo);
    146     if (spread + math.ilogb(r.hi) > -16383) {
    147         return math.scalbn(r.hi + adj, spread);
    148     } else {
    149         return add_and_denorm128(r.hi, adj, spread);
    150     }
    151 }
    152 
    153 pub fn fmal(x: c_longdouble, y: c_longdouble, z: c_longdouble) callconv(.c) c_longdouble {
    154     switch (@typeInfo(c_longdouble).float.bits) {
    155         16 => return __fmah(x, y, z),
    156         32 => return fmaf(x, y, z),
    157         64 => return fma(x, y, z),
    158         80 => return __fmax(x, y, z),
    159         128 => return fmaq(x, y, z),
    160         else => @compileError("unreachable"),
    161     }
    162 }
    163 
    164 const dd = struct {
    165     hi: f64,
    166     lo: f64,
    167 };
    168 
    169 fn dd_add(a: f64, b: f64) dd {
    170     var ret: dd = undefined;
    171     ret.hi = a + b;
    172     const s = ret.hi - a;
    173     ret.lo = (a - (ret.hi - s)) + (b - s);
    174     return ret;
    175 }
    176 
    177 fn dd_mul(a: f64, b: f64) dd {
    178     var ret: dd = undefined;
    179     const split: f64 = 0x1.0p27 + 1.0;
    180 
    181     var p = a * split;
    182     var ha = a - p;
    183     ha += p;
    184     const la = a - ha;
    185 
    186     p = b * split;
    187     var hb = b - p;
    188     hb += p;
    189     const lb = b - hb;
    190 
    191     p = ha * hb;
    192     const q = ha * lb + la * hb;
    193 
    194     ret.hi = p + q;
    195     ret.lo = p - ret.hi + q + la * lb;
    196     return ret;
    197 }
    198 
    199 fn add_adjusted(a: f64, b: f64) f64 {
    200     var sum = dd_add(a, b);
    201     if (sum.lo != 0) {
    202         var uhii: u64 = @bitCast(sum.hi);
    203         if (uhii & 1 == 0) {
    204             // hibits += copysign(1.0, sum.hi, sum.lo)
    205             const uloi: u64 = @bitCast(sum.lo);
    206             uhii += 1 - ((uhii ^ uloi) >> 62);
    207             sum.hi = @bitCast(uhii);
    208         }
    209     }
    210     return sum.hi;
    211 }
    212 
    213 fn add_and_denorm(a: f64, b: f64, scale: i32) f64 {
    214     var sum = dd_add(a, b);
    215     if (sum.lo != 0) {
    216         var uhii: u64 = @bitCast(sum.hi);
    217         const bits_lost = -@as(i32, @intCast((uhii >> 52) & 0x7FF)) - scale + 1;
    218         if ((bits_lost != 1) == (uhii & 1 != 0)) {
    219             const uloi: u64 = @bitCast(sum.lo);
    220             uhii += 1 - (((uhii ^ uloi) >> 62) & 2);
    221             sum.hi = @bitCast(uhii);
    222         }
    223     }
    224     return math.scalbn(sum.hi, scale);
    225 }
    226 
    227 /// A struct that represents a floating-point number with twice the precision
    228 /// of f128.  We maintain the invariant that "hi" stores the high-order
    229 /// bits of the result.
    230 const dd128 = struct {
    231     hi: f128,
    232     lo: f128,
    233 };
    234 
    235 /// Compute a+b exactly, returning the exact result in a struct dd.  We assume
    236 /// that both a and b are finite, but make no assumptions about their relative
    237 /// magnitudes.
    238 fn dd_add128(a: f128, b: f128) dd128 {
    239     var ret: dd128 = undefined;
    240     ret.hi = a + b;
    241     const s = ret.hi - a;
    242     ret.lo = (a - (ret.hi - s)) + (b - s);
    243     return ret;
    244 }
    245 
    246 /// Compute a+b, with a small tweak:  The least significant bit of the
    247 /// result is adjusted into a sticky bit summarizing all the bits that
    248 /// were lost to rounding.  This adjustment negates the effects of double
    249 /// rounding when the result is added to another number with a higher
    250 /// exponent.  For an explanation of round and sticky bits, see any reference
    251 /// on FPU design, e.g.,
    252 ///
    253 ///     J. Coonen.  An Implementation Guide to a Proposed Standard for
    254 ///     Floating-Point Arithmetic.  Computer, vol. 13, no. 1, Jan 1980.
    255 fn add_adjusted128(a: f128, b: f128) f128 {
    256     var sum = dd_add128(a, b);
    257     if (sum.lo != 0) {
    258         var uhii: u128 = @bitCast(sum.hi);
    259         if (uhii & 1 == 0) {
    260             // hibits += copysign(1.0, sum.hi, sum.lo)
    261             const uloi: u128 = @bitCast(sum.lo);
    262             uhii += 1 - ((uhii ^ uloi) >> 126);
    263             sum.hi = @bitCast(uhii);
    264         }
    265     }
    266     return sum.hi;
    267 }
    268 
    269 /// Compute ldexp(a+b, scale) with a single rounding error. It is assumed
    270 /// that the result will be subnormal, and care is taken to ensure that
    271 /// double rounding does not occur.
    272 fn add_and_denorm128(a: f128, b: f128, scale: i32) f128 {
    273     var sum = dd_add128(a, b);
    274     // If we are losing at least two bits of accuracy to denormalization,
    275     // then the first lost bit becomes a round bit, and we adjust the
    276     // lowest bit of sum.hi to make it a sticky bit summarizing all the
    277     // bits in sum.lo. With the sticky bit adjusted, the hardware will
    278     // break any ties in the correct direction.
    279     //
    280     // If we are losing only one bit to denormalization, however, we must
    281     // break the ties manually.
    282     if (sum.lo != 0) {
    283         var uhii: u128 = @bitCast(sum.hi);
    284         const bits_lost = -@as(i32, @intCast((uhii >> 112) & 0x7FFF)) - scale + 1;
    285         if ((bits_lost != 1) == (uhii & 1 != 0)) {
    286             const uloi: u128 = @bitCast(sum.lo);
    287             uhii += 1 - (((uhii ^ uloi) >> 126) & 2);
    288             sum.hi = @bitCast(uhii);
    289         }
    290     }
    291     return math.scalbn(sum.hi, scale);
    292 }
    293 
    294 /// Compute a*b exactly, returning the exact result in a struct dd.  We assume
    295 /// that both a and b are normalized, so no underflow or overflow will occur.
    296 /// The current rounding mode must be round-to-nearest.
    297 fn dd_mul128(a: f128, b: f128) dd128 {
    298     var ret: dd128 = undefined;
    299     const split: f128 = 0x1.0p57 + 1.0;
    300 
    301     var p = a * split;
    302     var ha = a - p;
    303     ha += p;
    304     const la = a - ha;
    305 
    306     p = b * split;
    307     var hb = b - p;
    308     hb += p;
    309     const lb = b - hb;
    310 
    311     p = ha * hb;
    312     const q = ha * lb + la * hb;
    313 
    314     ret.hi = p + q;
    315     ret.lo = p - ret.hi + q + la * lb;
    316     return ret;
    317 }
    318 
    319 test "32" {
    320     const epsilon = 0.000001;
    321 
    322     try expect(math.approxEqAbs(f32, fmaf(0.0, 5.0, 9.124), 9.124, epsilon));
    323     try expect(math.approxEqAbs(f32, fmaf(0.2, 5.0, 9.124), 10.124, epsilon));
    324     try expect(math.approxEqAbs(f32, fmaf(0.8923, 5.0, 9.124), 13.5855, epsilon));
    325     try expect(math.approxEqAbs(f32, fmaf(1.5, 5.0, 9.124), 16.624, epsilon));
    326     try expect(math.approxEqAbs(f32, fmaf(37.45, 5.0, 9.124), 196.374004, epsilon));
    327     try expect(math.approxEqAbs(f32, fmaf(89.123, 5.0, 9.124), 454.739005, epsilon));
    328     try expect(math.approxEqAbs(f32, fmaf(123123.234375, 5.0, 9.124), 615625.295875, epsilon));
    329 }
    330 
    331 test "64" {
    332     const epsilon = 0.000001;
    333 
    334     try expect(math.approxEqAbs(f64, fma(0.0, 5.0, 9.124), 9.124, epsilon));
    335     try expect(math.approxEqAbs(f64, fma(0.2, 5.0, 9.124), 10.124, epsilon));
    336     try expect(math.approxEqAbs(f64, fma(0.8923, 5.0, 9.124), 13.5855, epsilon));
    337     try expect(math.approxEqAbs(f64, fma(1.5, 5.0, 9.124), 16.624, epsilon));
    338     try expect(math.approxEqAbs(f64, fma(37.45, 5.0, 9.124), 196.374, epsilon));
    339     try expect(math.approxEqAbs(f64, fma(89.123, 5.0, 9.124), 454.739, epsilon));
    340     try expect(math.approxEqAbs(f64, fma(123123.234375, 5.0, 9.124), 615625.295875, epsilon));
    341 }
    342 
    343 test "128" {
    344     const epsilon = 0.000001;
    345 
    346     try expect(math.approxEqAbs(f128, fmaq(0.0, 5.0, 9.124), 9.124, epsilon));
    347     try expect(math.approxEqAbs(f128, fmaq(0.2, 5.0, 9.124), 10.124, epsilon));
    348     try expect(math.approxEqAbs(f128, fmaq(0.8923, 5.0, 9.124), 13.5855, epsilon));
    349     try expect(math.approxEqAbs(f128, fmaq(1.5, 5.0, 9.124), 16.624, epsilon));
    350     try expect(math.approxEqAbs(f128, fmaq(37.45, 5.0, 9.124), 196.374, epsilon));
    351     try expect(math.approxEqAbs(f128, fmaq(89.123, 5.0, 9.124), 454.739, epsilon));
    352     try expect(math.approxEqAbs(f128, fmaq(123123.234375, 5.0, 9.124), 615625.295875, epsilon));
    353 }