HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_hip_bfloat16.h
Go to the documentation of this file.
1
29#ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
30#define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BFLOAT16_H_
31
32#include "host_defines.h"
33#if defined(__HIPCC_RTC__)
34 #define __HOST_DEVICE__ __device__
35#else
36 #define __HOST_DEVICE__ __host__ __device__
37#endif
38
39#if __cplusplus < 201103L || !defined(__HIPCC__)
40
41// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
42// include a minimal definition of hip_bfloat16
43
44#include <stdint.h>
46typedef struct
47{
48 uint16_t data;
50
51#else // __cplusplus < 201103L || !defined(__HIPCC__)
52
53#include <hip/hip_runtime.h>
54
55#pragma clang diagnostic push
56#pragma clang diagnostic ignored "-Wshadow"
57struct hip_bfloat16
58{
59 __hip_uint16_t data;
60
61 enum truncate_t
62 {
63 truncate
64 };
65
66 __HOST_DEVICE__ hip_bfloat16() = default;
67
68 // round upper 16 bits of IEEE float to convert to bfloat16
69 explicit __HOST_DEVICE__ hip_bfloat16(float f)
70 : data(float_to_bfloat16(f))
71 {
72 }
73
74 explicit __HOST_DEVICE__ hip_bfloat16(float f, truncate_t)
75 : data(truncate_float_to_bfloat16(f))
76 {
77 }
78
79 // zero extend lower 16 bits of bfloat16 to convert to IEEE float
80 __HOST_DEVICE__ operator float() const
81 {
82 union
83 {
84 uint32_t int32;
85 float fp32;
86 } u = {uint32_t(data) << 16};
87 return u.fp32;
88 }
89
90 __HOST_DEVICE__ hip_bfloat16 &operator=(const float& f)
91 {
92 data = float_to_bfloat16(f);
93 return *this;
94 }
95
96 static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f)
97 {
98 hip_bfloat16 output;
99 output.data = float_to_bfloat16(f);
100 return output;
101 }
102
103 static __HOST_DEVICE__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
104 {
105 hip_bfloat16 output;
106 output.data = truncate_float_to_bfloat16(f);
107 return output;
108 }
109
110private:
111 static __HOST_DEVICE__ __hip_uint16_t float_to_bfloat16(float f)
112 {
113 union
114 {
115 float fp32;
116 uint32_t int32;
117 } u = {f};
118 if(~u.int32 & 0x7f800000)
119 {
120 // When the exponent bits are not all 1s, then the value is zero, normal,
121 // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
122 // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
123 // This causes the bfloat16's mantissa to be incremented by 1 if the 16
124 // least significant bits of the float mantissa are greater than 0x8000,
125 // or if they are equal to 0x8000 and the least significant bit of the
126 // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
127 // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
128 // has the value 0x7f, then incrementing it causes it to become 0x00 and
129 // the exponent is incremented by one, which is the next higher FP value
130 // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
131 // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
132 // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
133 // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
134 // incrementing it causes it to become an exponent of 0xFF and a mantissa
135 // of 0x00, which is Inf, the next higher value to the unrounded value.
136 u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
137 }
138 else if(u.int32 & 0xffff)
139 {
140 // When all of the exponent bits are 1, the value is Inf or NaN.
141 // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
142 // mantissa bit. Quiet NaN is indicated by the most significant mantissa
143 // bit being 1. Signaling NaN is indicated by the most significant
144 // mantissa bit being 0 but some other bit(s) being 1. If any of the
145 // lower 16 bits of the mantissa are 1, we set the least significant bit
146 // of the bfloat16 mantissa, in order to preserve signaling NaN in case
147 // the bloat16's mantissa bits are all 0.
148 u.int32 |= 0x10000; // Preserve signaling NaN
149 }
150 return __hip_uint16_t(u.int32 >> 16);
151 }
152
153 // Truncate instead of rounding, preserving SNaN
154 static __HOST_DEVICE__ __hip_uint16_t truncate_float_to_bfloat16(float f)
155 {
156 union
157 {
158 float fp32;
159 uint32_t int32;
160 } u = {f};
161 return __hip_uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
162 }
163};
164#pragma clang diagnostic pop
165
166typedef struct
167{
168 __hip_uint16_t data;
169} hip_bfloat16_public;
170
171static_assert(__hip_internal::is_standard_layout<hip_bfloat16>{},
172 "hip_bfloat16 is not a standard layout type, and thus is "
173 "incompatible with C.");
174
175static_assert(__hip_internal::is_trivial<hip_bfloat16>{},
176 "hip_bfloat16 is not a trivial type, and thus is "
177 "incompatible with C.");
178#if !defined(__HIPCC_RTC__)
179static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
180 && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
181 "internal hip_bfloat16 does not match public hip_bfloat16");
182
183inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
184{
185 return os << float(bf16);
186}
187#endif
188
189inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a)
190{
191 return a;
192}
193inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a)
194{
195 a.data ^= 0x8000;
196 return a;
197}
198inline __HOST_DEVICE__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
199{
200 return hip_bfloat16(float(a) + float(b));
201}
202inline __HOST_DEVICE__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
203{
204 return hip_bfloat16(float(a) - float(b));
205}
206inline __HOST_DEVICE__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
207{
208 return hip_bfloat16(float(a) * float(b));
209}
210inline __HOST_DEVICE__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
211{
212 return hip_bfloat16(float(a) / float(b));
213}
214inline __HOST_DEVICE__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
215{
216 return float(a) < float(b);
217}
218inline __HOST_DEVICE__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
219{
220 return float(a) == float(b);
221}
222inline __HOST_DEVICE__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
223{
224 return b < a;
225}
226inline __HOST_DEVICE__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
227{
228 return !(a > b);
229}
230inline __HOST_DEVICE__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
231{
232 return !(a == b);
233}
234inline __HOST_DEVICE__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
235{
236 return !(a < b);
237}
238inline __HOST_DEVICE__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
239{
240 return a = a + b;
241}
242inline __HOST_DEVICE__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
243{
244 return a = a - b;
245}
246inline __HOST_DEVICE__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
247{
248 return a = a * b;
249}
250inline __HOST_DEVICE__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
251{
252 return a = a / b;
253}
254inline __HOST_DEVICE__ hip_bfloat16& operator++(hip_bfloat16& a)
255{
256 return a += hip_bfloat16(1.0f);
257}
258inline __HOST_DEVICE__ hip_bfloat16& operator--(hip_bfloat16& a)
259{
260 return a -= hip_bfloat16(1.0f);
261}
262inline __HOST_DEVICE__ hip_bfloat16 operator++(hip_bfloat16& a, int)
263{
264 hip_bfloat16 orig = a;
265 ++a;
266 return orig;
267}
268inline __HOST_DEVICE__ hip_bfloat16 operator--(hip_bfloat16& a, int)
269{
270 hip_bfloat16 orig = a;
271 --a;
272 return orig;
273}
274
275namespace std
276{
277 constexpr __HOST_DEVICE__ bool isinf(hip_bfloat16 a)
278 {
279 return !(~a.data & 0x7f80) && !(a.data & 0x7f);
280 }
281 constexpr __HOST_DEVICE__ bool isnan(hip_bfloat16 a)
282 {
283 return !(~a.data & 0x7f80) && +(a.data & 0x7f);
284 }
285 constexpr __HOST_DEVICE__ bool iszero(hip_bfloat16 a)
286 {
287 return !(a.data & 0x7fff);
288 }
289}
290
291#endif // __cplusplus < 201103L || !defined(__HIPCC__)
292
293#endif // _HIP_BFLOAT16_H_
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator-=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to subtract-assign two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1017
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator+(const __hip_bfloat16 &l)
Operator to unary+ on a __hip_bfloat16 number.
Definition amd_hip_bf16.h:940
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator/(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to divide two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1026
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator-(const __hip_bfloat16 &l)
Operator to negate a __hip_bfloat16 number.
Definition amd_hip_bf16.h:955
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator/=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to divide-assign two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1035
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator*(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to multiply two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:922
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator*=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to multiply-assign two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:931
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator++(__hip_bfloat16 &l, const int)
Operator to post increment a __hip_bfloat16 number.
Definition amd_hip_bf16.h:970
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 & operator+=(__hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to add-assign two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1008
__BF16_HOST_DEVICE_STATIC__ __hip_bfloat16 operator--(__hip_bfloat16 &l, const int)
Operator to post decrement a __hip_bfloat16 number.
Definition amd_hip_bf16.h:989
__BF16_HOST_DEVICE_STATIC__ bool operator==(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform an equal compare on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1494
__BF16_HOST_DEVICE_STATIC__ bool operator!=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a not equal on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1502
__BF16_HOST_DEVICE_STATIC__ bool operator>(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a greater than on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1526
__BF16_HOST_DEVICE_STATIC__ bool operator<=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a less than equal on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1518
__BF16_HOST_DEVICE_STATIC__ bool operator<(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a less than on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1510
__BF16_HOST_DEVICE_STATIC__ bool operator>=(const __hip_bfloat16 &l, const __hip_bfloat16 &r)
Operator to perform a greater than equal on two __hip_bfloat16 numbers.
Definition amd_hip_bf16.h:1534
Struct to represent a 16 bit brain floating point number.
Definition amd_hip_bfloat16.h:47