amxbf16transposeintrin.h (3525B) - Raw
1 /*===----- amxbf16transposeintrin.h - AMX-BF16 and AMX-TRANSPOSE ------------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===------------------------------------------------------------------------=== 8 */ 9 10 #ifndef __IMMINTRIN_H 11 #error \ 12 "Never use <amxbf16transposeintrin.h> directly; use <immintrin.h> instead." 13 #endif /* __IMMINTRIN_H */ 14 15 #ifndef __AMX_BF16TRANSPOSEINTRIN_H 16 #define __AMX_BF16TRANSPOSEINTRIN_H 17 #ifdef __x86_64__ 18 19 /* Define the default attributes for the functions in this file. */ 20 #define __DEFAULT_FN_ATTRS \ 21 __attribute__((__always_inline__, __nodebug__, \ 22 __target__("amx-bf16,amx-transpose"))) 23 24 /// Compute transpose and dot-product of BF16 (16-bit) floating-point pairs in 25 /// tiles \a a and \a b, accumulating the intermediate single-precision 26 /// (32-bit) floating-point elements with elements in \a dst, and store the 27 /// 32-bit result back to tile \a dst. 28 /// 29 /// \headerfile <immintrin.h> 30 /// 31 /// \code 32 /// void _tile_tdpbf16ps (__tile dst, __tile a, __tile b) 33 /// \endcode 34 /// 35 /// \code{.operation} 36 /// FOR m := 0 TO dst.rows - 1 37 /// tmp := dst.row[m] 38 /// FOR k := 0 TO (a.colsb / 4) - 1 39 /// FOR n := 0 TO (dst.colsb / 4) - 1 40 /// tmp.bf32[n] += FP32(a.row[m].bf16[2*k+0]) * 41 /// FP32(b.row[k].bf16[2*n+0]) 42 /// tmp.bf32[n] += FP32(a.row[m].bf16[2*k+1]) * 43 /// FP32(b.row[k].bf16[2*n+1]) 44 /// ENDFOR 45 /// ENDFOR 46 /// write_row_and_zero(dst, m, tmp, dst.colsb) 47 /// ENDFOR 48 /// zero_upper_rows(dst, dst.rows) 49 /// zero_tileconfig_start() 50 /// \endcode 51 /// 52 /// This intrinsic corresponds to the \c TTDPBF16PS instruction. 53 /// 54 /// \param dst 55 /// The destination tile. Max size is 1024 Bytes. 56 /// \param a 57 /// The 1st source tile. Max size is 1024 Bytes. 58 /// \param b 59 /// The 2nd source tile. Max size is 1024 Bytes. 60 #define _tile_tdpbf16ps(dst, a, b) __builtin_ia32_ttdpbf16ps((dst), (a), (b)) 61 62 /// This is internal intrinsic. C/C++ user should avoid calling it directly. 63 static __inline__ _tile1024i __DEFAULT_FN_ATTRS 64 _tile_tdpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, 65 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 66 return __builtin_ia32_ttdpbf16ps_internal(m, n, k, dst, src1, src2); 67 } 68 69 /// Compute transpose and dot-product of BF16 (16-bit) floating-point pairs in 70 /// tiles src0 and src1, accumulating the intermediate single-precision 71 /// (32-bit) floating-point elements with elements in "dst", and store the 72 /// 32-bit result back to tile "dst". 73 /// 74 /// \headerfile <immintrin.h> 75 /// 76 /// This intrinsic corresponds to the <c> TTDPBF16PS </c> instruction. 77 /// 78 /// \param dst 79 /// The destination tile. Max size is 1024 Bytes. 80 /// \param src0 81 /// The 1st source tile. Max size is 1024 Bytes. 82 /// \param src1 83 /// The 2nd source tile. Max size is 1024 Bytes. 84 __DEFAULT_FN_ATTRS 85 static __inline__ void __tile_tdpbf16ps(__tile1024i *dst, __tile1024i src0, 86 __tile1024i src1) { 87 dst->tile = _tile_tdpbf16ps_internal(src0.row, src1.col, src0.col, dst->tile, 88 src0.tile, src1.tile); 89 } 90 91 #undef __DEFAULT_FN_ATTRS 92 93 #endif /* __x86_64__ */ 94 #endif /* __AMX_BF16TRANSPOSEINTRIN_H */