amxtf32transposeintrin.h (3642B) - Raw
1 /*===--------- amxtf32transposeintrin.h - AMX-TF32 and AMX-TRANSPOSE --------=== 2 * 3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 * See https://llvm.org/LICENSE.txt for license information. 5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 * 7 *===------------------------------------------------------------------------=== 8 */ 9 #ifndef __IMMINTRIN_H 10 #error \ 11 "Never use <amxtf32tranposeintrin.h> directly; include <immintrin.h> instead." 12 #endif // __IMMINTRIN_H 13 14 #ifndef __AMX_TF32TRANSPOSEINTRIN_H 15 #define __AMX_TF32TRANSPOSEINTRIN_H 16 #ifdef __x86_64__ 17 18 #define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE \ 19 __attribute__((__always_inline__, __nodebug__, \ 20 __target__("amx-tf32,amx-transpose"))) 21 22 /// \code 23 /// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \ 24 /// constexpr int b); 25 /// \endcode 26 /// 27 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. 28 /// 29 /// \param srcdst 30 /// The destination tile. Max size is 1024 Bytes. 31 /// \param a 32 /// The 1st source tile. Max size is 1024 Bytes. 33 /// \param b 34 /// The 2nd source tile. Max size is 1024 Bytes. 35 /// 36 /// \code{.operation} 37 /// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) { 38 /// dword[12:0] := 0 39 /// dword[31:13] := x[31:13] 40 /// return dword 41 /// } 42 /// 43 /// DEFINE silence_snan_fp32(x[31:0]) { 44 /// IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0) 45 /// x.fraction[22] := 1 46 /// return x 47 /// } 48 /// 49 /// elements_dest:= srcdst.colsb/4 50 /// 51 /// FOR m := 0 TO (srcdst.rows-1) 52 /// tmp[511:0] := 0 53 /// FOR k := 0 TO (a.rows-1) 54 /// FOR n := 0 TO (elements_dest-1) 55 /// a1e := silence_snan_fp32(a.row[k].fp32[m]) 56 /// a2e := silence_snan_fp32(b.row[k].fp32[n]) 57 /// s1e := zero_lower_mantissa_bits_fp32(a1e) 58 /// s2e := zero_lower_mantissa_bits_fp32(a2e) 59 /// tmp.fp32[n] += s1e * s2e 60 /// ENDFOR 61 /// ENDFOR 62 /// 63 /// FOR n := 0 TO (elements_dest-1) 64 /// tmp.fp32[n] += srcdst.row[m].fp32[n] 65 /// ENDFOR 66 /// write_row_and_zero(srcdst, m, tmp, srcdst.colsb) 67 /// 68 /// ENDFOR 69 /// 70 /// zero_upper_rows(srcdst, srcdst.rows) 71 /// zero_tileconfig_start() 72 /// \endcode 73 #define _tile_tmmultf32ps(srcdst, a, b) \ 74 __builtin_ia32_ttmmultf32ps((srcdst), (a), (b)) 75 76 // dst = m x n (srcdest), src1 = k x m, src2 = k x n 77 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE 78 _tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k, 79 _tile1024i dst, _tile1024i src1, _tile1024i src2) { 80 return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2); 81 } 82 83 /// Compute transpose and do Matrix Multiplication of src0 and src1, and then do 84 /// Matrix Plus with dst. All the calculation is base on float32 but with the 85 /// lower 13-bit set to 0. 86 /// 87 /// \headerfile <immintrin.h> 88 /// 89 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction. 90 /// 91 /// \param dst 92 /// The destination tile. Max size is 1024 Bytes. 93 /// \param src0 94 /// The 1st source tile. Max size is 1024 Bytes. 95 /// \param src1 96 /// The 2nd source tile. Max size is 1024 Bytes. 97 __DEFAULT_FN_ATTRS_TF32_TRANSPOSE 98 static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0, 99 __tile1024i src1) { 100 dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col, 101 dst->tile, src0.tile, src1.tile); 102 } 103 104 #endif // __x86_64__ 105 #endif // __AMX_TF32TRANSPOSEINTRIN_H