HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_fp8.h
Go to the documentation of this file.
1
30#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
32
33#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
34 defined(__gfx1201__)) && \
35 __HIP_DEVICE_COMPILE__
36#define HIP_FP8_CVT_FAST_PATH 1
37#else
38#define HIP_FP8_CVT_FAST_PATH 0
39#endif
40
41#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
42#define HIP_FP8_TYPE_OCP 0
43#define HIP_FP8_TYPE_FNUZ 1
44#elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
45#define HIP_FP8_TYPE_OCP 1
46#define HIP_FP8_TYPE_FNUZ 0
47#elif __HIP_DEVICE_COMPILE__
48#define HIP_FP8_TYPE_FNUZ 0
49#define HIP_FP8_TYPE_OCP 0
50#else // Host
51#define HIP_FP8_TYPE_FNUZ 1
52#define HIP_FP8_TYPE_OCP 1
53#endif
54
55#if !defined(__HIPCC_RTC__)
56#include <hip/amd_detail/amd_hip_common.h>
57#include <climits>
58
59#include "host_defines.h" // __hip_internal::
60#include "amd_hip_vector_types.h" // float2 etc
61#include "amd_hip_fp16.h" // __half_raw
62#include "amd_hip_bf16.h" // bf16
63#include "math_fwd.h" // ocml device functions
64#endif // !defined(__HIPCC_RTC__)
65
66#if defined(__HIPCC_RTC__)
67#define __FP8_HOST_DEVICE__ __device__
68#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
69#else
70#define __FP8_HOST_DEVICE__ __host__ __device__
71#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
72#endif // __HIPCC_RTC__
73
74#define __FP8_HOST__ __host__
75
76#if !defined(__HIPCC_RTC__)
77static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
78#endif
79static_assert(sizeof(unsigned char) == 1);
80static_assert(sizeof(unsigned short int) == 2);
81static_assert(sizeof(unsigned int) == 4);
82
92
100
105typedef unsigned char __hip_fp8_storage_t;
106
107
112typedef unsigned short int __hip_fp8x2_storage_t;
113
114
119typedef unsigned int __hip_fp8x4_storage_t;
120
121namespace internal {
122
123// The conversion function is from rocblas
124// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
125// This has been modified to add double types conversion as well
126template <typename T, bool is_fnuz>
127__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
128 bool stoch = false, unsigned int rng = 0) {
129 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
130 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
131 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
132 static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
133
134 const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
135 unsigned long long x;
136
137 if (sizeof(T) == 8)
138 x = reinterpret_cast<unsigned long long&>(_x);
139 else if (sizeof(T) == 4)
140 x = reinterpret_cast<unsigned int&>(_x);
141 else
142 x = reinterpret_cast<unsigned short int&>(_x);
143
144
145 unsigned long long head, mantissa;
146 int exponent, bias;
147 unsigned int sign;
148 unsigned long long fInf, mask;
149
150 if (sizeof(T) == 8) {
151 head = x & 0xFFF0000000000000ull;
152 mantissa = x & 0xFFFFFFFFFFFFFull;
153 exponent = (head >> 52) & 0x7FF;
154 sign = head >> 63;
155 bias = 1023;
156 fInf = 0x7FF0000000000000ull;
157 mask = 0x7FFFFFFFFFFFFFFFull;
158 } else if (sizeof(T) == 4) {
159 head = x & 0xFF800000;
160 mantissa = x & 0x7FFFFF;
161 exponent = (head >> 23) & 0xFF;
162 sign = head >> 31;
163 bias = 127;
164 fInf = 0x7F800000;
165 mask = 0x7FFFFFFF;
166 } else {
167 head = x & 0xFC00;
168 mantissa = x & 0x3FF;
169 exponent = (head >> 10) & 0x1F;
170 sign = head >> 15;
171 bias = 15;
172 fInf = 0x7C00;
173 mask = 0x7FFF;
174 }
175 unsigned int signed_inf = 0;
176 unsigned int nan = 0;
177 if (is_fnuz) {
178 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
179 nan = 0x80;
180 } else {
181 if (we == 4) { // e4m3
182 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
183 } else { // e5m2
184 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
185 }
186 nan = (sign << 7) + 0x7f;
187 }
188 // Max values
189 unsigned long long ifmax = 0;
190 if (sizeof(T) == 8) {
191 if (we == 5) { // 57344
192 ifmax = 0x40EC000000000000ull;
193 } else {
194 if (is_fnuz) { // 240
195 ifmax = 0x406E000000000000ull;
196 } else { // 448
197 ifmax = 0x407C000000000000ull;
198 }
199 }
200 } else if (sizeof(T) == 4) {
201 if (we == 5) {
202 ifmax = 0x47600000;
203 } else {
204 if (is_fnuz) {
205 ifmax = 0x43700000;
206 } else {
207 ifmax = 0x43E00000;
208 }
209 }
210 } else {
211 if (we == 5) {
212 ifmax = 0x7B00;
213 } else {
214 if (is_fnuz) {
215 ifmax = 0x5B80;
216 } else {
217 ifmax = 0x5F00;
218 }
219 }
220 }
221 // Deal with inf and NaNs
222 if ((x & fInf) == fInf) {
223 if (is_fnuz) return signed_inf;
224 return mantissa != 0 ? nan : signed_inf;
225 }
226
227 if ((x & mask) > ifmax) {
228 return signed_inf;
229 }
230
231 if (x == 0) {
232 return 0;
233 }
234
235 // First need to check if it is normal or denorm as there is a difference of implict 1
236 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
237 // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
238 // RNE, no need to add rng. Then probably need to check whether there is carry and adjust
239 // exponent and mantissa again
240
241 // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
242 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
243 const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
244 // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
245 // f8_exponent is the converted f8 exponent with bias encoding
246 // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
247 // the difference needs to be adjusted and mantissa shifted
248 int act_exponent, f8_exponent, exponent_diff;
249
250 if (exponent == 0) { // fp32/fp16 is in denormal.
251 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
252here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
253exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
254fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
255where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
256this case, the fp16 mantissa should be shift left by 1 */
257 act_exponent = exponent - bias + 1;
258 exponent_diff = f8_denormal_act_exponent -
259 act_exponent; // actual exponent is exponent-bias+1 as it is denormal
260 } else { // fp32/fp16 is normal with implicit 1
261 act_exponent = exponent - bias;
262 if (act_exponent <= f8_denormal_act_exponent) {
263 /* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
264For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
265actual exponent is -7, it is actually larger due to the implict 1,
266Therefore it needs to be adjust to -6 and mantissa shift right by 1.
267So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
268 exponent_diff = f8_denormal_act_exponent - act_exponent;
269 } else { // both fp32/fp16 and f8 are in normal range
270 exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
271 // act_exponent could be larger. Just that it does not need shift mantissa
272 }
273 mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
274 }
275
276 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
277 (1ull << (mfmt - wm + exponent_diff - 1));
278 /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
279right as shift right could rip off some residual part and make something not midpoint look like
280midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
281after shift right by 4 bits, it would look like midpoint.
282*/
283
284 if (exponent_diff > 0)
285 mantissa >>= exponent_diff;
286 else if (exponent_diff == -1)
287 mantissa <<= -exponent_diff;
288 bool implicit_one = mantissa & (1ull << mfmt);
289 // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
290 f8_exponent =
291 (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
292
293 // Now we have the exponent and mantissa adjusted
294 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
295 bool odd =
296 mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
297 mantissa +=
298 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
299
300 // Now we deal with overflow
301 if (f8_exponent == 0) {
302 if ((1ull << mfmt) & mantissa) {
303 f8_exponent = 1; // denormal overflow to become normal, promote exponent
304 }
305 } else {
306 if ((1ull << (mfmt + 1)) & mantissa) {
307 mantissa >>= 1;
308 f8_exponent++;
309 }
310 }
311
312 mantissa >>= (mfmt - wm);
313
314 // above range: quantize to maximum possible float of the same sign
315 const int max_exp = (1 << we) - 1;
316 if (f8_exponent > max_exp) {
317 if (clip) {
318 mantissa = (1 << wm) - 1;
319 f8_exponent = max_exp;
320 } else {
321 return signed_inf;
322 }
323 }
324
325 if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
326 mantissa &= (1 << wm) - 1;
327 return (sign << 7) | (f8_exponent << wm) | mantissa;
328}
329// The conversion function is from rocblas
330// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
331// This has been modified to handle double types as well
332template <typename T, bool is_fnuz>
333__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we, bool clip = false) {
334 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
335 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
336 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
337 static_assert(is_half || is_float || is_double, "only half, float and double are supported");
338
339 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
340 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
341
342 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
343 if (is_half) {
344 const unsigned short int ihInf = 0x7C00;
345 const unsigned short int ihNegInf = 0xFC00;
346 const unsigned short int ihNaN = 0x7C01;
347 const unsigned short int ihNeg0 = 0x8000;
348 /* Max number in e5m2 57344*/
349 const unsigned short int ifmax = 0x7B00;
350 const unsigned short int ifmin = 0xFB00;
351 fInf = reinterpret_cast<const _Float16&>(ihInf);
352 fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
353 fNaN = reinterpret_cast<const _Float16&>(ihNaN);
354 fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
355 fmax = reinterpret_cast<const _Float16&>(ifmax);
356 fmin = reinterpret_cast<const _Float16&>(ifmin);
357 } else if (is_float) {
358 const unsigned int ifInf = 0x7F800000;
359 const unsigned int ifNegInf = 0xFF800000;
360 const unsigned int ifNaN = 0x7F800001;
361 const unsigned int ifNeg0 = 0x80000000;
362 /* Max number in e5m2 57344*/
363 const unsigned int ifmax = 0x47600000;
364 const unsigned int ifmin = 0xC7600000;
365 fInf = reinterpret_cast<const float&>(ifInf);
366 fNegInf = reinterpret_cast<const float&>(ifNegInf);
367 fNaN = reinterpret_cast<const float&>(ifNaN);
368 fNeg0 = reinterpret_cast<const float&>(ifNeg0);
369 fmax = reinterpret_cast<const float&>(ifmax);
370 fmin = reinterpret_cast<const float&>(ifmin);
371 } else if (is_double) {
372 const unsigned long long ifInf = 0x7FF0000000000000ull;
373 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
374 const unsigned long long ifNaN = 0x7FF0000000000001ull;
375 const unsigned long long ifNeg0 = 0x8000000000000000ull;
376 /* Max number in e5m2 57344*/
377 const unsigned long long ifmax = 0x40EC000000000000ull;
378 const unsigned long long ifmin = 0xC0EC000000000000ull;
379 fInf = reinterpret_cast<const double&>(ifInf);
380 fNegInf = reinterpret_cast<const double&>(ifNegInf);
381 fNaN = reinterpret_cast<const double&>(ifNaN);
382 fNeg0 = reinterpret_cast<const double&>(ifNeg0);
383 fmax = reinterpret_cast<const double&>(ifmax);
384 fmin = reinterpret_cast<const double&>(ifmin);
385 }
386
387 if (x == 0) {
388 return 0;
389 }
390
391 unsigned long long sign = x >> 7;
392 unsigned long long mantissa = x & ((1 << wm) - 1);
393 int exponent = (x & 0x7F) >> wm;
394 if (is_fnuz) {
395 if (x == 0x80) {
396 return fNaN;
397 }
398 } else {
399 if (x == 0x80) {
400 return fNeg0;
401 }
402 if (we == 4) { // e4m3
403 if ((x & 0x7F) == 0x7F) {
404 return fNaN;
405 }
406 } else if ((x & 0x7C) == 0x7C) { // e5m2
407 if ((x & 0x3) == 0) {
408 if (clip) {
409 return sign ? fmin : fmax;
410 }
411 return sign ? fNegInf : fInf;
412 }
413 return fNaN;
414 }
415 }
416
417 typename __hip_internal::conditional<
418 sizeof(T) == 2, unsigned short int,
419 typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
420 unsigned long long>::type>::type retval;
421
422 if (we == 5 && is_half && !is_fnuz) {
423 retval = x << 8;
424 return reinterpret_cast<const T&>(retval);
425 }
426
427 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
428
429 // subnormal input
430 if (exponent == 0) {
431#if __HIP_DEVICE_COMPILE__
432 // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
433 int sh = 1 + __clz(mantissa) - (32 - wm);
434#else
435 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
436#endif
437 mantissa <<= sh;
438 exponent += 1 - sh;
439 mantissa &= ((1ull << wm) - 1);
440 }
441 exponent += exp_low_cutoff - 1;
442 mantissa <<= wmo - wm;
443
444 // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
445 if (exponent <= 0) {
446 mantissa |= 1 << wmo;
447 mantissa >>= 1 - exponent;
448 exponent = 0;
449 }
450
451 if (sizeof(T) == 2)
452 retval = (sign << 15) | (exponent << 10) | mantissa;
453 else if (sizeof(T) == 4)
454 retval = (sign << 31) | (exponent << 23) | mantissa;
455 else
456 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
457 return reinterpret_cast<const T&>(retval);
458}
459
460#if HIP_FP8_CVT_FAST_PATH
461// The conversion function is from rocblas
462// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
463template <bool stochastic_rounding = false>
464static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
466 unsigned int rng = 0) {
467 __hip_fp8_storage_t i8data;
468 union {
469 float fval;
470 unsigned int i32val;
471 unsigned char i8val[4]; // NOTE: not endian independent
472 } val;
473
474 unsigned int ival = 0;
475 val.fval = v;
476
477 if (saturate) {
478 if (interpret == __HIP_E4M3_FNUZ) {
479 if ((val.i32val & 0x7F800000) != 0x7F800000) {
480 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
481 }
482 } else if (interpret == __HIP_E4M3) { // OCP type
483 if ((val.i32val & 0x7F800000) != 0x7F800000) {
484 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
485 }
486 } else {
487 if ((val.i32val & 0x7F800000) != 0x7F800000) {
488 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
489 }
490 }
491 }
492
493 if (stochastic_rounding) {
494 ival = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
495 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
496 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
497 val.i32val = ival;
498 i8data = val.i8val[0]; // little endian
499 } else { // RNE CVT
500 ival = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
501 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
502 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
503 val.i32val = ival;
504 i8data = val.i8val[0];
505 }
506 return i8data;
507}
508
509static __device__ __hip_fp8x2_storage_t
510cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
511 union {
512 static_assert(sizeof(float2) == sizeof(unsigned int[2]));
513 static_assert(sizeof(float2) == sizeof(unsigned short[4]));
514 float2 fval;
515 unsigned int i32val[2];
516 unsigned short i16val[4];
517 } f2val;
518
519 f2val.fval = v;
520
521 if (saturate) {
522 if (interpret == __HIP_E4M3_FNUZ) {
523 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
524 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
525 }
526 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
527 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
528 }
529 } else if (interpret == __HIP_E4M3) {
530 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
531 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
532 }
533 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
534 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
535 }
536 } else {
537 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
538 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
539 }
540 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
541 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
542 }
543 }
544 }
545
546 f2val.i32val[0] = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
547 ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
548 : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
549
550 return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
551}
552
553static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
554 __hip_fp8_interpretation_t interpret) {
555 union {
556 unsigned int i32val;
557 unsigned char i8val[4];
558 } val;
559 val.i8val[0] = v;
560
561 float fval = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
562 ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
563 : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
564 return fval;
565}
566
567static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
568 __hip_fp8_interpretation_t interpret) {
569 union {
570 unsigned int i32val;
571 unsigned short i16val[2];
572 } val;
573 val.i16val[0] = v;
574
575 auto f2 = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
576 ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
577 : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
578 return float2{f2[0], f2[1]};
579}
580#endif // HIP_FP8_CVT_FAST_PATH
581
582/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
583Inf are not supported. This gives us one additional number to represent.
584NaN are represented by 1-0000-000 or 1-00000-00 */
585__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
586 return static_cast<unsigned char>(a) == 0x80;
587}
588
589__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_ocp_is_nan(__hip_fp8_storage_t a,
590 const __hip_fp8_interpretation_t type) {
591 return (type == __HIP_E4M3) ? ((a & 0x7f) == 0x7f)
592 : (type == __HIP_E5M2) ? ((a & 0x7f) > 0x7c)
593 : false;
594}
595
596__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_ocp_is_inf(__hip_fp8_storage_t a,
597 const __hip_fp8_interpretation_t type) {
598 return (type == __HIP_E5M2) ? (a & 0x7f) == 0x7c : false;
599}
600
601} // namespace internal
602
612 const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
613#if HIP_FP8_CVT_FAST_PATH
614 return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, type);
615#else // HIP_FP8_CVT_FAST_PATH
616 if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) {
617 int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
618 int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
619 return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
620 }
621 if (type == __HIP_E4M3 || type == __HIP_E5M2) {
622 int we = type == __HIP_E4M3 ? 4 : 5;
623 int wm = type == __HIP_E4M3 ? 3 : 2;
624 return internal::cast_to_f8<float, false>(f, wm, we, sat == __HIP_SATFINITE);
625 }
626#endif // HIP_FP8_CVT_FAST_PATH
627}
628
629
639 const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
640#if HIP_FP8_CVT_FAST_PATH
641 return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type);
642#else
643 return static_cast<__hip_fp8x2_storage_t>(
644 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 |
645 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, type)));
646#endif
647}
648
658 const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
659 if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) {
660 int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
661 int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
662 return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
663 }
664 if (type == __HIP_E4M3 || type == __HIP_E5M2) {
665 int we = type == __HIP_E4M3 ? 4 : 5;
666 int wm = type == __HIP_E4M3 ? 3 : 2;
667 return internal::cast_to_f8<double, false>(d, wm, we, sat == __HIP_SATFINITE);
668 }
669}
670
680 const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
681 return static_cast<__hip_fp8x2_storage_t>(
682 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 |
683 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, type)));
684}
685
694__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
695__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
696 const __hip_fp8_interpretation_t type) {
697 float fval = __hip_bfloat16(hr);
698 return __hip_cvt_float_to_fp8(fval, sat, type);
699}
700
709__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
710__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
711 const __hip_fp8_interpretation_t type) {
712 float2 f2 = __hip_bfloat162(hr);
713 return __hip_cvt_float2_to_fp8x2(f2, sat, type);
714}
715
723__FP8_HOST_DEVICE_STATIC__ __half_raw
725 if (type == __HIP_E4M3_FNUZ || type == __HIP_E5M2_FNUZ) {
726 unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
727 unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
728 return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
729 }
730 if (type == __HIP_E4M3 || type == __HIP_E5M2) {
731 unsigned int we = type == __HIP_E4M3 ? 4 : 5;
732 unsigned int wm = type == __HIP_E4M3 ? 3 : 2;
733 return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)};
734 }
735}
736
744__FP8_HOST_DEVICE_STATIC__ __half2_raw
746 __half2 ret(static_cast<__half>(
747 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)),
748 static_cast<__half>(
749 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type)));
750 return static_cast<__half2_raw>(ret);
751}
752
762 const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
763 return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type);
764}
765
775 const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
776 return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type);
777}
778
786 constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
787 constexpr static unsigned int __we = 4;
788 constexpr static unsigned int __wm = 3;
789
790 // TODO: SWDEV-452411
791 // Add cast from unsigned long long, long long to fp8
792
794#if HIP_FP8_TYPE_FNUZ
795 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
796#else
797 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const long int val)
798#endif
799 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
800 __default_interpret)) {}
801
803#if HIP_FP8_TYPE_FNUZ
804 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
805#else
806 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const int val)
807#endif
808 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
809 __default_interpret)) {}
810
812#if HIP_FP8_TYPE_FNUZ
813 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
814#else
815 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const short int val)
816#endif
817 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
818 __default_interpret)) {}
819
821#if HIP_FP8_TYPE_FNUZ
822 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
823#else
824 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
825#endif
826 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
827 __default_interpret)) {}
828
830#if HIP_FP8_TYPE_FNUZ
831 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
832#else
833 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned int val)
834#endif
835 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
836 __default_interpret)) {}
837
839#if HIP_FP8_TYPE_FNUZ
840 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
841#else
842 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
843#endif
844 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
845 __default_interpret)) {}
846
848#if HIP_FP8_TYPE_FNUZ
849 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
850#else
851 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const double f)
852#endif
853 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
854
856#if HIP_FP8_TYPE_FNUZ
857 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
858#else
859 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const float f)
860#endif
861 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
862
864#if HIP_FP8_TYPE_FNUZ
865 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
866#else
867 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
868#endif
869 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
870 __default_interpret)) {}
871
873#if HIP_FP8_TYPE_FNUZ
874 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
875#else
876 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __half f)
877#endif
879 __default_interpret)) {}
880
882#if HIP_FP8_TYPE_FNUZ
883 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
884#else
885 __FP8_HOST__ __hip_fp8_e4m3_fnuz() = default;
886#endif
887
889#if HIP_FP8_TYPE_FNUZ
890 __FP8_HOST_DEVICE__ operator __half() const {
891#else
892 __FP8_HOST__ operator __half() const {
893#endif
894 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
895 }
896
898#if HIP_FP8_TYPE_FNUZ
899 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
900#else
901 __FP8_HOST__ operator __hip_bfloat16() const {
902#endif
903 float f = *this;
904 return __hip_bfloat16(f);
905 }
906
908#if HIP_FP8_TYPE_FNUZ
909 __FP8_HOST_DEVICE__ operator bool() const {
910#else
911 __FP8_HOST__ operator bool() const {
912#endif
913 // it can be 0x00 (+0.0) since 0x80 will be nan
914 return !(static_cast<unsigned short>(__x) == 0);
915 }
916
918#if HIP_FP8_TYPE_FNUZ
919 __FP8_HOST_DEVICE__ operator char() const {
920#else
921 __FP8_HOST__ operator char() const {
922#endif
923 if (internal::hip_fp8_fnuz_is_nan(__x)) {
924 return 0;
925 }
926
927 auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
928 auto llval = static_cast<long long>(fval);
929 if (llval <= CHAR_MIN) {
930 return CHAR_MIN;
931 } else if (llval >= CHAR_MAX) {
932 return CHAR_MAX;
933 }
934 return static_cast<char>(fval);
935 }
936
938#if HIP_FP8_TYPE_FNUZ
939 __FP8_HOST_DEVICE__ operator double() const {
940#else
941 __FP8_HOST__ operator double() const {
942#endif
943 return internal::cast_from_f8<double, true>(__x, __wm, __we);
944 }
945
947#if HIP_FP8_TYPE_FNUZ
948 __FP8_HOST_DEVICE__ operator float() const {
949#else
950 __FP8_HOST__ operator float() const {
951#endif
952#if HIP_FP8_CVT_FAST_PATH
953 return internal::cast_to_f32_from_f8(__x, __default_interpret);
954#else
955 return internal::cast_from_f8<float, true>(__x, __wm, __we);
956#endif
957 }
958
960#if HIP_FP8_TYPE_FNUZ
961 __FP8_HOST_DEVICE__ operator int() const {
962#else
963 __FP8_HOST__ operator int() const {
964#endif
965 if (internal::hip_fp8_fnuz_is_nan(__x)) {
966 return 0;
967 }
968
969 float fval = *this;
970 return static_cast<int>(fval);
971 }
972
974#if HIP_FP8_TYPE_FNUZ
975 __FP8_HOST_DEVICE__ operator long int() const {
976#else
977 __FP8_HOST__ operator long int() const {
978#endif
979 if (internal::hip_fp8_fnuz_is_nan(__x)) {
980 return 0;
981 }
982
983 float fval = *this;
984 return static_cast<long>(fval);
985 }
986
988#if HIP_FP8_TYPE_FNUZ
989 __FP8_HOST_DEVICE__ operator long long int() const {
990#else
991 __FP8_HOST__ operator long long int() const {
992#endif
993 if (internal::hip_fp8_fnuz_is_nan(__x)) {
994 return 0;
995 }
996
997 float fval = *this;
998 return static_cast<long long>(fval);
999 }
1000
1002#if HIP_FP8_TYPE_FNUZ
1003 __FP8_HOST_DEVICE__ operator short int() const {
1004#else
1005 __FP8_HOST__ operator short int() const {
1006#endif
1007 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1008 return 0;
1009 }
1010
1011 float fval = *this;
1012 auto llval = static_cast<long long>(fval);
1013 if (llval <= SHRT_MIN) {
1014 return SHRT_MIN;
1015 } else if (llval >= SHRT_MAX) {
1016 return SHRT_MAX;
1017 }
1018 return static_cast<short>(fval);
1019 }
1020
1022#if HIP_FP8_TYPE_FNUZ
1023 __FP8_HOST_DEVICE__ operator signed char() const {
1024#else
1025 __FP8_HOST__ operator signed char() const {
1026#endif
1027 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1028 return 0;
1029 }
1030
1031 float fval = *this;
1032 auto llval = static_cast<long long>(fval);
1033 if (llval <= SCHAR_MIN) {
1034 return SCHAR_MIN;
1035 } else if (llval >= SCHAR_MAX) {
1036 return SCHAR_MAX;
1037 }
1038 return static_cast<signed char>(fval);
1039 }
1040
1042#if HIP_FP8_TYPE_FNUZ
1043 __FP8_HOST_DEVICE__ operator unsigned char() const {
1044#else
1045 __FP8_HOST__ operator unsigned char() const {
1046#endif
1047 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1048 return 0;
1049 }
1050
1051 float fval = *this;
1052 auto llval = static_cast<long long>(fval);
1053 if (llval <= 0) {
1054 return 0;
1055 } else if (llval >= UCHAR_MAX) {
1056 return UCHAR_MAX;
1057 }
1058 return static_cast<unsigned char>(fval);
1059 }
1060
1062#if HIP_FP8_TYPE_FNUZ
1063 __FP8_HOST_DEVICE__ operator unsigned int() const {
1064#else
1065 __FP8_HOST__ operator unsigned int() const {
1066#endif
1067 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1068 return 0;
1069 }
1070
1071 float fval = *this;
1072 auto llval = static_cast<long long>(fval);
1073 if (llval <= 0) {
1074 return 0;
1075 }
1076 return static_cast<unsigned int>(fval);
1077 }
1078
1080#if HIP_FP8_TYPE_FNUZ
1081 __FP8_HOST_DEVICE__ operator unsigned long int() const {
1082#else
1083 __FP8_HOST__ operator unsigned long int() const {
1084#endif
1085 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1086 return 0;
1087 }
1088
1089 float fval = *this;
1090 auto llval = static_cast<long long>(fval);
1091 if (llval <= 0) {
1092 return 0;
1093 }
1094 return static_cast<unsigned long>(fval);
1095 }
1096
1098#if HIP_FP8_TYPE_FNUZ
1099 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1100#else
1101 __FP8_HOST__ operator unsigned long long int() const {
1102#endif
1103 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1104 return 0;
1105 }
1106
1107 float fval = *this;
1108 auto llval = static_cast<long long>(fval);
1109 if (llval <= 0) {
1110 return 0;
1111 }
1112 return static_cast<unsigned long long>(fval);
1113 }
1114
1116#if HIP_FP8_TYPE_FNUZ
1117 __FP8_HOST_DEVICE__ operator unsigned short int() const {
1118#else
1119 __FP8_HOST__ operator unsigned short int() const {
1120#endif
1121 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1122 return 0;
1123 }
1124
1125 float fval = *this;
1126 auto llval = static_cast<long long>(fval);
1127 if (llval <= 0) {
1128 return 0;
1129 }
1130 return static_cast<unsigned short>(fval);
1131 }
1132};
1133
1140 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1141 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
1142 static constexpr unsigned int __we = 4;
1143 static constexpr unsigned int __wm = 3;
1144
1146#if HIP_FP8_TYPE_FNUZ
1147 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
1148#else
1149 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const double2 val)
1150#endif
1151 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1152
1154#if HIP_FP8_TYPE_FNUZ
1155 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
1156#else
1157 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const float2 val)
1158#endif
1159 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1160
1162#if HIP_FP8_TYPE_FNUZ
1163 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
1164#else
1165 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
1166#endif
1167 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1168
1170#if HIP_FP8_TYPE_FNUZ
1171 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
1172#else
1173 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
1174#endif
1175 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1176
1178#if HIP_FP8_TYPE_FNUZ
1179 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
1180#else
1181 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz() = default;
1182#endif
1183
1185#if HIP_FP8_TYPE_FNUZ
1186 __FP8_HOST_DEVICE__ operator __half2() const {
1187#else
1188 __FP8_HOST__ operator __half2() const {
1189#endif
1190 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1191 }
1192
1194#if HIP_FP8_TYPE_FNUZ
1195 __FP8_HOST_DEVICE__ operator float2() const {
1196#else
1197 __FP8_HOST__ operator float2() const {
1198#endif
1199#if HIP_FP8_CVT_FAST_PATH
1200 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1201#else
1202 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1203 __wm, __we),
1204 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1205 __wm, __we));
1206#endif
1207 }
1208};
1209
1216 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1217 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
1218 static constexpr unsigned int __we = 4;
1219 static constexpr unsigned int __wm = 3;
1220
1222#if HIP_FP8_TYPE_FNUZ
1223 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
1224#else
1225 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const double4 val)
1226#endif
1227 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
1228 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1229 val.x, __default_saturation, __default_interpret)) |
1230 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1231 val.y, __default_saturation, __default_interpret))
1232 << 8 |
1233 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1234 val.z, __default_saturation, __default_interpret))
1235 << 16 |
1236 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1237 val.w, __default_saturation, __default_interpret))
1238 << 24))} {}
1239
1241#if HIP_FP8_TYPE_FNUZ
1242 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
1243#else
1244 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const float4 val)
1245#endif
1246 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
1247 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1248 val.x, __default_saturation, __default_interpret)) |
1249 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1250 val.y, __default_saturation, __default_interpret))
1251 << 8 |
1252 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1253 val.z, __default_saturation, __default_interpret))
1254 << 16 |
1255 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1256 val.w, __default_saturation, __default_interpret))
1257 << 24))} {}
1258
1260#if HIP_FP8_TYPE_FNUZ
1261 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1262#else
1263 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1264#endif
1265 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1266 reinterpret_cast<unsigned short>(
1267 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1268 reinterpret_cast<unsigned short>(
1269 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1270 << 16))) {}
1271
1273#if HIP_FP8_TYPE_FNUZ
1274 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
1275#else
1276 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
1277#endif
1278 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1279 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1280 high, __default_saturation, __default_interpret)) |
1281 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1282 low, __default_saturation, __default_interpret))
1283 << 16))) {}
1284
1286#if HIP_FP8_TYPE_FNUZ
1287 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
1288#else
1289 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz() = default;
1290#endif
1291
1293#if HIP_FP8_TYPE_FNUZ
1294 __FP8_HOST_DEVICE__ operator float4() const {
1295#else
1296 __FP8_HOST__ operator float4() const {
1297#endif
1298 auto x = __x; // bypass const
1299 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1300 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1301#if HIP_FP8_CVT_FAST_PATH
1302 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1303 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1304#else
1305 float2 high = float2(internal::cast_from_f8<float, true>(
1306 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1307 internal::cast_from_f8<float, true>(
1308 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1309 float2 low = float2(internal::cast_from_f8<float, true>(
1310 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1311 internal::cast_from_f8<float, true>(
1312 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1313#endif
1314 return float4(low.x, low.y, high.x, high.y);
1315 }
1316};
1317
1324 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1325 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1326 static constexpr unsigned int __we = 5;
1327 static constexpr unsigned int __wm = 2;
1328
1329
1330 // TODO: SWDEV-452411
1331 // Add cast from unsigned long long, long long to fp8
1332
1334#if HIP_FP8_TYPE_FNUZ
1335 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
1336#else
1337 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const long int val)
1338#endif
1339 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1340 __default_interpret)) {}
1341
1343#if HIP_FP8_TYPE_FNUZ
1344 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
1345#else
1346 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const int val)
1347#endif
1348 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1349 __default_interpret)) {}
1350
1352#if HIP_FP8_TYPE_FNUZ
1353 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
1354#else
1355 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const short int val)
1356#endif
1357 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1358 __default_interpret)) {}
1359
1361#if HIP_FP8_TYPE_FNUZ
1362 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1363#else
1364 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1365#endif
1366 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1367 __default_interpret)) {}
1368
1370#if HIP_FP8_TYPE_FNUZ
1371 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1372#else
1373 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1374#endif
1375 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1376 __default_interpret)) {}
1377
1379#if HIP_FP8_TYPE_FNUZ
1380 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1381#else
1382 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1383#endif
1384 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1385 __default_interpret)) {}
1386
1388#if HIP_FP8_TYPE_FNUZ
1389 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
1390#else
1391 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const double f)
1392#endif
1393 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
1394
1396#if HIP_FP8_TYPE_FNUZ
1397 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
1398#else
1399 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const float f)
1400#endif
1401 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
1402
1404#if HIP_FP8_TYPE_FNUZ
1405 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1406#else
1407 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1408#endif
1409 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1410 __default_interpret)) {}
1411
1413#if HIP_FP8_TYPE_FNUZ
1414 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
1415#else
1416 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __half f)
1417#endif
1418 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
1419 __default_interpret)) {}
1420
1422#if HIP_FP8_TYPE_FNUZ
1423 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
1424#else
1425 __FP8_HOST__ __hip_fp8_e5m2_fnuz() = default;
1426#endif
1427
1429#if HIP_FP8_TYPE_FNUZ
1430 __FP8_HOST_DEVICE__ operator float() const {
1431#else
1432 __FP8_HOST__ operator float() const {
1433#endif
1434#if HIP_FP8_CVT_FAST_PATH
1435 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1436#else
1437 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1438#endif
1439 }
1440
1442#if HIP_FP8_TYPE_FNUZ
1443 __FP8_HOST_DEVICE__ operator __half() const {
1444#else
1445 __FP8_HOST__ operator __half() const {
1446#endif
1447 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1448 }
1449
1451#if HIP_FP8_TYPE_FNUZ
1452 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1453#else
1454 __FP8_HOST__ operator __hip_bfloat16() const {
1455#endif
1456 float f = *this;
1457 return __hip_bfloat16(f);
1458 }
1459
1461#if HIP_FP8_TYPE_FNUZ
1462 __FP8_HOST_DEVICE__ operator bool() const {
1463#else
1464 __FP8_HOST__ operator bool() const {
1465#endif
1466 // it can be 0x00 (+0.0) since 0x80 will be nan
1467 return !(static_cast<unsigned short>(__x) == 0);
1468 }
1469
1471#if HIP_FP8_TYPE_FNUZ
1472 __FP8_HOST_DEVICE__ operator char() const {
1473#else
1474 __FP8_HOST__ operator char() const {
1475#endif
1476 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1477 return 0;
1478 }
1479
1480 float fval = *this;
1481 auto llval = static_cast<long long>(fval);
1482 if (llval <= CHAR_MIN) {
1483 return CHAR_MIN;
1484 } else if (llval >= CHAR_MAX) {
1485 return CHAR_MAX;
1486 }
1487 return static_cast<char>(fval);
1488 }
1489
1491#if HIP_FP8_TYPE_FNUZ
1492 __FP8_HOST_DEVICE__ operator double() const {
1493#else
1494 __FP8_HOST__ operator double() const {
1495#endif
1496 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1497 }
1498
1500#if HIP_FP8_TYPE_FNUZ
1501 __FP8_HOST_DEVICE__ operator int() const {
1502#else
1503 __FP8_HOST__ operator int() const {
1504#endif
1505 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1506 return 0;
1507 }
1508
1509 float fval = *this;
1510 return static_cast<int>(fval);
1511 }
1512
1514#if HIP_FP8_TYPE_FNUZ
1515 __FP8_HOST_DEVICE__ operator long int() const {
1516#else
1517 __FP8_HOST__ operator long int() const {
1518#endif
1519 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1520 return 0;
1521 }
1522
1523 float fval = *this;
1524 return static_cast<long>(fval);
1525 }
1526
1528#if HIP_FP8_TYPE_FNUZ
1529 __FP8_HOST_DEVICE__ operator long long int() const {
1530#else
1531 __FP8_HOST__ operator long long int() const {
1532#endif
1533 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1534 return 0;
1535 }
1536
1537 float fval = *this;
1538 return static_cast<long long>(fval);
1539 }
1540
1542#if HIP_FP8_TYPE_FNUZ
1543 __FP8_HOST_DEVICE__ operator short int() const {
1544#else
1545 __FP8_HOST__ operator short int() const {
1546#endif
1547 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1548 return 0;
1549 }
1550
1551 float fval = *this;
1552 auto llval = static_cast<long long>(fval);
1553 if (llval <= SHRT_MIN) {
1554 return SHRT_MIN;
1555 } else if (llval >= SHRT_MAX) {
1556 return SHRT_MAX;
1557 }
1558 return static_cast<short>(fval);
1559 }
1560
1562#if HIP_FP8_TYPE_FNUZ
1563 __FP8_HOST_DEVICE__ operator signed char() const {
1564#else
1565 __FP8_HOST__ operator signed char() const {
1566#endif
1567 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1568 return 0;
1569 }
1570
1571 float fval = *this;
1572 auto llval = static_cast<long long>(fval);
1573 if (llval <= SCHAR_MIN) {
1574 return SCHAR_MIN;
1575 } else if (llval >= SCHAR_MAX) {
1576 return SCHAR_MAX;
1577 }
1578 return static_cast<signed char>(fval);
1579 }
1580
1582#if HIP_FP8_TYPE_FNUZ
1583 __FP8_HOST_DEVICE__ operator unsigned char() const {
1584#else
1585 __FP8_HOST__ operator unsigned char() const {
1586#endif
1587 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1588 return 0;
1589 }
1590
1591 float fval = *this;
1592 auto llval = static_cast<long long>(fval);
1593 if (llval <= 0) {
1594 return 0;
1595 } else if (llval >= UCHAR_MAX) {
1596 return UCHAR_MAX;
1597 }
1598 return static_cast<unsigned char>(fval);
1599 }
1600
1602#if HIP_FP8_TYPE_FNUZ
1603 __FP8_HOST_DEVICE__ operator unsigned int() const {
1604#else
1605 __FP8_HOST__ operator unsigned int() const {
1606#endif
1607 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1608 return 0;
1609 }
1610
1611 float fval = *this;
1612 auto llval = static_cast<long long>(fval);
1613 if (llval <= 0) {
1614 return 0;
1615 }
1616 return static_cast<unsigned int>(fval);
1617 }
1618
1620#if HIP_FP8_TYPE_FNUZ
1621 __FP8_HOST_DEVICE__ operator unsigned long int() const {
1622#else
1623 __FP8_HOST__ operator unsigned long int() const {
1624#endif
1625 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1626 return 0;
1627 }
1628
1629 float fval = *this;
1630 auto llval = static_cast<long long>(fval);
1631 if (llval <= 0) {
1632 return 0;
1633 }
1634 return static_cast<unsigned long>(fval);
1635 }
1636
1638#if HIP_FP8_TYPE_FNUZ
1639 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1640#else
1641 __FP8_HOST__ operator unsigned long long int() const {
1642#endif
1643 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1644 return 0;
1645 }
1646
1647 float fval = *this;
1648 auto llval = static_cast<long long>(fval);
1649 if (llval <= 0) {
1650 return 0;
1651 }
1652 return static_cast<unsigned long long>(fval);
1653 }
1654
1656#if HIP_FP8_TYPE_FNUZ
1657 __FP8_HOST_DEVICE__ operator unsigned short int() const {
1658#else
1659 __FP8_HOST__ operator unsigned short int() const {
1660#endif
1661 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1662 return 0;
1663 }
1664
1665 float fval = *this;
1666 auto llval = static_cast<long long>(fval);
1667 if (llval <= 0) {
1668 return 0;
1669 }
1670 return static_cast<unsigned short>(fval);
1671 }
1672};
1673
1680 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1681 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1682 static constexpr unsigned int __we = 5;
1683 static constexpr unsigned int __wm = 2;
1684
1686#if HIP_FP8_TYPE_FNUZ
1687 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1688#else
1689 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1690#endif
1691 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1692
1694#if HIP_FP8_TYPE_FNUZ
1695 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1696#else
1697 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1698#endif
1699 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1700
1702#if HIP_FP8_TYPE_FNUZ
1703 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1704#else
1705 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1706#endif
1707 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1708
1710#if HIP_FP8_TYPE_FNUZ
1711 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1712#else
1713 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1714#endif
1715 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1716
1718#if HIP_FP8_TYPE_FNUZ
1719 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
1720#else
1721 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz() = default;
1722#endif
1723
1725#if HIP_FP8_TYPE_FNUZ
1726 __FP8_HOST_DEVICE__ operator __half2() const {
1727#else
1728 __FP8_HOST__ operator __half2() const {
1729#endif
1730 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1731 }
1732
1734#if HIP_FP8_TYPE_FNUZ
1735 __FP8_HOST_DEVICE__ operator float2() const {
1736#else
1737 __FP8_HOST__ operator float2() const {
1738#endif
1739#if HIP_FP8_CVT_FAST_PATH
1740 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1741#else
1742 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1743 __wm, __we),
1744 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1745 __wm, __we));
1746#endif
1747 }
1748};
1749
1756 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1757 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1758 static constexpr unsigned int __we = 5;
1759 static constexpr unsigned int __wm = 2;
1760
1762#if HIP_FP8_TYPE_FNUZ
1763 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1764#else
1765 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1766#endif
1767 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1768 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1769 val.x, __default_saturation, __default_interpret)) |
1770 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1771 val.y, __default_saturation, __default_interpret))
1772 << 8 |
1773 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1774 val.z, __default_saturation, __default_interpret))
1775 << 16 |
1776 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1777 val.w, __default_saturation, __default_interpret))
1778 << 24))) {}
1779
1781#if HIP_FP8_TYPE_FNUZ
1782 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1783#else
1784 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1785#endif
1786 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1787 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1788 val.x, __default_saturation, __default_interpret)) |
1789 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1790 val.y, __default_saturation, __default_interpret))
1791 << 8 |
1792 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1793 val.z, __default_saturation, __default_interpret))
1794 << 16 |
1795 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1796 val.w, __default_saturation, __default_interpret))
1797 << 24))) {}
1798
1800#if HIP_FP8_TYPE_FNUZ
1801 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1802#else
1803 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1804#endif
1805 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1806 reinterpret_cast<unsigned short>(
1807 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1808 reinterpret_cast<unsigned short>(
1809 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1810 << 16))) {}
1811
1813#if HIP_FP8_TYPE_FNUZ
1814 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
1815#else
1816 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
1817#endif
1818 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1819 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1820 high, __default_saturation, __default_interpret)) |
1821 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1822 low, __default_saturation, __default_interpret))
1823 << 16))) {}
1824
1825 /* default construct fp8x4 e5m2 */
1826#if HIP_FP8_TYPE_FNUZ
1827 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
1828#else
1829 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz() = default;
1830#endif
1831
1833#if HIP_FP8_TYPE_FNUZ
1834 __FP8_HOST_DEVICE__ operator float4() const {
1835#else
1836 __FP8_HOST__ operator float4() const {
1837#endif
1838 auto x = __x; // bypass const
1839 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1840 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1841#if HIP_FP8_CVT_FAST_PATH
1842 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1843 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1844#else
1845 float2 high = float2(internal::cast_from_f8<float, true>(
1846 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1847 internal::cast_from_f8<float, true>(
1848 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1849 float2 low = float2(internal::cast_from_f8<float, true>(
1850 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1851 internal::cast_from_f8<float, true>(
1852 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1853#endif
1854 return float4(low.x, low.y, high.x, high.y);
1855 }
1856};
1857
1864 constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1865 constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
1866 constexpr static unsigned int __we = 4;
1867 constexpr static unsigned int __wm = 3;
1868
1869 // TODO: SWDEV-452411
1870 // Add cast from unsigned long long, long long to fp8
1871
1873#if HIP_FP8_TYPE_OCP
1874__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val)
1875#else
1876__FP8_HOST__ __hip_fp8_e4m3(const long int val)
1877#endif
1878 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1879 __default_interpret)) {}
1880
1882#if HIP_FP8_TYPE_OCP
1883__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val)
1884#else
1885__FP8_HOST__ __hip_fp8_e4m3(const int val)
1886#endif
1887 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1888 __default_interpret)) {}
1889
1891 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val)
1892 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1893 __default_interpret)) {}
1894
1896#if HIP_FP8_TYPE_OCP
1897__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val)
1898#else
1899__FP8_HOST__ __hip_fp8_e4m3(const unsigned long int val)
1900#endif
1901 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1902 __default_interpret)) {}
1903
1905#if HIP_FP8_TYPE_OCP
1906__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val)
1907#else
1908__FP8_HOST__ __hip_fp8_e4m3(const unsigned int val)
1909#endif
1910 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1911 __default_interpret)) {}
1912
1914#if HIP_FP8_TYPE_OCP
1915__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val)
1916#else
1917__FP8_HOST__ __hip_fp8_e4m3(const unsigned short int val)
1918#endif
1919 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1920 __default_interpret)) {}
1921
1923#if HIP_FP8_TYPE_OCP
1924__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f)
1925#else
1926__FP8_HOST__ __hip_fp8_e4m3(const double f)
1927#endif
1928 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
1929
1931#if HIP_FP8_TYPE_OCP
1932__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f)
1933#else
1934__FP8_HOST__ __hip_fp8_e4m3(const float f)
1935#endif
1936 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
1937
1939#if HIP_FP8_TYPE_OCP
1940__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f)
1941#else
1942__FP8_HOST__ __hip_fp8_e4m3(const __hip_bfloat16 f)
1943#endif
1944 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1945 __default_interpret)) {}
1946
1948#if HIP_FP8_TYPE_OCP
1949__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f)
1950#else
1951__FP8_HOST__ __hip_fp8_e4m3(const __half f)
1952#endif
1953 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
1954 __default_interpret)) {}
1955
1957#if HIP_FP8_TYPE_OCP
1958__FP8_HOST_DEVICE__ __hip_fp8_e4m3() = default;
1959#else
1960__FP8_HOST__ __hip_fp8_e4m3() = default;
1961#endif
1962
1965#if HIP_FP8_TYPE_OCP
1966__FP8_HOST_DEVICE__ operator __half() const {
1967#else
1968__FP8_HOST__ operator __half() const {
1969#endif
1970 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1971 }
1972
1974#if HIP_FP8_TYPE_OCP
1975__FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1976#else
1977__FP8_HOST__ operator __hip_bfloat16() const {
1978#endif
1979 float f = *this;
1980 return __hip_bfloat16(f);
1981 }
1982
1984#if HIP_FP8_TYPE_OCP
1985__FP8_HOST_DEVICE__ operator bool() const {
1986#else
1987__FP8_HOST__ operator bool() const {
1988#endif
1989 // it can be 0x00 (+0.0) since 0x80 will be nan
1990 return !(static_cast<unsigned short>(__x) == 0 || static_cast<unsigned short>(__x) == 0x80);
1991 }
1992
1994#if HIP_FP8_TYPE_OCP
1995__FP8_HOST_DEVICE__ operator char() const {
1996#else
1997__FP8_HOST__ operator char() const {
1998#endif
1999 if (internal::hip_fp8_ocp_is_nan(__x,__default_interpret)) {
2000 return 0;
2001 }
2002
2003 auto fval = internal::cast_from_f8<float, false>(__x, __wm, __we);
2004 auto llval = static_cast<long long>(fval);
2005 if (llval <= CHAR_MIN) {
2006 return CHAR_MIN;
2007 } else if (llval >= CHAR_MAX) {
2008 return CHAR_MAX;
2009 }
2010 return static_cast<char>(fval);
2011 }
2012
2014#if HIP_FP8_TYPE_OCP
2015__FP8_HOST_DEVICE__ operator double() const {
2016#else
2017__FP8_HOST__ operator double() const {
2018#endif
2019 return internal::cast_from_f8<double, false>(__x, __wm, __we);
2020 }
2021
2023#if HIP_FP8_TYPE_OCP
2024__FP8_HOST_DEVICE__ operator float() const {
2025#else
2026__FP8_HOST__ operator float() const {
2027#endif
2028#if HIP_FP8_CVT_FAST_PATH
2029 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2030#else
2031 return internal::cast_from_f8<float, false>(__x, __wm, __we);
2032#endif
2033 }
2034
2036#if HIP_FP8_TYPE_OCP
2037__FP8_HOST_DEVICE__ operator int() const {
2038#else
2039__FP8_HOST__ operator int() const {
2040#endif
2041 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2042 return 0;
2043 }
2044
2045 float fval = *this;
2046 return static_cast<int>(fval);
2047 }
2048
2050#if HIP_FP8_TYPE_OCP
2051__FP8_HOST_DEVICE__ operator long int() const {
2052#else
2053__FP8_HOST__ operator long int() const {
2054#endif
2055 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2056 return 0;
2057 }
2058
2059 float fval = *this;
2060 return static_cast<long>(fval);
2061 }
2062
2064#if HIP_FP8_TYPE_OCP
2065__FP8_HOST_DEVICE__ operator long long int() const {
2066#else
2067__FP8_HOST__ operator long long int() const {
2068#endif
2069 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2070 return 0;
2071 }
2072
2073 float fval = *this;
2074 return static_cast<long long>(fval);
2075 }
2076
2078#if HIP_FP8_TYPE_OCP
2079__FP8_HOST_DEVICE__ operator short int() const {
2080#else
2081__FP8_HOST__ operator short int() const {
2082#endif
2083 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2084 return 0;
2085 }
2086
2087 float fval = *this;
2088 auto llval = static_cast<long long>(fval);
2089 if (llval <= SHRT_MIN) {
2090 return SHRT_MIN;
2091 } else if (llval >= SHRT_MAX) {
2092 return SHRT_MAX;
2093 }
2094 return static_cast<short>(fval);
2095 }
2096
2098#if HIP_FP8_TYPE_OCP
2099__FP8_HOST_DEVICE__ operator signed char() const {
2100#else
2101__FP8_HOST__ operator signed char() const {
2102#endif
2103 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2104 return 0;
2105 }
2106
2107 float fval = *this;
2108 auto llval = static_cast<long long>(fval);
2109 if (llval <= SCHAR_MIN) {
2110 return SCHAR_MIN;
2111 } else if (llval >= SCHAR_MAX) {
2112 return SCHAR_MAX;
2113 }
2114 return static_cast<signed char>(fval);
2115 }
2116
2118#if HIP_FP8_TYPE_OCP
2119__FP8_HOST_DEVICE__ operator unsigned char() const {
2120#else
2121__FP8_HOST__ operator unsigned char() const {
2122#endif
2123 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2124 return 0;
2125 }
2126
2127 float fval = *this;
2128 auto llval = static_cast<long long>(fval);
2129 if (llval <= 0) {
2130 return 0;
2131 } else if (llval >= UCHAR_MAX) {
2132 return UCHAR_MAX;
2133 }
2134 return static_cast<unsigned char>(fval);
2135 }
2136
2138#if HIP_FP8_TYPE_OCP
2139__FP8_HOST_DEVICE__ operator unsigned int() const {
2140#else
2141__FP8_HOST__ operator unsigned int() const {
2142#endif
2143 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2144 return 0;
2145 }
2146
2147 float fval = *this;
2148 auto llval = static_cast<long long>(fval);
2149 if (llval <= 0) {
2150 return 0;
2151 }
2152 return static_cast<unsigned int>(fval);
2153 }
2154
2156#if HIP_FP8_TYPE_OCP
2157__FP8_HOST_DEVICE__ operator unsigned long int() const {
2158#else
2159__FP8_HOST__ operator unsigned long int() const {
2160#endif
2161 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2162 return 0;
2163 }
2164
2165 float fval = *this;
2166 auto llval = static_cast<long long>(fval);
2167 if (llval <= 0) {
2168 return 0;
2169 }
2170 return static_cast<unsigned long>(fval);
2171 }
2172
2174#if HIP_FP8_TYPE_OCP
2175__FP8_HOST_DEVICE__ operator unsigned long long int() const {
2176#else
2177__FP8_HOST__ operator unsigned long long int() const {
2178#endif
2179 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2180 return 0;
2181 }
2182
2183 float fval = *this;
2184 auto llval = static_cast<long long>(fval);
2185 if (llval <= 0) {
2186 return 0;
2187 }
2188 return static_cast<unsigned long long>(fval);
2189 }
2190
2192#if HIP_FP8_TYPE_OCP
2193__FP8_HOST_DEVICE__ operator unsigned short int() const {
2194#else
2195__FP8_HOST__ operator unsigned short int() const {
2196#endif
2197 if (internal::hip_fp8_ocp_is_nan(__x,__default_interpret)) {
2198 return 0;
2199 }
2200
2201 float fval = *this;
2202 auto llval = static_cast<long long>(fval);
2203 if (llval <= 0) {
2204 return 0;
2205 }
2206 return static_cast<unsigned short>(fval);
2207 }
2208};
2209
2216 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2217 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
2218 static constexpr unsigned int __we = 4;
2219 static constexpr unsigned int __wm = 3;
2220
2223#if HIP_FP8_TYPE_OCP
2224__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val)
2225#else
2226__FP8_HOST__ __hip_fp8x2_e4m3(const double2 val)
2227#endif
2228 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2229
2231#if HIP_FP8_TYPE_OCP
2232__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val)
2233#else
2234__FP8_HOST__ __hip_fp8x2_e4m3(const float2 val)
2235#endif
2236 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2237
2239#if HIP_FP8_TYPE_OCP
2240__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
2241#else
2242__FP8_HOST__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
2243#endif
2244 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2245
2247#if HIP_FP8_TYPE_OCP
2248__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val)
2249#else
2250__FP8_HOST__ __hip_fp8x2_e4m3(const __half2 val)
2251#endif
2252 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2253
2255#if HIP_FP8_TYPE_OCP
2256__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3() = default;
2257#else
2258__FP8_HOST__ __hip_fp8x2_e4m3() = default;
2259#endif
2260
2262#if HIP_FP8_TYPE_OCP
2263__FP8_HOST_DEVICE__ operator __half2() const {
2264#else
2265__FP8_HOST__ operator __half2() const {
2266#endif
2267 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
2268 }
2269
2271#if HIP_FP8_TYPE_OCP
2272__FP8_HOST_DEVICE__ operator float2() const {
2273#else
2274__FP8_HOST__ operator float2() const {
2275#endif
2276#if HIP_FP8_CVT_FAST_PATH
2277 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
2278#else
2279 return float2(internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we),
2280 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we));
2281#endif
2282 }
2283};
2284
2291 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2292 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
2293 static constexpr unsigned int __we = 4;
2294 static constexpr unsigned int __wm = 3;
2295
2298#if HIP_FP8_TYPE_OCP
2299__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val)
2300#else
2301__FP8_HOST__ __hip_fp8x4_e4m3(const double4 val)
2302#endif
2303 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
2304 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2305 val.x, __default_saturation, __default_interpret)) |
2306 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2307 val.y, __default_saturation, __default_interpret))
2308 << 8 |
2309 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2310 val.z, __default_saturation, __default_interpret))
2311 << 16 |
2312 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2313 val.w, __default_saturation, __default_interpret))
2314 << 24))} {}
2315
2317#if HIP_FP8_TYPE_OCP
2318__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val)
2319#else
2320__FP8_HOST__ __hip_fp8x4_e4m3(const float4 val)
2321#endif
2322 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
2323 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2324 val.x, __default_saturation, __default_interpret)) |
2325 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2326 val.y, __default_saturation, __default_interpret))
2327 << 8 |
2328 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2329 val.z, __default_saturation, __default_interpret))
2330 << 16 |
2331 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2332 val.w, __default_saturation, __default_interpret))
2333 << 24))} {}
2334
2336#if HIP_FP8_TYPE_OCP
2337__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
2338#else
2339__FP8_HOST__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
2340#endif
2341 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
2342 reinterpret_cast<unsigned short>(
2343 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
2344 reinterpret_cast<unsigned short>(
2345 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
2346 << 16))) {}
2347
2349#if HIP_FP8_TYPE_OCP
2350__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
2351#else
2352__FP8_HOST__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
2353#endif
2354 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2355 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2356 high, __default_saturation, __default_interpret)) |
2357 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2358 low, __default_saturation, __default_interpret))
2359 << 16))) {}
2360
2362#if HIP_FP8_TYPE_OCP
2363__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3() = default;
2364#else
2365__FP8_HOST__ __hip_fp8x4_e4m3() = default;
2366#endif
2367
2369#if HIP_FP8_TYPE_OCP
2370__FP8_HOST_DEVICE__ operator float4() const {
2371#else
2372__FP8_HOST__ operator float4() const {
2373#endif
2374 auto x = __x; // bypass const
2375 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
2376 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
2377#if HIP_FP8_CVT_FAST_PATH
2378 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2379 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2380#else
2381 float2 high = float2(internal::cast_from_f8<float, false>(
2382 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
2383 internal::cast_from_f8<float, false>(
2384 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
2385 float2 low = float2(internal::cast_from_f8<float, false>(
2386 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
2387 internal::cast_from_f8<float, false>(
2388 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
2389#endif
2390 return float4(low.x, low.y, high.x, high.y);
2391 }
2392};
2393
2400 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2401 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
2402 static constexpr unsigned int __we = 5;
2403 static constexpr unsigned int __wm = 2;
2404
2405
2406 // TODO: SWDEV-452411
2407 // Add cast from unsigned long long, long long to fp8
2408
2411#if HIP_FP8_TYPE_OCP
2412__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val)
2413#else
2414__FP8_HOST__ __hip_fp8_e5m2(const long int val)
2415#endif
2416 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2417 __default_interpret)) {}
2418
2420#if HIP_FP8_TYPE_OCP
2421__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val)
2422#else
2423__FP8_HOST__ __hip_fp8_e5m2(const int val)
2424#endif
2425 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2426 __default_interpret)) {}
2427
2429#if HIP_FP8_TYPE_OCP
2430__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val)
2431#else
2432__FP8_HOST__ __hip_fp8_e5m2(const short int val)
2433#endif
2434 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2435 __default_interpret)) {}
2436
2438#if HIP_FP8_TYPE_OCP
2439__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val)
2440#else
2441__FP8_HOST__ __hip_fp8_e5m2(const unsigned long int val)
2442#endif
2443 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2444 __default_interpret)) {}
2445
2447#if HIP_FP8_TYPE_OCP
2448__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val)
2449#else
2450__FP8_HOST__ __hip_fp8_e5m2(const unsigned int val)
2451#endif
2452 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2453 __default_interpret)) {}
2454
2456#if HIP_FP8_TYPE_OCP
2457__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val)
2458#else
2459__FP8_HOST__ __hip_fp8_e5m2(const unsigned short int val)
2460#endif
2461 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2462 __default_interpret)) {}
2463
2465#if HIP_FP8_TYPE_OCP
2466__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f)
2467#else
2468__FP8_HOST__ __hip_fp8_e5m2(const double f)
2469#endif
2470 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
2471
2473#if HIP_FP8_TYPE_OCP
2474__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f)
2475#else
2476__FP8_HOST__ __hip_fp8_e5m2(const float f)
2477#endif
2478 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
2479
2481#if HIP_FP8_TYPE_OCP
2482__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f)
2483#else
2484__FP8_HOST__ __hip_fp8_e5m2(const __hip_bfloat16 f)
2485#endif
2486 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
2487 __default_interpret)) {}
2488
2490#if HIP_FP8_TYPE_OCP
2491__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f)
2492#else
2493__FP8_HOST__ __hip_fp8_e5m2(const __half f)
2494#endif
2495 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
2496 __default_interpret)) {}
2497
2499#if HIP_FP8_TYPE_OCP
2500__FP8_HOST_DEVICE__ __hip_fp8_e5m2() = default;
2501#else
2502__FP8_HOST__ __hip_fp8_e5m2() = default;
2503#endif
2504
2506#if HIP_FP8_TYPE_OCP
2507__FP8_HOST_DEVICE__ operator float() const {
2508#else
2509__FP8_HOST__ operator float() const {
2510#endif
2511#if HIP_FP8_CVT_FAST_PATH
2512 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2513#else
2514 return internal::cast_from_f8<float, false>(__x, __wm, __we, __default_saturation == __HIP_SATFINITE);
2515#endif
2516 }
2517
2519#if HIP_FP8_TYPE_OCP
2520__FP8_HOST_DEVICE__ operator __half() const {
2521#else
2522__FP8_HOST__ operator __half() const {
2523#endif
2524 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
2525 }
2526
2528#if HIP_FP8_TYPE_OCP
2529__FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
2530#else
2531__FP8_HOST__ operator __hip_bfloat16() const {
2532#endif
2533 float f = *this;
2534 return __hip_bfloat16(f);
2535 }
2536
2538#if HIP_FP8_TYPE_OCP
2539__FP8_HOST_DEVICE__ operator bool() const {
2540#else
2541__FP8_HOST__ operator bool() const {
2542#endif
2543 // it can be 0x00 (+0.0) since 0x80 will be nan
2544 return !(static_cast<unsigned short>(__x) == 0 || static_cast<unsigned short>(__x) == 0x80);
2545 }
2546
2548#if HIP_FP8_TYPE_OCP
2549__FP8_HOST_DEVICE__ operator char() const {
2550#else
2551__FP8_HOST__ operator char() const {
2552#endif
2553 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2554 return 0;
2555 }
2556
2557 float fval = *this;
2558 auto llval = static_cast<long long>(fval);
2559 if (llval <= CHAR_MIN) {
2560 return CHAR_MIN;
2561 } else if (llval >= CHAR_MAX) {
2562 return CHAR_MAX;
2563 }
2564 return static_cast<char>(fval);
2565 }
2566
2568#if HIP_FP8_TYPE_OCP
2569__FP8_HOST_DEVICE__ operator double() const {
2570#else
2571__FP8_HOST__ operator double() const {
2572#endif
2573 return internal::cast_from_f8<double, false>(__x, __wm, __we, __default_saturation == __HIP_SATFINITE);
2574 }
2575
2577#if HIP_FP8_TYPE_OCP
2578__FP8_HOST_DEVICE__ operator int() const {
2579#else
2580__FP8_HOST__ operator int() const {
2581#endif
2582 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2583 return 0;
2584 }
2585
2586 float fval = *this;
2587 return static_cast<int>(fval);
2588 }
2589
2591#if HIP_FP8_TYPE_OCP
2592__FP8_HOST_DEVICE__ operator long int() const {
2593#else
2594__FP8_HOST__ operator long int() const {
2595#endif
2596 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2597 return 0;
2598 }
2599
2600 float fval = *this;
2601 return static_cast<long>(fval);
2602 }
2603
2605#if HIP_FP8_TYPE_OCP
2606__FP8_HOST_DEVICE__ operator long long int() const {
2607#else
2608__FP8_HOST__ operator long long int() const {
2609#endif
2610 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2611 return 0;
2612 }
2613
2614 float fval = *this;
2615 return static_cast<long long>(fval);
2616 }
2617
2619#if HIP_FP8_TYPE_OCP
2620__FP8_HOST_DEVICE__ operator short int() const {
2621#else
2622__FP8_HOST__ operator short int() const {
2623#endif
2624 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2625 return 0;
2626 }
2627
2628 float fval = *this;
2629 auto llval = static_cast<long long>(fval);
2630 if (llval <= SHRT_MIN) {
2631 return SHRT_MIN;
2632 } else if (llval >= SHRT_MAX) {
2633 return SHRT_MAX;
2634 }
2635 return static_cast<short>(fval);
2636 }
2637
2639#if HIP_FP8_TYPE_OCP
2640__FP8_HOST_DEVICE__ operator signed char() const {
2641#else
2642__FP8_HOST__ operator signed char() const {
2643#endif
2644 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2645 return 0;
2646 }
2647
2648 float fval = *this;
2649 auto llval = static_cast<long long>(fval);
2650 if (llval <= SCHAR_MIN) {
2651 return SCHAR_MIN;
2652 } else if (llval >= SCHAR_MAX) {
2653 return SCHAR_MAX;
2654 }
2655 return static_cast<signed char>(fval);
2656 }
2657
2659#if HIP_FP8_TYPE_OCP
2660__FP8_HOST_DEVICE__ operator unsigned char() const {
2661#else
2662__FP8_HOST__ operator unsigned char() const {
2663#endif
2664 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2665 return 0;
2666 }
2667
2668 float fval = *this;
2669 auto llval = static_cast<long long>(fval);
2670 if (llval <= 0) {
2671 return 0;
2672 } else if (llval >= UCHAR_MAX) {
2673 return UCHAR_MAX;
2674 }
2675 return static_cast<unsigned char>(fval);
2676 }
2677
2679#if HIP_FP8_TYPE_OCP
2680__FP8_HOST_DEVICE__ operator unsigned int() const {
2681#else
2682__FP8_HOST__ operator unsigned int() const {
2683#endif
2684 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2685 return 0;
2686 }
2687
2688 float fval = *this;
2689 auto llval = static_cast<long long>(fval);
2690 if (llval <= 0) {
2691 return 0;
2692 }
2693 return static_cast<unsigned int>(fval);
2694 }
2695
2697#if HIP_FP8_TYPE_OCP
2698__FP8_HOST_DEVICE__ operator unsigned long int() const {
2699#else
2700__FP8_HOST__ operator unsigned long int() const {
2701#endif
2702 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2703 return 0;
2704 }
2705
2706 float fval = *this;
2707 auto llval = static_cast<long long>(fval);
2708 if (llval <= 0) {
2709 return 0;
2710 }
2711 return static_cast<unsigned long>(fval);
2712 }
2713
2715#if HIP_FP8_TYPE_OCP
2716__FP8_HOST_DEVICE__ operator unsigned long long int() const {
2717#else
2718__FP8_HOST__ operator unsigned long long int() const {
2719#endif
2720 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2721 return 0;
2722 }
2723
2724 float fval = *this;
2725 auto llval = static_cast<long long>(fval);
2726 if (llval <= 0) {
2727 return 0;
2728 }
2729 return static_cast<unsigned long long>(fval);
2730 }
2731
2733#if HIP_FP8_TYPE_OCP
2734__FP8_HOST_DEVICE__ operator unsigned short int() const {
2735#else
2736__FP8_HOST__ operator unsigned short int() const {
2737#endif
2738 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2739 return 0;
2740 }
2741
2742 float fval = *this;
2743 auto llval = static_cast<long long>(fval);
2744 if (llval <= 0) {
2745 return 0;
2746 }
2747 return static_cast<unsigned short>(fval);
2748 }
2749};
2750
2757 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2758 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
2759 static constexpr unsigned int __we = 5;
2760 static constexpr unsigned int __wm = 2;
2761
2764#if HIP_FP8_TYPE_OCP
2765__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val)
2766#else
2767__FP8_HOST__ __hip_fp8x2_e5m2(const double2 val)
2768#endif
2769 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2770
2772#if HIP_FP8_TYPE_OCP
2773__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val)
2774#else
2775__FP8_HOST__ __hip_fp8x2_e5m2(const float2 val)
2776#endif
2777 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2778
2780#if HIP_FP8_TYPE_OCP
2781__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
2782#else
2783__FP8_HOST__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
2784#endif
2785 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2786
2788#if HIP_FP8_TYPE_OCP
2789__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val)
2790#else
2791__FP8_HOST__ __hip_fp8x2_e5m2(const __half2 val)
2792#endif
2793 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
2794
2796#if HIP_FP8_TYPE_OCP
2797__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2() = default;
2798#else
2799__FP8_HOST__ __hip_fp8x2_e5m2() = default;
2800#endif
2801
2803#if HIP_FP8_TYPE_OCP
2804__FP8_HOST_DEVICE__ operator __half2() const {
2805#else
2806__FP8_HOST__ operator __half2() const {
2807#endif
2808 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
2809 }
2810
2812#if HIP_FP8_TYPE_OCP
2813__FP8_HOST_DEVICE__ operator float2() const {
2814#else
2815__FP8_HOST__ operator float2() const {
2816#endif
2817#if HIP_FP8_CVT_FAST_PATH
2818 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
2819#else
2820 return float2(internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm, __we, __default_saturation == __HIP_SATFINITE),
2821 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE));
2822#endif
2823 }
2824};
2825
2832 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2833 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
2834 static constexpr unsigned int __we = 5;
2835 static constexpr unsigned int __wm = 2;
2836
2838#if HIP_FP8_TYPE_OCP
2839__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val)
2840#else
2841__FP8_HOST__ __hip_fp8x4_e5m2(const double4 val)
2842#endif
2843 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2844 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2845 val.x, __default_saturation, __default_interpret)) |
2846 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2847 val.y, __default_saturation, __default_interpret))
2848 << 8 |
2849 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2850 val.z, __default_saturation, __default_interpret))
2851 << 16 |
2852 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2853 val.w, __default_saturation, __default_interpret))
2854 << 24))) {}
2855
2857#if HIP_FP8_TYPE_OCP
2858__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val)
2859#else
2860__FP8_HOST__ __hip_fp8x4_e5m2(const float4 val)
2861#endif
2862 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2863 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2864 val.x, __default_saturation, __default_interpret)) |
2865 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2866 val.y, __default_saturation, __default_interpret))
2867 << 8 |
2868 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2869 val.z, __default_saturation, __default_interpret))
2870 << 16 |
2871 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2872 val.w, __default_saturation, __default_interpret))
2873 << 24))) {}
2874
2876#if HIP_FP8_TYPE_OCP
2877__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
2878#else
2879__FP8_HOST__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
2880#endif
2881 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
2882 reinterpret_cast<unsigned short>(
2883 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
2884 reinterpret_cast<unsigned short>(
2885 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
2886 << 16))) {}
2887
2889#if HIP_FP8_TYPE_OCP
2890__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
2891#else
2892__FP8_HOST__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
2893#endif
2894 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2895 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2896 high, __default_saturation, __default_interpret)) |
2897 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2898 low, __default_saturation, __default_interpret))
2899 << 16))) {}
2900
2901 /* default construct fp8x4 e5m2 */
2902#if HIP_FP8_TYPE_OCP
2903__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2() = default;
2904#else
2905__FP8_HOST__ __hip_fp8x4_e5m2() = default;
2906#endif
2907
2909#if HIP_FP8_TYPE_OCP
2910__FP8_HOST_DEVICE__ operator float4() const {
2911#else
2912__FP8_HOST__ operator float4() const {
2913#endif
2914 auto x = __x; // bypass const
2915 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
2916 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
2917#if HIP_FP8_CVT_FAST_PATH
2918 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2919 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2920#else
2921 float2 high = float2(internal::cast_from_f8<float, false>(
2922 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE),
2923 internal::cast_from_f8<float, false>(
2924 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE));
2925 float2 low = float2(internal::cast_from_f8<float, false>(
2926 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE),
2927 internal::cast_from_f8<float, false>(
2928 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we, __default_saturation == __HIP_SATFINITE));
2929#endif
2930 return float4(low.x, low.y, high.x, high.y);
2931 }
2932};
2933#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
hip_bf16.h provides struct for __hip_bfloat16 types
__hip_saturation_t
Describes saturation behavior.
Definition amd_hip_fp8.h:96
@ __HIP_SATFINITE
Definition amd_hip_fp8.h:98
@ __HIP_NOSAT
Definition amd_hip_fp8.h:97
__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8x2_storage_t to __half2_raw
Definition amd_hip_fp8.h:745
__hip_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_hip_fp8.h:86
@ __HIP_E4M3_FNUZ
Definition amd_hip_fp8.h:89
@ __HIP_E5M2
Definition amd_hip_fp8.h:88
@ __HIP_E4M3
Definition amd_hip_fp8.h:87
@ __HIP_E5M2_FNUZ
Definition amd_hip_fp8.h:90
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:679
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double to __hip_fp8_storage_t
Definition amd_hip_fp8.h:657
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half2_raw to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:774
unsigned short int __hip_fp8x2_storage_t
type to store two fp8 numbers
Definition amd_hip_fp8.h:112
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:761
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float to __hip_fp8_storage_t
Definition amd_hip_fp8.h:611
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __hip_bfloat16_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:695
unsigned int __hip_fp8x4_storage_t
type to store four fp8 numbers
Definition amd_hip_fp8.h:119
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:710
unsigned char __hip_fp8_storage_t
type to store single fp8 number
Definition amd_hip_fp8.h:105
__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8_storage_t to __half_raw
Definition amd_hip_fp8.h:724
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:638
struct representing single fp8 number with e4m3 interpretation
Definition amd_hip_fp8.h:783
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:831
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
Definition amd_hip_fp8.h:849
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
Definition amd_hip_fp8.h:874
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
Definition amd_hip_fp8.h:813
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
Definition amd_hip_fp8.h:857
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:865
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
Definition amd_hip_fp8.h:795
static constexpr __hip_saturation_t __default_saturation
raw storage of fp8 number
Definition amd_hip_fp8.h:785
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:822
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:840
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
Definition amd_hip_fp8.h:804
struct representing two fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1138
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
Definition amd_hip_fp8.h:1155
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
Definition amd_hip_fp8.h:1147
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1171
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1163
struct representing four fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1214
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1261
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
Definition amd_hip_fp8.h:1223
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:1274
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
Definition amd_hip_fp8.h:1242
struct representing one fp8 number with e5m2 interpretation
Definition amd_hip_fp8.h:1322
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1380
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1405
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
Definition amd_hip_fp8.h:1335
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:1371
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
Definition amd_hip_fp8.h:1414
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
Definition amd_hip_fp8.h:1344
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
Definition amd_hip_fp8.h:1389
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
Definition amd_hip_fp8.h:1353
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
Definition amd_hip_fp8.h:1397
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:1362
struct representing two fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1678
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
Definition amd_hip_fp8.h:1695
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1711
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1703
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
Definition amd_hip_fp8.h:1687
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz()=default
struct representing four fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1754
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1801
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
Definition amd_hip_fp8.h:1782
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:1814
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
Definition amd_hip_fp8.h:1763
struct representing ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1862
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1940
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val)
Definition amd_hip_fp8.h:1874
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val)
Definition amd_hip_fp8.h:1915
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f)
Definition amd_hip_fp8.h:1932
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val)
Definition amd_hip_fp8.h:1891
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f)
Definition amd_hip_fp8.h:1949
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val)
Definition amd_hip_fp8.h:1897
__FP8_HOST_DEVICE__ __hip_fp8_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val)
Definition amd_hip_fp8.h:1883
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val)
Definition amd_hip_fp8.h:1906
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f)
Definition amd_hip_fp8.h:1924
struct representing two ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2214
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val)
Definition amd_hip_fp8.h:2232
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val)
Definition amd_hip_fp8.h:2248
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:2240
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val)
Definition amd_hip_fp8.h:2224
struct representing four ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2289
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val)
Definition amd_hip_fp8.h:2299
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val)
Definition amd_hip_fp8.h:2318
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2350
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:2337
struct representing ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2398
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val)
Definition amd_hip_fp8.h:2421
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val)
Definition amd_hip_fp8.h:2430
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val)
Definition amd_hip_fp8.h:2448
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f)
Definition amd_hip_fp8.h:2474
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val)
Definition amd_hip_fp8.h:2457
__FP8_HOST_DEVICE__ __hip_fp8_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val)
Definition amd_hip_fp8.h:2439
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val)
Definition amd_hip_fp8.h:2412
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f)
Definition amd_hip_fp8.h:2466
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:2482
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f)
Definition amd_hip_fp8.h:2491
struct representing two ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2755
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val)
Definition amd_hip_fp8.h:2789
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val)
Definition amd_hip_fp8.h:2765
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val)
Definition amd_hip_fp8.h:2773
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:2781
struct representing four ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2830
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:2877
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val)
Definition amd_hip_fp8.h:2839
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2890
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val)
Definition amd_hip_fp8.h:2858
Definition amd_hip_vector_types.h:2035
Definition amd_hip_vector_types.h:2042
Definition amd_hip_vector_types.h:2072
Definition amd_hip_vector_types.h:2079
Definition hip_fp16_gcc.h:7
Definition hip_fp16_gcc.h:11