amxcomplextransposeintrin.h (12061B) - Raw
1 /*===----- amxcomplextransposeintrin.h - AMX-COMPLEX and AMX-TRANSPOSE ------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===------------------------------------------------------------------------=== 8 */ 9 10 #ifndef __IMMINTRIN_H 11 #error \ 12 "Never use <amxcomplextransposeintrin.h> directly; include <immintrin.h> instead." 13 #endif // __IMMINTRIN_H 14 15 #ifndef __AMX_COMPLEXTRANSPOSEINTRIN_H 16 #define __AMX_COMPLEXTRANSPOSEINTRIN_H 17 #ifdef __x86_64__ 18 19 #define __DEFAULT_FN_ATTRS \ 20 __attribute__((__always_inline__, __nodebug__, \ 21 __target__("amx-complex,amx-transpose"))) 22 23 /// Perform matrix multiplication of two tiles containing complex elements and 24 /// accumulate the results into a packed single precision tile. Each dword 25 /// element in input tiles \a a and \a b is interpreted as a complex number 26 /// with FP16 real part and FP16 imaginary part. 27 /// Calculates the imaginary part of the result. For each possible combination 28 /// of (transposed column of \a a, column of \a b), it performs a set of 29 /// multiplication and accumulations on all corresponding complex numbers 30 /// (one from \a a and one from \a b). The imaginary part of the \a a element 31 /// is multiplied with the real part of the corresponding \a b element, and 32 /// the real part of the \a a element is multiplied with the imaginary part 33 /// of the corresponding \a b elements. The two accumulated results are 34 /// added, and then accumulated into the corresponding row and column of 35 /// \a dst. 36 /// 37 /// \headerfile <x86intrin.h> 38 /// 39 /// \code 40 /// void _tile_tcmmimfp16ps(__tile dst, __tile a, __tile b); 41 /// \endcode 42 /// 43 /// \code{.operation} 44 /// FOR m := 0 TO dst.rows - 1 45 /// tmp := dst.row[m] 46 /// FOR k := 0 TO a.rows - 1 47 /// FOR n := 0 TO (dst.colsb / 4) - 1 48 /// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1]) 49 /// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0]) 50 /// ENDFOR 51 /// ENDFOR 52 /// write_row_and_zero(dst, m, tmp, dst.colsb) 53 /// ENDFOR 54 /// zero_upper_rows(dst, dst.rows) 55 /// zero_tileconfig_start() 56 /// \endcode 57 /// 58 /// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction. 59 /// 60 /// \param dst 61 /// The destination tile. Max size is 1024 Bytes. 62 /// \param a 63 /// The 1st source tile. Max size is 1024 Bytes. 64 /// \param b 65 /// The 2nd source tile. Max size is 1024 Bytes. 66 #define _tile_tcmmimfp16ps(dst, a, b) \ 67 __builtin_ia32_ttcmmimfp16ps((dst), (a), (b)) 68 69 /// Perform matrix multiplication of two tiles containing complex elements and 70 /// accumulate the results into a packed single precision tile. Each dword 71 /// element in input tiles \a a and \a b is interpreted as a complex number 72 /// with FP16 real part and FP16 imaginary part. 73 /// Calculates the real part of the result. For each possible combination 74 /// of (rtransposed colum of \a a, column of \a b), it performs a set of 75 /// multiplication and accumulations on all corresponding complex numbers 76 /// (one from \a a and one from \a b). The real part of the \a a element is 77 /// multiplied with the real part of the corresponding \a b element, and the 78 /// negated imaginary part of the \a a element is multiplied with the 79 /// imaginary part of the corresponding \a b elements. The two accumulated 80 /// results are added, and then accumulated into the corresponding row and 81 /// column of \a dst. 82 /// 83 /// \headerfile <x86intrin.h> 84 /// 85 /// \code 86 /// void _tile_tcmmrlfp16ps(__tile dst, __tile a, __tile b); 87 /// \endcode 88 /// 89 /// \code{.operation} 90 /// FOR m := 0 TO dst.rows - 1 91 /// tmp := dst.row[m] 92 /// FOR k := 0 TO a.rows - 1 93 /// FOR n := 0 TO (dst.colsb / 4) - 1 94 /// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+0]) 95 /// tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+1]) 96 /// ENDFOR 97 /// ENDFOR 98 /// write_row_and_zero(dst, m, tmp, dst.colsb) 99 /// ENDFOR 100 /// zero_upper_rows(dst, dst.rows) 101 /// zero_tileconfig_start() 102 /// \endcode 103 /// 104 /// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction. 105 /// 106 /// \param dst 107 /// The destination tile. Max size is 1024 Bytes. 108 /// \param a 109 /// The 1st source tile. Max size is 1024 Bytes. 110 /// \param b 111 /// The 2nd source tile. Max size is 1024 Bytes. 112 #define _tile_tcmmrlfp16ps(dst, a, b) \ 113 __builtin_ia32_ttcmmrlfp16ps((dst), (a), (b)) 114 115 /// Perform matrix conjugate transpose and multiplication of two tiles 116 /// containing complex elements and accumulate the results into a packed 117 /// single precision tile. Each dword element in input tiles \a a and \a b 118 /// is interpreted as a complex number with FP16 real part and FP16 imaginary 119 /// part. 120 /// Calculates the imaginary part of the result. For each possible combination 121 /// of (transposed column of \a a, column of \a b), it performs a set of 122 /// multiplication and accumulations on all corresponding complex numbers 123 /// (one from \a a and one from \a b). The negated imaginary part of the \a a 124 /// element is multiplied with the real part of the corresponding \a b 125 /// element, and the real part of the \a a element is multiplied with the 126 /// imaginary part of the corresponding \a b elements. The two accumulated 127 /// results are added, and then accumulated into the corresponding row and 128 /// column of \a dst. 129 /// 130 /// \headerfile <x86intrin.h> 131 /// 132 /// \code 133 /// void _tile_conjtcmmimfp16ps(__tile dst, __tile a, __tile b); 134 /// \endcode 135 /// 136 /// \code{.operation} 137 /// FOR m := 0 TO dst.rows - 1 138 /// tmp := dst.row[m] 139 /// FOR k := 0 TO a.rows - 1 140 /// FOR n := 0 TO (dst.colsb / 4) - 1 141 /// tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1]) 142 /// tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0]) 143 /// ENDFOR 144 /// ENDFOR 145 /// write_row_and_zero(dst, m, tmp, dst.colsb) 146 /// ENDFOR 147 /// zero_upper_rows(dst, dst.rows) 148 /// zero_tileconfig_start() 149 /// \endcode 150 /// 151 /// This intrinsic corresponds to the \c TCONJTCMMIMFP16PS instruction. 152 /// 153 /// \param dst 154 /// The destination tile. Max size is 1024 Bytes. 155 /// \param a 156 /// The 1st source tile. Max size is 1024 Bytes. 157 /// \param b 158 /// The 2nd source tile. Max size is 1024 Bytes. 159 #define _tile_conjtcmmimfp16ps(dst, a, b) \ 160 __builtin_ia32_tconjtcmmimfp16ps((dst), (a), (b)) 161 162 /// Perform conjugate transpose of an FP16-pair of complex elements from \a a 163 /// and writes the result to \a dst. 164 /// 165 /// \headerfile <x86intrin.h> 166 /// 167 /// \code 168 /// void _tile_conjtfp16(__tile dst, __tile a); 169 /// \endcode 170 /// 171 /// \code{.operation} 172 /// FOR i := 0 TO dst.rows - 1 173 /// FOR j := 0 TO (dst.colsb / 4) - 1 174 /// tmp.fp16[2*j+0] := a.row[j].fp16[2*i+0] 175 /// tmp.fp16[2*j+1] := -a.row[j].fp16[2*i+1] 176 /// ENDFOR 177 /// write_row_and_zero(dst, i, tmp, dst.colsb) 178 /// ENDFOR 179 /// zero_upper_rows(dst, dst.rows) 180 /// zero_tileconfig_start() 181 /// \endcode 182 /// 183 /// This intrinsic corresponds to the \c TCONJTFP16 instruction. 184 /// 185 /// \param dst 186 /// The destination tile. Max size is 1024 Bytes. 187 /// \param a 188 /// The source tile. Max size is 1024 Bytes. 189 #define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16((dst), (a)) 190 191 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmimfp16ps_internal( 192 unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, 193 _tile1024i src1, _tile1024i src2) { 194 return __builtin_ia32_ttcmmimfp16ps_internal(m, n, k, dst, src1, src2); 195 } 196 197 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmrlfp16ps_internal( 198 unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, 199 _tile1024i src1, _tile1024i src2) { 200 return __builtin_ia32_ttcmmrlfp16ps_internal(m, n, k, dst, src1, src2); 201 } 202 203 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtcmmimfp16ps_internal( 204 unsigned short m, unsigned short n, unsigned short k, _tile1024i dst, 205 _tile1024i src1, _tile1024i src2) { 206 return __builtin_ia32_tconjtcmmimfp16ps_internal(m, n, k, dst, src1, src2); 207 } 208 209 static __inline__ _tile1024i __DEFAULT_FN_ATTRS 210 _tile_conjtfp16_internal(unsigned short m, unsigned short n, _tile1024i src) { 211 return __builtin_ia32_tconjtfp16_internal(m, n, src); 212 } 213 214 /// Perform matrix multiplication of two tiles containing complex elements and 215 /// accumulate the results into a packed single precision tile. Each dword 216 /// element in input tiles src0 and src1 is interpreted as a complex number 217 /// with FP16 real part and FP16 imaginary part. 218 /// This function calculates the imaginary part of the result. 219 /// 220 /// \headerfile <immintrin.h> 221 /// 222 /// This intrinsic corresponds to the <c> TTCMMIMFP16PS </c> instruction. 223 /// 224 /// \param dst 225 /// The destination tile. Max size is 1024 Bytes. 226 /// \param src0 227 /// The 1st source tile. Max size is 1024 Bytes. 228 /// \param src1 229 /// The 2nd source tile. Max size is 1024 Bytes. 230 __DEFAULT_FN_ATTRS 231 static void __tile_tcmmimfp16ps(__tile1024i *dst, __tile1024i src0, 232 __tile1024i src1) { 233 dst->tile = _tile_tcmmimfp16ps_internal(src0.row, src1.col, src0.col, 234 dst->tile, src0.tile, src1.tile); 235 } 236 237 /// Perform matrix multiplication of two tiles containing complex elements and 238 /// accumulate the results into a packed single precision tile. Each dword 239 /// element in input tiles src0 and src1 is interpreted as a complex number 240 /// with FP16 real part and FP16 imaginary part. 241 /// This function calculates the real part of the result. 242 /// 243 /// \headerfile <immintrin.h> 244 /// 245 /// This intrinsic corresponds to the <c> TTCMMRLFP16PS </c> instruction. 246 /// 247 /// \param dst 248 /// The destination tile. Max size is 1024 Bytes. 249 /// \param src0 250 /// The 1st source tile. Max size is 1024 Bytes. 251 /// \param src1 252 /// The 2nd source tile. Max size is 1024 Bytes. 253 __DEFAULT_FN_ATTRS 254 static void __tile_tcmmrlfp16ps(__tile1024i *dst, __tile1024i src0, 255 __tile1024i src1) { 256 dst->tile = _tile_tcmmrlfp16ps_internal(src0.row, src1.col, src0.col, 257 dst->tile, src0.tile, src1.tile); 258 } 259 260 /// Perform matrix conjugate transpose and multiplication of two tiles 261 /// containing complex elements and accumulate the results into a packed 262 /// single precision tile. Each dword element in input tiles src0 and src1 263 /// is interpreted as a complex number with FP16 real part and FP16 imaginary 264 /// part. 265 /// This function calculates the imaginary part of the result. 266 /// 267 /// \headerfile <immintrin.h> 268 /// 269 /// This intrinsic corresponds to the <c> TCONJTCMMIMFP16PS </c> instruction. 270 /// 271 /// \param dst 272 /// The destination tile. Max size is 1024 Bytes. 273 /// \param src0 274 /// The 1st source tile. Max size is 1024 Bytes. 275 /// \param src1 276 /// The 2nd source tile. Max size is 1024 Bytes. 277 __DEFAULT_FN_ATTRS 278 static void __tile_conjtcmmimfp16ps(__tile1024i *dst, __tile1024i src0, 279 __tile1024i src1) { 280 dst->tile = _tile_conjtcmmimfp16ps_internal(src0.row, src1.col, src0.col, 281 dst->tile, src0.tile, src1.tile); 282 } 283 284 /// Perform conjugate transpose of an FP16-pair of complex elements from src and 285 /// writes the result to dst. 286 /// 287 /// \headerfile <immintrin.h> 288 /// 289 /// This intrinsic corresponds to the <c> TCONJTFP16 </c> instruction. 290 /// 291 /// \param dst 292 /// The destination tile. Max size is 1024 Bytes. 293 /// \param src 294 /// The source tile. Max size is 1024 Bytes. 295 __DEFAULT_FN_ATTRS 296 static void __tile_conjtfp16(__tile1024i *dst, __tile1024i src) { 297 dst->tile = _tile_conjtfp16_internal(src.row, src.col, src.tile); 298 } 299 300 #undef __DEFAULT_FN_ATTRS 301 302 #endif // __x86_64__ 303 #endif // __AMX_COMPLEXTRANSPOSEINTRIN_H