HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_fp8.h
Go to the documentation of this file.
1
30#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
31#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
32
33#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx1200__) || \
34 defined(__gfx1201__)) && \
35 __HIP_DEVICE_COMPILE__
36#define HIP_FP8_CVT_FAST_PATH 1
37#else
38#define HIP_FP8_CVT_FAST_PATH 0
39#endif
40
41#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && __HIP_DEVICE_COMPILE__
42#define HIP_FP8_TYPE_OCP 0
43#define HIP_FP8_TYPE_FNUZ 1
44#elif (defined(__gfx1200__) || defined(__gfx1201__)) && __HIP_DEVICE_COMPILE__
45#define HIP_FP8_TYPE_OCP 1
46#define HIP_FP8_TYPE_FNUZ 0
47#else
48#define HIP_FP8_TYPE_FNUZ 1
49#define HIP_FP8_TYPE_OCP 1
50#endif
51
52#if defined(__HIPCC_RTC__)
53 #if HIP_FP8_TYPE_FNUZ
54 #define ENABLE_FNUZ_HIPRTC 1
55 #else
56 #define ENABLE_FNUZ_HIPRTC 0
57 #endif
58 #if HIP_FP8_TYPE_OCP
59 #define ENABLE_OCP_HIPRTC 1
60 #else
61 #define ENABLE_OCP_HIPRTC 0
62 #endif
63#endif
64
65// Include it explicitly for HIPRTC
66#include "amd_hip_bf16.h"
67
68#if !defined(__HIPCC_RTC__)
69#include <hip/amd_detail/amd_hip_common.h>
70#include <climits>
71
72#include "host_defines.h" // __hip_internal::
73#include "amd_hip_vector_types.h" // float2 etc
74#include "amd_hip_fp16.h" // __half_raw
75#include "math_fwd.h" // ocml device functions
76#include "hip_assert.h" // hip assertions
77#define __HIP_SCHAR_MAX SCHAR_MAX
78#define __HIP_SCHAR_MIN SCHAR_MIN
79#define __HIP_UCHAR_MAX UCHAR_MAX
80#define __HIP_SHRT_MIN SHRT_MIN
81#define __HIP_SHRT_MAX SHRT_MAX
82#define __HIP_CHAR_MIN CHAR_MIN
83#define __HIP_CHAR_MAX CHAR_MAX
84#else
85// fp8 header uses all this, since we do not include standard header, we include this
86#define __HIP_SCHAR_MAX __SCHAR_MAX__
87#define __HIP_SCHAR_MIN (-__SCHAR_MAX__ - 1)
88#define __HIP_UCHAR_MAX (__SCHAR_MAX__ * 2 + 1)
89#define __HIP_SHRT_MIN (-__SHRT_MAX__ - 1)
90#define __HIP_SHRT_MAX __SHRT_MAX__
91#ifdef __CHAR_UNSIGNED__ /* -funsigned-char */
92#define __HIP_CHAR_MIN 0
93#define __HIP_CHAR_MAX __HIP_UCHAR_MAX
94#else
95#define __HIP_CHAR_MIN __HIP_SCHAR_MIN
96#define __HIP_CHAR_MAX __SCHAR_MAX__
97#endif
98#endif // !defined(__HIPCC_RTC__)
99
100#if defined(__HIPCC_RTC__)
101#define __FP8_HOST_DEVICE__ __device__
102#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static
103#else
104#define __FP8_HOST_DEVICE__ __host__ __device__
105#define __FP8_HOST_DEVICE_STATIC__ __FP8_HOST_DEVICE__ static inline
106#endif // __HIPCC_RTC__
107
108#define __FP8_HOST__ __host__
109#define __FP8_HOST_STATIC__ __FP8_HOST__ static inline
110
111
112#if !defined(__HIPCC_RTC__)
113static_assert(CHAR_BIT == 8, "byte size should be of 8 bits");
114#endif
115static_assert(sizeof(unsigned char) == 1);
116static_assert(sizeof(unsigned short int) == 2);
117static_assert(sizeof(unsigned int) == 4);
118
128
136
141typedef unsigned char __hip_fp8_storage_t;
142
143
148typedef unsigned short int __hip_fp8x2_storage_t;
149
150
155typedef unsigned int __hip_fp8x4_storage_t;
156
157
158namespace internal {
159
160// Assertions to check for supported conversion types
161#define __assert_ocp_support(interp) \
162 { \
163 if (interp != __HIP_E4M3 && interp != __HIP_E5M2) { \
164 __hip_assert(false && "type is unsupported by current target device"); \
165 } \
166 }
167#define __assert_fnuz_support(interp) \
168 { \
169 if (interp != __HIP_E4M3_FNUZ && interp != __HIP_E5M2_FNUZ) { \
170 __hip_assert(false && "type is unsupported by current target device"); \
171 } \
172 }
173
174__FP8_HOST_DEVICE_STATIC__ void __is_interpret_supported(__hip_fp8_interpretation_t interp) {
175#if __HIP_DEVICE_COMPILE__
176#if HIP_FP8_TYPE_OCP
177 __assert_ocp_support(interp);
178#endif
179#if HIP_FP8_TYPE_FNUZ
180 __assert_fnuz_support(interp);
181#endif
182#endif
183}
184
185// The conversion function is from rocblas
186// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39
187// This has been modified to add double types conversion as well
188template <typename T, bool is_fnuz>
189__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t cast_to_f8(T _x, int wm, int we, bool clip = false,
190 bool stoch = false,
191 unsigned int rng = 0) {
192 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
193 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
194 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
195 static_assert(is_half || is_float || is_double, "Only half, float and double can be cast to f8");
196
197 const int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
198 unsigned long long x;
199
200 if (sizeof(T) == 8)
201 x = reinterpret_cast<unsigned long long&>(_x);
202 else if (sizeof(T) == 4)
203 x = reinterpret_cast<unsigned int&>(_x);
204 else
205 x = reinterpret_cast<unsigned short int&>(_x);
206
207
208 unsigned long long head, mantissa;
209 int exponent, bias;
210 unsigned int sign;
211 unsigned long long fInf, mask;
212
213 if (sizeof(T) == 8) {
214 head = x & 0xFFF0000000000000ull;
215 mantissa = x & 0xFFFFFFFFFFFFFull;
216 exponent = (head >> 52) & 0x7FF;
217 sign = head >> 63;
218 bias = 1023;
219 fInf = 0x7FF0000000000000ull;
220 mask = 0x7FFFFFFFFFFFFFFFull;
221 } else if (sizeof(T) == 4) {
222 head = x & 0xFF800000;
223 mantissa = x & 0x7FFFFF;
224 exponent = (head >> 23) & 0xFF;
225 sign = head >> 31;
226 bias = 127;
227 fInf = 0x7F800000;
228 mask = 0x7FFFFFFF;
229 } else {
230 head = x & 0xFC00;
231 mantissa = x & 0x3FF;
232 exponent = (head >> 10) & 0x1F;
233 sign = head >> 15;
234 bias = 15;
235 fInf = 0x7C00;
236 mask = 0x7FFF;
237 }
238 unsigned int signed_inf = 0;
239 unsigned int nan = 0;
240 if (is_fnuz) {
241 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
242 nan = 0x80;
243 } else {
244 if (we == 4) { // e4m3
245 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
246 } else { // e5m2
247 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
248 }
249 nan = (sign << 7) + 0x7f;
250 }
251 // Max values
252 unsigned long long ifmax = 0;
253 if (sizeof(T) == 8) {
254 if (we == 5) { // 57344
255 ifmax = 0x40EC000000000000ull;
256 } else {
257 if (is_fnuz) { // 240
258 ifmax = 0x406E000000000000ull;
259 } else { // 448
260 ifmax = 0x407C000000000000ull;
261 }
262 }
263 } else if (sizeof(T) == 4) {
264 if (we == 5) {
265 ifmax = 0x47600000;
266 } else {
267 if (is_fnuz) {
268 ifmax = 0x43700000;
269 } else {
270 ifmax = 0x43E00000;
271 }
272 }
273 } else {
274 if (we == 5) {
275 ifmax = 0x7B00;
276 } else {
277 if (is_fnuz) {
278 ifmax = 0x5B80;
279 } else {
280 ifmax = 0x5F00;
281 }
282 }
283 }
284 // Deal with inf and NaNs
285 if ((x & fInf) == fInf) {
286 if (is_fnuz) return signed_inf;
287 return mantissa != 0 ? nan : signed_inf;
288 }
289
290 if ((x & mask) > ifmax) {
291 return signed_inf;
292 }
293
294 if (x == 0) {
295 return 0;
296 }
297
298 // First need to check if it is normal or denorm as there is a difference of implict 1
299 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
300 // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
301 // RNE, no need to add rng. Then probably need to check whether there is carry and adjust
302 // exponent and mantissa again
303
304 // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
305 const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
306 const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
307 // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
308 // f8_exponent is the converted f8 exponent with bias encoding
309 // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
310 // the difference needs to be adjusted and mantissa shifted
311 int act_exponent, f8_exponent, exponent_diff;
312
313 if (exponent == 0) { // fp32/fp16 is in denormal.
314 /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
315here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has
316exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in
317fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
318where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. In
319this case, the fp16 mantissa should be shift left by 1 */
320 act_exponent = exponent - bias + 1;
321 exponent_diff = f8_denormal_act_exponent -
322 act_exponent; // actual exponent is exponent-bias+1 as it is denormal
323 } else { // fp32/fp16 is normal with implicit 1
324 act_exponent = exponent - bias;
325 if (act_exponent <= f8_denormal_act_exponent) {
326 /* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
327For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
328actual exponent is -7, it is actually larger due to the implict 1,
329Therefore it needs to be adjust to -6 and mantissa shift right by 1.
330So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
331 exponent_diff = f8_denormal_act_exponent - act_exponent;
332 } else { // both fp32/fp16 and f8 are in normal range
333 exponent_diff = 0; // exponent_diff=0 does not mean there is no difference for this case,
334 // act_exponent could be larger. Just that it does not need shift mantissa
335 }
336 mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa
337 }
338
339 bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
340 (1ull << (mfmt - wm + exponent_diff - 1));
341 /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we shift
342right as shift right could rip off some residual part and make something not midpoint look like
343midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than midpoint, but
344after shift right by 4 bits, it would look like midpoint.
345*/
346
347 if (exponent_diff > 0)
348 mantissa >>= exponent_diff;
349 else if (exponent_diff == -1)
350 mantissa <<= -exponent_diff;
351 bool implicit_one = mantissa & (1ull << mfmt);
352 // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
353 f8_exponent =
354 (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
355
356 // Now we have the exponent and mantissa adjusted
357 unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
358 bool odd =
359 mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1
360 mantissa +=
361 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
362
363 // Now we deal with overflow
364 if (f8_exponent == 0) {
365 if ((1ull << mfmt) & mantissa) {
366 f8_exponent = 1; // denormal overflow to become normal, promote exponent
367 }
368 } else {
369 if ((1ull << (mfmt + 1)) & mantissa) {
370 mantissa >>= 1;
371 f8_exponent++;
372 }
373 }
374
375 mantissa >>= (mfmt - wm);
376
377 // above range: quantize to maximum possible float of the same sign
378 const int max_exp = (1 << we) - 1;
379 if (f8_exponent > max_exp) {
380 if (clip) {
381 mantissa = (1 << wm) - 1;
382 f8_exponent = max_exp;
383 } else {
384 return signed_inf;
385 }
386 }
387
388 if (f8_exponent == 0 && mantissa == 0) return is_fnuz ? 0 : (sign << 7);
389 mantissa &= (1 << wm) - 1;
390 return (sign << 7) | (f8_exponent << wm) | mantissa;
391}
392// The conversion function is from rocblas
393// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220
394// This has been modified to handle double types as well
395template <typename T, bool is_fnuz>
396__FP8_HOST_DEVICE_STATIC__ T cast_from_f8(__hip_fp8_storage_t x, int wm, int we,
397 bool clip = false) {
398 constexpr bool is_half = __hip_internal::is_same<T, _Float16>::value;
399 constexpr bool is_float = __hip_internal::is_same<T, float>::value;
400 constexpr bool is_double = __hip_internal::is_same<T, double>::value;
401 static_assert(is_half || is_float || is_double, "only half, float and double are supported");
402
403 constexpr int weo = is_half ? 5 : (is_float ? 8 : 11);
404 constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52);
405
406 T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
407 if (is_half) {
408 const unsigned short int ihInf = 0x7C00;
409 const unsigned short int ihNegInf = 0xFC00;
410 const unsigned short int ihNaN = 0x7C01;
411 const unsigned short int ihNeg0 = 0x8000;
412 /* Max number in e5m2 57344*/
413 const unsigned short int ifmax = 0x7B00;
414 const unsigned short int ifmin = 0xFB00;
415 fInf = reinterpret_cast<const _Float16&>(ihInf);
416 fNegInf = reinterpret_cast<const _Float16&>(ihNegInf);
417 fNaN = reinterpret_cast<const _Float16&>(ihNaN);
418 fNeg0 = reinterpret_cast<const _Float16&>(ihNeg0);
419 fmax = reinterpret_cast<const _Float16&>(ifmax);
420 fmin = reinterpret_cast<const _Float16&>(ifmin);
421 } else if (is_float) {
422 const unsigned int ifInf = 0x7F800000;
423 const unsigned int ifNegInf = 0xFF800000;
424 const unsigned int ifNaN = 0x7F800001;
425 const unsigned int ifNeg0 = 0x80000000;
426 /* Max number in e5m2 57344*/
427 const unsigned int ifmax = 0x47600000;
428 const unsigned int ifmin = 0xC7600000;
429 fInf = reinterpret_cast<const float&>(ifInf);
430 fNegInf = reinterpret_cast<const float&>(ifNegInf);
431 fNaN = reinterpret_cast<const float&>(ifNaN);
432 fNeg0 = reinterpret_cast<const float&>(ifNeg0);
433 fmax = reinterpret_cast<const float&>(ifmax);
434 fmin = reinterpret_cast<const float&>(ifmin);
435 } else if (is_double) {
436 const unsigned long long ifInf = 0x7FF0000000000000ull;
437 const unsigned long long ifNegInf = 0xFFF0000000000000ull;
438 const unsigned long long ifNaN = 0x7FF0000000000001ull;
439 const unsigned long long ifNeg0 = 0x8000000000000000ull;
440 /* Max number in e5m2 57344*/
441 const unsigned long long ifmax = 0x40EC000000000000ull;
442 const unsigned long long ifmin = 0xC0EC000000000000ull;
443 fInf = reinterpret_cast<const double&>(ifInf);
444 fNegInf = reinterpret_cast<const double&>(ifNegInf);
445 fNaN = reinterpret_cast<const double&>(ifNaN);
446 fNeg0 = reinterpret_cast<const double&>(ifNeg0);
447 fmax = reinterpret_cast<const double&>(ifmax);
448 fmin = reinterpret_cast<const double&>(ifmin);
449 }
450
451 if (x == 0) {
452 return 0;
453 }
454
455 unsigned long long sign = x >> 7;
456 unsigned long long mantissa = x & ((1 << wm) - 1);
457 int exponent = (x & 0x7F) >> wm;
458 if (is_fnuz) {
459 if (x == 0x80) {
460 return fNaN;
461 }
462 } else {
463 if (x == 0x80) {
464 return fNeg0;
465 }
466 if (we == 4) { // e4m3
467 if ((x & 0x7F) == 0x7F) {
468 return fNaN;
469 }
470 } else if ((x & 0x7C) == 0x7C) { // e5m2
471 if ((x & 0x3) == 0) {
472 if (clip) {
473 return sign ? fmin : fmax;
474 }
475 return sign ? fNegInf : fInf;
476 }
477 return fNaN;
478 }
479 }
480
481 typename __hip_internal::conditional<
482 sizeof(T) == 2, unsigned short int,
483 typename __hip_internal::conditional<sizeof(T) == 4, unsigned int,
484 unsigned long long>::type>::type retval;
485
486 if (we == 5 && is_half && !is_fnuz) {
487 retval = x << 8;
488 return reinterpret_cast<const T&>(retval);
489 }
490
491 const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
492
493 // subnormal input
494 if (exponent == 0) {
495#if __HIP_DEVICE_COMPILE__
496 // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
497 int sh = 1 + __clz(mantissa) - (32 - wm);
498#else
499 int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
500#endif
501 mantissa <<= sh;
502 exponent += 1 - sh;
503 mantissa &= ((1ull << wm) - 1);
504 }
505 exponent += exp_low_cutoff - 1;
506 mantissa <<= wmo - wm;
507
508 // subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
509 if (exponent <= 0) {
510 mantissa |= 1ull << wmo;
511 mantissa >>= 1 - exponent;
512 exponent = 0;
513 }
514
515 if (sizeof(T) == 2)
516 retval = (sign << 15) | (exponent << 10) | mantissa;
517 else if (sizeof(T) == 4)
518 retval = (sign << 31) | (exponent << 23) | mantissa;
519 else
520 retval = (sign << 63) | (static_cast<unsigned long long>(exponent) << 52) | mantissa;
521 return reinterpret_cast<const T&>(retval);
522}
523
524#if HIP_FP8_CVT_FAST_PATH
525// The conversion function is from rocblas
526// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79
527template <bool stochastic_rounding = false>
528static __device__ __hip_fp8_storage_t cast_to_f8_from_f32(float v, bool saturate,
530 unsigned int rng = 0) {
531 __hip_fp8_storage_t i8data;
532 union {
533 float fval;
534 unsigned int i32val;
535 unsigned char i8val[4]; // NOTE: not endian independent
536 } val;
537
538 unsigned int ival = 0;
539 val.fval = v;
540
541 if (saturate) {
542 if (interpret == __HIP_E4M3_FNUZ) {
543 if ((val.i32val & 0x7F800000) != 0x7F800000) {
544 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
545 }
546 } else if (interpret == __HIP_E4M3) { // OCP type
547 if ((val.i32val & 0x7F800000) != 0x7F800000) {
548 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
549 }
550 } else {
551 if ((val.i32val & 0x7F800000) != 0x7F800000) {
552 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
553 }
554 }
555 }
556
557 if (stochastic_rounding) {
558 ival = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
559 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
560 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
561 val.i32val = ival;
562 i8data = val.i8val[0]; // little endian
563 } else { // RNE CVT
564 ival = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
565 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
566 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0
567 val.i32val = ival;
568 i8data = val.i8val[0];
569 }
570 return i8data;
571}
572
573static __device__ __hip_fp8x2_storage_t
574cast_to_f8x2_from_f32x2(float2 v, bool saturate, __hip_fp8_interpretation_t interpret) {
575 union {
576 static_assert(sizeof(float2) == sizeof(unsigned int[2]), "size mismatch");
577 static_assert(sizeof(float2) == sizeof(unsigned short[4]), "size mismatch");
578 float2 fval;
579 unsigned int i32val[2];
580 unsigned short i16val[4];
581 } f2val;
582
583 f2val.fval = v;
584
585 if (saturate) {
586 if (interpret == __HIP_E4M3_FNUZ) {
587 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
588 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
589 }
590 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
591 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 240.0, -240.0);
592 }
593 } else if (interpret == __HIP_E4M3) {
594 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
595 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
596 }
597 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
598 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 448.0, -448.0);
599 }
600 } else {
601 if ((f2val.i32val[0] & 0x7F800000) != 0x7F800000) {
602 f2val.fval.x = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
603 }
604 if ((f2val.i32val[1] & 0x7F800000) != 0x7F800000) {
605 f2val.fval.y = __builtin_amdgcn_fmed3f(f2val.fval.x, 57344.0, -57344.0);
606 }
607 }
608 }
609
610 f2val.i32val[0] = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
611 ? __builtin_amdgcn_cvt_pk_fp8_f32(v.x, v.y, 0, false)
612 : __builtin_amdgcn_cvt_pk_bf8_f32(v.x, v.y, 0, false);
613
614 return static_cast<__hip_fp8x2_storage_t>(f2val.i16val[0]);
615}
616
617static __device__ float cast_to_f32_from_f8(__hip_fp8_storage_t v,
618 __hip_fp8_interpretation_t interpret) {
619 union {
620 unsigned int i32val;
621 unsigned char i8val[4];
622 } val;
623 val.i8val[0] = v;
624
625 float fval = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
626 ? __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0)
627 : __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
628 return fval;
629}
630
631static __device__ float2 cast_to_f32x2_from_f8x2(__hip_fp8x2_storage_t v,
632 __hip_fp8_interpretation_t interpret) {
633 union {
634 unsigned int i32val;
635 unsigned short i16val[2];
636 } val;
637 val.i16val[0] = v;
638
639 auto f2 = (interpret == __HIP_E4M3_FNUZ) || (interpret == __HIP_E4M3)
640 ? __builtin_amdgcn_cvt_pk_f32_fp8(val.i32val, false)
641 : __builtin_amdgcn_cvt_pk_f32_bf8(val.i32val, false);
642 return float2{f2[0], f2[1]};
643}
644#endif // HIP_FP8_CVT_FAST_PATH
645
646/* For fp8 fnuz types, finite and NaN values are supported. Zero is unsigned.
647Inf are not supported. This gives us one additional number to represent.
648NaN are represented by 1-0000-000 or 1-00000-00 */
649__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_fnuz_is_nan(__hip_fp8_storage_t a) {
650 return static_cast<unsigned char>(a) == 0x80;
651}
652
653__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_ocp_is_nan(__hip_fp8_storage_t a,
654 const __hip_fp8_interpretation_t type) {
655 return (type == __HIP_E4M3) ? ((a & 0x7f) == 0x7f)
656 : (type == __HIP_E5M2) ? ((a & 0x7f) > 0x7c)
657 : false;
658}
659
660__FP8_HOST_DEVICE_STATIC__ bool hip_fp8_ocp_is_inf(__hip_fp8_storage_t a,
661 const __hip_fp8_interpretation_t type) {
662 return (type == __HIP_E5M2) ? (a & 0x7f) == 0x7c : false;
663}
664
665} // namespace internal
666
675#if HIP_FP8_CVT_FAST_PATH
676__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(
677 const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
678 internal::__is_interpret_supported(interp);
679 return internal::cast_to_f8_from_f32<false>(f, sat == __HIP_SATFINITE, interp);
680#else
681#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
683 const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
684#else
686 const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
687#endif
688 if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) {
689 int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
690 int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
691 return internal::cast_to_f8<float, true>(f, wm, we, sat == __HIP_SATFINITE);
692 } else {
693 int we = interp == __HIP_E4M3 ? 4 : 5;
694 int wm = interp == __HIP_E4M3 ? 3 : 2;
695 return internal::cast_to_f8<float, false>(f, wm, we, sat == __HIP_SATFINITE);
696 }
697#endif // HIP_FP8_CVT_FAST_PATH
698}
699
700
709#if HIP_FP8_CVT_FAST_PATH
710__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(
711 const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
712 internal::__is_interpret_supported(interp);
713 return internal::cast_to_f8x2_from_f32x2(f2, sat == __HIP_SATFINITE, interp);
714#else
715#if HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
717 const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
718#else
720 const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
721#endif
722 return static_cast<__hip_fp8x2_storage_t>(
723 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.y, sat, interp)) << 8 |
724 static_cast<unsigned short int>(__hip_cvt_float_to_fp8(f2.x, sat, interp)));
725#endif // HIP_FP8_CVT_FAST_PATH
726}
727
736#if HIP_FP8_CVT_FAST_PATH
737__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(
738 const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
739 internal::__is_interpret_supported(interp);
740#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
742 const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
743#else
745 const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
746#endif
747 if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) {
748 int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
749 int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
750 return internal::cast_to_f8<double, true>(d, wm, we, sat == __HIP_SATFINITE);
751 } else {
752 int we = interp == __HIP_E4M3 ? 4 : 5;
753 int wm = interp == __HIP_E4M3 ? 3 : 2;
754 return internal::cast_to_f8<double, false>(d, wm, we, sat == __HIP_SATFINITE);
755 }
756}
757
766#if HIP_FP8_CVT_FAST_PATH
767__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(
768 const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
769 internal::__is_interpret_supported(interp);
770#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
772 const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
773#else
775 const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
776#endif
777 return static_cast<__hip_fp8x2_storage_t>(
778 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.y, sat, interp)) << 8 |
779 static_cast<unsigned short int>(__hip_cvt_double_to_fp8(d2.x, sat, interp)));
780}
781
790#if HIP_FP8_CVT_FAST_PATH
791__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
792__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
793 const __hip_fp8_interpretation_t interp) {
794 internal::__is_interpret_supported(interp);
795#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
796__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t
797__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
798 const __hip_fp8_interpretation_t interp) {
799#else
800__FP8_HOST_STATIC__ __hip_fp8_storage_t
801__hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat,
802 const __hip_fp8_interpretation_t interp) {
803#endif
804 float fval = __hip_bfloat16(hr);
805 return __hip_cvt_float_to_fp8(fval, sat, interp);
806}
807
816#if HIP_FP8_CVT_FAST_PATH
817__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
818__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
819 const __hip_fp8_interpretation_t interp) {
820 internal::__is_interpret_supported(interp);
821#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
822__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t
823__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
824 const __hip_fp8_interpretation_t interp) {
825#else
826__FP8_HOST_STATIC__ __hip_fp8x2_storage_t
827__hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat,
828 const __hip_fp8_interpretation_t interp) {
829#endif
830 float2 f2 = __hip_bfloat162(hr);
831 return __hip_cvt_float2_to_fp8x2(f2, sat, interp);
832}
833
841#if HIP_FP8_CVT_FAST_PATH
842__FP8_HOST_DEVICE_STATIC__ __half_raw
844 internal::__is_interpret_supported(interp);
845#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
846__FP8_HOST_DEVICE_STATIC__ __half_raw
848#else
850 const __hip_fp8_interpretation_t interp) {
851#endif
852 if (interp == __HIP_E4M3_FNUZ || interp == __HIP_E5M2_FNUZ) {
853 unsigned int we = interp == __HIP_E4M3_FNUZ ? 4 : 5;
854 unsigned int wm = interp == __HIP_E4M3_FNUZ ? 3 : 2;
855 return __half_raw{internal::cast_from_f8<_Float16, true>(x, wm, we)};
856 } else {
857 unsigned int we = interp == __HIP_E4M3 ? 4 : 5;
858 unsigned int wm = interp == __HIP_E4M3 ? 3 : 2;
859 return __half_raw{internal::cast_from_f8<_Float16, false>(x, wm, we)};
860 }
861}
862
870#if HIP_FP8_CVT_FAST_PATH
871__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(
873 internal::__is_interpret_supported(interp);
874#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
875__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(
877#else
878__FP8_HOST_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(
880#endif
881 __half2 ret(static_cast<__half>(
882 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x & 0xFF), interp)),
883 static_cast<__half>(
884 __hip_cvt_fp8_to_halfraw(static_cast<__hip_fp8_storage_t>(x >> 8), interp)));
885 return static_cast<__half2_raw>(ret);
886}
887
896#if HIP_FP8_CVT_FAST_PATH
897__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(
898 const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
899 internal::__is_interpret_supported(interp);
900#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
902 const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
903#else
905 const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
906#endif
907 return __hip_cvt_float_to_fp8(__half2float(__half(x)), sat, interp);
908}
909
918#if HIP_FP8_CVT_FAST_PATH
919__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(
920 const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
921 internal::__is_interpret_supported(interp);
922#elif HIP_FP8_TYPE_OCP && HIP_FP8_TYPE_FNUZ
924 const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
925#else
927 const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp) {
928#endif
929 return __hip_cvt_float2_to_fp8x2(__half22float2(__half2(x)), sat, interp);
930}
931
937#if !defined(ENABLE_FNUZ_HIPRTC) || ENABLE_FNUZ_HIPRTC
941 constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
942 constexpr static unsigned int __we = 4;
943 constexpr static unsigned int __wm = 3;
944
945 // TODO: SWDEV-452411
946 // Add cast from unsigned long long, long long to fp8
947
949#if HIP_FP8_TYPE_FNUZ
950 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
951#else
952 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const long int val)
953#endif
954 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
955 __default_interpret)) {
956 }
957
959#if HIP_FP8_TYPE_FNUZ
960 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
961#else
962 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const int val)
963#endif
964 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
965 __default_interpret)) {
966 }
967
969#if HIP_FP8_TYPE_FNUZ
970 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
971#else
972 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const short int val)
973#endif
974 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
975 __default_interpret)) {
976 }
977
979#if HIP_FP8_TYPE_FNUZ
980 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
981#else
982 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
983#endif
984 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
985 __default_interpret)) {
986 }
987
989#if HIP_FP8_TYPE_FNUZ
990 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
991#else
992 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned int val)
993#endif
994 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
995 __default_interpret)) {
996 }
997
999#if HIP_FP8_TYPE_FNUZ
1000 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
1001#else
1002 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
1003#endif
1004 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1005 __default_interpret)) {
1006 }
1007
1009#if HIP_FP8_TYPE_FNUZ
1010 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
1011#else
1012 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const double f)
1013#endif
1014 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {
1015 }
1016
1018#if HIP_FP8_TYPE_FNUZ
1019 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
1020#else
1021 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const float f)
1022#endif
1023 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {
1024 }
1025
1027#if HIP_FP8_TYPE_FNUZ
1028 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
1029#else
1030 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
1031#endif
1032 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1033 __default_interpret)) {
1034 }
1035
1037#if HIP_FP8_TYPE_FNUZ
1038 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
1039#else
1040 __FP8_HOST__ __hip_fp8_e4m3_fnuz(const __half f)
1041#endif
1043 __default_interpret)) {
1044 }
1045
1047#if HIP_FP8_TYPE_FNUZ
1048 __FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz() = default;
1049#else
1050 __FP8_HOST__ __hip_fp8_e4m3_fnuz() = default;
1051#endif
1052
1054#if HIP_FP8_TYPE_FNUZ
1055 __FP8_HOST_DEVICE__ operator __half() const {
1056#else
1057 __FP8_HOST__ operator __half() const {
1058#endif
1059 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1060 }
1061
1063#if HIP_FP8_TYPE_FNUZ
1064 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1065#else
1066 __FP8_HOST__ operator __hip_bfloat16() const {
1067#endif
1068 float f = *this;
1069 return __hip_bfloat16(f);
1070 }
1071
1073#if HIP_FP8_TYPE_FNUZ
1074 __FP8_HOST_DEVICE__ operator bool() const {
1075#else
1076 __FP8_HOST__ operator bool() const {
1077#endif
1078 // it can be 0x00 (+0.0) since 0x80 will be nan
1079 return !(static_cast<unsigned short>(__x) == 0);
1080 }
1081
1083#if HIP_FP8_TYPE_FNUZ
1084 __FP8_HOST_DEVICE__ operator char() const {
1085#else
1086 __FP8_HOST__ operator char() const {
1087#endif
1088 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1089 return 0;
1090 }
1091
1092 auto fval = internal::cast_from_f8<float, true>(__x, __wm, __we);
1093 auto llval = static_cast<long long>(fval);
1094 if (llval <= __HIP_CHAR_MIN) {
1095 return __HIP_CHAR_MIN;
1096 } else if (llval >= __HIP_CHAR_MAX) {
1097 return __HIP_CHAR_MAX;
1098 }
1099 return static_cast<char>(fval);
1100 }
1101
1103#if HIP_FP8_TYPE_FNUZ
1104 __FP8_HOST_DEVICE__ operator double() const {
1105#else
1106 __FP8_HOST__ operator double() const {
1107#endif
1108 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1109 }
1110
1112#if HIP_FP8_TYPE_FNUZ
1113 __FP8_HOST_DEVICE__ operator float() const {
1114#else
1115 __FP8_HOST__ operator float() const {
1116#endif
1117#if HIP_FP8_CVT_FAST_PATH
1118 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1119#else
1120 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1121#endif
1122 }
1123
1125#if HIP_FP8_TYPE_FNUZ
1126 __FP8_HOST_DEVICE__ operator int() const {
1127#else
1128 __FP8_HOST__ operator int() const {
1129#endif
1130 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1131 return 0;
1132 }
1133
1134 float fval = *this;
1135 return static_cast<int>(fval);
1136 }
1137
1139#if HIP_FP8_TYPE_FNUZ
1140 __FP8_HOST_DEVICE__ operator long int() const {
1141#else
1142 __FP8_HOST__ operator long int() const {
1143#endif
1144 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1145 return 0;
1146 }
1147
1148 float fval = *this;
1149 return static_cast<long>(fval);
1150 }
1151
1153#if HIP_FP8_TYPE_FNUZ
1154 __FP8_HOST_DEVICE__ operator long long int() const {
1155#else
1156 __FP8_HOST__ operator long long int() const {
1157#endif
1158 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1159 return 0;
1160 }
1161
1162 float fval = *this;
1163 return static_cast<long long>(fval);
1164 }
1165
1167#if HIP_FP8_TYPE_FNUZ
1168 __FP8_HOST_DEVICE__ operator short int() const {
1169#else
1170 __FP8_HOST__ operator short int() const {
1171#endif
1172 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1173 return 0;
1174 }
1175
1176 float fval = *this;
1177 auto llval = static_cast<long long>(fval);
1178 if (llval <= __HIP_SHRT_MIN) {
1179 return __HIP_SHRT_MIN;
1180 } else if (llval >= __HIP_SHRT_MAX) {
1181 return __HIP_SHRT_MAX;
1182 }
1183 return static_cast<short>(fval);
1184 }
1185
1187#if HIP_FP8_TYPE_FNUZ
1188 __FP8_HOST_DEVICE__ operator signed char() const {
1189#else
1190 __FP8_HOST__ operator signed char() const {
1191#endif
1192 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1193 return 0;
1194 }
1195
1196 float fval = *this;
1197 auto llval = static_cast<long long>(fval);
1198 if (llval <= __HIP_SCHAR_MIN) {
1199 return __HIP_SCHAR_MIN;
1200 } else if (llval >= __HIP_SCHAR_MAX) {
1201 return __HIP_SCHAR_MAX;
1202 }
1203 return static_cast<signed char>(fval);
1204 }
1205
1207#if HIP_FP8_TYPE_FNUZ
1208 __FP8_HOST_DEVICE__ operator unsigned char() const {
1209#else
1210 __FP8_HOST__ operator unsigned char() const {
1211#endif
1212 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1213 return 0;
1214 }
1215
1216 float fval = *this;
1217 auto llval = static_cast<long long>(fval);
1218 if (llval <= 0) {
1219 return 0;
1220 } else if (llval >= __HIP_UCHAR_MAX) {
1221 return __HIP_UCHAR_MAX;
1222 }
1223 return static_cast<unsigned char>(fval);
1224 }
1225
1227#if HIP_FP8_TYPE_FNUZ
1228 __FP8_HOST_DEVICE__ operator unsigned int() const {
1229#else
1230 __FP8_HOST__ operator unsigned int() const {
1231#endif
1232 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1233 return 0;
1234 }
1235
1236 float fval = *this;
1237 auto llval = static_cast<long long>(fval);
1238 if (llval <= 0) {
1239 return 0;
1240 }
1241 return static_cast<unsigned int>(fval);
1242 }
1243
1245#if HIP_FP8_TYPE_FNUZ
1246 __FP8_HOST_DEVICE__ operator unsigned long int() const {
1247#else
1248 __FP8_HOST__ operator unsigned long int() const {
1249#endif
1250 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1251 return 0;
1252 }
1253
1254 float fval = *this;
1255 auto llval = static_cast<long long>(fval);
1256 if (llval <= 0) {
1257 return 0;
1258 }
1259 return static_cast<unsigned long>(fval);
1260 }
1261
1263#if HIP_FP8_TYPE_FNUZ
1264 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1265#else
1266 __FP8_HOST__ operator unsigned long long int() const {
1267#endif
1268 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1269 return 0;
1270 }
1271
1272 float fval = *this;
1273 auto llval = static_cast<long long>(fval);
1274 if (llval <= 0) {
1275 return 0;
1276 }
1277 return static_cast<unsigned long long>(fval);
1278 }
1279
1281#if HIP_FP8_TYPE_FNUZ
1282 __FP8_HOST_DEVICE__ operator unsigned short int() const {
1283#else
1284 __FP8_HOST__ operator unsigned short int() const {
1285#endif
1286 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1287 return 0;
1288 }
1289
1290 float fval = *this;
1291 auto llval = static_cast<long long>(fval);
1292 if (llval <= 0) {
1293 return 0;
1294 }
1295 return static_cast<unsigned short>(fval);
1296 }
1297};
1298
1305 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1306 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
1307 static constexpr unsigned int __we = 4;
1308 static constexpr unsigned int __wm = 3;
1309
1311#if HIP_FP8_TYPE_FNUZ
1312 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
1313#else
1314 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const double2 val)
1315#endif
1316 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1317 }
1318
1320#if HIP_FP8_TYPE_FNUZ
1321 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
1322#else
1323 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const float2 val)
1324#endif
1325 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1326 }
1327
1329#if HIP_FP8_TYPE_FNUZ
1330 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
1331#else
1332 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
1333#endif
1334 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1335 }
1336
1338#if HIP_FP8_TYPE_FNUZ
1339 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
1340#else
1341 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
1342#endif
1343 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1344 }
1345
1347#if HIP_FP8_TYPE_FNUZ
1348 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz() = default;
1349#else
1350 __FP8_HOST__ __hip_fp8x2_e4m3_fnuz() = default;
1351#endif
1352
1354#if HIP_FP8_TYPE_FNUZ
1355 __FP8_HOST_DEVICE__ operator __half2() const {
1356#else
1357 __FP8_HOST__ operator __half2() const {
1358#endif
1359 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1360 }
1361
1363#if HIP_FP8_TYPE_FNUZ
1364 __FP8_HOST_DEVICE__ operator float2() const {
1365#else
1366 __FP8_HOST__ operator float2() const {
1367#endif
1368#if HIP_FP8_CVT_FAST_PATH
1369 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1370#else
1371 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1372 __wm, __we),
1373 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1374 __wm, __we));
1375#endif
1376 }
1377};
1378
1385 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1386 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3_FNUZ;
1387 static constexpr unsigned int __we = 4;
1388 static constexpr unsigned int __wm = 3;
1389
1391#if HIP_FP8_TYPE_FNUZ
1392 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
1393#else
1394 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const double4 val)
1395#endif
1396 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
1397 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1398 val.x, __default_saturation, __default_interpret)) |
1399 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1400 val.y, __default_saturation, __default_interpret))
1401 << 8 |
1402 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1403 val.z, __default_saturation, __default_interpret))
1404 << 16 |
1405 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1406 val.w, __default_saturation, __default_interpret))
1407 << 24))} {
1408 }
1409
1411#if HIP_FP8_TYPE_FNUZ
1412 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
1413#else
1414 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const float4 val)
1415#endif
1416 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
1417 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1418 val.x, __default_saturation, __default_interpret)) |
1419 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1420 val.y, __default_saturation, __default_interpret))
1421 << 8 |
1422 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1423 val.z, __default_saturation, __default_interpret))
1424 << 16 |
1425 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1426 val.w, __default_saturation, __default_interpret))
1427 << 24))} {
1428 }
1429
1431#if HIP_FP8_TYPE_FNUZ
1432 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1433#else
1434 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1435#endif
1436 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1437 reinterpret_cast<unsigned short>(
1438 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1439 reinterpret_cast<unsigned short>(
1440 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1441 << 16))) {
1442 }
1443
1445#if HIP_FP8_TYPE_FNUZ
1446 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
1447#else
1448 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
1449#endif
1450 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1451 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1452 high, __default_saturation, __default_interpret)) |
1453 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
1454 low, __default_saturation, __default_interpret))
1455 << 16))) {
1456 }
1457
1459#if HIP_FP8_TYPE_FNUZ
1460 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz() = default;
1461#else
1462 __FP8_HOST__ __hip_fp8x4_e4m3_fnuz() = default;
1463#endif
1464
1466#if HIP_FP8_TYPE_FNUZ
1467 __FP8_HOST_DEVICE__ operator float4() const {
1468#else
1469 __FP8_HOST__ operator float4() const {
1470#endif
1471 auto x = __x; // bypass const
1472 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
1473 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
1474#if HIP_FP8_CVT_FAST_PATH
1475 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
1476 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
1477#else
1478 float2 high = float2(internal::cast_from_f8<float, true>(
1479 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
1480 internal::cast_from_f8<float, true>(
1481 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
1482 float2 low = float2(internal::cast_from_f8<float, true>(
1483 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
1484 internal::cast_from_f8<float, true>(
1485 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
1486#endif
1487 return float4(low.x, low.y, high.x, high.y);
1488 }
1489};
1490
1497 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1498 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1499 static constexpr unsigned int __we = 5;
1500 static constexpr unsigned int __wm = 2;
1501
1502
1503 // TODO: SWDEV-452411
1504 // Add cast from unsigned long long, long long to fp8
1505
1507#if HIP_FP8_TYPE_FNUZ
1508 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
1509#else
1510 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const long int val)
1511#endif
1512 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1513 __default_interpret)) {
1514 }
1515
1517#if HIP_FP8_TYPE_FNUZ
1518 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
1519#else
1520 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const int val)
1521#endif
1522 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1523 __default_interpret)) {
1524 }
1525
1527#if HIP_FP8_TYPE_FNUZ
1528 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
1529#else
1530 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const short int val)
1531#endif
1532 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1533 __default_interpret)) {
1534 }
1535
1537#if HIP_FP8_TYPE_FNUZ
1538 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1539#else
1540 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
1541#endif
1542 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1543 __default_interpret)) {
1544 }
1545
1547#if HIP_FP8_TYPE_FNUZ
1548 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1549#else
1550 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned int val)
1551#endif
1552 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1553 __default_interpret)) {
1554 }
1555
1557#if HIP_FP8_TYPE_FNUZ
1558 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1559#else
1560 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
1561#endif
1562 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
1563 __default_interpret)) {
1564 }
1565
1567#if HIP_FP8_TYPE_FNUZ
1568 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
1569#else
1570 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const double f)
1571#endif
1572 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {
1573 }
1574
1576#if HIP_FP8_TYPE_FNUZ
1577 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
1578#else
1579 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const float f)
1580#endif
1581 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {
1582 }
1583
1585#if HIP_FP8_TYPE_FNUZ
1586 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1587#else
1588 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
1589#endif
1590 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
1591 __default_interpret)) {
1592 }
1593
1595#if HIP_FP8_TYPE_FNUZ
1596 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
1597#else
1598 __FP8_HOST__ __hip_fp8_e5m2_fnuz(const __half f)
1599#endif
1600 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
1601 __default_interpret)) {
1602 }
1603
1605#if HIP_FP8_TYPE_FNUZ
1606 __FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz() = default;
1607#else
1608 __FP8_HOST__ __hip_fp8_e5m2_fnuz() = default;
1609#endif
1610
1612#if HIP_FP8_TYPE_FNUZ
1613 __FP8_HOST_DEVICE__ operator float() const {
1614#else
1615 __FP8_HOST__ operator float() const {
1616#endif
1617#if HIP_FP8_CVT_FAST_PATH
1618 return internal::cast_to_f32_from_f8(__x, __default_interpret);
1619#else
1620 return internal::cast_from_f8<float, true>(__x, __wm, __we);
1621#endif
1622 }
1623
1625#if HIP_FP8_TYPE_FNUZ
1626 __FP8_HOST_DEVICE__ operator __half() const {
1627#else
1628 __FP8_HOST__ operator __half() const {
1629#endif
1630 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
1631 }
1632
1634#if HIP_FP8_TYPE_FNUZ
1635 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
1636#else
1637 __FP8_HOST__ operator __hip_bfloat16() const {
1638#endif
1639 float f = *this;
1640 return __hip_bfloat16(f);
1641 }
1642
1644#if HIP_FP8_TYPE_FNUZ
1645 __FP8_HOST_DEVICE__ operator bool() const {
1646#else
1647 __FP8_HOST__ operator bool() const {
1648#endif
1649 // it can be 0x00 (+0.0) since 0x80 will be nan
1650 return !(static_cast<unsigned short>(__x) == 0);
1651 }
1652
1654#if HIP_FP8_TYPE_FNUZ
1655 __FP8_HOST_DEVICE__ operator char() const {
1656#else
1657 __FP8_HOST__ operator char() const {
1658#endif
1659 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1660 return 0;
1661 }
1662
1663 float fval = *this;
1664 auto llval = static_cast<long long>(fval);
1665 if (llval <= __HIP_CHAR_MIN) {
1666 return __HIP_CHAR_MIN;
1667 } else if (llval >= __HIP_CHAR_MAX) {
1668 return __HIP_CHAR_MAX;
1669 }
1670 return static_cast<char>(fval);
1671 }
1672
1674#if HIP_FP8_TYPE_FNUZ
1675 __FP8_HOST_DEVICE__ operator double() const {
1676#else
1677 __FP8_HOST__ operator double() const {
1678#endif
1679 return internal::cast_from_f8<double, true>(__x, __wm, __we);
1680 }
1681
1683#if HIP_FP8_TYPE_FNUZ
1684 __FP8_HOST_DEVICE__ operator int() const {
1685#else
1686 __FP8_HOST__ operator int() const {
1687#endif
1688 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1689 return 0;
1690 }
1691
1692 float fval = *this;
1693 return static_cast<int>(fval);
1694 }
1695
1697#if HIP_FP8_TYPE_FNUZ
1698 __FP8_HOST_DEVICE__ operator long int() const {
1699#else
1700 __FP8_HOST__ operator long int() const {
1701#endif
1702 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1703 return 0;
1704 }
1705
1706 float fval = *this;
1707 return static_cast<long>(fval);
1708 }
1709
1711#if HIP_FP8_TYPE_FNUZ
1712 __FP8_HOST_DEVICE__ operator long long int() const {
1713#else
1714 __FP8_HOST__ operator long long int() const {
1715#endif
1716 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1717 return 0;
1718 }
1719
1720 float fval = *this;
1721 return static_cast<long long>(fval);
1722 }
1723
1725#if HIP_FP8_TYPE_FNUZ
1726 __FP8_HOST_DEVICE__ operator short int() const {
1727#else
1728 __FP8_HOST__ operator short int() const {
1729#endif
1730 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1731 return 0;
1732 }
1733
1734 float fval = *this;
1735 auto llval = static_cast<long long>(fval);
1736 if (llval <= __HIP_SHRT_MIN) {
1737 return __HIP_SHRT_MIN;
1738 } else if (llval >= __HIP_SHRT_MAX) {
1739 return __HIP_SHRT_MAX;
1740 }
1741 return static_cast<short>(fval);
1742 }
1743
1745#if HIP_FP8_TYPE_FNUZ
1746 __FP8_HOST_DEVICE__ operator signed char() const {
1747#else
1748 __FP8_HOST__ operator signed char() const {
1749#endif
1750 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1751 return 0;
1752 }
1753
1754 float fval = *this;
1755 auto llval = static_cast<long long>(fval);
1756 if (llval <= __HIP_SCHAR_MIN) {
1757 return __HIP_SCHAR_MIN;
1758 } else if (llval >= __HIP_SCHAR_MAX) {
1759 return __HIP_SCHAR_MAX;
1760 }
1761 return static_cast<signed char>(fval);
1762 }
1763
1765#if HIP_FP8_TYPE_FNUZ
1766 __FP8_HOST_DEVICE__ operator unsigned char() const {
1767#else
1768 __FP8_HOST__ operator unsigned char() const {
1769#endif
1770 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1771 return 0;
1772 }
1773
1774 float fval = *this;
1775 auto llval = static_cast<long long>(fval);
1776 if (llval <= 0) {
1777 return 0;
1778 } else if (llval >= __HIP_UCHAR_MAX) {
1779 return __HIP_UCHAR_MAX;
1780 }
1781 return static_cast<unsigned char>(fval);
1782 }
1783
1785#if HIP_FP8_TYPE_FNUZ
1786 __FP8_HOST_DEVICE__ operator unsigned int() const {
1787#else
1788 __FP8_HOST__ operator unsigned int() const {
1789#endif
1790 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1791 return 0;
1792 }
1793
1794 float fval = *this;
1795 auto llval = static_cast<long long>(fval);
1796 if (llval <= 0) {
1797 return 0;
1798 }
1799 return static_cast<unsigned int>(fval);
1800 }
1801
1803#if HIP_FP8_TYPE_FNUZ
1804 __FP8_HOST_DEVICE__ operator unsigned long int() const {
1805#else
1806 __FP8_HOST__ operator unsigned long int() const {
1807#endif
1808 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1809 return 0;
1810 }
1811
1812 float fval = *this;
1813 auto llval = static_cast<long long>(fval);
1814 if (llval <= 0) {
1815 return 0;
1816 }
1817 return static_cast<unsigned long>(fval);
1818 }
1819
1821#if HIP_FP8_TYPE_FNUZ
1822 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
1823#else
1824 __FP8_HOST__ operator unsigned long long int() const {
1825#endif
1826 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1827 return 0;
1828 }
1829
1830 float fval = *this;
1831 auto llval = static_cast<long long>(fval);
1832 if (llval <= 0) {
1833 return 0;
1834 }
1835 return static_cast<unsigned long long>(fval);
1836 }
1837
1839#if HIP_FP8_TYPE_FNUZ
1840 __FP8_HOST_DEVICE__ operator unsigned short int() const {
1841#else
1842 __FP8_HOST__ operator unsigned short int() const {
1843#endif
1844 if (internal::hip_fp8_fnuz_is_nan(__x)) {
1845 return 0;
1846 }
1847
1848 float fval = *this;
1849 auto llval = static_cast<long long>(fval);
1850 if (llval <= 0) {
1851 return 0;
1852 }
1853 return static_cast<unsigned short>(fval);
1854 }
1855};
1856
1863 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1864 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1865 static constexpr unsigned int __we = 5;
1866 static constexpr unsigned int __wm = 2;
1867
1869#if HIP_FP8_TYPE_FNUZ
1870 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1871#else
1872 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const double2 val)
1873#endif
1874 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1875 }
1876
1878#if HIP_FP8_TYPE_FNUZ
1879 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1880#else
1881 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const float2 val)
1882#endif
1883 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1884 }
1885
1887#if HIP_FP8_TYPE_FNUZ
1888 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1889#else
1890 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
1891#endif
1892 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1893 }
1894
1896#if HIP_FP8_TYPE_FNUZ
1897 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1898#else
1899 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
1900#endif
1901 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
1902 }
1903
1905#if HIP_FP8_TYPE_FNUZ
1906 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz() = default;
1907#else
1908 __FP8_HOST__ __hip_fp8x2_e5m2_fnuz() = default;
1909#endif
1910
1912#if HIP_FP8_TYPE_FNUZ
1913 __FP8_HOST_DEVICE__ operator __half2() const {
1914#else
1915 __FP8_HOST__ operator __half2() const {
1916#endif
1917 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
1918 }
1919
1921#if HIP_FP8_TYPE_FNUZ
1922 __FP8_HOST_DEVICE__ operator float2() const {
1923#else
1924 __FP8_HOST__ operator float2() const {
1925#endif
1926#if HIP_FP8_CVT_FAST_PATH
1927 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
1928#else
1929 return float2(internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
1930 __wm, __we),
1931 internal::cast_from_f8<float, true>(static_cast<__hip_fp8_storage_t>(__x >> 8),
1932 __wm, __we));
1933#endif
1934 }
1935};
1936
1943 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
1944 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2_FNUZ;
1945 static constexpr unsigned int __we = 5;
1946 static constexpr unsigned int __wm = 2;
1947
1949#if HIP_FP8_TYPE_FNUZ
1950 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1951#else
1952 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const double4 val)
1953#endif
1954 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1955 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1956 val.x, __default_saturation, __default_interpret)) |
1957 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1958 val.y, __default_saturation, __default_interpret))
1959 << 8 |
1960 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1961 val.z, __default_saturation, __default_interpret))
1962 << 16 |
1963 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
1964 val.w, __default_saturation, __default_interpret))
1965 << 24))) {
1966 }
1967
1969#if HIP_FP8_TYPE_FNUZ
1970 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1971#else
1972 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const float4 val)
1973#endif
1974 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
1975 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1976 val.x, __default_saturation, __default_interpret)) |
1977 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1978 val.y, __default_saturation, __default_interpret))
1979 << 8 |
1980 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1981 val.z, __default_saturation, __default_interpret))
1982 << 16 |
1983 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
1984 val.w, __default_saturation, __default_interpret))
1985 << 24))) {
1986 }
1987
1989#if HIP_FP8_TYPE_FNUZ
1990 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1991#else
1992 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
1993#endif
1994 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
1995 reinterpret_cast<unsigned short>(
1996 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
1997 reinterpret_cast<unsigned short>(
1998 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
1999 << 16))) {
2000 }
2001
2003#if HIP_FP8_TYPE_FNUZ
2004 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
2005#else
2006 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
2007#endif
2008 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2009 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2010 high, __default_saturation, __default_interpret)) |
2011 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2012 low, __default_saturation, __default_interpret))
2013 << 16))) {
2014 }
2015
2016 /* default construct fp8x4 e5m2 */
2017#if HIP_FP8_TYPE_FNUZ
2018 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz() = default;
2019#else
2020 __FP8_HOST__ __hip_fp8x4_e5m2_fnuz() = default;
2021#endif
2022
2024#if HIP_FP8_TYPE_FNUZ
2025 __FP8_HOST_DEVICE__ operator float4() const {
2026#else
2027 __FP8_HOST__ operator float4() const {
2028#endif
2029 auto x = __x; // bypass const
2030 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
2031 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
2032#if HIP_FP8_CVT_FAST_PATH
2033 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2034 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2035#else
2036 float2 high = float2(internal::cast_from_f8<float, true>(
2037 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
2038 internal::cast_from_f8<float, true>(
2039 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
2040 float2 low = float2(internal::cast_from_f8<float, true>(
2041 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
2042 internal::cast_from_f8<float, true>(
2043 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
2044#endif
2045 return float4(low.x, low.y, high.x, high.y);
2046 }
2047};
2048
2049#endif // ENABLE_FNUZ_HIPRTC
2050
2056#if !defined(ENABLE_OCP_HIPRTC) || ENABLE_OCP_HIPRTC
2057
2060 constexpr static __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2061 constexpr static __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
2062 constexpr static unsigned int __we = 4;
2063 constexpr static unsigned int __wm = 3;
2064
2065 // TODO: SWDEV-452411
2066 // Add cast from unsigned long long, long long to fp8
2067
2069#if HIP_FP8_TYPE_OCP
2070 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val)
2071#else
2072 __FP8_HOST__ __hip_fp8_e4m3(const long int val)
2073#endif
2074 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2075 __default_interpret)) {
2076 }
2077
2079#if HIP_FP8_TYPE_OCP
2080 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val)
2081#else
2082 __FP8_HOST__ __hip_fp8_e4m3(const int val)
2083#endif
2084 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2085 __default_interpret)) {
2086 }
2087
2089 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val)
2090 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2091 __default_interpret)) {}
2092
2094#if HIP_FP8_TYPE_OCP
2095 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val)
2096#else
2097 __FP8_HOST__ __hip_fp8_e4m3(const unsigned long int val)
2098#endif
2099 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2100 __default_interpret)) {
2101 }
2102
2104#if HIP_FP8_TYPE_OCP
2105 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val)
2106#else
2107 __FP8_HOST__ __hip_fp8_e4m3(const unsigned int val)
2108#endif
2109 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2110 __default_interpret)) {
2111 }
2112
2114#if HIP_FP8_TYPE_OCP
2115 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val)
2116#else
2117 __FP8_HOST__ __hip_fp8_e4m3(const unsigned short int val)
2118#endif
2119 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2120 __default_interpret)) {
2121 }
2122
2124#if HIP_FP8_TYPE_OCP
2125 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f)
2126#else
2127 __FP8_HOST__ __hip_fp8_e4m3(const double f)
2128#endif
2129 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {
2130 }
2131
2133#if HIP_FP8_TYPE_OCP
2134 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f)
2135#else
2136 __FP8_HOST__ __hip_fp8_e4m3(const float f)
2137#endif
2138 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {
2139 }
2140
2142#if HIP_FP8_TYPE_OCP
2143 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f)
2144#else
2145 __FP8_HOST__ __hip_fp8_e4m3(const __hip_bfloat16 f)
2146#endif
2147 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
2148 __default_interpret)) {
2149 }
2150
2152#if HIP_FP8_TYPE_OCP
2153 __FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f)
2154#else
2155 __FP8_HOST__ __hip_fp8_e4m3(const __half f)
2156#endif
2157 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
2158 __default_interpret)) {
2159 }
2160
2162#if HIP_FP8_TYPE_OCP
2163 __FP8_HOST_DEVICE__ __hip_fp8_e4m3() = default;
2164#else
2165 __FP8_HOST__ __hip_fp8_e4m3() = default;
2166#endif
2167
2170#if HIP_FP8_TYPE_OCP
2171 __FP8_HOST_DEVICE__ operator __half() const {
2172#else
2173 __FP8_HOST__ operator __half() const {
2174#endif
2175 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
2176 }
2177
2179#if HIP_FP8_TYPE_OCP
2180 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
2181#else
2182 __FP8_HOST__ operator __hip_bfloat16() const {
2183#endif
2184 float f = *this;
2185 return __hip_bfloat16(f);
2186 }
2187
2189#if HIP_FP8_TYPE_OCP
2190 __FP8_HOST_DEVICE__ operator bool() const {
2191#else
2192 __FP8_HOST__ operator bool() const {
2193#endif
2194 // it can be 0x00 (+0.0) since 0x80 will be nan
2195 return !(static_cast<unsigned short>(__x) == 0 || static_cast<unsigned short>(__x) == 0x80);
2196 }
2197
2199#if HIP_FP8_TYPE_OCP
2200 __FP8_HOST_DEVICE__ operator char() const {
2201#else
2202 __FP8_HOST__ operator char() const {
2203#endif
2204 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2205 return 0;
2206 }
2207
2208 auto fval = internal::cast_from_f8<float, false>(__x, __wm, __we);
2209 auto llval = static_cast<long long>(fval);
2210 if (llval <= __HIP_CHAR_MIN) {
2211 return __HIP_CHAR_MIN;
2212 } else if (llval >= __HIP_CHAR_MAX) {
2213 return __HIP_CHAR_MAX;
2214 }
2215 return static_cast<char>(fval);
2216 }
2217
2219#if HIP_FP8_TYPE_OCP
2220 __FP8_HOST_DEVICE__ operator double() const {
2221#else
2222 __FP8_HOST__ operator double() const {
2223#endif
2224 return internal::cast_from_f8<double, false>(__x, __wm, __we);
2225 }
2226
2228#if HIP_FP8_TYPE_OCP
2229 __FP8_HOST_DEVICE__ operator float() const {
2230#else
2231 __FP8_HOST__ operator float() const {
2232#endif
2233#if HIP_FP8_CVT_FAST_PATH
2234 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2235#else
2236 return internal::cast_from_f8<float, false>(__x, __wm, __we);
2237#endif
2238 }
2239
2241#if HIP_FP8_TYPE_OCP
2242 __FP8_HOST_DEVICE__ operator int() const {
2243#else
2244 __FP8_HOST__ operator int() const {
2245#endif
2246 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2247 return 0;
2248 }
2249
2250 float fval = *this;
2251 return static_cast<int>(fval);
2252 }
2253
2255#if HIP_FP8_TYPE_OCP
2256 __FP8_HOST_DEVICE__ operator long int() const {
2257#else
2258 __FP8_HOST__ operator long int() const {
2259#endif
2260 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2261 return 0;
2262 }
2263
2264 float fval = *this;
2265 return static_cast<long>(fval);
2266 }
2267
2269#if HIP_FP8_TYPE_OCP
2270 __FP8_HOST_DEVICE__ operator long long int() const {
2271#else
2272 __FP8_HOST__ operator long long int() const {
2273#endif
2274 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2275 return 0;
2276 }
2277
2278 float fval = *this;
2279 return static_cast<long long>(fval);
2280 }
2281
2283#if HIP_FP8_TYPE_OCP
2284 __FP8_HOST_DEVICE__ operator short int() const {
2285#else
2286 __FP8_HOST__ operator short int() const {
2287#endif
2288 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2289 return 0;
2290 }
2291
2292 float fval = *this;
2293 auto llval = static_cast<long long>(fval);
2294 if (llval <= __HIP_SHRT_MIN) {
2295 return __HIP_SHRT_MIN;
2296 } else if (llval >= __HIP_SHRT_MAX) {
2297 return __HIP_SHRT_MAX;
2298 }
2299 return static_cast<short>(fval);
2300 }
2301
2303#if HIP_FP8_TYPE_OCP
2304 __FP8_HOST_DEVICE__ operator signed char() const {
2305#else
2306 __FP8_HOST__ operator signed char() const {
2307#endif
2308 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2309 return 0;
2310 }
2311
2312 float fval = *this;
2313 auto llval = static_cast<long long>(fval);
2314 if (llval <= __HIP_SCHAR_MIN) {
2315 return __HIP_SCHAR_MIN;
2316 } else if (llval >= __HIP_SCHAR_MAX) {
2317 return __HIP_SCHAR_MAX;
2318 }
2319 return static_cast<signed char>(fval);
2320 }
2321
2323#if HIP_FP8_TYPE_OCP
2324 __FP8_HOST_DEVICE__ operator unsigned char() const {
2325#else
2326 __FP8_HOST__ operator unsigned char() const {
2327#endif
2328 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2329 return 0;
2330 }
2331
2332 float fval = *this;
2333 auto llval = static_cast<long long>(fval);
2334 if (llval <= 0) {
2335 return 0;
2336 } else if (llval >= __HIP_UCHAR_MAX) {
2337 return __HIP_UCHAR_MAX;
2338 }
2339 return static_cast<unsigned char>(fval);
2340 }
2341
2343#if HIP_FP8_TYPE_OCP
2344 __FP8_HOST_DEVICE__ operator unsigned int() const {
2345#else
2346 __FP8_HOST__ operator unsigned int() const {
2347#endif
2348 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2349 return 0;
2350 }
2351
2352 float fval = *this;
2353 auto llval = static_cast<long long>(fval);
2354 if (llval <= 0) {
2355 return 0;
2356 }
2357 return static_cast<unsigned int>(fval);
2358 }
2359
2361#if HIP_FP8_TYPE_OCP
2362 __FP8_HOST_DEVICE__ operator unsigned long int() const {
2363#else
2364 __FP8_HOST__ operator unsigned long int() const {
2365#endif
2366 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2367 return 0;
2368 }
2369
2370 float fval = *this;
2371 auto llval = static_cast<long long>(fval);
2372 if (llval <= 0) {
2373 return 0;
2374 }
2375 return static_cast<unsigned long>(fval);
2376 }
2377
2379#if HIP_FP8_TYPE_OCP
2380 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
2381#else
2382 __FP8_HOST__ operator unsigned long long int() const {
2383#endif
2384 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2385 return 0;
2386 }
2387
2388 float fval = *this;
2389 auto llval = static_cast<long long>(fval);
2390 if (llval <= 0) {
2391 return 0;
2392 }
2393 return static_cast<unsigned long long>(fval);
2394 }
2395
2397#if HIP_FP8_TYPE_OCP
2398 __FP8_HOST_DEVICE__ operator unsigned short int() const {
2399#else
2400 __FP8_HOST__ operator unsigned short int() const {
2401#endif
2402 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2403 return 0;
2404 }
2405
2406 float fval = *this;
2407 auto llval = static_cast<long long>(fval);
2408 if (llval <= 0) {
2409 return 0;
2410 }
2411 return static_cast<unsigned short>(fval);
2412 }
2413};
2414
2421 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2422 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
2423 static constexpr unsigned int __we = 4;
2424 static constexpr unsigned int __wm = 3;
2425
2428#if HIP_FP8_TYPE_OCP
2429 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val)
2430#else
2431 __FP8_HOST__ __hip_fp8x2_e4m3(const double2 val)
2432#endif
2433 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {
2434 }
2435
2437#if HIP_FP8_TYPE_OCP
2438 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val)
2439#else
2440 __FP8_HOST__ __hip_fp8x2_e4m3(const float2 val)
2441#endif
2442 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {
2443 }
2444
2446#if HIP_FP8_TYPE_OCP
2447 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
2448#else
2449 __FP8_HOST__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
2450#endif
2451 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
2452 }
2453
2455#if HIP_FP8_TYPE_OCP
2456 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val)
2457#else
2458 __FP8_HOST__ __hip_fp8x2_e4m3(const __half2 val)
2459#endif
2460 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
2461 }
2462
2464#if HIP_FP8_TYPE_OCP
2465 __FP8_HOST_DEVICE__ __hip_fp8x2_e4m3() = default;
2466#else
2467 __FP8_HOST__ __hip_fp8x2_e4m3() = default;
2468#endif
2469
2471#if HIP_FP8_TYPE_OCP
2472 __FP8_HOST_DEVICE__ operator __half2() const {
2473#else
2474 __FP8_HOST__ operator __half2() const {
2475#endif
2476 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
2477 }
2478
2480#if HIP_FP8_TYPE_OCP
2481 __FP8_HOST_DEVICE__ operator float2() const {
2482#else
2483 __FP8_HOST__ operator float2() const {
2484#endif
2485#if HIP_FP8_CVT_FAST_PATH
2486 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
2487#else
2488 return float2(internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x & 0xFF),
2489 __wm, __we),
2490 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x >> 8),
2491 __wm, __we));
2492#endif
2493 }
2494};
2495
2502 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2503 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E4M3;
2504 static constexpr unsigned int __we = 4;
2505 static constexpr unsigned int __wm = 3;
2506
2509#if HIP_FP8_TYPE_OCP
2510 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val)
2511#else
2512 __FP8_HOST__ __hip_fp8x4_e4m3(const double4 val)
2513#endif
2514 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
2515 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2516 val.x, __default_saturation, __default_interpret)) |
2517 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2518 val.y, __default_saturation, __default_interpret))
2519 << 8 |
2520 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2521 val.z, __default_saturation, __default_interpret))
2522 << 16 |
2523 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
2524 val.w, __default_saturation, __default_interpret))
2525 << 24))} {
2526 }
2527
2529#if HIP_FP8_TYPE_OCP
2530 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val)
2531#else
2532 __FP8_HOST__ __hip_fp8x4_e4m3(const float4 val)
2533#endif
2534 : __x{reinterpret_cast<__hip_fp8x4_storage_t>(
2535 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2536 val.x, __default_saturation, __default_interpret)) |
2537 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2538 val.y, __default_saturation, __default_interpret))
2539 << 8 |
2540 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2541 val.z, __default_saturation, __default_interpret))
2542 << 16 |
2543 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
2544 val.w, __default_saturation, __default_interpret))
2545 << 24))} {
2546 }
2547
2549#if HIP_FP8_TYPE_OCP
2550 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
2551#else
2552 __FP8_HOST__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
2553#endif
2554 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
2555 reinterpret_cast<unsigned short>(
2556 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
2557 reinterpret_cast<unsigned short>(
2558 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
2559 << 16))) {
2560 }
2561
2563#if HIP_FP8_TYPE_OCP
2564 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
2565#else
2566 __FP8_HOST__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
2567#endif
2568 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
2569 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2570 high, __default_saturation, __default_interpret)) |
2571 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
2572 low, __default_saturation, __default_interpret))
2573 << 16))) {
2574 }
2575
2577#if HIP_FP8_TYPE_OCP
2578 __FP8_HOST_DEVICE__ __hip_fp8x4_e4m3() = default;
2579#else
2580 __FP8_HOST__ __hip_fp8x4_e4m3() = default;
2581#endif
2582
2584#if HIP_FP8_TYPE_OCP
2585 __FP8_HOST_DEVICE__ operator float4() const {
2586#else
2587 __FP8_HOST__ operator float4() const {
2588#endif
2589 auto x = __x; // bypass const
2590 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
2591 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
2592#if HIP_FP8_CVT_FAST_PATH
2593 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
2594 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
2595#else
2596 float2 high = float2(internal::cast_from_f8<float, false>(
2597 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we),
2598 internal::cast_from_f8<float, false>(
2599 static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8), __wm, __we));
2600 float2 low = float2(internal::cast_from_f8<float, false>(
2601 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we),
2602 internal::cast_from_f8<float, false>(
2603 static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm, __we));
2604#endif
2605 return float4(low.x, low.y, high.x, high.y);
2606 }
2607};
2608
2615 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2616 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
2617 static constexpr unsigned int __we = 5;
2618 static constexpr unsigned int __wm = 2;
2619
2620
2621 // TODO: SWDEV-452411
2622 // Add cast from unsigned long long, long long to fp8
2623
2626#if HIP_FP8_TYPE_OCP
2627 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val)
2628#else
2629 __FP8_HOST__ __hip_fp8_e5m2(const long int val)
2630#endif
2631 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2632 __default_interpret)) {
2633 }
2634
2636#if HIP_FP8_TYPE_OCP
2637 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val)
2638#else
2639 __FP8_HOST__ __hip_fp8_e5m2(const int val)
2640#endif
2641 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2642 __default_interpret)) {
2643 }
2644
2646#if HIP_FP8_TYPE_OCP
2647 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val)
2648#else
2649 __FP8_HOST__ __hip_fp8_e5m2(const short int val)
2650#endif
2651 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2652 __default_interpret)) {
2653 }
2654
2656#if HIP_FP8_TYPE_OCP
2657 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val)
2658#else
2659 __FP8_HOST__ __hip_fp8_e5m2(const unsigned long int val)
2660#endif
2661 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2662 __default_interpret)) {
2663 }
2664
2666#if HIP_FP8_TYPE_OCP
2667 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val)
2668#else
2669 __FP8_HOST__ __hip_fp8_e5m2(const unsigned int val)
2670#endif
2671 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2672 __default_interpret)) {
2673 }
2674
2676#if HIP_FP8_TYPE_OCP
2677 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val)
2678#else
2679 __FP8_HOST__ __hip_fp8_e5m2(const unsigned short int val)
2680#endif
2681 : __x(__hip_cvt_float_to_fp8(static_cast<float>(val), __default_saturation,
2682 __default_interpret)) {
2683 }
2684
2686#if HIP_FP8_TYPE_OCP
2687 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f)
2688#else
2689 __FP8_HOST__ __hip_fp8_e5m2(const double f)
2690#endif
2691 : __x(__hip_cvt_double_to_fp8(f, __default_saturation, __default_interpret)) {
2692 }
2693
2695#if HIP_FP8_TYPE_OCP
2696 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f)
2697#else
2698 __FP8_HOST__ __hip_fp8_e5m2(const float f)
2699#endif
2700 : __x(__hip_cvt_float_to_fp8(f, __default_saturation, __default_interpret)) {
2701 }
2702
2704#if HIP_FP8_TYPE_OCP
2705 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f)
2706#else
2707 __FP8_HOST__ __hip_fp8_e5m2(const __hip_bfloat16 f)
2708#endif
2709 : __x(__hip_cvt_float_to_fp8(static_cast<float>(f), __default_saturation,
2710 __default_interpret)) {
2711 }
2712
2714#if HIP_FP8_TYPE_OCP
2715 __FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f)
2716#else
2717 __FP8_HOST__ __hip_fp8_e5m2(const __half f)
2718#endif
2719 : __x(__hip_cvt_halfraw_to_fp8(static_cast<__half_raw>(f), __default_saturation,
2720 __default_interpret)) {
2721 }
2722
2724#if HIP_FP8_TYPE_OCP
2725 __FP8_HOST_DEVICE__ __hip_fp8_e5m2() = default;
2726#else
2727 __FP8_HOST__ __hip_fp8_e5m2() = default;
2728#endif
2729
2731#if HIP_FP8_TYPE_OCP
2732 __FP8_HOST_DEVICE__ operator float() const {
2733#else
2734 __FP8_HOST__ operator float() const {
2735#endif
2736#if HIP_FP8_CVT_FAST_PATH
2737 return internal::cast_to_f32_from_f8(__x, __default_interpret);
2738#else
2739 return internal::cast_from_f8<float, false>(__x, __wm, __we,
2740 __default_saturation == __HIP_SATFINITE);
2741#endif
2742 }
2743
2745#if HIP_FP8_TYPE_OCP
2746 __FP8_HOST_DEVICE__ operator __half() const {
2747#else
2748 __FP8_HOST__ operator __half() const {
2749#endif
2750 return __half(__hip_cvt_fp8_to_halfraw(__x, __default_interpret));
2751 }
2752
2754#if HIP_FP8_TYPE_OCP
2755 __FP8_HOST_DEVICE__ operator __hip_bfloat16() const {
2756#else
2757 __FP8_HOST__ operator __hip_bfloat16() const {
2758#endif
2759 float f = *this;
2760 return __hip_bfloat16(f);
2761 }
2762
2764#if HIP_FP8_TYPE_OCP
2765 __FP8_HOST_DEVICE__ operator bool() const {
2766#else
2767 __FP8_HOST__ operator bool() const {
2768#endif
2769 // it can be 0x00 (+0.0) since 0x80 will be nan
2770 return !(static_cast<unsigned short>(__x) == 0 || static_cast<unsigned short>(__x) == 0x80);
2771 }
2772
2774#if HIP_FP8_TYPE_OCP
2775 __FP8_HOST_DEVICE__ operator char() const {
2776#else
2777 __FP8_HOST__ operator char() const {
2778#endif
2779 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2780 return 0;
2781 }
2782
2783 float fval = *this;
2784 auto llval = static_cast<long long>(fval);
2785 if (llval <= __HIP_CHAR_MIN) {
2786 return __HIP_CHAR_MIN;
2787 } else if (llval >= __HIP_CHAR_MAX) {
2788 return __HIP_CHAR_MAX;
2789 }
2790 return static_cast<char>(fval);
2791 }
2792
2794#if HIP_FP8_TYPE_OCP
2795 __FP8_HOST_DEVICE__ operator double() const {
2796#else
2797 __FP8_HOST__ operator double() const {
2798#endif
2799 return internal::cast_from_f8<double, false>(__x, __wm, __we,
2800 __default_saturation == __HIP_SATFINITE);
2801 }
2802
2804#if HIP_FP8_TYPE_OCP
2805 __FP8_HOST_DEVICE__ operator int() const {
2806#else
2807 __FP8_HOST__ operator int() const {
2808#endif
2809 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2810 return 0;
2811 }
2812
2813 float fval = *this;
2814 return static_cast<int>(fval);
2815 }
2816
2818#if HIP_FP8_TYPE_OCP
2819 __FP8_HOST_DEVICE__ operator long int() const {
2820#else
2821 __FP8_HOST__ operator long int() const {
2822#endif
2823 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2824 return 0;
2825 }
2826
2827 float fval = *this;
2828 return static_cast<long>(fval);
2829 }
2830
2832#if HIP_FP8_TYPE_OCP
2833 __FP8_HOST_DEVICE__ operator long long int() const {
2834#else
2835 __FP8_HOST__ operator long long int() const {
2836#endif
2837 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2838 return 0;
2839 }
2840
2841 float fval = *this;
2842 return static_cast<long long>(fval);
2843 }
2844
2846#if HIP_FP8_TYPE_OCP
2847 __FP8_HOST_DEVICE__ operator short int() const {
2848#else
2849 __FP8_HOST__ operator short int() const {
2850#endif
2851 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2852 return 0;
2853 }
2854
2855 float fval = *this;
2856 auto llval = static_cast<long long>(fval);
2857 if (llval <= __HIP_SHRT_MIN) {
2858 return __HIP_SHRT_MIN;
2859 } else if (llval >= __HIP_SHRT_MAX) {
2860 return __HIP_SHRT_MAX;
2861 }
2862 return static_cast<short>(fval);
2863 }
2864
2866#if HIP_FP8_TYPE_OCP
2867 __FP8_HOST_DEVICE__ operator signed char() const {
2868#else
2869 __FP8_HOST__ operator signed char() const {
2870#endif
2871 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2872 return 0;
2873 }
2874
2875 float fval = *this;
2876 auto llval = static_cast<long long>(fval);
2877 if (llval <= __HIP_SCHAR_MIN) {
2878 return __HIP_SCHAR_MIN;
2879 } else if (llval >= __HIP_SCHAR_MAX) {
2880 return __HIP_SCHAR_MAX;
2881 }
2882 return static_cast<signed char>(fval);
2883 }
2884
2886#if HIP_FP8_TYPE_OCP
2887 __FP8_HOST_DEVICE__ operator unsigned char() const {
2888#else
2889 __FP8_HOST__ operator unsigned char() const {
2890#endif
2891 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2892 return 0;
2893 }
2894
2895 float fval = *this;
2896 auto llval = static_cast<long long>(fval);
2897 if (llval <= 0) {
2898 return 0;
2899 } else if (llval >= __HIP_UCHAR_MAX) {
2900 return __HIP_UCHAR_MAX;
2901 }
2902 return static_cast<unsigned char>(fval);
2903 }
2904
2906#if HIP_FP8_TYPE_OCP
2907 __FP8_HOST_DEVICE__ operator unsigned int() const {
2908#else
2909 __FP8_HOST__ operator unsigned int() const {
2910#endif
2911 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2912 return 0;
2913 }
2914
2915 float fval = *this;
2916 auto llval = static_cast<long long>(fval);
2917 if (llval <= 0) {
2918 return 0;
2919 }
2920 return static_cast<unsigned int>(fval);
2921 }
2922
2924#if HIP_FP8_TYPE_OCP
2925 __FP8_HOST_DEVICE__ operator unsigned long int() const {
2926#else
2927 __FP8_HOST__ operator unsigned long int() const {
2928#endif
2929 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2930 return 0;
2931 }
2932
2933 float fval = *this;
2934 auto llval = static_cast<long long>(fval);
2935 if (llval <= 0) {
2936 return 0;
2937 }
2938 return static_cast<unsigned long>(fval);
2939 }
2940
2942#if HIP_FP8_TYPE_OCP
2943 __FP8_HOST_DEVICE__ operator unsigned long long int() const {
2944#else
2945 __FP8_HOST__ operator unsigned long long int() const {
2946#endif
2947 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2948 return 0;
2949 }
2950
2951 float fval = *this;
2952 auto llval = static_cast<long long>(fval);
2953 if (llval <= 0) {
2954 return 0;
2955 }
2956 return static_cast<unsigned long long>(fval);
2957 }
2958
2960#if HIP_FP8_TYPE_OCP
2961 __FP8_HOST_DEVICE__ operator unsigned short int() const {
2962#else
2963 __FP8_HOST__ operator unsigned short int() const {
2964#endif
2965 if (internal::hip_fp8_ocp_is_nan(__x, __default_interpret)) {
2966 return 0;
2967}
2968
2969 float fval = *this;
2970 auto llval = static_cast<long long>(fval);
2971 if (llval <= 0) {
2972 return 0;
2973 }
2974 return static_cast<unsigned short>(fval);
2975 }
2976};
2977
2984 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
2985 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
2986 static constexpr unsigned int __we = 5;
2987 static constexpr unsigned int __wm = 2;
2988
2991#if HIP_FP8_TYPE_OCP
2992 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val)
2993#else
2994 __FP8_HOST__ __hip_fp8x2_e5m2(const double2 val)
2995#endif
2996 : __x(__hip_cvt_double2_to_fp8x2(val, __default_saturation, __default_interpret)) {
2997 }
2998
3000#if HIP_FP8_TYPE_OCP
3001 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val)
3002#else
3003 __FP8_HOST__ __hip_fp8x2_e5m2(const float2 val)
3004#endif
3005 : __x(__hip_cvt_float2_to_fp8x2(val, __default_saturation, __default_interpret)) {
3006 }
3007
3009#if HIP_FP8_TYPE_OCP
3010 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
3011#else
3012 __FP8_HOST__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
3013#endif
3014 : __x(__hip_cvt_bfloat16raw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
3015 }
3016
3018#if HIP_FP8_TYPE_OCP
3019 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val)
3020#else
3021 __FP8_HOST__ __hip_fp8x2_e5m2(const __half2 val)
3022#endif
3023 : __x(__hip_cvt_halfraw2_to_fp8x2(val, __default_saturation, __default_interpret)) {
3024 }
3025
3027#if HIP_FP8_TYPE_OCP
3028 __FP8_HOST_DEVICE__ __hip_fp8x2_e5m2() = default;
3029#else
3030 __FP8_HOST__ __hip_fp8x2_e5m2() = default;
3031#endif
3032
3034#if HIP_FP8_TYPE_OCP
3035 __FP8_HOST_DEVICE__ operator __half2() const {
3036#else
3037 __FP8_HOST__ operator __half2() const {
3038#endif
3039 return __half2(__hip_cvt_fp8x2_to_halfraw2(__x, __default_interpret));
3040 }
3041
3043#if HIP_FP8_TYPE_OCP
3044 __FP8_HOST_DEVICE__ operator float2() const {
3045#else
3046 __FP8_HOST__ operator float2() const {
3047#endif
3048#if HIP_FP8_CVT_FAST_PATH
3049 return internal::cast_to_f32x2_from_f8x2(__x, __default_interpret);
3050#else
3051 return float2(
3052 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x & 0xFF), __wm,
3053 __we, __default_saturation == __HIP_SATFINITE),
3054 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(__x >> 8), __wm, __we,
3055 __default_saturation == __HIP_SATFINITE));
3056#endif
3057 }
3058};
3059
3066 static constexpr __hip_saturation_t __default_saturation = __HIP_SATFINITE;
3067 static constexpr __hip_fp8_interpretation_t __default_interpret = __HIP_E5M2;
3068 static constexpr unsigned int __we = 5;
3069 static constexpr unsigned int __wm = 2;
3070
3072#if HIP_FP8_TYPE_OCP
3073 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val)
3074#else
3075 __FP8_HOST__ __hip_fp8x4_e5m2(const double4 val)
3076#endif
3077 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
3078 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
3079 val.x, __default_saturation, __default_interpret)) |
3080 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
3081 val.y, __default_saturation, __default_interpret))
3082 << 8 |
3083 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
3084 val.z, __default_saturation, __default_interpret))
3085 << 16 |
3086 reinterpret_cast<unsigned char>(__hip_cvt_double_to_fp8(
3087 val.w, __default_saturation, __default_interpret))
3088 << 24))) {
3089 }
3090
3092#if HIP_FP8_TYPE_OCP
3093 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val)
3094#else
3095 __FP8_HOST__ __hip_fp8x4_e5m2(const float4 val)
3096#endif
3097 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
3098 static_cast<unsigned int>(reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
3099 val.x, __default_saturation, __default_interpret)) |
3100 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
3101 val.y, __default_saturation, __default_interpret))
3102 << 8 |
3103 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
3104 val.z, __default_saturation, __default_interpret))
3105 << 16 |
3106 reinterpret_cast<unsigned char>(__hip_cvt_float_to_fp8(
3107 val.w, __default_saturation, __default_interpret))
3108 << 24))) {
3109 }
3110
3112#if HIP_FP8_TYPE_OCP
3113 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
3114#else
3115 __FP8_HOST__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
3116#endif
3117 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(static_cast<unsigned int>(
3118 reinterpret_cast<unsigned short>(
3119 __hip_cvt_bfloat16raw2_to_fp8x2(high, __default_saturation, __default_interpret)) |
3120 reinterpret_cast<unsigned short>(
3121 __hip_cvt_bfloat16raw2_to_fp8x2(low, __default_saturation, __default_interpret))
3122 << 16))) {
3123 }
3124
3126#if HIP_FP8_TYPE_OCP
3127 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
3128#else
3129 __FP8_HOST__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
3130#endif
3131 : __x(reinterpret_cast<__hip_fp8x4_storage_t>(
3132 static_cast<unsigned int>(reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
3133 high, __default_saturation, __default_interpret)) |
3134 reinterpret_cast<unsigned short>(__hip_cvt_halfraw2_to_fp8x2(
3135 low, __default_saturation, __default_interpret))
3136 << 16))) {
3137 }
3138
3139 /* default construct fp8x4 e5m2 */
3140#if HIP_FP8_TYPE_OCP
3141 __FP8_HOST_DEVICE__ __hip_fp8x4_e5m2() = default;
3142#else
3143 __FP8_HOST__ __hip_fp8x4_e5m2() = default;
3144#endif
3145
3147#if HIP_FP8_TYPE_OCP
3148 __FP8_HOST_DEVICE__ operator float4() const {
3149#else
3150 __FP8_HOST__ operator float4() const {
3151#endif
3152 auto x = __x; // bypass const
3153 auto fp8x2_low = *reinterpret_cast<__hip_fp8x2_storage_t*>(&x); // Little E
3154 auto fp8x2_high = *(reinterpret_cast<__hip_fp8x2_storage_t*>(&x) + 1);
3155#if HIP_FP8_CVT_FAST_PATH
3156 float2 high = internal::cast_to_f32x2_from_f8x2(fp8x2_high, __default_interpret);
3157 float2 low = internal::cast_to_f32x2_from_f8x2(fp8x2_low, __default_interpret);
3158#else
3159 float2 high = float2(
3160 internal::cast_from_f8<float, false>(
3161 static_cast<__hip_fp8_storage_t>((fp8x2_high << 8) >> 8), __wm, __we,
3162 __default_saturation == __HIP_SATFINITE),
3163 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(fp8x2_high >> 8),
3164 __wm, __we, __default_saturation == __HIP_SATFINITE));
3165 float2 low = float2(
3166 internal::cast_from_f8<float, false>(
3167 static_cast<__hip_fp8_storage_t>((fp8x2_low << 8) >> 8), __wm, __we,
3168 __default_saturation == __HIP_SATFINITE),
3169 internal::cast_from_f8<float, false>(static_cast<__hip_fp8_storage_t>(fp8x2_low >> 8), __wm,
3170 __we, __default_saturation == __HIP_SATFINITE));
3171#endif
3172 return float4(low.x, low.y, high.x, high.y);
3173 }
3174};
3175#endif // ENABLE_OCP_HIPRTC
3176#endif // _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP8_H_
__hip_saturation_t
Describes saturation behavior.
Definition amd_hip_fp8.h:132
@ __HIP_SATFINITE
Definition amd_hip_fp8.h:134
@ __HIP_NOSAT
Definition amd_hip_fp8.h:133
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_double_to_fp8(const double d, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double to __hip_fp8_storage_t
Definition amd_hip_fp8.h:741
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_bfloat16raw2_to_fp8x2(const __hip_bfloat162_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:823
__hip_fp8_interpretation_t
Describes FP8 interpretation.
Definition amd_hip_fp8.h:122
@ __HIP_E4M3_FNUZ
Definition amd_hip_fp8.h:125
@ __HIP_E5M2
Definition amd_hip_fp8.h:124
@ __HIP_E4M3
Definition amd_hip_fp8.h:123
@ __HIP_E5M2_FNUZ
Definition amd_hip_fp8.h:126
__FP8_HOST_DEVICE_STATIC__ __half_raw __hip_cvt_fp8_to_halfraw(const __hip_fp8_storage_t x, const __hip_fp8_interpretation_t interp)
convert __hip_fp8_storage_t to __half_raw
Definition amd_hip_fp8.h:847
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_double2_to_fp8x2(const double2 d2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert double2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:771
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_float2_to_fp8x2(const float2 f2, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert float2 to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:716
unsigned short int __hip_fp8x2_storage_t
type to store two fp8 numbers
Definition amd_hip_fp8.h:148
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_bfloat16raw_to_fp8(const __hip_bfloat16_raw hr, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __hip_bfloat16_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:797
__FP8_HOST_DEVICE_STATIC__ __hip_fp8x2_storage_t __hip_cvt_halfraw2_to_fp8x2(const __half2_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __half2_raw to __hip_fp8x2_storage_t
Definition amd_hip_fp8.h:923
unsigned int __hip_fp8x4_storage_t
type to store four fp8 numbers
Definition amd_hip_fp8.h:155
__FP8_HOST_DEVICE_STATIC__ __half2_raw __hip_cvt_fp8x2_to_halfraw2(const __hip_fp8x2_storage_t x, const __hip_fp8_interpretation_t interp)
convert __hip_fp8x2_storage_t to __half2_raw
Definition amd_hip_fp8.h:875
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_halfraw_to_fp8(const __half_raw x, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert __half_raw to __hip_fp8_storage_t
Definition amd_hip_fp8.h:901
unsigned char __hip_fp8_storage_t
type to store single fp8 number
Definition amd_hip_fp8.h:141
__FP8_HOST_DEVICE_STATIC__ __hip_fp8_storage_t __hip_cvt_float_to_fp8(const float f, const __hip_saturation_t sat, const __hip_fp8_interpretation_t interp)
convert float to __hip_fp8_storage_t
Definition amd_hip_fp8.h:682
hip_bf16.h provides struct for __hip_bfloat16 types
struct representing single fp8 number with e4m3 interpretation
Definition amd_hip_fp8.h:938
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:990
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const double f)
Definition amd_hip_fp8.h:1010
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __half f)
Definition amd_hip_fp8.h:1038
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const short int val)
Definition amd_hip_fp8.h:970
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const float f)
Definition amd_hip_fp8.h:1019
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1028
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const long int val)
Definition amd_hip_fp8.h:950
static constexpr __hip_saturation_t __default_saturation
raw storage of fp8 number
Definition amd_hip_fp8.h:940
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:980
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1000
__FP8_HOST_DEVICE__ __hip_fp8_e4m3_fnuz(const int val)
Definition amd_hip_fp8.h:960
struct representing two fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1303
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const float2 val)
Definition amd_hip_fp8.h:1321
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const double2 val)
Definition amd_hip_fp8.h:1312
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1339
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1330
struct representing four fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:1383
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1432
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const double4 val)
Definition amd_hip_fp8.h:1392
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:1446
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3_fnuz(const float4 val)
Definition amd_hip_fp8.h:1412
struct representing one fp8 number with e5m2 interpretation
Definition amd_hip_fp8.h:1495
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned short int val)
Definition amd_hip_fp8.h:1558
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:1586
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const long int val)
Definition amd_hip_fp8.h:1508
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned int val)
Definition amd_hip_fp8.h:1548
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const __half f)
Definition amd_hip_fp8.h:1596
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const int val)
Definition amd_hip_fp8.h:1518
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const double f)
Definition amd_hip_fp8.h:1568
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const short int val)
Definition amd_hip_fp8.h:1528
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const float f)
Definition amd_hip_fp8.h:1577
__FP8_HOST_DEVICE__ __hip_fp8_e5m2_fnuz(const unsigned long int val)
Definition amd_hip_fp8.h:1538
struct representing two fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1861
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const float2 val)
Definition amd_hip_fp8.h:1879
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __half2 val)
Definition amd_hip_fp8.h:1897
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:1888
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz(const double2 val)
Definition amd_hip_fp8.h:1870
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2_fnuz()=default
struct representing four fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:1941
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:1990
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const float4 val)
Definition amd_hip_fp8.h:1970
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2004
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2_fnuz(const double4 val)
Definition amd_hip_fp8.h:1950
struct representing ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2058
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:2143
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const long int val)
Definition amd_hip_fp8.h:2070
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned short int val)
Definition amd_hip_fp8.h:2115
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const float f)
Definition amd_hip_fp8.h:2134
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const short int val)
Definition amd_hip_fp8.h:2089
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const __half f)
Definition amd_hip_fp8.h:2153
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned long int val)
Definition amd_hip_fp8.h:2095
__FP8_HOST_DEVICE__ __hip_fp8_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const int val)
Definition amd_hip_fp8.h:2080
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const unsigned int val)
Definition amd_hip_fp8.h:2105
__FP8_HOST_DEVICE__ __hip_fp8_e4m3(const double f)
Definition amd_hip_fp8.h:2125
struct representing two ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2419
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const float2 val)
Definition amd_hip_fp8.h:2438
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __half2 val)
Definition amd_hip_fp8.h:2456
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:2447
__FP8_HOST_DEVICE__ __hip_fp8x2_e4m3(const double2 val)
Definition amd_hip_fp8.h:2429
struct representing four ocp fp8 numbers with e4m3 interpretation
Definition amd_hip_fp8.h:2500
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const double4 val)
Definition amd_hip_fp8.h:2510
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const float4 val)
Definition amd_hip_fp8.h:2530
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3()=default
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:2564
__FP8_HOST_DEVICE__ __hip_fp8x4_e4m3(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:2550
struct representing ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2613
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const int val)
Definition amd_hip_fp8.h:2637
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const short int val)
Definition amd_hip_fp8.h:2647
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned int val)
Definition amd_hip_fp8.h:2667
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const float f)
Definition amd_hip_fp8.h:2696
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned short int val)
Definition amd_hip_fp8.h:2677
__FP8_HOST_DEVICE__ __hip_fp8_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const unsigned long int val)
Definition amd_hip_fp8.h:2657
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const long int val)
Definition amd_hip_fp8.h:2627
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const double f)
Definition amd_hip_fp8.h:2687
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __hip_bfloat16 f)
Definition amd_hip_fp8.h:2705
__FP8_HOST_DEVICE__ __hip_fp8_e5m2(const __half f)
Definition amd_hip_fp8.h:2715
struct representing two ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:2982
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __half2 val)
Definition amd_hip_fp8.h:3019
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const double2 val)
Definition amd_hip_fp8.h:2992
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2()=default
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const float2 val)
Definition amd_hip_fp8.h:3001
__FP8_HOST_DEVICE__ __hip_fp8x2_e5m2(const __hip_bfloat162 val)
Definition amd_hip_fp8.h:3010
struct representing four ocp fp8 numbers with e5m2 interpretation
Definition amd_hip_fp8.h:3064
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __hip_bfloat162 low, const __hip_bfloat162 high)
Definition amd_hip_fp8.h:3113
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const double4 val)
Definition amd_hip_fp8.h:3073
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const __half2 low, const __half2 high)
Definition amd_hip_fp8.h:3127
__FP8_HOST_DEVICE__ __hip_fp8x4_e5m2(const float4 val)
Definition amd_hip_fp8.h:3093
Definition amd_hip_vector_types.h:2035
Definition amd_hip_vector_types.h:2042
Definition amd_hip_vector_types.h:2072
Definition amd_hip_vector_types.h:2079
Definition hip_fp16_gcc.h:7
Definition hip_fp16_gcc.h:11