32#ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33#define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
36#if !defined(__HIPCC_RTC__)
40namespace cooperative_groups {
62 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size =
static_cast<uint64_t
>(0),
63 uint64_t mask =
static_cast<uint64_t
>(0)) {
72 unsigned int meta_group_rank;
73 unsigned int meta_group_size;
76 struct _coalesced_info {
77 lane_mask member_mask;
79 struct _tiled_info tiled_info;
82 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
83 unsigned int tile_size);
84 friend class thread_block;
90 __CG_QUALIFIER__ uint32_t size()
const {
return _size; }
91 __CG_QUALIFIER__
unsigned int cg_type()
const {
return _type; }
93 __CG_QUALIFIER__ uint32_t thread_rank()
const;
95 __CG_QUALIFIER__
bool is_valid()
const;
97 __CG_QUALIFIER__
void sync()
const;
122class multi_grid_group :
public thread_group {
125 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
129 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
130 : thread_group(internal::cg_multi_grid, size) {}
135 __CG_QUALIFIER__ uint32_t num_grids() {
return internal::multi_grid::num_grids(); }
138 __CG_QUALIFIER__ uint32_t grid_rank() {
return internal::multi_grid::grid_rank(); }
139 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::multi_grid::thread_rank(); }
140 __CG_QUALIFIER__
bool is_valid()
const {
return internal::multi_grid::is_valid(); }
141 __CG_QUALIFIER__
void sync()
const { internal::multi_grid::sync(); }
153__CG_QUALIFIER__ multi_grid_group this_multi_grid() {
154 return multi_grid_group(internal::multi_grid::size());
165class grid_group :
public thread_group {
168 friend __CG_QUALIFIER__ grid_group this_grid();
172 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
175 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::grid::thread_rank(); }
176 __CG_QUALIFIER__
bool is_valid()
const {
return internal::grid::is_valid(); }
177 __CG_QUALIFIER__
void sync()
const { internal::grid::sync(); }
189__CG_QUALIFIER__ grid_group this_grid() {
return grid_group(internal::grid::size()); }
200class thread_block :
public thread_group {
203 friend __CG_QUALIFIER__ thread_block this_thread_block();
204 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
205 unsigned int tile_size);
206 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
207 unsigned int tile_size);
210 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
211 : thread_group(internal::cg_workgroup, size) {}
213 __CG_QUALIFIER__ thread_group new_tiled_group(
unsigned int tile_size)
const {
214 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
216 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
217 __hip_assert(
false &&
"invalid tile size");
220 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
221 tiledGroup.coalesced_info.tiled_info.size = tile_size;
222 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
223 tiledGroup.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
224 tiledGroup.coalesced_info.tiled_info.meta_group_size = (size() + tile_size - 1) / tile_size;
230 __CG_STATIC_QUALIFIER__ dim3 group_index() {
return internal::workgroup::group_index(); }
232 __CG_STATIC_QUALIFIER__ dim3 thread_index() {
return internal::workgroup::thread_index(); }
233 __CG_STATIC_QUALIFIER__ uint32_t thread_rank() {
return internal::workgroup::thread_rank(); }
234 __CG_STATIC_QUALIFIER__ uint32_t size() {
return internal::workgroup::size(); }
235 __CG_STATIC_QUALIFIER__
bool is_valid() {
return internal::workgroup::is_valid(); }
236 __CG_STATIC_QUALIFIER__
void sync() { internal::workgroup::sync(); }
237 __CG_QUALIFIER__ dim3 group_dim() {
return internal::workgroup::block_dim(); }
249__CG_QUALIFIER__ thread_block this_thread_block() {
250 return thread_block(internal::workgroup::size());
261class tiled_group :
public thread_group {
263 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
264 unsigned int tile_size);
265 friend __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
266 unsigned int tile_size);
268 __CG_QUALIFIER__ tiled_group new_tiled_group(
unsigned int tile_size)
const {
269 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
271 if (!tile_size || (tile_size > __AMDGCN_WAVEFRONT_SIZE) || !pow2) {
272 __hip_assert(
false &&
"invalid tile size");
275 if (size() <= tile_size) {
279 tiled_group tiledGroup = tiled_group(tile_size);
280 tiledGroup.coalesced_info.tiled_info.is_tiled =
true;
285 explicit __CG_QUALIFIER__ tiled_group(
unsigned int tileSize)
286 : thread_group(internal::cg_tiled_group, tileSize) {
287 coalesced_info.tiled_info.size = tileSize;
288 coalesced_info.tiled_info.is_tiled =
true;
292 __CG_QUALIFIER__
unsigned int size()
const {
return (coalesced_info.tiled_info.size); }
294 __CG_QUALIFIER__
unsigned int thread_rank()
const {
295 return (internal::workgroup::thread_rank() & (coalesced_info.tiled_info.size - 1));
298 __CG_QUALIFIER__
void sync()
const {
299 internal::tiled_group::sync();
310class coalesced_group :
public thread_group {
312 friend __CG_QUALIFIER__ coalesced_group coalesced_threads();
313 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size);
314 friend __CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size);
316 __CG_QUALIFIER__ coalesced_group new_tiled_group(
unsigned int tile_size)
const {
317 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
319 if (!tile_size || (tile_size > size()) || !pow2) {
320 return coalesced_group(0);
325 if (coalesced_info.tiled_info.is_tiled) {
326 unsigned int base_offset = (thread_rank() & (~(tile_size - 1)));
327 unsigned int masklength = min(
static_cast<unsigned int>(size()) - base_offset, tile_size);
328 lane_mask member_mask =
static_cast<lane_mask
>(-1) >> (__AMDGCN_WAVEFRONT_SIZE - masklength);
330 member_mask <<= (__lane_id() & ~(tile_size - 1));
331 coalesced_group coalesced_tile = coalesced_group(member_mask);
332 coalesced_tile.coalesced_info.tiled_info.is_tiled =
true;
333 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
334 coalesced_tile.coalesced_info.tiled_info.meta_group_size = size() / tile_size;
335 return coalesced_tile;
339 lane_mask member_mask = 0;
340 unsigned int tile_rank = 0;
341 int lanes_to_skip = ((thread_rank()) / tile_size) * tile_size;
343 for (
unsigned int i = 0; i < __AMDGCN_WAVEFRONT_SIZE; i++) {
344 lane_mask active = coalesced_info.member_mask & (1 << i);
347 if (lanes_to_skip <= 0 && tile_rank < tile_size) {
349 member_mask |= active;
355 coalesced_group coalesced_tile = coalesced_group(member_mask);
356 coalesced_tile.coalesced_info.tiled_info.meta_group_rank = thread_rank() / tile_size;
357 coalesced_tile.coalesced_info.tiled_info.meta_group_size =
358 (size() + tile_size - 1) / tile_size;
359 return coalesced_tile;
361 return coalesced_group(0);
366 explicit __CG_QUALIFIER__ coalesced_group(lane_mask member_mask)
367 : thread_group(internal::cg_coalesced_group) {
368 coalesced_info.member_mask = member_mask;
369 coalesced_info.size = __popcll(coalesced_info.member_mask);
370 coalesced_info.tiled_info.is_tiled =
false;
371 coalesced_info.tiled_info.meta_group_rank = 0;
372 coalesced_info.tiled_info.meta_group_size = 1;
376 __CG_QUALIFIER__
unsigned int size()
const {
377 return coalesced_info.size;
380 __CG_QUALIFIER__
unsigned int thread_rank()
const {
381 return internal::coalesced_group::masked_bit_count(coalesced_info.member_mask);
384 __CG_QUALIFIER__
void sync()
const {
385 internal::coalesced_group::sync();
388 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
389 return coalesced_info.tiled_info.meta_group_rank;
392 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
393 return coalesced_info.tiled_info.meta_group_size;
397 __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
398 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
400 srcRank = srcRank %
static_cast<int>(size());
402 int lane = (size() == __AMDGCN_WAVEFRONT_SIZE) ? srcRank
403 : (__AMDGCN_WAVEFRONT_SIZE == 64) ? __fns64(coalesced_info.member_mask, 0, (srcRank + 1))
404 : __fns32(coalesced_info.member_mask, 0, (srcRank + 1));
406 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
410 __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
411 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
417 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
418 return __shfl_down(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
422 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
423 lane = __fns64(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
426 lane = __fns32(coalesced_info.member_mask, __lane_id(), lane_delta + 1);
433 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
437 __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
438 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
444 if (size() == __AMDGCN_WAVEFRONT_SIZE) {
445 return __shfl_up(var, lane_delta, __AMDGCN_WAVEFRONT_SIZE);
449 if (__AMDGCN_WAVEFRONT_SIZE == 64) {
450 lane = __fns64(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
452 else if (__AMDGCN_WAVEFRONT_SIZE == 32) {
453 lane = __fns32(coalesced_info.member_mask, __lane_id(), -(lane_delta + 1));
460 return __shfl(var, lane, __AMDGCN_WAVEFRONT_SIZE);
471__CG_QUALIFIER__ coalesced_group coalesced_threads() {
472 return cooperative_groups::coalesced_group(__builtin_amdgcn_read_exec());
480__CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
481 switch (this->_type) {
482 case internal::cg_multi_grid: {
483 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
485 case internal::cg_grid: {
486 return (
static_cast<const grid_group*
>(
this)->thread_rank());
488 case internal::cg_workgroup: {
489 return (
static_cast<const thread_block*
>(
this)->thread_rank());
491 case internal::cg_tiled_group: {
492 return (
static_cast<const tiled_group*
>(
this)->thread_rank());
494 case internal::cg_coalesced_group: {
495 return (
static_cast<const coalesced_group*
>(
this)->thread_rank());
498 __hip_assert(
false &&
"invalid cooperative group type");
508__CG_QUALIFIER__
bool thread_group::is_valid()
const {
509 switch (this->_type) {
510 case internal::cg_multi_grid: {
511 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
513 case internal::cg_grid: {
514 return (
static_cast<const grid_group*
>(
this)->is_valid());
516 case internal::cg_workgroup: {
517 return (
static_cast<const thread_block*
>(
this)->is_valid());
519 case internal::cg_tiled_group: {
520 return (
static_cast<const tiled_group*
>(
this)->is_valid());
522 case internal::cg_coalesced_group: {
523 return (
static_cast<const coalesced_group*
>(
this)->is_valid());
526 __hip_assert(
false &&
"invalid cooperative group type");
536__CG_QUALIFIER__
void thread_group::sync()
const {
537 switch (this->_type) {
538 case internal::cg_multi_grid: {
539 static_cast<const multi_grid_group*
>(
this)->sync();
542 case internal::cg_grid: {
543 static_cast<const grid_group*
>(
this)->sync();
546 case internal::cg_workgroup: {
547 static_cast<const thread_block*
>(
this)->sync();
550 case internal::cg_tiled_group: {
551 static_cast<const tiled_group*
>(
this)->sync();
554 case internal::cg_coalesced_group: {
555 static_cast<const coalesced_group*
>(
this)->sync();
559 __hip_assert(
false &&
"invalid cooperative group type");
570template <
class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy
const& g) {
return g.size(); }
577template <
class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const& g) {
578 return g.thread_rank();
586template <
class CGTy> __CG_QUALIFIER__
bool is_valid(CGTy
const& g) {
return g.is_valid(); }
593template <
class CGTy> __CG_QUALIFIER__
void sync(CGTy
const& g) { g.sync(); }
599template <
unsigned int tileSize>
class tile_base {
601 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
605 _CG_STATIC_CONST_DECL_
unsigned int thread_rank() {
606 return (internal::workgroup::thread_rank() & (numThreads - 1));
610 __CG_STATIC_QUALIFIER__
unsigned int size() {
return numThreads; }
617template <
unsigned int size>
class thread_block_tile_base :
public tile_base<size> {
618 static_assert(is_valid_tile_size<size>::value,
619 "Tile size is either not a power of 2 or greater than the wavefront size");
620 using tile_base<size>::numThreads;
623 __CG_STATIC_QUALIFIER__
void sync() {
624 internal::tiled_group::sync();
627 template <
class T> __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
628 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
629 return (__shfl(var, srcRank, numThreads));
632 template <
class T> __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
633 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
634 return (__shfl_down(var, lane_delta, numThreads));
637 template <
class T> __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
638 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
639 return (__shfl_up(var, lane_delta, numThreads));
642 template <
class T> __CG_QUALIFIER__ T shfl_xor(T var,
unsigned int laneMask)
const {
643 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
644 return (__shfl_xor(var, laneMask, numThreads));
649template <
unsigned int tileSize,
typename ParentCGTy>
650class parent_group_info {
654 __CG_STATIC_QUALIFIER__
unsigned int meta_group_rank() {
655 return ParentCGTy::thread_rank() / tileSize;
659 __CG_STATIC_QUALIFIER__
unsigned int meta_group_size() {
660 return (ParentCGTy::size() + tileSize - 1) / tileSize;
670template <
unsigned int tileSize,
class ParentCGTy>
671class thread_block_tile_type :
public thread_block_tile_base<tileSize>,
673 public parent_group_info<tileSize, ParentCGTy> {
674 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
675 typedef thread_block_tile_base<numThreads> tbtBase;
677 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
678 coalesced_info.tiled_info.size = numThreads;
679 coalesced_info.tiled_info.is_tiled =
true;
684 using tbtBase::thread_rank;
688template <
unsigned int tileSize>
689class thread_block_tile_type<tileSize, void> :
public thread_block_tile_base<tileSize>,
692 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
694 typedef thread_block_tile_base<numThreads> tbtBase;
698 __CG_QUALIFIER__ thread_block_tile_type(
unsigned int meta_group_rank,
unsigned int meta_group_size)
699 : tiled_group(numThreads) {
700 coalesced_info.tiled_info.size = numThreads;
701 coalesced_info.tiled_info.is_tiled =
true;
702 coalesced_info.tiled_info.meta_group_rank = meta_group_rank;
703 coalesced_info.tiled_info.meta_group_size = meta_group_size;
709 using tbtBase::thread_rank;
711 __CG_QUALIFIER__
unsigned int meta_group_rank()
const {
712 return coalesced_info.tiled_info.meta_group_rank;
715 __CG_QUALIFIER__
unsigned int meta_group_size()
const {
716 return coalesced_info.tiled_info.meta_group_size;
731__CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size) {
732 if (parent.cg_type() == internal::cg_tiled_group) {
733 const tiled_group* cg =
static_cast<const tiled_group*
>(&parent);
734 return cg->new_tiled_group(tile_size);
736 else if(parent.cg_type() == internal::cg_coalesced_group) {
737 const coalesced_group* cg =
static_cast<const coalesced_group*
>(&parent);
738 return cg->new_tiled_group(tile_size);
741 const thread_block* tb =
static_cast<const thread_block*
>(&parent);
742 return tb->new_tiled_group(tile_size);
747__CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
unsigned int tile_size) {
748 return (parent.new_tiled_group(tile_size));
751__CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
unsigned int tile_size) {
752 return (parent.new_tiled_group(tile_size));
756__CG_QUALIFIER__ coalesced_group tiled_partition(
const coalesced_group& parent,
unsigned int tile_size) {
757 return (parent.new_tiled_group(tile_size));
760template <
unsigned int size,
class ParentCGTy>
class thread_block_tile;
763template <
unsigned int size,
class ParentCGTy>
class thread_block_tile_internal;
765template <
unsigned int size,
class ParentCGTy>
766class thread_block_tile_internal :
public thread_block_tile_type<size, ParentCGTy> {
768 template <
unsigned int tbtSize,
class tbtParentT>
769 __CG_QUALIFIER__ thread_block_tile_internal(
770 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
771 : thread_block_tile_type<size, ParentCGTy>(g.meta_group_rank(), g.meta_group_size()) {}
773 __CG_QUALIFIER__ thread_block_tile_internal(
const thread_block& g)
774 : thread_block_tile_type<size, ParentCGTy>() {}
778template <
unsigned int size,
class ParentCGTy>
779class thread_block_tile :
public impl::thread_block_tile_internal<size, ParentCGTy> {
781 __CG_QUALIFIER__ thread_block_tile(
const ParentCGTy& g)
782 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
785 __CG_QUALIFIER__
operator thread_block_tile<size, void>()
const {
786 return thread_block_tile<size, void>(*
this);
791template <
unsigned int size>
792class thread_block_tile<size, void> :
public impl::thread_block_tile_internal<size, void> {
793 template <
unsigned int,
class ParentCGTy>
friend class thread_block_tile;
797 template <
class ParentCGTy>
798 __CG_QUALIFIER__ thread_block_tile(
const thread_block_tile<size, ParentCGTy>& g)
799 : impl::thread_block_tile_internal<size, void>(g) {}
802template <
unsigned int size,
class ParentCGTy =
void>
class thread_block_tile;
805template <
unsigned int size,
class ParentCGTy>
struct tiled_partition_internal;
807template <
unsigned int size>
808struct tiled_partition_internal<size, thread_block> :
public thread_block_tile<size, thread_block> {
809 __CG_QUALIFIER__ tiled_partition_internal(
const thread_block& g)
810 : thread_block_tile<size, thread_block>(g) {}
820template <
unsigned int size,
class ParentCGTy>
821__CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(
const ParentCGTy& g) {
822 static_assert(is_valid_tile_size<size>::value,
823 "Tiled partition with size > wavefront size. Currently not supported ");
824 return impl::tiled_partition_internal<size, ParentCGTy>(g);
Device side implementation of cooperative group feature.