HIP: Heterogenous-computing Interface for Portability
Loading...
Searching...
No Matches
amd_warp_sync_functions.h
1/*
2Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
3
4Permission is hereby granted, free of charge, to any person obtaining a copy
5of this software and associated documentation files (the "Software"), to deal
6in the Software without restriction, including without limitation the rights
7to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8copies of the Software, and to permit persons to whom the Software is
9furnished to do so, subject to the following conditions:
10
11The above copyright notice and this permission notice shall be included in
12all copies or substantial portions of the Software.
13
14THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20THE SOFTWARE.
21*/
22
23#pragma once
24
25// Warp sync builtins (with explicit mask argument) introduced in ROCm 6.2 as a
26// preview to allow end-users to adapt to the new interface involving 64-bit
27// masks. These are disabled by default, and can be enabled by setting the macro
28// below. The builtins will be enabled unconditionally in ROCm 6.3.
29//
30// This arrangement also applies to the __activemask() builtin defined in
31// amd_warp_functions.h.
32#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
33
34#if !defined(__HIPCC_RTC__)
35#include "amd_warp_functions.h"
36#include "hip_assert.h"
37#endif
38
39template <typename T>
40__device__ inline
41T __hip_readfirstlane(T val) {
42 // In theory, behaviour is undefined when reading from a union member other
43 // than the member that was last assigned to, but it works in practice because
44 // we rely on the compiler to do the reasonable thing.
45 union {
46 unsigned long long l;
47 T d;
48 } u;
49 u.d = val;
50 // NOTE: The builtin returns int, so we first cast it to unsigned int and only
51 // then extend it to 64 bits.
52 unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l);
53 unsigned long long upper =
54 (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32);
55 u.l = (upper << 32) | lower;
56 return u.d;
57}
58
59// When compiling for wave32 mode, ignore the upper half of the 64-bit mask.
60#define __hip_adjust_mask_for_wave32(MASK) \
61 do { \
62 if (warpSize == 32) MASK &= 0xFFFFFFFF; \
63 } while (0)
64
65// We use a macro to expand each builtin into a waterfall that implements the
66// mask semantics:
67//
68// 1. The mask argument may be divergent.
69// 2. Each active thread must have its own bit set in its own mask value.
70// 3. For a given mask value, all threads that are mentioned in the mask must
71// execute the same static instance of the builtin with the same mask.
72// 4. The union of all mask values supplied at a static instance must be equal
73// to the activemask at the program point.
74//
75// Thus, the mask argument partitions the set of currently active threads in the
76// wave into disjoint subsets that cover all active threads.
77//
78// Implementation notes:
79// ---------------------
80//
81// We implement this as a waterfall loop that executes the builtin for each
82// subset separately. The return value is a divergent value across the active
83// threads. The value for inactive threads is defined by each builtin
84// separately.
85//
86// As long as every mask value is non-zero, we don't need to check if a lane
87// specifies itself in the mask; that is done by the later assertion where all
88// chosen lanes must be in the chosen mask.
89
90#define __hip_check_mask(MASK) \
91 do { \
92 __hip_assert(MASK && "mask must be non-zero"); \
93 bool done = false; \
94 while (__any(!done)) { \
95 if (!done) { \
96 auto chosen_mask = __hip_readfirstlane(MASK); \
97 if (MASK == chosen_mask) { \
98 __hip_assert(MASK == __ballot(true) && \
99 "all threads specified in the mask" \
100 " must execute the same operation with the same mask"); \
101 done = true; \
102 } \
103 } \
104 } \
105 } while(0)
106
107#define __hip_do_sync(RETVAL, FUNC, MASK, ...) \
108 do { \
109 __hip_assert(MASK && "mask must be non-zero"); \
110 bool done = false; \
111 while (__any(!done)) { \
112 if (!done) { \
113 auto chosen_mask = __hip_readfirstlane(MASK); \
114 if (MASK == chosen_mask) { \
115 __hip_assert(MASK == __ballot(true) && \
116 "all threads specified in the mask" \
117 " must execute the same operation with the same mask"); \
118 RETVAL = FUNC(__VA_ARGS__); \
119 done = true; \
120 } \
121 } \
122 } \
123 } while(0)
124
125// __all_sync, __any_sync, __ballot_sync
126
127template <typename MaskT>
128__device__ inline
129unsigned long long __ballot_sync(MaskT mask, int predicate) {
130 static_assert(
131 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
132 "The mask must be a 64-bit integer. "
133 "Implicitly promoting a smaller integer is almost always an error.");
134 __hip_adjust_mask_for_wave32(mask);
135 __hip_check_mask(mask);
136 return __ballot(predicate) & mask;
137}
138
139template <typename MaskT>
140__device__ inline
141int __all_sync(MaskT mask, int predicate) {
142 static_assert(
143 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
144 "The mask must be a 64-bit integer. "
145 "Implicitly promoting a smaller integer is almost always an error.");
146 __hip_adjust_mask_for_wave32(mask);
147 return __ballot_sync(mask, predicate) == mask;
148}
149
150template <typename MaskT>
151__device__ inline
152int __any_sync(MaskT mask, int predicate) {
153 static_assert(
154 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
155 "The mask must be a 64-bit integer. "
156 "Implicitly promoting a smaller integer is almost always an error.");
157 __hip_adjust_mask_for_wave32(mask);
158 return __ballot_sync(mask, predicate) != 0;
159}
160
161// __match_any, __match_all and sync variants
162
163template <typename T>
164__device__ inline
165unsigned long long __match_any(T value) {
166 static_assert(
167 (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
168 (sizeof(T) == 4 || sizeof(T) == 8),
169 "T can be int, unsigned int, long, unsigned long, long long, unsigned "
170 "long long, float or double.");
171 bool done = false;
172 unsigned long long retval = 0;
173
174 while (__any(!done)) {
175 if (!done) {
176 T chosen = __hip_readfirstlane(value);
177 if (chosen == value) {
178 retval = __activemask();
179 done = true;
180 }
181 }
182 }
183
184 return retval;
185}
186
187template <typename MaskT, typename T>
188__device__ inline
189unsigned long long __match_any_sync(MaskT mask, T value) {
190 static_assert(
191 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
192 "The mask must be a 64-bit integer. "
193 "Implicitly promoting a smaller integer is almost always an error.");
194 __hip_adjust_mask_for_wave32(mask);
195 __hip_check_mask(mask);
196 return __match_any(value) & mask;
197}
198
199template <typename T>
200__device__ inline
201unsigned long long __match_all(T value, int* pred) {
202 static_assert(
203 (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
204 (sizeof(T) == 4 || sizeof(T) == 8),
205 "T can be int, unsigned int, long, unsigned long, long long, unsigned "
206 "long long, float or double.");
207 T first = __hip_readfirstlane(value);
208 if (__all(first == value)) {
209 *pred = true;
210 return __activemask();
211 } else {
212 *pred = false;
213 return 0;
214 }
215}
216
217template <typename MaskT, typename T>
218__device__ inline
219unsigned long long __match_all_sync(MaskT mask, T value, int* pred) {
220 static_assert(
221 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
222 "The mask must be a 64-bit integer. "
223 "Implicitly promoting a smaller integer is almost always an error.");
224 MaskT retval = 0;
225 __hip_adjust_mask_for_wave32(mask);
226 __hip_do_sync(retval, __match_all, mask, value, pred);
227 return retval;
228}
229
230// various variants of shfl
231
232template <typename MaskT, typename T>
233__device__ inline
234T __shfl_sync(MaskT mask, T var, int srcLane,
235 int width = __AMDGCN_WAVEFRONT_SIZE) {
236 static_assert(
237 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
238 "The mask must be a 64-bit integer. "
239 "Implicitly promoting a smaller integer is almost always an error.");
240 __hip_adjust_mask_for_wave32(mask);
241 __hip_check_mask(mask);
242 return __shfl(var, srcLane, width);
243}
244
245template <typename MaskT, typename T>
246__device__ inline
247T __shfl_up_sync(MaskT mask, T var, unsigned int delta,
248 int width = __AMDGCN_WAVEFRONT_SIZE) {
249 static_assert(
250 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
251 "The mask must be a 64-bit integer. "
252 "Implicitly promoting a smaller integer is almost always an error.");
253 __hip_adjust_mask_for_wave32(mask);
254 __hip_check_mask(mask);
255 return __shfl_up(var, delta, width);
256}
257
258template <typename MaskT, typename T>
259__device__ inline
260T __shfl_down_sync(MaskT mask, T var, unsigned int delta,
261 int width = __AMDGCN_WAVEFRONT_SIZE) {
262 static_assert(
263 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
264 "The mask must be a 64-bit integer. "
265 "Implicitly promoting a smaller integer is almost always an error.");
266 __hip_adjust_mask_for_wave32(mask);
267 __hip_check_mask(mask);
268 return __shfl_down(var, delta, width);
269}
270
271template <typename MaskT, typename T>
272__device__ inline
273T __shfl_xor_sync(MaskT mask, T var, int laneMask,
274 int width = __AMDGCN_WAVEFRONT_SIZE) {
275 static_assert(
276 __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
277 "The mask must be a 64-bit integer. "
278 "Implicitly promoting a smaller integer is almost always an error.");
279 __hip_adjust_mask_for_wave32(mask);
280 __hip_check_mask(mask);
281 return __shfl_xor(var, laneMask, width);
282}
283
284#undef __hip_do_sync
285#undef __hip_check_mask
286#undef __hip_adjust_mask_for_wave32
287
288#endif // HIP_ENABLE_WARP_SYNC_BUILTINS