zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

amxfp8intrin.h (8839B) - Raw


      1 /*===------------- amxfp8intrin.h - AMX intrinsics -*- C++ -*----------------===
      2  *
      3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4  * See https://llvm.org/LICENSE.txt for license information.
      5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6  *
      7  *===------------------------------------------------------------------------===
      8  */
      9 
     10 #ifndef __IMMINTRIN_H
     11 #error "Never use <amxfp8intrin.h> directly; include <immintrin.h> instead."
     12 #endif /* __IMMINTRIN_H */
     13 
     14 #ifndef __AMXFP8INTRIN_H
     15 #define __AMXFP8INTRIN_H
     16 #ifdef __x86_64__
     17 
     18 #define __DEFAULT_FN_ATTRS_FP8                                                 \
     19   __attribute__((__always_inline__, __nodebug__, __target__("amx-fp8")))
     20 
     21 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
     22 _tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
     23                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
     24   return __builtin_ia32_tdpbf8ps_internal(m, n, k, dst, src1, src2);
     25 }
     26 
     27 /// Perform the dot product of a BF8 value \a src1 by a BF8 value \a src2
     28 /// accumulating into a Single Precision (FP32) source/dest \a dst.
     29 ///
     30 /// \headerfile <immintrin.h>
     31 ///
     32 /// \code
     33 /// void __tile_dpbf8ps (__tile1024i *dst, __tile1024i src1, __tile1024i src2)
     34 /// \endcode
     35 ///
     36 /// \code{.operation}
     37 /// FOR m := 0 TO dst.rows - 1
     38 ///   temp1[(dst.colsb / 4 - 1) : 0] = 0
     39 ///   FOR k := 0 TO src1.colsb / 4 - 1
     40 ///     FOR n := 0 TO dst.colsb / 4 - 1
     41 ///       temp1[n] +=
     42 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
     43 ///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
     44 ///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
     45 ///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
     46 ///     ENDFOR
     47 ///   ENDFOR
     48 ///   FOR n := 0 TO dst.colsb / 4 - 1
     49 ///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
     50 ///   ENDFOR
     51 /// write_row_and_zero(dst, m, tmp, dst.colsb)
     52 /// zero_upper_rows(dst, dst.rows)
     53 /// zero_tileconfig_start()
     54 /// \endcode
     55 ///
     56 /// This intrinsic corresponds to the \c TDPBF8PS instruction.
     57 ///
     58 /// \param dst
     59 ///    The destination tile. Max size is 1024 Bytes.
     60 /// \param src1
     61 ///    The 1st source tile. Max size is 1024 Bytes.
     62 /// \param src2
     63 ///    The 2nd source tile. Max size is 1024 Bytes.
     64 __DEFAULT_FN_ATTRS_FP8 static void
     65 __tile_dpbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
     66   dst->tile = _tile_dpbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
     67                                      src1.tile, src2.tile);
     68 }
     69 
     70 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
     71 _tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
     72                         _tile1024i dst, _tile1024i src1, _tile1024i src2) {
     73   return __builtin_ia32_tdpbhf8ps_internal(m, n, k, dst, src1, src2);
     74 }
     75 
     76 /// Perform the dot product of a BF8 value \a src1 by an HF8 value \a src2
     77 /// accumulating into a Single Precision (FP32) source/dest \a dst.
     78 ///
     79 /// \headerfile <immintrin.h>
     80 ///
     81 /// \code
     82 /// void __tile_dpbhf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
     83 /// \endcode
     84 ///
     85 /// \code{.operation}
     86 /// FOR m := 0 TO dst.rows - 1
     87 ///   temp1[(dst.colsb / 4 - 1) : 0] = 0
     88 ///   FOR k := 0 TO src1.colsb / 4 - 1
     89 ///     FOR n := 0 TO dst.colsb / 4 - 1
     90 ///       temp1[n] +=
     91 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
     92 ///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
     93 ///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
     94 ///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
     95 ///     ENDFOR
     96 ///   ENDFOR
     97 ///   FOR n := 0 TO dst.colsb / 4 - 1
     98 ///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
     99 ///   ENDFOR
    100 /// write_row_and_zero(dst, m, tmp, dst.colsb)
    101 /// zero_upper_rows(dst, dst.rows)
    102 /// zero_tileconfig_start()
    103 /// \endcode
    104 ///
    105 /// This intrinsic corresponds to the \c TDPBHF8PS instruction.
    106 ///
    107 /// \param dst
    108 ///    The destination tile. Max size is 1024 Bytes.
    109 /// \param src1
    110 ///    The 1st source tile. Max size is 1024 Bytes.
    111 /// \param src2
    112 ///    The 2nd source tile. Max size is 1024 Bytes.
    113 __DEFAULT_FN_ATTRS_FP8 static void
    114 __tile_dpbhf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
    115   dst->tile = _tile_dpbhf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
    116                                       src1.tile, src2.tile);
    117 }
    118 
    119 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
    120 _tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
    121                         _tile1024i dst, _tile1024i src1, _tile1024i src2) {
    122   return __builtin_ia32_tdphbf8ps_internal(m, n, k, dst, src1, src2);
    123 }
    124 
    125 /// Perform the dot product of an HF8 value \a src1 by a BF8 value \a src2
    126 /// accumulating into a Single Precision (FP32) source/dest \a dst.
    127 ///
    128 /// \headerfile <immintrin.h>
    129 ///
    130 /// \code
    131 /// void __tile_dphbf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
    132 /// \endcode
    133 ///
    134 /// \code{.operation}
    135 /// FOR m := 0 TO dst.rows - 1
    136 ///   temp1[(dst.colsb / 4 - 1) : 0] = 0
    137 ///   FOR k := 0 TO src1.colsb / 4 - 1
    138 ///     FOR n := 0 TO dst.colsb / 4 - 1
    139 ///       temp1[n] +=
    140 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
    141 ///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
    142 ///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
    143 ///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
    144 ///     ENDFOR
    145 ///   ENDFOR
    146 ///   FOR n := 0 TO dst.colsb / 4 - 1
    147 ///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
    148 ///   ENDFOR
    149 /// write_row_and_zero(dst, m, tmp, dst.colsb)
    150 /// zero_upper_rows(dst, dst.rows)
    151 /// zero_tileconfig_start()
    152 /// \endcode
    153 ///
    154 /// This intrinsic corresponds to the \c TDPHBF8PS instruction.
    155 ///
    156 /// \param dst
    157 ///    The destination tile. Max size is 1024 Bytes.
    158 /// \param src1
    159 ///    The 1st source tile. Max size is 1024 Bytes.
    160 /// \param src2
    161 ///    The 2nd source tile. Max size is 1024 Bytes.
    162 
    163 __DEFAULT_FN_ATTRS_FP8 static void
    164 __tile_dphbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
    165   dst->tile = _tile_dphbf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
    166                                       src1.tile, src2.tile);
    167 }
    168 
    169 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8
    170 _tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k,
    171                        _tile1024i dst, _tile1024i src1, _tile1024i src2) {
    172   return __builtin_ia32_tdphf8ps_internal(m, n, k, dst, src1, src2);
    173 }
    174 
    175 /// Perform the dot product of an HF8 value \a src1 by an HF8 value \a src2
    176 /// accumulating into a Single Precision (FP32) source/dest \a dst.
    177 ///
    178 /// \headerfile <immintrin.h>
    179 ///
    180 /// \code
    181 /// void __tile_dphf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2)
    182 /// \endcode
    183 ///
    184 /// \code{.operation}
    185 /// FOR m := 0 TO dst.rows - 1
    186 ///   temp1[(dst.colsb / 4 - 1) : 0] = 0
    187 ///   FOR k := 0 TO src1.colsb / 4 - 1
    188 ///     FOR n := 0 TO dst.colsb / 4 - 1
    189 ///       temp1[n] +=
    190 ///         INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0])
    191 ///         + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1])
    192 ///         + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2])
    193 ///         + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3])
    194 ///     ENDFOR
    195 ///   ENDFOR
    196 ///   FOR n := 0 TO dst.colsb / 4 - 1
    197 ///     tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n])
    198 ///   ENDFOR
    199 /// write_row_and_zero(dst, m, tmp, dst.colsb)
    200 /// zero_upper_rows(dst, dst.rows)
    201 /// zero_tileconfig_start()
    202 /// \endcode
    203 ///
    204 /// This intrinsic corresponds to the \c TDPHF8PS instruction.
    205 ///
    206 /// \param dst
    207 ///    The destination tile. Max size is 1024 Bytes.
    208 /// \param src1
    209 ///    The 1st source tile. Max size is 1024 Bytes.
    210 /// \param src2
    211 ///    The 2nd source tile. Max size is 1024 Bytes.
    212 __DEFAULT_FN_ATTRS_FP8 static void
    213 __tile_dphf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) {
    214   dst->tile = _tile_dphf8ps_internal(src1.row, src2.col, src1.col, dst->tile,
    215                                      src1.tile, src2.tile);
    216 }
    217 
    218 #define _tile_dpbf8ps(dst, src1, src2)                                         \
    219   __builtin_ia32_tdpbf8ps((dst), (src1), (src2))
    220 #define _tile_dpbhf8ps(dst, src1, src2)                                        \
    221   __builtin_ia32_tdpbhf8ps((dst), (src1), (src2))
    222 #define _tile_dphbf8ps(dst, src1, src2)                                        \
    223   __builtin_ia32_tdphbf8ps((dst), (src1), (src2))
    224 #define _tile_dphf8ps(dst, src1, src2)                                         \
    225   __builtin_ia32_tdphf8ps((dst), (src1), (src2))
    226 
    227 #undef __DEFAULT_FN_ATTRS_FP8
    228 
    229 #endif /* __x86_64__ */
    230 #endif /* __AMXFP8INTRIN_H */