zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

amxtf32transposeintrin.h (3642B) - Raw


      1 /*===--------- amxtf32transposeintrin.h - AMX-TF32 and AMX-TRANSPOSE --------===
      2  *
      3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4  * See https://llvm.org/LICENSE.txt for license information.
      5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6  *
      7  *===------------------------------------------------------------------------===
      8  */
      9 #ifndef __IMMINTRIN_H
     10 #error                                                                         \
     11     "Never use <amxtf32tranposeintrin.h> directly; include <immintrin.h> instead."
     12 #endif // __IMMINTRIN_H
     13 
     14 #ifndef __AMX_TF32TRANSPOSEINTRIN_H
     15 #define __AMX_TF32TRANSPOSEINTRIN_H
     16 #ifdef __x86_64__
     17 
     18 #define __DEFAULT_FN_ATTRS_TF32_TRANSPOSE                                      \
     19   __attribute__((__always_inline__, __nodebug__,                               \
     20                  __target__("amx-tf32,amx-transpose")))
     21 
     22 /// \code
     23 /// void _tile_tmmultf32ps(constexpr int srcdst, constexpr int a, \
     24 ///                        constexpr int b);
     25 /// \endcode
     26 ///
     27 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
     28 ///
     29 /// \param srcdst
     30 /// 	The destination tile. Max size is 1024 Bytes.
     31 /// \param a
     32 /// 	The 1st source tile. Max size is 1024 Bytes.
     33 /// \param b
     34 /// 	The 2nd source tile. Max size is 1024 Bytes.
     35 ///
     36 /// \code{.operation}
     37 /// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
     38 /// 	dword[12:0] := 0
     39 /// 	dword[31:13] := x[31:13]
     40 /// 	return dword
     41 /// }
     42 ///
     43 /// DEFINE silence_snan_fp32(x[31:0]) {
     44 /// 	IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
     45 /// 		x.fraction[22] := 1
     46 /// 	return x
     47 /// }
     48 ///
     49 /// elements_dest:= srcdst.colsb/4
     50 ///
     51 /// FOR m := 0 TO (srcdst.rows-1)
     52 /// 	tmp[511:0] := 0
     53 /// 	FOR k := 0 TO (a.rows-1)
     54 /// 		FOR n := 0 TO (elements_dest-1)
     55 /// 			a1e := silence_snan_fp32(a.row[k].fp32[m])
     56 /// 			a2e := silence_snan_fp32(b.row[k].fp32[n])
     57 /// 			s1e := zero_lower_mantissa_bits_fp32(a1e)
     58 /// 			s2e := zero_lower_mantissa_bits_fp32(a2e)
     59 /// 			tmp.fp32[n] += s1e * s2e
     60 /// 		ENDFOR
     61 /// 	ENDFOR
     62 ///
     63 /// 	FOR n := 0 TO (elements_dest-1)
     64 /// 		tmp.fp32[n] += srcdst.row[m].fp32[n]
     65 /// 	ENDFOR
     66 ///	write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
     67 ///
     68 /// ENDFOR
     69 ///
     70 /// zero_upper_rows(srcdst, srcdst.rows)
     71 /// zero_tileconfig_start()
     72 /// \endcode
     73 #define _tile_tmmultf32ps(srcdst, a, b)                                        \
     74   __builtin_ia32_ttmmultf32ps((srcdst), (a), (b))
     75 
     76 // dst = m x n (srcdest), src1 = k x m, src2 = k x n
     77 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
     78 _tile_tmmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
     79                            _tile1024i dst, _tile1024i src1, _tile1024i src2) {
     80   return __builtin_ia32_ttmmultf32ps_internal(m, n, k, dst, src1, src2);
     81 }
     82 
     83 /// Compute transpose and do Matrix Multiplication of src0 and src1, and then do
     84 /// Matrix Plus with dst. All the calculation is base on float32 but with the
     85 /// lower 13-bit set to 0.
     86 ///
     87 /// \headerfile <immintrin.h>
     88 ///
     89 /// This intrinsic corresponds to the <c> TTMMULTF32PS </c> instruction.
     90 ///
     91 /// \param dst
     92 ///    The destination tile. Max size is 1024 Bytes.
     93 /// \param src0
     94 ///    The 1st source tile. Max size is 1024 Bytes.
     95 /// \param src1
     96 ///    The 2nd source tile. Max size is 1024 Bytes.
     97 __DEFAULT_FN_ATTRS_TF32_TRANSPOSE
     98 static void __tile_tmmultf32ps(__tile1024i *dst, __tile1024i src0,
     99                                __tile1024i src1) {
    100   dst->tile = _tile_tmmultf32ps_internal(src0.row, src1.col, src0.col,
    101                                          dst->tile, src0.tile, src1.tile);
    102 }
    103 
    104 #endif // __x86_64__
    105 #endif // __AMX_TF32TRANSPOSEINTRIN_H