HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
hip_fp16_gcc.h
1#pragma once
2
3#if defined(__cplusplus)
4 #include <cstring>
5#endif
6
7struct __half_raw {
8 unsigned short x;
9};
10
12 unsigned short x;
13 unsigned short y;
14};
15
16#if defined(__cplusplus)
17 struct __half;
18
19 __half __float2half(float);
20 float __half2float(__half);
21
22 // BEGIN STRUCT __HALF
23 struct __half {
24 protected:
25 unsigned short __x;
26 public:
27 // CREATORS
28 __half() = default;
29 __half(const __half_raw& x) : __x{x.x} {}
30 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
31 __half(float x) : __x{__float2half(x).__x} {}
32 __half(double x) : __x{__float2half(x).__x} {}
33 #endif
34 __half(const __half&) = default;
35 __half(__half&&) = default;
36 ~__half() = default;
37
38 // MANIPULATORS
39 __half& operator=(const __half&) = default;
40 __half& operator=(__half&&) = default;
41 __half& operator=(const __half_raw& x) { __x = x.x; return *this; }
42 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
43 __half& operator=(float x)
44 {
45 __x = __float2half(x).__x;
46 return *this;
47 }
48 __half& operator=(double x)
49 {
50 return *this = static_cast<float>(x);
51 }
52 #endif
53
54 // ACCESSORS
55 operator float() const { return __half2float(*this); }
56 operator __half_raw() const { return __half_raw{__x}; }
57 };
58 // END STRUCT __HALF
59
60 // BEGIN STRUCT __HALF2
61 struct __half2 {
62 public:
63 __half x;
64 __half y;
65
66 // CREATORS
67 __half2() = default;
68 __half2(const __half2_raw& ix)
69 :
70 x{reinterpret_cast<const __half&>(ix.x)},
71 y{reinterpret_cast<const __half&>(ix.y)}
72 {}
73 __half2(const __half& ix, const __half& iy) : x{ix}, y{iy} {}
74 __half2(const __half2&) = default;
75 __half2(__half2&&) = default;
76 ~__half2() = default;
77
78 // MANIPULATORS
79 __half2& operator=(const __half2&) = default;
80 __half2& operator=(__half2&&) = default;
81 __half2& operator=(const __half2_raw& ix)
82 {
83 x = reinterpret_cast<const __half_raw&>(ix.x);
84 y = reinterpret_cast<const __half_raw&>(ix.y);
85 return *this;
86 }
87
88 // ACCESSORS
89 operator __half2_raw() const
90 {
91 return __half2_raw{
92 reinterpret_cast<const unsigned short&>(x),
93 reinterpret_cast<const unsigned short&>(y)};
94 }
95 };
96 // END STRUCT __HALF2
97
98 inline
99 unsigned short __internal_float2half(
100 float flt, unsigned int& sgn, unsigned int& rem)
101 {
102 unsigned int x{};
103 std::memcpy(&x, &flt, sizeof(flt));
104
105 unsigned int u = (x & 0x7fffffffU);
106 sgn = ((x >> 16) & 0x8000U);
107
108 // NaN/+Inf/-Inf
109 if (u >= 0x7f800000U) {
110 rem = 0;
111 return static_cast<unsigned short>(
112 (u == 0x7f800000U) ? (sgn | 0x7c00U) : 0x7fffU);
113 }
114 // Overflows
115 if (u > 0x477fefffU) {
116 rem = 0x80000000U;
117 return static_cast<unsigned short>(sgn | 0x7bffU);
118 }
119 // Normal numbers
120 if (u >= 0x38800000U) {
121 rem = u << 19;
122 u -= 0x38000000U;
123 return static_cast<unsigned short>(sgn | (u >> 13));
124 }
125 // +0/-0
126 if (u < 0x33000001U) {
127 rem = u;
128 return static_cast<unsigned short>(sgn);
129 }
130 // Denormal numbers
131 unsigned int exponent = u >> 23;
132 unsigned int mantissa = (u & 0x7fffffU);
133 unsigned int shift = 0x7eU - exponent;
134 mantissa |= 0x800000U;
135 rem = mantissa << (32 - shift);
136 return static_cast<unsigned short>(sgn | (mantissa >> shift));
137 }
138
139 inline
140 __half __float2half(float x)
141 {
142 __half_raw r;
143 unsigned int sgn{};
144 unsigned int rem{};
145 r.x = __internal_float2half(x, sgn, rem);
146 if (rem > 0x80000000U || (rem == 0x80000000U && (r.x & 0x1))) ++r.x;
147
148 return r;
149 }
150
151 inline
152 __half __float2half_rn(float x) { return __float2half(x); }
153
154 inline
155 __half __float2half_rz(float x)
156 {
157 __half_raw r;
158 unsigned int sgn{};
159 unsigned int rem{};
160 r.x = __internal_float2half(x, sgn, rem);
161
162 return r;
163 }
164
165 inline
166 __half __float2half_rd(float x)
167 {
168 __half_raw r;
169 unsigned int sgn{};
170 unsigned int rem{};
171 r.x = __internal_float2half(x, sgn, rem);
172 if (rem && sgn) ++r.x;
173
174 return r;
175 }
176
177 inline
178 __half __float2half_ru(float x)
179 {
180 __half_raw r;
181 unsigned int sgn{};
182 unsigned int rem{};
183 r.x = __internal_float2half(x, sgn, rem);
184 if (rem && !sgn) ++r.x;
185
186 return r;
187 }
188
189 inline
190 __half2 __float2half2_rn(float x)
191 {
192 return __half2{__float2half_rn(x), __float2half_rn(x)};
193 }
194
195 inline
196 __half2 __floats2half2_rn(float x, float y)
197 {
198 return __half2{__float2half_rn(x), __float2half_rn(y)};
199 }
200
201 inline
202 float __internal_half2float(unsigned short x)
203 {
204 unsigned int sign = ((x >> 15) & 1);
205 unsigned int exponent = ((x >> 10) & 0x1f);
206 unsigned int mantissa = ((x & 0x3ff) << 13);
207
208 if (exponent == 0x1fU) { /* NaN or Inf */
209 mantissa = (mantissa ? (sign = 0, 0x7fffffU) : 0);
210 exponent = 0xffU;
211 } else if (!exponent) { /* Denorm or Zero */
212 if (mantissa) {
213 unsigned int msb;
214 exponent = 0x71U;
215 do {
216 msb = (mantissa & 0x400000U);
217 mantissa <<= 1; /* normalize */
218 --exponent;
219 } while (!msb);
220 mantissa &= 0x7fffffU; /* 1.mantissa is implicit */
221 }
222 } else {
223 exponent += 0x70U;
224 }
225 unsigned int u = ((sign << 31) | (exponent << 23) | mantissa);
226 float f;
227 memcpy(&f, &u, sizeof(u));
228
229 return f;
230 }
231
232 inline
233 float __half2float(__half x)
234 {
235 return __internal_half2float(static_cast<__half_raw>(x).x);
236 }
237
238 inline
239 float __low2float(__half2 x)
240 {
241 return __internal_half2float(static_cast<__half2_raw>(x).x);
242 }
243
244 inline
245 float __high2float(__half2 x)
246 {
247 return __internal_half2float(static_cast<__half2_raw>(x).y);
248 }
249
250 #if !defined(HIP_NO_HALF)
251 using half = __half;
252 using half2 = __half2;
253 #endif
254#endif // defined(__cplusplus)
__BF16_HOST_DEVICE_STATIC__ float __low2float(const __hip_bfloat162 a)
Converts low 16 bits of __hip_bfloat162 to float and returns the result.
Definition amd_hip_bf16.h:637
__BF16_HOST_DEVICE_STATIC__ float __high2float(const __hip_bfloat162 a)
Converts high 16 bits of __hip_bfloat162 to float and returns the result.
Definition amd_hip_bf16.h:606
Definition hip_fp16_gcc.h:7
Definition hip_fp16_gcc.h:11