zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

amxtf32intrin.h (3562B) - Raw


      1 /*===------------- amxtf32intrin.h - AMX_TF32 intrinsics -*- C++ -*---------===
      2  *
      3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4  * See https://llvm.org/LICENSE.txt for license information.
      5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6  *
      7  *===------------------------------------------------------------------------===
      8  */
      9 
     10 #ifndef __IMMINTRIN_H
     11 #error "Never use <amxtf32intrin.h> directly; include <immintrin.h> instead."
     12 #endif // __IMMINTRIN_H
     13 
     14 #ifndef __AMX_TF32INTRIN_H
     15 #define __AMX_TF32INTRIN_H
     16 #ifdef __x86_64__
     17 
     18 #define __DEFAULT_FN_ATTRS_TF32                                                \
     19   __attribute__((__always_inline__, __nodebug__, __target__("amx-tf32")))
     20 
     21 /// Do Matrix Multiplication of \a a and \a b, and then do Matrix Plus
     22 /// with \a srcdst.
     23 /// All the calculation is base on float32 but with the lower 13-bit set to 0.
     24 ///
     25 /// \headerfile <immintrin.h>
     26 ///
     27 /// \code
     28 /// void _tile_mmultf32ps(constexpr int srcdst, constexpr int a, \
     29 ///                       constexpr int b);
     30 /// \endcode
     31 ///
     32 /// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
     33 ///
     34 /// \param srcdst
     35 /// 	The destination tile. Max size is 1024 Bytes.
     36 /// \param a
     37 /// 	The 1st source tile. Max size is 1024 Bytes.
     38 /// \param b
     39 /// 	The 2nd source tile. Max size is 1024 Bytes.
     40 ///
     41 /// \code{.operation}
     42 /// DEFINE zero_lower_mantissa_bits_fp32(x[31:0]) {
     43 ///	dword[12:0] := 0
     44 ///	dword[31:13] := x[31:13]
     45 ///	return dword
     46 /// }
     47 ///
     48 /// DEFINE silence_snan_fp32(x[31:0]) {
     49 /// 	IF (x.exponent == 255 and x.fraction != 0 and x.fraction[22] == 0)
     50 /// 		x.fraction[22] := 1
     51 /// 	return x
     52 /// }
     53 ///
     54 /// elements_a := a.colsb / 4
     55 /// elements_dest := srcdst.colsb / 4
     56 ///
     57 /// FOR m = 0 TO (srcdst.rows-1)
     58 /// 	tmp[511:0] := 0
     59 /// 	FOR k = 0 TO (elements_a-1)
     60 /// 		FOR n = 0 TO (elements_dest-1)
     61 /// 			af := silence_snan_fp32(a.row[m].fp32[k])
     62 /// 			bf := silence_snan_fp32(b.row[k].fp32[n])
     63 /// 			tmp.fp32[n] += zero_lower_mantissa_bits_fp32(af)
     64 /// 					* zero_lower_mantissa_bits_fp32(bf)
     65 /// 		ENDFOR
     66 /// 	ENDFOR
     67 ///
     68 /// 	FOR n = 0 TO (elements_dest-1)
     69 /// 		tmp.fp32[n] += srcdst.row[m].fp32[n]
     70 /// 	ENDFOR
     71 ///	write_row_and_zero(srcdst, m, tmp, srcdst.colsb)
     72 ///
     73 /// ENDFOR
     74 ///
     75 /// zero_upper_rows(srcdst, srcdst.rows)
     76 /// zero_tileconfig_start()
     77 /// \endcode
     78 #define _tile_mmultf32ps(srcdst, a, b)                                         \
     79   __builtin_ia32_tmmultf32ps((srcdst), (a), (b))
     80 
     81 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_TF32
     82 _tile_mmultf32ps_internal(unsigned short m, unsigned short n, unsigned short k,
     83                           _tile1024i dst, _tile1024i src1, _tile1024i src2) {
     84   return __builtin_ia32_tmmultf32ps_internal(m, n, k, dst, src1, src2);
     85 }
     86 
     87 /// Do Matrix Multiplication of src0 and src1, and then do Matrix Plus with dst.
     88 /// All the calculation is base on float32 but with the lower 13-bit set to 0.
     89 ///
     90 /// \headerfile <immintrin.h>
     91 ///
     92 /// This intrinsic corresponds to the <c> TMMULTF32PS </c> instruction.
     93 ///
     94 /// \param dst
     95 ///    The destination tile. Max size is 1024 Bytes.
     96 /// \param src0
     97 ///    The 1st source tile. Max size is 1024 Bytes.
     98 /// \param src1
     99 ///    The 2nd source tile. Max size is 1024 Bytes.
    100 __DEFAULT_FN_ATTRS_TF32
    101 static void __tile_mmultf32ps(__tile1024i *dst, __tile1024i src0,
    102                               __tile1024i src1) {
    103   dst->tile = _tile_mmultf32ps_internal(src0.row, src1.col, src0.col, dst->tile,
    104                                         src0.tile, src1.tile);
    105 }
    106 
    107 #endif // __x86_64__
    108 #endif // __AMX_TF32INTRIN_H