zig

fork of https://codeberg.org/ziglang/zig
Log | Files | Refs | README | LICENSE

amxcomplextransposeintrin.h (12061B) - Raw


      1 /*===----- amxcomplextransposeintrin.h - AMX-COMPLEX and AMX-TRANSPOSE ------===
      2  *
      3  * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
      4  * See https://llvm.org/LICENSE.txt for license information.
      5  * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
      6  *
      7  *===------------------------------------------------------------------------===
      8  */
      9 
     10 #ifndef __IMMINTRIN_H
     11 #error                                                                         \
     12     "Never use <amxcomplextransposeintrin.h> directly; include <immintrin.h> instead."
     13 #endif // __IMMINTRIN_H
     14 
     15 #ifndef __AMX_COMPLEXTRANSPOSEINTRIN_H
     16 #define __AMX_COMPLEXTRANSPOSEINTRIN_H
     17 #ifdef __x86_64__
     18 
     19 #define __DEFAULT_FN_ATTRS                                                     \
     20   __attribute__((__always_inline__, __nodebug__,                               \
     21                  __target__("amx-complex,amx-transpose")))
     22 
     23 /// Perform matrix multiplication of two tiles containing complex elements and
     24 ///    accumulate the results into a packed single precision tile. Each dword
     25 ///    element in input tiles \a a and \a b is interpreted as a complex number
     26 ///    with FP16 real part and FP16 imaginary part.
     27 /// Calculates the imaginary part of the result. For each possible combination
     28 ///    of (transposed column of \a a, column of \a b), it performs a set of
     29 ///    multiplication and accumulations on all corresponding complex numbers
     30 ///    (one from \a a and one from \a b). The imaginary part of the \a a element
     31 ///    is multiplied with the real part of the corresponding \a b element, and
     32 ///    the real part of the \a a element is multiplied with the imaginary part
     33 ///    of the corresponding \a b elements. The two accumulated results are
     34 ///    added, and then accumulated into the corresponding row and column of
     35 ///    \a dst.
     36 ///
     37 /// \headerfile <x86intrin.h>
     38 ///
     39 /// \code
     40 /// void _tile_tcmmimfp16ps(__tile dst, __tile a, __tile b);
     41 /// \endcode
     42 ///
     43 /// \code{.operation}
     44 /// FOR m := 0 TO dst.rows - 1
     45 ///	tmp := dst.row[m]
     46 ///	FOR k := 0 TO a.rows - 1
     47 ///		FOR n := 0 TO (dst.colsb / 4) - 1
     48 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1])
     49 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0])
     50 ///		ENDFOR
     51 ///	ENDFOR
     52 ///	write_row_and_zero(dst, m, tmp, dst.colsb)
     53 /// ENDFOR
     54 /// zero_upper_rows(dst, dst.rows)
     55 /// zero_tileconfig_start()
     56 /// \endcode
     57 ///
     58 /// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction.
     59 ///
     60 /// \param dst
     61 ///    The destination tile. Max size is 1024 Bytes.
     62 /// \param a
     63 ///    The 1st source tile. Max size is 1024 Bytes.
     64 /// \param b
     65 ///    The 2nd source tile. Max size is 1024 Bytes.
     66 #define _tile_tcmmimfp16ps(dst, a, b)                                          \
     67   __builtin_ia32_ttcmmimfp16ps((dst), (a), (b))
     68 
     69 /// Perform matrix multiplication of two tiles containing complex elements and
     70 ///    accumulate the results into a packed single precision tile. Each dword
     71 ///    element in input tiles \a a and \a b is interpreted as a complex number
     72 ///    with FP16 real part and FP16 imaginary part.
     73 /// Calculates the real part of the result. For each possible combination
     74 ///    of (rtransposed colum of \a a, column of \a b), it performs a set of
     75 ///    multiplication and accumulations on all corresponding complex numbers
     76 ///    (one from \a a and one from \a b). The real part of the \a a element is
     77 ///    multiplied with the real part of the corresponding \a b element, and the
     78 ///    negated imaginary part of the \a a element is multiplied with the
     79 ///    imaginary part of the corresponding \a b elements. The two accumulated
     80 ///    results are added, and then accumulated into the corresponding row and
     81 ///    column of \a dst.
     82 ///
     83 /// \headerfile <x86intrin.h>
     84 ///
     85 /// \code
     86 /// void _tile_tcmmrlfp16ps(__tile dst, __tile a, __tile b);
     87 /// \endcode
     88 ///
     89 /// \code{.operation}
     90 /// FOR m := 0 TO dst.rows - 1
     91 ///	tmp := dst.row[m]
     92 ///	FOR k := 0 TO a.rows - 1
     93 ///		FOR n := 0 TO (dst.colsb / 4) - 1
     94 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+0])
     95 ///			tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+1])
     96 ///		ENDFOR
     97 ///	ENDFOR
     98 ///	write_row_and_zero(dst, m, tmp, dst.colsb)
     99 /// ENDFOR
    100 /// zero_upper_rows(dst, dst.rows)
    101 /// zero_tileconfig_start()
    102 /// \endcode
    103 ///
    104 /// This intrinsic corresponds to the \c TTCMMIMFP16PS instruction.
    105 ///
    106 /// \param dst
    107 ///    The destination tile. Max size is 1024 Bytes.
    108 /// \param a
    109 ///    The 1st source tile. Max size is 1024 Bytes.
    110 /// \param b
    111 ///    The 2nd source tile. Max size is 1024 Bytes.
    112 #define _tile_tcmmrlfp16ps(dst, a, b)                                          \
    113   __builtin_ia32_ttcmmrlfp16ps((dst), (a), (b))
    114 
    115 /// Perform matrix conjugate transpose and multiplication of two tiles
    116 ///    containing complex elements and accumulate the results into a packed
    117 ///    single precision tile. Each dword element in input tiles \a a and \a b
    118 ///    is interpreted as a complex number with FP16 real part and FP16 imaginary
    119 ///    part.
    120 /// Calculates the imaginary part of the result. For each possible combination
    121 ///    of (transposed column of \a a, column of \a b), it performs a set of
    122 ///    multiplication and accumulations on all corresponding complex numbers
    123 ///    (one from \a a and one from \a b). The negated imaginary part of the \a a
    124 ///    element is multiplied with the real part of the corresponding \a b
    125 ///    element, and the real part of the \a a element is multiplied with the
    126 ///    imaginary part of the corresponding \a b elements. The two accumulated
    127 ///    results are added, and then accumulated into the corresponding row and
    128 ///    column of \a dst.
    129 ///
    130 /// \headerfile <x86intrin.h>
    131 ///
    132 /// \code
    133 /// void _tile_conjtcmmimfp16ps(__tile dst, __tile a, __tile b);
    134 /// \endcode
    135 ///
    136 /// \code{.operation}
    137 /// FOR m := 0 TO dst.rows - 1
    138 ///	tmp := dst.row[m]
    139 ///	FOR k := 0 TO a.rows - 1
    140 ///		FOR n := 0 TO (dst.colsb / 4) - 1
    141 ///			tmp.fp32[n] += FP32(a.row[m].fp16[2*k+0]) * FP32(b.row[k].fp16[2*n+1])
    142 ///			tmp.fp32[n] += FP32(-a.row[m].fp16[2*k+1]) * FP32(b.row[k].fp16[2*n+0])
    143 ///		ENDFOR
    144 ///	ENDFOR
    145 ///	write_row_and_zero(dst, m, tmp, dst.colsb)
    146 /// ENDFOR
    147 /// zero_upper_rows(dst, dst.rows)
    148 /// zero_tileconfig_start()
    149 /// \endcode
    150 ///
    151 /// This intrinsic corresponds to the \c TCONJTCMMIMFP16PS instruction.
    152 ///
    153 /// \param dst
    154 ///    The destination tile. Max size is 1024 Bytes.
    155 /// \param a
    156 ///    The 1st source tile. Max size is 1024 Bytes.
    157 /// \param b
    158 ///    The 2nd source tile. Max size is 1024 Bytes.
    159 #define _tile_conjtcmmimfp16ps(dst, a, b)                                      \
    160   __builtin_ia32_tconjtcmmimfp16ps((dst), (a), (b))
    161 
    162 /// Perform conjugate transpose of an FP16-pair of complex elements from \a a
    163 ///    and writes the result to \a dst.
    164 ///
    165 /// \headerfile <x86intrin.h>
    166 ///
    167 /// \code
    168 /// void _tile_conjtfp16(__tile dst, __tile a);
    169 /// \endcode
    170 ///
    171 /// \code{.operation}
    172 /// FOR i := 0 TO dst.rows - 1
    173 ///	FOR j := 0 TO (dst.colsb / 4) - 1
    174 ///		tmp.fp16[2*j+0] := a.row[j].fp16[2*i+0]
    175 ///		tmp.fp16[2*j+1] := -a.row[j].fp16[2*i+1]
    176 ///	ENDFOR
    177 ///	write_row_and_zero(dst, i, tmp, dst.colsb)
    178 /// ENDFOR
    179 /// zero_upper_rows(dst, dst.rows)
    180 /// zero_tileconfig_start()
    181 /// \endcode
    182 ///
    183 /// This intrinsic corresponds to the \c TCONJTFP16 instruction.
    184 ///
    185 /// \param dst
    186 ///    The destination tile. Max size is 1024 Bytes.
    187 /// \param a
    188 ///    The source tile. Max size is 1024 Bytes.
    189 #define _tile_conjtfp16(dst, a) __builtin_ia32_tconjtfp16((dst), (a))
    190 
    191 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmimfp16ps_internal(
    192     unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    193     _tile1024i src1, _tile1024i src2) {
    194   return __builtin_ia32_ttcmmimfp16ps_internal(m, n, k, dst, src1, src2);
    195 }
    196 
    197 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_tcmmrlfp16ps_internal(
    198     unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    199     _tile1024i src1, _tile1024i src2) {
    200   return __builtin_ia32_ttcmmrlfp16ps_internal(m, n, k, dst, src1, src2);
    201 }
    202 
    203 static __inline__ _tile1024i __DEFAULT_FN_ATTRS _tile_conjtcmmimfp16ps_internal(
    204     unsigned short m, unsigned short n, unsigned short k, _tile1024i dst,
    205     _tile1024i src1, _tile1024i src2) {
    206   return __builtin_ia32_tconjtcmmimfp16ps_internal(m, n, k, dst, src1, src2);
    207 }
    208 
    209 static __inline__ _tile1024i __DEFAULT_FN_ATTRS
    210 _tile_conjtfp16_internal(unsigned short m, unsigned short n, _tile1024i src) {
    211   return __builtin_ia32_tconjtfp16_internal(m, n, src);
    212 }
    213 
    214 /// Perform matrix multiplication of two tiles containing complex elements and
    215 ///    accumulate the results into a packed single precision tile. Each dword
    216 ///    element in input tiles src0 and src1 is interpreted as a complex number
    217 ///    with FP16 real part and FP16 imaginary part.
    218 ///    This function calculates the imaginary part of the result.
    219 ///
    220 /// \headerfile <immintrin.h>
    221 ///
    222 /// This intrinsic corresponds to the <c> TTCMMIMFP16PS </c> instruction.
    223 ///
    224 /// \param dst
    225 ///    The destination tile. Max size is 1024 Bytes.
    226 /// \param src0
    227 ///    The 1st source tile. Max size is 1024 Bytes.
    228 /// \param src1
    229 ///    The 2nd source tile. Max size is 1024 Bytes.
    230 __DEFAULT_FN_ATTRS
    231 static void __tile_tcmmimfp16ps(__tile1024i *dst, __tile1024i src0,
    232                                 __tile1024i src1) {
    233   dst->tile = _tile_tcmmimfp16ps_internal(src0.row, src1.col, src0.col,
    234                                           dst->tile, src0.tile, src1.tile);
    235 }
    236 
    237 /// Perform matrix multiplication of two tiles containing complex elements and
    238 ///    accumulate the results into a packed single precision tile. Each dword
    239 ///    element in input tiles src0 and src1 is interpreted as a complex number
    240 ///    with FP16 real part and FP16 imaginary part.
    241 ///    This function calculates the real part of the result.
    242 ///
    243 /// \headerfile <immintrin.h>
    244 ///
    245 /// This intrinsic corresponds to the <c> TTCMMRLFP16PS </c> instruction.
    246 ///
    247 /// \param dst
    248 ///    The destination tile. Max size is 1024 Bytes.
    249 /// \param src0
    250 ///    The 1st source tile. Max size is 1024 Bytes.
    251 /// \param src1
    252 ///    The 2nd source tile. Max size is 1024 Bytes.
    253 __DEFAULT_FN_ATTRS
    254 static void __tile_tcmmrlfp16ps(__tile1024i *dst, __tile1024i src0,
    255                                 __tile1024i src1) {
    256   dst->tile = _tile_tcmmrlfp16ps_internal(src0.row, src1.col, src0.col,
    257                                           dst->tile, src0.tile, src1.tile);
    258 }
    259 
    260 /// Perform matrix conjugate transpose and multiplication of two tiles
    261 ///    containing complex elements and accumulate the results into a packed
    262 ///    single precision tile. Each dword element in input tiles src0 and src1
    263 ///    is interpreted as a complex number with FP16 real part and FP16 imaginary
    264 ///    part.
    265 ///    This function calculates the imaginary part of the result.
    266 ///
    267 /// \headerfile <immintrin.h>
    268 ///
    269 /// This intrinsic corresponds to the <c> TCONJTCMMIMFP16PS </c> instruction.
    270 ///
    271 /// \param dst
    272 ///    The destination tile. Max size is 1024 Bytes.
    273 /// \param src0
    274 ///    The 1st source tile. Max size is 1024 Bytes.
    275 /// \param src1
    276 ///    The 2nd source tile. Max size is 1024 Bytes.
    277 __DEFAULT_FN_ATTRS
    278 static void __tile_conjtcmmimfp16ps(__tile1024i *dst, __tile1024i src0,
    279                                     __tile1024i src1) {
    280   dst->tile = _tile_conjtcmmimfp16ps_internal(src0.row, src1.col, src0.col,
    281                                               dst->tile, src0.tile, src1.tile);
    282 }
    283 
    284 /// Perform conjugate transpose of an FP16-pair of complex elements from src and
    285 ///    writes the result to dst.
    286 ///
    287 /// \headerfile <immintrin.h>
    288 ///
    289 /// This intrinsic corresponds to the <c> TCONJTFP16 </c> instruction.
    290 ///
    291 /// \param dst
    292 ///    The destination tile. Max size is 1024 Bytes.
    293 /// \param src
    294 ///    The source tile. Max size is 1024 Bytes.
    295 __DEFAULT_FN_ATTRS
    296 static void __tile_conjtfp16(__tile1024i *dst, __tile1024i src) {
    297   dst->tile = _tile_conjtfp16_internal(src.row, src.col, src.tile);
    298 }
    299 
    300 #undef __DEFAULT_FN_ATTRS
    301 
    302 #endif // __x86_64__
    303 #endif // __AMX_COMPLEXTRANSPOSEINTRIN_H