HIP: Heterogenous-computing Interface for Portability
amd_hip_fp8.h
Go to the documentation of this file.
1 
30 #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31 #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
32 
33 #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
34 #define HIP_FP8_CVT_FAST_PATH 1
35 #else
36 #define HIP_FP8_CVT_FAST_PATH 0
37 #endif
38 
39 #if !defined(__HIPCC_RTC__)
40 #include <hip/amd_detail/amd_hip_common.h>
41 #include <climits>
42 
43 #include "host_defines.h" // __hip_internal::
44 #include "amd_hip_vector_types.h" // float2 etc
45 #include "amd_hip_fp16.h" // __half_raw
46 #include "amd_hip_bf16.h" // bf16
47 #include "math_fwd.h" // ocml device functions
48 #endif // !defined(__HIPCC_RTC__)
49 
50 #if defined(__HIPCC_RTC__)
51 #define __FP8_HOST_DEVICE__ __device__
52 #define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
53 #else
54 #define __FP8_HOST_DEVICE__ __host__ __device__
55 #define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
56 #endif // __HIPCC_RTC__
57 
58 #if !defined(__HIPCC_RTC__)
59 static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
60 #endif
61 static_assert(sizeof(unsigned char) == 1);
62 static_assert(sizeof(unsigned short int) == 2);
63 static_assert(sizeof(unsigned int) == 4);
64 
71 };
72 
79 };
80 
85 typedef unsigned char __hip_fp8_storage_t;
86 
87 
92 typedef unsigned short int __hip_fp8x2_storage_t;
93 
94 
99 typedef unsigned int __hip_fp8x4_storage_t;
100 
101 namespace internal {
102 // The conversion function is from rocblas
103 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
104 // This has been modified to add double types conversion as well
105 template <typename T, bool negative_zero_nan>
106 __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
107  bool stoch = false,
108  unsigned int rng = 0) {
109  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
110  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
111  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
112  static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
113 
114  const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
115  unsigned long long x;
116 
117  if (sizeof(T) == 8)
118  x = reinterpret_cast<unsigned long long&>(_x);
119  else if (sizeof(T) == 4)
120  x = reinterpret_cast<unsigned int&>(_x);
121  else
122  x = reinterpret_cast<unsigned short int&>(_x);
123 
124 
125  unsigned long long head, mantissa;
126  int exponent, bias;
127  unsigned int sign;
128 
129  if (sizeof(T) == 8) {
130  head = x & 0xFFF0000000000000ull;
131  mantissa = x & 0xFFFFFFFFFFFFFull;
132  exponent = (head >> 52) & 0x7FF;
133  sign = head >> 63;
134  bias = 1023;
135  } else if (sizeof(T) == 4) {
136  head = x & 0xFF800000;
137  mantissa = x & 0x7FFFFF;
138  exponent = (head >> 23) & 0xFF;
139  sign = head >> 31;
140  bias = 127;
141  } else {
142  head = x & 0xFC00;
143  mantissa = x & 0x3FF;
144  exponent = (head >> 10) & 0x1F;
145  sign = head >> 15;
146  bias = 15;
147  }
148 
149  unsigned int signed_inf = (sign << 7) + (((1 << we) - 1) << wm);
150 
151  // Deal with inf and NaNs
152  if (negative_zero_nan) {
153  if (sizeof(T) == 8) {
154  if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull) return 0x80;
155  } else if (sizeof(T) == 4) {
156  if ((x & 0x7F800000) == 0x7F800000) return 0x80;
157  } else {
158  if ((x & 0x7C00) == 0x7C00) return 0x80;
159  }
160  } else {
161  if (sizeof(T) == 8) {
162  if ((x & 0x7FF0000000000000ull) == 0x7FF0000000000000ull)
163  return signed_inf + (mantissa != 0 ? 1 : 0);
164  } else if (sizeof(T) == 4) {
165  if ((x & 0x7F800000) == 0x7F800000) return signed_inf + (mantissa != 0 ? 1 : 0);
166  } else {
167  if ((x & 0x7C00) == 0x7C00) return signed_inf + (mantissa != 0 ? 1 : 0);
168  }
169  }
170 
171  if (x == 0) {
172  return 0;
173  }
174 
175  // First need to check if it is normal or denorm as there is a difference of implict 1
176  // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
177  // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
178  // RNE, no need to add rng. Then probably need to check whether there is carry and adjust
179  // exponent and mantissa again
180 
181  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
182  const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0);
183  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
184  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
185  // f8_exponent is the converted f8 exponent with bias encoding
186  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
187  // the difference needs to be adjusted and mantissa shifted
188  int act_exponent, f8_exponent, exponent_diff;
189 
190  if (exponent == 0) { // fp32/fp16 is in denormal.
191  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
192 here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
193 exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
194 fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
195 where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
196 this case, the fp16 mantissa should be shift left by 1 */
197  act_exponent = exponent - bias + 1;
198  exponent_diff = f8_denormal_act_exponent -
199  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
200  } else { // fp32/fp16 is normal with implicit 1
201  act_exponent = exponent - bias;
202  if (act_exponent <= f8_denormal_act_exponent) {
203  /* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
204 For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
205 actual exponent is -7, it is actually larger due to the implict 1,
206 Therefore it needs to be adjust to -6 and mantissa shift right by 1.
207 So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
208  exponent_diff = f8_denormal_act_exponent - act_exponent;
209  } else { // both fp32/fp16 and f8 are in normal range
210  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
211  // act_exponent could be larger. Just that it does not need shift mantissa
212  }
213  mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
214  }
215 
216  bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
217  (1ull << (mfmt - wm + exponent_diff - 1));
218  /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
219 right as shift right could rip off some residual part and make something not midpoint look like
220 midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
221 after shift right by 4 bits, it would look like midpoint.
222 */
223 
224  if (exponent_diff > 0)
225  mantissa >>= exponent_diff;
226  else if (exponent_diff == -1)
227  mantissa <<= -exponent_diff;
228  bool implicit_one = mantissa & (1ull << mfmt);
229  // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
230  f8_exponent =
231  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
232 
233  // Now we have the exponent and mantissa adjusted
234  unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
235  bool odd =
236  mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
237  mantissa +=
238  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
239 
240  // Now we deal with overflow
241  if (f8_exponent == 0) {
242  if ((1ull << mfmt) & mantissa) {
243  f8_exponent = 1; // denormal overflow to become normal, promote exponent
244  }
245  } else {
246  if ((1ull << (mfmt + 1)) & mantissa) {
247  mantissa >>= 1;
248  f8_exponent++;
249  }
250  }
251 
252  mantissa >>= (mfmt - wm);
253 
254  // above range: quantize to maximum possible float of the same sign
255  const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2);
256  if (f8_exponent > max_exp) {
257  if (clip) {
258  mantissa = (1 << wm) - 1;
259  f8_exponent = max_exp;
260  } else {
261  return signed_inf;
262  }
263  }
264 
265  if (f8_exponent == 0 && mantissa == 0) return negative_zero_nan ? 0 : (sign << 7);
266  mantissa &= (1 << wm) - 1;
267  return (sign << 7) | (f8_exponent << wm) | mantissa;
268 }
269 
270 // The conversion function is from rocblas
271 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
272 // This has been modified to handle double types as well
273 template <typename T, bool negative_zero_nan>
274 __FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we) {
275  constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
276  constexpr bool is_float = __hip_internal::is_same<T, float>::value;
277  constexpr bool is_double = __hip_internal::is_same<T, double>::value;
278  static_assert(is_half || is_float || is_double, "only half, float and double are supported");
279 
280  constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
281  constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
282 
283  T fInf, fNegInf, fNaN, fNeg0;
284  if (is_half) {
285  const unsigned short int ihInf = 0x7C00;
286  const unsigned short int ihNegInf = 0xFC00;
287  const unsigned short int ihNaN = 0x7C01;
288  const unsigned short int ihNeg0 = 0x8000;
289  fInf = reinterpret_cast<const _Float16&>(ihInf);
290  fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
291  fNaN = reinterpret_cast<const _Float16&>(ihNaN);
292  fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
293  } else if (is_float) {
294  const unsigned int ifInf = 0x7F800000;
295  const unsigned int ifNegInf = 0xFF800000;
296  const unsigned int ifNaN = 0x7F800001;
297  const unsigned int ifNeg0 = 0x80000000;
298  fInf = reinterpret_cast<const float&>(ifInf);
299  fNegInf = reinterpret_cast<const float&>(ifNegInf);
300  fNaN = reinterpret_cast<const float&>(ifNaN);
301  fNeg0 = reinterpret_cast<const float&>(ifNeg0);
302  } else if (is_double) {
303  const unsigned long long ifInf = 0x7FF0000000000000ull;
304  const unsigned long long ifNegInf = 0xFFF0000000000000ull;
305  const unsigned long long ifNaN = 0x7FF0000000000001ull;
306  const unsigned long long ifNeg0 = 0x8000000000000000ull;
307  fInf = reinterpret_cast<const double&>(ifInf);
308  fNegInf = reinterpret_cast<const double&>(ifNegInf);
309  fNaN = reinterpret_cast<const double&>(ifNaN);
310  fNeg0 = reinterpret_cast<const double&>(ifNeg0);
311  }
312 
313  if (x == 0) {
314  return 0;
315  }
316 
317  unsigned long long sign = x >> 7;
318  unsigned long long mantissa = x & ((1 << wm) - 1);
319  int exponent = (x & 0x7F) >> wm;
320  if (negative_zero_nan) {
321  if (x == 0x80) return fNaN;
322  } else {
323  if (x == 0x80) return fNeg0;
324  if (exponent == ((1 << we) - 1)) return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
325  }
326 
327  typename __hip_internal::conditional<
328  sizeof(T) == 2, unsigned short int,
329  typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
330  unsigned long long>::type>::type retval;
331 
332  if (we == 5 && is_half && !negative_zero_nan) {
333  retval = x << 8;
334  return reinterpret_cast<const T&>(retval);
335  }
336 
337  const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0);
338 
339  // subnormal input
340  if (exponent == 0) {
341 #if __HIP_DEVICE_COMPILE__
342  // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
343  int sh = 1 + __clz(mantissa) - (32 - wm);
344 #else
345  int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
346 #endif
347  mantissa <<= sh;
348  exponent += 1 - sh;
349  mantissa &= ((1ull << wm) - 1);
350  }
351  exponent += exp_low_cutoff - 1;
352  mantissa <<= wmo - wm;
353 
354  // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
355  if (exponent <= 0) {
356  mantissa |= 1 << wmo;
357  mantissa >>= 1 - exponent;
358  exponent = 0;
359  }
360 
361  if (sizeof(T) == 2)
362  retval = (sign << 15) | (exponent << 10) | mantissa;
363  else if (sizeof(T) == 4)
364  retval = (sign << 31) | (exponent << 23) | mantissa;
365  else
366  retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
367  return reinterpret_cast<const T&>(retval);
368 }
369 
370 #if HIP_FP8_CVT_FAST_PATH
371 // The conversion function is from rocblas
372 // https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
373 template <bool stochastic_rounding = false>
374 static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
375  __hip_fp8_interpretation_t interpret,
376  unsigned int rng = 0) {
377  __hip_fp8_storage_t i8data;
378  union {
379  float fval;
380  unsigned int i32val;
381  unsigned char i8val[4]; // NOTE: not endian independent
382  } val;
383 
384  unsigned int ival = 0;
385  val.fval = v;
386 
387  if (saturate) {
388  if (interpret == __HIP_E4M3_FNUZ) {
389  if ((val.i32val & 0x7F800000) != 0x7F800000) {
390  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
391  }
392  } else {
393  if ((val.i32val & 0x7F800000) != 0x7F800000) {
394  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
395  }
396  }
397  }
398 
399  if (stochastic_rounding) {
400  ival = interpret == __HIP_E4M3_FNUZ
401  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
402  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
403  val.i32val = ival;
404  i8data = val.i8val[0]; // little endian
405  } else { // RNE CVT
406  ival = interpret == __HIP_E4M3_FNUZ
407  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
408  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
409  val.i32val = ival;
410  i8data = val.i8val[0];
411  }
412  return i8data;
413 }
414 
415 static __device__ __hip_fp8x2_storage_t
416 cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
417  union {
418  static_assert(sizeof(float2) == sizeof(unsigned int[2]));
419  static_assert(sizeof(float2) == sizeof(unsigned short[4]));
420  float2 fval;
421  unsigned int i32val[2];
422  unsigned short i16val[4];
423  } f2val;
424 
425  f2val.fval = v;
426 
427  if (saturate) {
428  if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
429  f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
430  }
431  if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
432  f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
433  }
434  }
435 
436  f2val.i32val[0] = interpret == __HIP_E4M3_FNUZ
437  ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
438  : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
439 
440  return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
441 }
442 
443 static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
444  __hip_fp8_interpretation_t interpret) {
445  union {
446  unsigned int i32val;
447  unsigned char i8val[4];
448  } val;
449  val.i8val[0] = v;
450 
451  float fval = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
452  : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
453  return fval;
454 }
455 
456 static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
457  __hip_fp8_interpretation_t interpret) {
458  union {
459  unsigned int i32val;
460  unsigned short i16val[2];
461  } val;
462  val.i16val[0] = v;
463 
464  auto f2 = interpret == __HIP_E4M3_FNUZ ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
465  : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
466  return float2{f2[0], f2[1]};
467 }
468 #endif // HIP_FP8_CVT_FAST_PATH
469 
470 /* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
471 Inf are not supported. This gives us one additional number to represent.
472 NaN are represented by 1-0000-000 or 1-00000-00 */
473 __FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
474  return static_cast<unsigned char>(a) == 0x80;
475 }
476 } // namespace internal
477 
486 __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
487  const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
488 #if HIP_FP8_CVT_FAST_PATH
489  return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, type);
490 #else // HIP_FP8_CVT_FAST_PATH
491  int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
492  int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
493  return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
494 #endif // HIP_FP8_CVT_FAST_PATH
495 }
496 
506  const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
507 #if HIP_FP8_CVT_FAST_PATH
508  return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, type);
509 #else
510  return static_cast<__hip_fp8x2_storage_t>(
511  static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, type)) << 8 |
512  static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, type)));
513 #endif
514 }
515 
524 __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
525  const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
526  int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
527  int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
528  return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
529 }
530 
540  const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
541  return static_cast<__hip_fp8x2_storage_t>(
542  static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, type)) << 8 |
543  static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, type)));
544 }
545 
554 __FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
555 __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
556  const __hip_fp8_interpretation_t type) {
557  float fval = __hip_bfloat16(hr);
558  return __hip_cvt_float_to_fp8(fval, sat, type);
559 }
560 
569 __FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
570 __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
571  const __hip_fp8_interpretation_t type) {
572  float2 f2 = __hip_bfloat162(hr);
573  return __hip_cvt_float2_to_fp8x2(f2, sat, type);
574 }
575 
583 __FP8_HOST_DEVICE_STATIC__ __half_raw
585  unsigned int we = type == __HIP_E4M3_FNUZ ? 4 : 5;
586  unsigned int wm = type == __HIP_E4M3_FNUZ ? 3 : 2;
587  return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
588 }
589 
597 __FP8_HOST_DEVICE_STATIC__ __half2_raw
599  __half2 ret(static_cast<__half>(
600  __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), type)),
601  static_cast<__half>(
602  __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), type)));
603  return static_cast<__half2_raw>(ret);
604 }
605 
615  const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
616  return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, type);
617 }
618 
628  const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type) {
629  return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, type);
630 }
631 
637  __hip_fp8_storage_t __x;
639  constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
640  constexpr static unsigned int __we = 4;
641  constexpr static unsigned int __wm = 3;
642 
643  // TODO: SWDEV-452411
644  // Add cast from unsigned long long, long long to fp8
645 
647  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
648  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
649  __default_interpret)) {}
650 
652  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
653  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
654  __default_interpret)) {}
655 
657  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
658  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
659  __default_interpret)) {}
660 
662  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
663  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
664  __default_interpret)) {}
665 
667  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
668  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
669  __default_interpret)) {}
670 
672  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
673  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
674  __default_interpret)) {}
675 
677  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
678  : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
679 
681  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
682  : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
683 
685  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
686  : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
687  __default_interpret)) {}
688 
690  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
692  __default_interpret)) {}
693 
695  __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
696 
698  __FP8_HOST_DEVICE__ operator __half() const {
699  return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
700  }
701 
703  __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
704  float f = *this;
705  return __hip_bfloat16(f);
706  }
707 
709  __FP8_HOST_DEVICE__ operator bool() const {
710  // it can be 0x00 (+0.0) since 0x80 will be nan
711  return !(static_cast<unsigned short>(__x) == 0);
712  }
713 
715  __FP8_HOST_DEVICE__ operator char() const {
716  if (internal::hip_fp8_fnuz_is_nan(__x)) {
717  return 0;
718  }
719 
720  auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
721  auto llval = static_cast<long long>(fval);
722  if (llval <= CHAR_MIN) {
723  return CHAR_MIN;
724  } else if (llval >= CHAR_MAX) {
725  return CHAR_MAX;
726  }
727  return static_cast<char>(fval);
728  }
729 
731  __FP8_HOST_DEVICE__ operator double() const {
732  return internal::cast_from_f8<double, true>(__x, __wm, __we);
733  }
734 
736  __FP8_HOST_DEVICE__ operator float() const {
737 #if HIP_FP8_CVT_FAST_PATH
738  return internal::cast_to_f32_from_f8(__x, __default_interpret);
739 #else
740  return internal::cast_from_f8<float, true>(__x, __wm, __we);
741 #endif
742  }
743 
745  __FP8_HOST_DEVICE__ operator int() const {
746  if (internal::hip_fp8_fnuz_is_nan(__x)) {
747  return 0;
748  }
749 
750  float fval = *this;
751  return static_cast<int>(fval);
752  }
753 
755  __FP8_HOST_DEVICE__ operator long int() const {
756  if (internal::hip_fp8_fnuz_is_nan(__x)) {
757  return 0;
758  }
759 
760  float fval = *this;
761  return static_cast<long>(fval);
762  }
763 
765  __FP8_HOST_DEVICE__ operator long long int() const {
766  if (internal::hip_fp8_fnuz_is_nan(__x)) {
767  return 0;
768  }
769 
770  float fval = *this;
771  return static_cast<long long>(fval);
772  }
773 
775  __FP8_HOST_DEVICE__ operator short int() const {
776  if (internal::hip_fp8_fnuz_is_nan(__x)) {
777  return 0;
778  }
779 
780  float fval = *this;
781  auto llval = static_cast<long long>(fval);
782  if (llval <= SHRT_MIN) {
783  return SHRT_MIN;
784  } else if (llval >= SHRT_MAX) {
785  return SHRT_MAX;
786  }
787  return static_cast<short>(fval);
788  }
789 
791  __FP8_HOST_DEVICE__ operator signed char() const {
792  if (internal::hip_fp8_fnuz_is_nan(__x)) {
793  return 0;
794  }
795 
796  float fval = *this;
797  auto llval = static_cast<long long>(fval);
798  if (llval <= SCHAR_MIN) {
799  return SCHAR_MIN;
800  } else if (llval >= SCHAR_MAX) {
801  return SCHAR_MAX;
802  }
803  return static_cast<signed char>(fval);
804  }
805 
807  __FP8_HOST_DEVICE__ operator unsigned char() const {
808  if (internal::hip_fp8_fnuz_is_nan(__x)) {
809  return 0;
810  }
811 
812  float fval = *this;
813  auto llval = static_cast<long long>(fval);
814  if (llval <= 0) {
815  return 0;
816  } else if (llval >= UCHAR_MAX) {
817  return UCHAR_MAX;
818  }
819  return static_cast<unsigned char>(fval);
820  }
821 
823  __FP8_HOST_DEVICE__ operator unsigned int() const {
824  if (internal::hip_fp8_fnuz_is_nan(__x)) {
825  return 0;
826  }
827 
828  float fval = *this;
829  auto llval = static_cast<long long>(fval);
830  if (llval <= 0) {
831  return 0;
832  }
833  return static_cast<unsigned int>(fval);
834  }
835 
837  __FP8_HOST_DEVICE__ operator unsigned long int() const {
838  if (internal::hip_fp8_fnuz_is_nan(__x)) {
839  return 0;
840  }
841 
842  float fval = *this;
843  auto llval = static_cast<long long>(fval);
844  if (llval <= 0) {
845  return 0;
846  }
847  return static_cast<unsigned long>(fval);
848  }
849 
851  __FP8_HOST_DEVICE__ operator unsigned long long int() const {
852  if (internal::hip_fp8_fnuz_is_nan(__x)) {
853  return 0;
854  }
855 
856  float fval = *this;
857  auto llval = static_cast<long long>(fval);
858  if (llval <= 0) {
859  return 0;
860  }
861  return static_cast<unsigned long long>(fval);
862  }
863 
865  __FP8_HOST_DEVICE__ operator unsigned short int() const {
866  if (internal::hip_fp8_fnuz_is_nan(__x)) {
867  return 0;
868  }
869 
870  float fval = *this;
871  auto llval = static_cast<long long>(fval);
872  if (llval <= 0) {
873  return 0;
874  }
875  return static_cast<unsigned short>(fval);
876  }
877 };
878 
886  static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
887  static constexpr unsigned int __we = 4;
888  static constexpr unsigned int __wm = 3;
889 
891  __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
892  : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
893 
895  __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
896  : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
897 
899  __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
900  : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
901 
903  __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
904  : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
905 
907  __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
908 
910  __FP8_HOST_DEVICE__ operator __half2() const {
911  return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
912  }
913 
915  __FP8_HOST_DEVICE__ operator float2() const {
916 #if HIP_FP8_CVT_FAST_PATH
917  return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
918 #else
919  return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
920  __wm, __we),
921  internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
922  __wm, __we));
923 #endif
924  }
925 };
926 
934  static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
935  static constexpr unsigned int __we = 4;
936  static constexpr unsigned int __wm = 3;
937 
939  __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
940  : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
941  static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
942  val.x, __default_saturation, __default_interpret)) |
943  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
944  val.y, __default_saturation, __default_interpret))
945  << 8 |
946  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
947  val.z, __default_saturation, __default_interpret))
948  << 16 |
949  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
950  val.w, __default_saturation, __default_interpret))
951  << 24))} {}
952 
954  __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
955  : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
956  static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
957  val.x, __default_saturation, __default_interpret)) |
958  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
959  val.y, __default_saturation, __default_interpret))
960  << 8 |
961  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
962  val.z, __default_saturation, __default_interpret))
963  << 16 |
964  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
965  val.w, __default_saturation, __default_interpret))
966  << 24))} {}
967 
969  __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
970  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
971  reinterpret_cast<unsigned short>(
972  __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
973  reinterpret_cast<unsigned short>(
974  __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
975  << 16))) {}
976 
978  __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
979  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
980  static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
981  high, __default_saturation, __default_interpret)) |
982  reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
983  low, __default_saturation, __default_interpret))
984  << 16))) {}
985 
987  __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
988 
990  __FP8_HOST_DEVICE__ operator float4() const {
991  auto x = __x; // bypass const
992  auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
993  auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
994 #if HIP_FP8_CVT_FAST_PATH
995  float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
996  float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
997 #else
998  float2 high = float2(internal::cast_from_f8<float, true>(
999  static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1000  internal::cast_from_f8<float, true>(
1001  static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1002  float2 low = float2(internal::cast_from_f8<float, true>(
1003  static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1004  internal::cast_from_f8<float, true>(
1005  static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1006 #endif
1007  return float4(low.x, low.y, high.x, high.y);
1008  }
1009 };
1010 
1016  __hip_fp8_storage_t __x;
1018  static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1019  static constexpr unsigned int __we = 5;
1020  static constexpr unsigned int __wm = 2;
1021 
1022 
1023  // TODO: SWDEV-452411
1024  // Add cast from unsigned long long, long long to fp8
1025 
1027  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
1028  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1029  __default_interpret)) {}
1030 
1032  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
1033  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1034  __default_interpret)) {}
1035 
1037  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
1038  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1039  __default_interpret)) {}
1040 
1042  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1043  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1044  __default_interpret)) {}
1045 
1047  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1048  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1049  __default_interpret)) {}
1050 
1052  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1053  : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1054  __default_interpret)) {}
1055 
1057  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
1058  : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {}
1059 
1061  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
1062  : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {}
1063 
1065  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1066  : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1067  __default_interpret)) {}
1068 
1070  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
1072  __default_interpret)) {}
1073 
1075  __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
1076 
1078  __FP8_HOST_DEVICE__ operator float() const {
1079 #if HIP_FP8_CVT_FAST_PATH
1080  return internal::cast_to_f32_from_f8(__x, __default_interpret);
1081 #else
1082  return internal::cast_from_f8<float, true>(__x, __wm, __we);
1083 #endif
1084  }
1085 
1087  __FP8_HOST_DEVICE__ operator __half() const {
1088  return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1089  }
1090 
1092  __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1093  float f = *this;
1094  return __hip_bfloat16(f);
1095  }
1096 
1098  __FP8_HOST_DEVICE__ operator bool() const {
1099  // it can be 0x00 (+0.0) since 0x80 will be nan
1100  return !(static_cast<unsigned short>(__x) == 0);
1101  }
1102 
1104  __FP8_HOST_DEVICE__ operator char() const {
1105  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1106  return 0;
1107  }
1108 
1109  float fval = *this;
1110  auto llval = static_cast<long long>(fval);
1111  if (llval <= CHAR_MIN) {
1112  return CHAR_MIN;
1113  } else if (llval >= CHAR_MAX) {
1114  return CHAR_MAX;
1115  }
1116  return static_cast<char>(fval);
1117  }
1118 
1120  __FP8_HOST_DEVICE__ operator double() const {
1121  return internal::cast_from_f8<double, true>(__x, __wm, __we);
1122  }
1123 
1125  __FP8_HOST_DEVICE__ operator int() const {
1126  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1127  return 0;
1128  }
1129 
1130  float fval = *this;
1131  return static_cast<int>(fval);
1132  }
1133 
1135  __FP8_HOST_DEVICE__ operator long int() const {
1136  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1137  return 0;
1138  }
1139 
1140  float fval = *this;
1141  return static_cast<long>(fval);
1142  }
1143 
1145  __FP8_HOST_DEVICE__ operator long long int() const {
1146  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1147  return 0;
1148  }
1149 
1150  float fval = *this;
1151  return static_cast<long long>(fval);
1152  }
1153 
1155  __FP8_HOST_DEVICE__ operator short int() const {
1156  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1157  return 0;
1158  }
1159 
1160  float fval = *this;
1161  auto llval = static_cast<long long>(fval);
1162  if (llval <= SHRT_MIN) {
1163  return SHRT_MIN;
1164  } else if (llval >= SHRT_MAX) {
1165  return SHRT_MAX;
1166  }
1167  return static_cast<short>(fval);
1168  }
1169 
1171  __FP8_HOST_DEVICE__ operator signed char() const {
1172  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1173  return 0;
1174  }
1175 
1176  float fval = *this;
1177  auto llval = static_cast<long long>(fval);
1178  if (llval <= SCHAR_MIN) {
1179  return SCHAR_MIN;
1180  } else if (llval >= SCHAR_MAX) {
1181  return SCHAR_MAX;
1182  }
1183  return static_cast<signed char>(fval);
1184  }
1185 
1187  __FP8_HOST_DEVICE__ operator unsigned char() const {
1188  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1189  return 0;
1190  }
1191 
1192  float fval = *this;
1193  auto llval = static_cast<long long>(fval);
1194  if (llval <= 0) {
1195  return 0;
1196  } else if (llval >= UCHAR_MAX) {
1197  return UCHAR_MAX;
1198  }
1199  return static_cast<unsigned char>(fval);
1200  }
1201 
1203  __FP8_HOST_DEVICE__ operator unsigned int() const {
1204  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1205  return 0;
1206  }
1207 
1208  float fval = *this;
1209  auto llval = static_cast<long long>(fval);
1210  if (llval <= 0) {
1211  return 0;
1212  }
1213  return static_cast<unsigned int>(fval);
1214  }
1215 
1217  __FP8_HOST_DEVICE__ operator unsigned long int() const {
1218  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1219  return 0;
1220  }
1221 
1222  float fval = *this;
1223  auto llval = static_cast<long long>(fval);
1224  if (llval <= 0) {
1225  return 0;
1226  }
1227  return static_cast<unsigned long>(fval);
1228  }
1229 
1231  __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1232  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1233  return 0;
1234  }
1235 
1236  float fval = *this;
1237  auto llval = static_cast<long long>(fval);
1238  if (llval <= 0) {
1239  return 0;
1240  }
1241  return static_cast<unsigned long long>(fval);
1242  }
1243 
1245  __FP8_HOST_DEVICE__ operator unsigned short int() const {
1246  if (internal::hip_fp8_fnuz_is_nan(__x)) {
1247  return 0;
1248  }
1249 
1250  float fval = *this;
1251  auto llval = static_cast<long long>(fval);
1252  if (llval <= 0) {
1253  return 0;
1254  }
1255  return static_cast<unsigned short>(fval);
1256  }
1257 };
1258 
1264  __hip_fp8x2_storage_t __x;
1266  static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1267  static constexpr unsigned int __we = 5;
1268  static constexpr unsigned int __wm = 2;
1269 
1271  __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1272  : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1273 
1275  __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1276  : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1277 
1279  __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1280  : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1281 
1283  __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1284  : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {}
1285 
1287  __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
1288 
1290  __FP8_HOST_DEVICE__ operator __half2() const {
1291  return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1292  }
1293 
1295  __FP8_HOST_DEVICE__ operator float2() const {
1296 #if HIP_FP8_CVT_FAST_PATH
1297  return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1298 #else
1299  return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1300  __wm, __we),
1301  internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1302  __wm, __we));
1303 #endif
1304  }
1305 };
1306 
1312  __hip_fp8x4_storage_t __x;
1314  static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1315  static constexpr unsigned int __we = 5;
1316  static constexpr unsigned int __wm = 2;
1317 
1319  __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1320  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1321  static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1322  val.x, __default_saturation, __default_interpret)) |
1323  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1324  val.y, __default_saturation, __default_interpret))
1325  << 8 |
1326  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1327  val.z, __default_saturation, __default_interpret))
1328  << 16 |
1329  reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1330  val.w, __default_saturation, __default_interpret))
1331  << 24))) {}
1332 
1334  __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1335  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1336  static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1337  val.x, __default_saturation, __default_interpret)) |
1338  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1339  val.y, __default_saturation, __default_interpret))
1340  << 8 |
1341  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1342  val.z, __default_saturation, __default_interpret))
1343  << 16 |
1344  reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1345  val.w, __default_saturation, __default_interpret))
1346  << 24))) {}
1347 
1349  __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1350  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1351  reinterpret_cast<unsigned short>(
1352  __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1353  reinterpret_cast<unsigned short>(
1354  __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1355  << 16))) {}
1356 
1358  __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
1359  : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1360  static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1361  high, __default_saturation, __default_interpret)) |
1362  reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1363  low, __default_saturation, __default_interpret))
1364  << 16))) {}
1365 
1366  /* default construct fp8x4 e5m2 */
1367  __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
1368 
1370  __FP8_HOST_DEVICE__ operator float4() const {
1371  auto x = __x; // bypass const
1372  auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1373  auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1374 #if HIP_FP8_CVT_FAST_PATH
1375  float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1376  float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1377 #else
1378  float2 high = float2(internal::cast_from_f8<float, true>(
1379  static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1380  internal::cast_from_f8<float, true>(
1381  static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1382  float2 low = float2(internal::cast_from_f8<float, true>(
1383  static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1384  internal::cast_from_f8<float, true>(
1385  static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1386 #endif
1387  return float4(low.x, low.y, high.x, high.y);
1388  }
1389 };
1390 
1391 #endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
hip_bf16.h provides struct for __hip_bfloat16 types
__hip_saturation_t
Describes saturation behavior.
Definition: amd_hip_fp8.h:76
@ __HIP_SATFINITE
Definition: amd_hip_fp8.h:78
@ __HIP_NOSAT
Definition: amd_hip_fp8.h:77
__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8x2_storage_t to __half2_raw
Definition: amd_hip_fp8.h:598
__hip_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_hip_fp8.h:68
@ __HIP_E4M3_FNUZ
Definition: amd_hip_fp8.h:69
@ __HIP_E5M2_FNUZ
Definition: amd_hip_fp8.h:70
unsigned int __hip_fp8x4_storage_t
type to store four fp8 numbers
Definition: amd_hip_fp8.h:99
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition: amd_hip_fp8.h:539
unsigned short int __hip_fp8x2_storage_t
type to store two fp8 numbers
Definition: amd_hip_fp8.h:92
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double to __hip_fp8_storage_t
Definition: amd_hip_fp8.h:524
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half2_raw to __hip_fp8x2_storage_t
Definition: amd_hip_fp8.h:627
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __half_raw to __hip_fp8_storage_t
Definition: amd_hip_fp8.h:614
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float to __hip_fp8_storage_t
Definition: amd_hip_fp8.h:486
unsigned char __hip_fp8_storage_t
type to store single fp8 number
Definition: amd_hip_fp8.h:85
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert __hip_bfloat16_raw to __hip_fp8_storage_t
Definition: amd_hip_fp8.h:555
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert double2 to __hip_fp8x2_storage_t
Definition: amd_hip_fp8.h:570
__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t type)
convert __hip_fp8_storage_t to __half_raw
Definition: amd_hip_fp8.h:584
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t type)
convert float2 to __hip_fp8x2_storage_t
Definition: amd_hip_fp8.h:505
struct representing single fp8 number with e4m3 interpretation
Definition: amd_hip_fp8.h:636
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz()=default
constexpr static __hip_saturation_t __default_saturation
raw storage of fp8 number
Definition: amd_hip_fp8.h:638
struct representing two fp8 numbers with e4m3 interpretation
Definition: amd_hip_fp8.h:883
static constexpr __hip_saturation_t __default_saturation
raw storage of two fp8 numbers
Definition: amd_hip_fp8.h:885
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz()=default
struct representing four fp8 numbers with e4m3 interpretation
Definition: amd_hip_fp8.h:931
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz()=default
static constexpr __hip_saturation_t __default_saturation
raw storage of four fp8 numbers
Definition: amd_hip_fp8.h:933
struct representing one fp8 number with e5m2 interpretation
Definition: amd_hip_fp8.h:1015
static constexpr __hip_saturation_t __default_saturation
raw storage of one fp8 numbers
Definition: amd_hip_fp8.h:1017
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz()=default
struct representing two fp8 numbers with e5m2 interpretation
Definition: amd_hip_fp8.h:1263
static constexpr __hip_saturation_t __default_saturation
raw storage of two fp8 numbers
Definition: amd_hip_fp8.h:1265
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz()=default
struct representing four fp8 numbers with e5m2 interpretation
Definition: amd_hip_fp8.h:1311
static constexpr __hip_saturation_t __default_saturation
raw storage of four fp8 numbers
Definition: amd_hip_fp8.h:1313
Definition: amd_hip_vector_types.h:2035
Definition: amd_hip_vector_types.h:2042
Definition: amd_hip_vector_types.h:2072
Definition: amd_hip_vector_types.h:2079
Definition: hip_fp16_gcc.h:7
Definition: hip_fp16_gcc.h:11