HIP: Heterogenous-computing Interface for Portability
amd_warp_sync_functions.h
1 /*
2 Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
3 
4 Permission is hereby granted, free of charge, to any person obtaining a copy
5 of this software and associated documentation files (the "Software"), to deal
6 in the Software without restriction, including without limitation the rights
7 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 copies of the Software, and to permit persons to whom the Software is
9 furnished to do so, subject to the following conditions:
10 
11 The above copyright notice and this permission notice shall be included in
12 all copies or substantial portions of the Software.
13 
14 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 THE SOFTWARE.
21 */
22 
23 #pragma once
24 
25 // Warp sync builtins (with explicit mask argument) introduced in ROCm 6.2 as a
26 // preview to allow end-users to adapt to the new interface involving 64-bit
27 // masks. These are disabled by default, and can be enabled by setting the macro
28 // below. The builtins will be enabled unconditionally in ROCm 6.3.
29 //
30 // This arrangement also applies to the __activemask() builtin defined in
31 // amd_warp_functions.h.
32 #ifdef HIP_ENABLE_WARP_SYNC_BUILTINS
33 
34 #if !defined(__HIPCC_RTC__)
35 #include "amd_warp_functions.h"
36 #include "hip_assert.h"
37 #endif
38 
39 template <typename T>
40 __device__ inline
41 T __hip_readfirstlane(T val) {
42  // In theory, behaviour is undefined when reading from a union member other
43  // than the member that was last assigned to, but it works in practice because
44  // we rely on the compiler to do the reasonable thing.
45  union {
46  unsigned long long l;
47  T d;
48  } u;
49  u.d = val;
50  // NOTE: The builtin returns int, so we first cast it to unsigned int and only
51  // then extend it to 64 bits.
52  unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l);
53  unsigned long long upper =
54  (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32);
55  u.l = (upper << 32) | lower;
56  return u.d;
57 }
58 
59 // When compiling for wave32 mode, ignore the upper half of the 64-bit mask.
60 #define __hip_adjust_mask_for_wave32(MASK) \
61  do { \
62  if (warpSize == 32) MASK &= 0xFFFFFFFF; \
63  } while (0)
64 
65 // We use a macro to expand each builtin into a waterfall that implements the
66 // mask semantics:
67 //
68 // 1. The mask argument may be divergent.
69 // 2. Each active thread must have its own bit set in its own mask value.
70 // 3. For a given mask value, all threads that are mentioned in the mask must
71 // execute the same static instance of the builtin with the same mask.
72 // 4. The union of all mask values supplied at a static instance must be equal
73 // to the activemask at the program point.
74 //
75 // Thus, the mask argument partitions the set of currently active threads in the
76 // wave into disjoint subsets that cover all active threads.
77 //
78 // Implementation notes:
79 // ---------------------
80 //
81 // We implement this as a waterfall loop that executes the builtin for each
82 // subset separately. The return value is a divergent value across the active
83 // threads. The value for inactive threads is defined by each builtin
84 // separately.
85 //
86 // As long as every mask value is non-zero, we don't need to check if a lane
87 // specifies itself in the mask; that is done by the later assertion where all
88 // chosen lanes must be in the chosen mask.
89 
90 #define __hip_check_mask(MASK) \
91  do { \
92  __hip_assert(MASK && "mask must be non-zero"); \
93  bool done = false; \
94  while (__any(!done)) { \
95  if (!done) { \
96  auto chosen_mask = __hip_readfirstlane(MASK); \
97  if (MASK == chosen_mask) { \
98  __hip_assert(MASK == __ballot(true) && \
99  "all threads specified in the mask" \
100  " must execute the same operation with the same mask"); \
101  done = true; \
102  } \
103  } \
104  } \
105  } while(0)
106 
107 #define __hip_do_sync(RETVAL, FUNC, MASK, ...) \
108  do { \
109  __hip_assert(MASK && "mask must be non-zero"); \
110  bool done = false; \
111  while (__any(!done)) { \
112  if (!done) { \
113  auto chosen_mask = __hip_readfirstlane(MASK); \
114  if (MASK == chosen_mask) { \
115  __hip_assert(MASK == __ballot(true) && \
116  "all threads specified in the mask" \
117  " must execute the same operation with the same mask"); \
118  RETVAL = FUNC(__VA_ARGS__); \
119  done = true; \
120  } \
121  } \
122  } \
123  } while(0)
124 
125 // __all_sync, __any_sync, __ballot_sync
126 
127 template <typename MaskT>
128 __device__ inline
129 unsigned long long __ballot_sync(MaskT mask, int predicate) {
130  static_assert(
131  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
132  "The mask must be a 64-bit integer. "
133  "Implicitly promoting a smaller integer is almost always an error.");
134  __hip_adjust_mask_for_wave32(mask);
135  __hip_check_mask(mask);
136  return __ballot(predicate) & mask;
137 }
138 
139 template <typename MaskT>
140 __device__ inline
141 int __all_sync(MaskT mask, int predicate) {
142  static_assert(
143  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
144  "The mask must be a 64-bit integer. "
145  "Implicitly promoting a smaller integer is almost always an error.");
146  __hip_adjust_mask_for_wave32(mask);
147  return __ballot_sync(mask, predicate) == mask;
148 }
149 
150 template <typename MaskT>
151 __device__ inline
152 int __any_sync(MaskT mask, int predicate) {
153  static_assert(
154  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
155  "The mask must be a 64-bit integer. "
156  "Implicitly promoting a smaller integer is almost always an error.");
157  __hip_adjust_mask_for_wave32(mask);
158  return __ballot_sync(mask, predicate) != 0;
159 }
160 
161 // __match_any, __match_all and sync variants
162 
163 template <typename T>
164 __device__ inline
165 unsigned long long __match_any(T value) {
166  static_assert(
167  (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
168  (sizeof(T) == 4 || sizeof(T) == 8),
169  "T can be int, unsigned int, long, unsigned long, long long, unsigned "
170  "long long, float or double.");
171  bool done = false;
172  unsigned long long retval = 0;
173 
174  while (__any(!done)) {
175  if (!done) {
176  T chosen = __hip_readfirstlane(value);
177  if (chosen == value) {
178  retval = __activemask();
179  done = true;
180  }
181  }
182  }
183 
184  return retval;
185 }
186 
187 template <typename MaskT, typename T>
188 __device__ inline
189 unsigned long long __match_any_sync(MaskT mask, T value) {
190  static_assert(
191  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
192  "The mask must be a 64-bit integer. "
193  "Implicitly promoting a smaller integer is almost always an error.");
194  __hip_adjust_mask_for_wave32(mask);
195  __hip_check_mask(mask);
196  return __match_any(value) & mask;
197 }
198 
199 template <typename T>
200 __device__ inline
201 unsigned long long __match_all(T value, int* pred) {
202  static_assert(
203  (__hip_internal::is_integral<T>::value || __hip_internal::is_floating_point<T>::value) &&
204  (sizeof(T) == 4 || sizeof(T) == 8),
205  "T can be int, unsigned int, long, unsigned long, long long, unsigned "
206  "long long, float or double.");
207  T first = __hip_readfirstlane(value);
208  if (__all(first == value)) {
209  *pred = true;
210  return __activemask();
211  } else {
212  *pred = false;
213  return 0;
214  }
215 }
216 
217 template <typename MaskT, typename T>
218 __device__ inline
219 unsigned long long __match_all_sync(MaskT mask, T value, int* pred) {
220  static_assert(
221  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
222  "The mask must be a 64-bit integer. "
223  "Implicitly promoting a smaller integer is almost always an error.");
224  MaskT retval = 0;
225  __hip_adjust_mask_for_wave32(mask);
226  __hip_do_sync(retval, __match_all, mask, value, pred);
227  return retval;
228 }
229 
230 // various variants of shfl
231 
232 template <typename MaskT, typename T>
233 __device__ inline
234 T __shfl_sync(MaskT mask, T var, int srcLane,
235  int width = __AMDGCN_WAVEFRONT_SIZE) {
236  static_assert(
237  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
238  "The mask must be a 64-bit integer. "
239  "Implicitly promoting a smaller integer is almost always an error.");
240  __hip_adjust_mask_for_wave32(mask);
241  __hip_check_mask(mask);
242  return __shfl(var, srcLane, width);
243 }
244 
245 template <typename MaskT, typename T>
246 __device__ inline
247 T __shfl_up_sync(MaskT mask, T var, unsigned int delta,
248  int width = __AMDGCN_WAVEFRONT_SIZE) {
249  static_assert(
250  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
251  "The mask must be a 64-bit integer. "
252  "Implicitly promoting a smaller integer is almost always an error.");
253  __hip_adjust_mask_for_wave32(mask);
254  __hip_check_mask(mask);
255  return __shfl_up(var, delta, width);
256 }
257 
258 template <typename MaskT, typename T>
259 __device__ inline
260 T __shfl_down_sync(MaskT mask, T var, unsigned int delta,
261  int width = __AMDGCN_WAVEFRONT_SIZE) {
262  static_assert(
263  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
264  "The mask must be a 64-bit integer. "
265  "Implicitly promoting a smaller integer is almost always an error.");
266  __hip_adjust_mask_for_wave32(mask);
267  __hip_check_mask(mask);
268  return __shfl_down(var, delta, width);
269 }
270 
271 template <typename MaskT, typename T>
272 __device__ inline
273 T __shfl_xor_sync(MaskT mask, T var, int laneMask,
274  int width = __AMDGCN_WAVEFRONT_SIZE) {
275  static_assert(
276  __hip_internal::is_integral<MaskT>::value && sizeof(MaskT) == 8,
277  "The mask must be a 64-bit integer. "
278  "Implicitly promoting a smaller integer is almost always an error.");
279  __hip_adjust_mask_for_wave32(mask);
280  __hip_check_mask(mask);
281  return __shfl_xor(var, laneMask, width);
282 }
283 
284 #undef __hip_do_sync
285 #undef __hip_check_mask
286 #undef __hip_adjust_mask_for_wave32
287 
288 #endif // HIP_ENABLE_WARP_SYNC_BUILTINS