amxfp8intrin.h (8839B) - Raw
1 /*===------------- amxfp8intrin.h - AMX intrinsics -*- C++ -*----------------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===------------------------------------------------------------------------=== 8 */ 9 10 #ifndef __IMMINTRIN_H 11 #error "Never use <amxfp8intrin.h> directly; include <immintrin.h> instead." 12 #endif /* __IMMINTRIN_H */ 13 14 #ifndef __AMXFP8INTRIN_H 15 #define __AMXFP8INTRIN_H 16 #ifdef __x86_64__ 17 18 #define __DEFAULT_FN_ATTRS_FP8 \ 19 __attribute__((__always_inline__, __nodebug__, __target__("amx-fp8"))) 20 21 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8 22 _tile_dpbf8ps_internal(unsigned short m, unsigned short n, unsigned short k, 23 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 24 return __builtin_ia32_tdpbf8ps_internal(m, n, k, dst, src1, src2); 25 } 26 27 /// Perform the dot product of a BF8 value \a src1 by a BF8 value \a src2 28 /// accumulating into a Single Precision (FP32) source/dest \a dst. 29 /// 30 /// \headerfile <immintrin.h> 31 /// 32 /// \code 33 /// void __tile_dpbf8ps (__tile1024i *dst, __tile1024i src1, __tile1024i src2) 34 /// \endcode 35 /// 36 /// \code{.operation} 37 /// FOR m := 0 TO dst.rows - 1 38 /// temp1[(dst.colsb / 4 - 1) : 0] = 0 39 /// FOR k := 0 TO src1.colsb / 4 - 1 40 /// FOR n := 0 TO dst.colsb / 4 - 1 41 /// temp1[n] += 42 /// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0]) 43 /// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1]) 44 /// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2]) 45 /// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3]) 46 /// ENDFOR 47 /// ENDFOR 48 /// FOR n := 0 TO dst.colsb / 4 - 1 49 /// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n]) 50 /// ENDFOR 51 /// write_row_and_zero(dst, m, tmp, dst.colsb) 52 /// zero_upper_rows(dst, dst.rows) 53 /// zero_tileconfig_start() 54 /// \endcode 55 /// 56 /// This intrinsic corresponds to the \c TDPBF8PS instruction. 57 /// 58 /// \param dst 59 /// The destination tile. Max size is 1024 Bytes. 60 /// \param src1 61 /// The 1st source tile. Max size is 1024 Bytes. 62 /// \param src2 63 /// The 2nd source tile. Max size is 1024 Bytes. 64 __DEFAULT_FN_ATTRS_FP8 static void 65 __tile_dpbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) { 66 dst->tile = _tile_dpbf8ps_internal(src1.row, src2.col, src1.col, dst->tile, 67 src1.tile, src2.tile); 68 } 69 70 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8 71 _tile_dpbhf8ps_internal(unsigned short m, unsigned short n, unsigned short k, 72 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 73 return __builtin_ia32_tdpbhf8ps_internal(m, n, k, dst, src1, src2); 74 } 75 76 /// Perform the dot product of a BF8 value \a src1 by an HF8 value \a src2 77 /// accumulating into a Single Precision (FP32) source/dest \a dst. 78 /// 79 /// \headerfile <immintrin.h> 80 /// 81 /// \code 82 /// void __tile_dpbhf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2) 83 /// \endcode 84 /// 85 /// \code{.operation} 86 /// FOR m := 0 TO dst.rows - 1 87 /// temp1[(dst.colsb / 4 - 1) : 0] = 0 88 /// FOR k := 0 TO src1.colsb / 4 - 1 89 /// FOR n := 0 TO dst.colsb / 4 - 1 90 /// temp1[n] += 91 /// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0]) 92 /// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1]) 93 /// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2]) 94 /// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3]) 95 /// ENDFOR 96 /// ENDFOR 97 /// FOR n := 0 TO dst.colsb / 4 - 1 98 /// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n]) 99 /// ENDFOR 100 /// write_row_and_zero(dst, m, tmp, dst.colsb) 101 /// zero_upper_rows(dst, dst.rows) 102 /// zero_tileconfig_start() 103 /// \endcode 104 /// 105 /// This intrinsic corresponds to the \c TDPBHF8PS instruction. 106 /// 107 /// \param dst 108 /// The destination tile. Max size is 1024 Bytes. 109 /// \param src1 110 /// The 1st source tile. Max size is 1024 Bytes. 111 /// \param src2 112 /// The 2nd source tile. Max size is 1024 Bytes. 113 __DEFAULT_FN_ATTRS_FP8 static void 114 __tile_dpbhf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) { 115 dst->tile = _tile_dpbhf8ps_internal(src1.row, src2.col, src1.col, dst->tile, 116 src1.tile, src2.tile); 117 } 118 119 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8 120 _tile_dphbf8ps_internal(unsigned short m, unsigned short n, unsigned short k, 121 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 122 return __builtin_ia32_tdphbf8ps_internal(m, n, k, dst, src1, src2); 123 } 124 125 /// Perform the dot product of an HF8 value \a src1 by a BF8 value \a src2 126 /// accumulating into a Single Precision (FP32) source/dest \a dst. 127 /// 128 /// \headerfile <immintrin.h> 129 /// 130 /// \code 131 /// void __tile_dphbf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2) 132 /// \endcode 133 /// 134 /// \code{.operation} 135 /// FOR m := 0 TO dst.rows - 1 136 /// temp1[(dst.colsb / 4 - 1) : 0] = 0 137 /// FOR k := 0 TO src1.colsb / 4 - 1 138 /// FOR n := 0 TO dst.colsb / 4 - 1 139 /// temp1[n] += 140 /// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0]) 141 /// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1]) 142 /// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2]) 143 /// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3]) 144 /// ENDFOR 145 /// ENDFOR 146 /// FOR n := 0 TO dst.colsb / 4 - 1 147 /// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n]) 148 /// ENDFOR 149 /// write_row_and_zero(dst, m, tmp, dst.colsb) 150 /// zero_upper_rows(dst, dst.rows) 151 /// zero_tileconfig_start() 152 /// \endcode 153 /// 154 /// This intrinsic corresponds to the \c TDPHBF8PS instruction. 155 /// 156 /// \param dst 157 /// The destination tile. Max size is 1024 Bytes. 158 /// \param src1 159 /// The 1st source tile. Max size is 1024 Bytes. 160 /// \param src2 161 /// The 2nd source tile. Max size is 1024 Bytes. 162 163 __DEFAULT_FN_ATTRS_FP8 static void 164 __tile_dphbf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) { 165 dst->tile = _tile_dphbf8ps_internal(src1.row, src2.col, src1.col, dst->tile, 166 src1.tile, src2.tile); 167 } 168 169 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_FP8 170 _tile_dphf8ps_internal(unsigned short m, unsigned short n, unsigned short k, 171 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 172 return __builtin_ia32_tdphf8ps_internal(m, n, k, dst, src1, src2); 173 } 174 175 /// Perform the dot product of an HF8 value \a src1 by an HF8 value \a src2 176 /// accumulating into a Single Precision (FP32) source/dest \a dst. 177 /// 178 /// \headerfile <immintrin.h> 179 /// 180 /// \code 181 /// void __tile_dphf8ps (__tile1024i dst, __tile1024i src1, __tile1024i src2) 182 /// \endcode 183 /// 184 /// \code{.operation} 185 /// FOR m := 0 TO dst.rows - 1 186 /// temp1[(dst.colsb / 4 - 1) : 0] = 0 187 /// FOR k := 0 TO src1.colsb / 4 - 1 188 /// FOR n := 0 TO dst.colsb / 4 - 1 189 /// temp1[n] += 190 /// INT64(src1.row[m].float8[4*k+0]) * INT64(src2.row[k].float8[4*n+0]) 191 /// + INT64(src1.row[m].float8[4*k+1]) * INT64(src2.row[k].float8[4*n+1]) 192 /// + INT64(src1.row[m].float8[4*k+2]) * INT64(src2.row[k].float8[4*n+2]) 193 /// + INT64(src1.row[m].float8[4*k+3]) * INT64(src2.row[k].float8[4*n+3]) 194 /// ENDFOR 195 /// ENDFOR 196 /// FOR n := 0 TO dst.colsb / 4 - 1 197 /// tmp.row[m].fp32[n] = dst.row[m].fp32[n] + FP32(temp1[n]) 198 /// ENDFOR 199 /// write_row_and_zero(dst, m, tmp, dst.colsb) 200 /// zero_upper_rows(dst, dst.rows) 201 /// zero_tileconfig_start() 202 /// \endcode 203 /// 204 /// This intrinsic corresponds to the \c TDPHF8PS instruction. 205 /// 206 /// \param dst 207 /// The destination tile. Max size is 1024 Bytes. 208 /// \param src1 209 /// The 1st source tile. Max size is 1024 Bytes. 210 /// \param src2 211 /// The 2nd source tile. Max size is 1024 Bytes. 212 __DEFAULT_FN_ATTRS_FP8 static void 213 __tile_dphf8ps(__tile1024i *dst, __tile1024i src1, __tile1024i src2) { 214 dst->tile = _tile_dphf8ps_internal(src1.row, src2.col, src1.col, dst->tile, 215 src1.tile, src2.tile); 216 } 217 218 #define _tile_dpbf8ps(dst, src1, src2) \ 219 __builtin_ia32_tdpbf8ps((dst), (src1), (src2)) 220 #define _tile_dpbhf8ps(dst, src1, src2) \ 221 __builtin_ia32_tdpbhf8ps((dst), (src1), (src2)) 222 #define _tile_dphbf8ps(dst, src1, src2) \ 223 __builtin_ia32_tdphbf8ps((dst), (src1), (src2)) 224 #define _tile_dphf8ps(dst, src1, src2) \ 225 __builtin_ia32_tdphf8ps((dst), (src1), (src2)) 226 227 #undef __DEFAULT_FN_ATTRS_FP8 228 229 #endif /* __x86_64__ */ 230 #endif /* __AMXFP8INTRIN_H */