zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

amxcomplexintrin.h (6973B) - Raw


      1 /*===--------- amxcomplexintrin.h - AMXCOMPLEX intrinsics -*- C++ -*---------===
      2  *
      3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4  * See https://llvm.org/LICENSE.txt for license information.
      5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6  *
      7  *===------------------------------------------------------------------------===
      8  */
      9 
     10 #ifndef __IMMINTRIN_H
     11 #error "Never use <amxcomplexintrin.h> directly; include <immintrin.h> instead."
     12 #endif // __IMMINTRIN_H
     13 
     14 #ifndef __AMX_COMPLEXINTRIN_H
     15 #define __AMX_COMPLEXINTRIN_H
     16 #ifdef __x86_64__
     17 
     18 #define __DEFAULT_FN_ATTRS_COMPLEX                                             \
     19   __attribute__((__always_inline__, __nodebug__, __target__("amx-complex")))
     20 
     21 /// Perform matrix multiplication of two tiles containing complex elements and
     22 ///    accumulate the results into a packed single precision tile. Each dword
     23 ///    element in input tiles \a a and \a b is interpreted as a complex number
     24 ///    with FP16 real part and FP16 imaginary part.
     25 /// Calculates the imaginary part of the result. For each possible combination
     26 ///    of (row of \a a, column of \a b), it performs a set of multiplication
     27 ///    and accumulations on all corresponding complex numbers (one from \a a
     28 ///    and one from \a b). The imaginary part of the \a a element is multiplied
     29 ///    with the real part of the corresponding \a b element, and the real part
     30 ///    of the \a a element is multiplied with the imaginary part of the
     31 ///    corresponding \a b elements. The two accumulated results are added, and
     32 ///    then accumulated into the corresponding row and column of \a dst.
     33 ///
     34 /// \headerfile <x86intrin.h>
     35 ///
     36 /// \code
     37 /// void _tile_cmmimfp16ps(__tile dst, __tile a, __tile b);
     38 /// \endcode
     39 ///
     40 /// \code{.operation}
     41 /// FOR m := 0 TO dst.rows - 1
     42 ///	tmp := dst.row[m]
     43 ///	FOR k := 0 TO (a.colsb / 4) - 1
     44 ///		FOR n := 0 TO (dst.colsb / 4) - 1
     45 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1])
     46 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0])
     47 ///		ENDFOR
     48 ///	ENDFOR
     49 ///	write_row_and_zero(dst, m, tmp, dst.colsb)
     50 /// ENDFOR
     51 /// zero_upper_rows(dst, dst.rows)
     52 /// zero_tileconfig_start()
     53 /// \endcode
     54 ///
     55 /// This intrinsic corresponds to the \c TCMMIMFP16PS instruction.
     56 ///
     57 /// \param dst
     58 ///    The destination tile. Max size is 1024 Bytes.
     59 /// \param a
     60 ///    The 1st source tile. Max size is 1024 Bytes.
     61 /// \param b
     62 ///    The 2nd source tile. Max size is 1024 Bytes.
     63 #define _tile_cmmimfp16ps(dst, a, b) __builtin_ia32_tcmmimfp16ps(dst, a, b)
     64 
     65 /// Perform matrix multiplication of two tiles containing complex elements and
     66 ///    accumulate the results into a packed single precision tile. Each dword
     67 ///    element in input tiles \a a and \a b is interpreted as a complex number
     68 ///    with FP16 real part and FP16 imaginary part.
     69 /// Calculates the real part of the result. For each possible combination
     70 ///    of (row of \a a, column of \a b), it performs a set of multiplication
     71 ///    and accumulations on all corresponding complex numbers (one from \a a
     72 ///    and one from \a b). The real part of the \a a element is multiplied
     73 ///    with the real part of the corresponding \a b element, and the negated
     74 ///    imaginary part of the \a a element is multiplied with the imaginary
     75 ///    part of the corresponding \a b elements. The two accumulated results
     76 ///    are added, and then accumulated into the corresponding row and column
     77 ///    of \a dst.
     78 ///
     79 /// \headerfile <x86intrin.h>
     80 ///
     81 /// \code
     82 /// void _tile_cmmrlfp16ps(__tile dst, __tile a, __tile b);
     83 /// \endcode
     84 ///
     85 /// \code{.operation}
     86 /// FOR m := 0 TO dst.rows - 1
     87 ///	tmp := dst.row[m]
     88 ///	FOR k := 0 TO (a.colsb / 4) - 1
     89 ///		FOR n := 0 TO (dst.colsb / 4) - 1
     90 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+0])
     91 ///			tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+1])
     92 ///		ENDFOR
     93 ///	ENDFOR
     94 ///	write_row_and_zero(dst, m, tmp, dst.colsb)
     95 /// ENDFOR
     96 /// zero_upper_rows(dst, dst.rows)
     97 /// zero_tileconfig_start()
     98 /// \endcode
     99 ///
    100 /// This intrinsic corresponds to the \c TCMMIMFP16PS instruction.
    101 ///
    102 /// \param dst
    103 ///    The destination tile. Max size is 1024 Bytes.
    104 /// \param a
    105 ///    The 1st source tile. Max size is 1024 Bytes.
    106 /// \param b
    107 ///    The 2nd source tile. Max size is 1024 Bytes.
    108 #define _tile_cmmrlfp16ps(dst, a, b) __builtin_ia32_tcmmrlfp16ps(dst, a, b)
    109 
    110 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_COMPLEX
    111 _tile_cmmimfp16ps_internal(unsigned short m, unsigned short n, unsigned short k,
    112                            _tile1024i dst, _tile1024i src1, _tile1024i src2) {
    113   return __builtin_ia32_tcmmimfp16ps_internal(m, n, k, dst, src1, src2);
    114 }
    115 
    116 static __inline__ _tile1024i __DEFAULT_FN_ATTRS_COMPLEX
    117 _tile_cmmrlfp16ps_internal(unsigned short m, unsigned short n, unsigned short k,
    118                            _tile1024i dst, _tile1024i src1, _tile1024i src2) {
    119   return __builtin_ia32_tcmmrlfp16ps_internal(m, n, k, dst, src1, src2);
    120 }
    121 
    122 /// Perform matrix multiplication of two tiles containing complex elements and
    123 /// accumulate the results into a packed single precision tile. Each dword
    124 /// element in input tiles src0 and src1 is interpreted as a complex number with
    125 /// FP16 real part and FP16 imaginary part.
    126 /// This function calculates the imaginary part of the result.
    127 ///
    128 /// \headerfile <immintrin.h>
    129 ///
    130 /// This intrinsic corresponds to the <c> TCMMIMFP16PS </c> instruction.
    131 ///
    132 /// \param dst
    133 ///    The destination tile. Max size is 1024 Bytes.
    134 /// \param src0
    135 ///    The 1st source tile. Max size is 1024 Bytes.
    136 /// \param src1
    137 ///    The 2nd source tile. Max size is 1024 Bytes.
    138 __DEFAULT_FN_ATTRS_COMPLEX
    139 static void __tile_cmmimfp16ps(__tile1024i *dst, __tile1024i src0,
    140                                __tile1024i src1) {
    141   dst->tile = _tile_cmmimfp16ps_internal(src0.row, src1.col, src0.col,
    142                                          dst->tile, src0.tile, src1.tile);
    143 }
    144 
    145 /// Perform matrix multiplication of two tiles containing complex elements and
    146 /// accumulate the results into a packed single precision tile. Each dword
    147 /// element in input tiles src0 and src1 is interpreted as a complex number with
    148 /// FP16 real part and FP16 imaginary part.
    149 /// This function calculates the real part of the result.
    150 ///
    151 /// \headerfile <immintrin.h>
    152 ///
    153 /// This intrinsic corresponds to the <c> TCMMRLFP16PS </c> instruction.
    154 ///
    155 /// \param dst
    156 ///    The destination tile. Max size is 1024 Bytes.
    157 /// \param src0
    158 ///    The 1st source tile. Max size is 1024 Bytes.
    159 /// \param src1
    160 ///    The 2nd source tile. Max size is 1024 Bytes.
    161 __DEFAULT_FN_ATTRS_COMPLEX
    162 static void __tile_cmmrlfp16ps(__tile1024i *dst, __tile1024i src0,
    163                                __tile1024i src1) {
    164   dst->tile = _tile_cmmrlfp16ps_internal(src0.row, src1.col, src0.col,
    165                                          dst->tile, src0.tile, src1.tile);
    166 }
    167 
    168 #endif // __x86_64__
    169 #endif // __AMX_COMPLEXINTRIN_H