Actual source code: sfkok.kokkos.cxx
1: #include <../src/vec/is/sf/impls/basic/sfpack.h>
3: #include <petsc_kokkos.hpp>
4: #include <petsc/private/kokkosimpl.hpp>
6: using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
8: typedef Kokkos::View<char *, DefaultMemorySpace> deviceBuffer_t;
9: typedef Kokkos::View<char *, HostMirrorMemorySpace> HostBuffer_t;
11: typedef Kokkos::View<const char *, DefaultMemorySpace> deviceConstBuffer_t;
12: typedef Kokkos::View<const char *, HostMirrorMemorySpace> HostConstBuffer_t;
14: /*====================================================================================*/
15: /* Regular operations */
16: /*====================================================================================*/
17: template <typename Type>
18: struct Insert {
19: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
20: {
21: Type old = x;
22: x = y;
23: return old;
24: }
25: };
26: template <typename Type>
27: struct Add {
28: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
29: {
30: Type old = x;
31: x += y;
32: return old;
33: }
34: };
35: template <typename Type>
36: struct Mult {
37: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
38: {
39: Type old = x;
40: x *= y;
41: return old;
42: }
43: };
44: template <typename Type>
45: struct Min {
46: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
47: {
48: Type old = x;
49: x = PetscMin(x, y);
50: return old;
51: }
52: };
53: template <typename Type>
54: struct Max {
55: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
56: {
57: Type old = x;
58: x = PetscMax(x, y);
59: return old;
60: }
61: };
62: template <typename Type>
63: struct LAND {
64: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
65: {
66: Type old = x;
67: x = x && y;
68: return old;
69: }
70: };
71: template <typename Type>
72: struct LOR {
73: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
74: {
75: Type old = x;
76: x = x || y;
77: return old;
78: }
79: };
80: template <typename Type>
81: struct LXOR {
82: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
83: {
84: Type old = x;
85: x = !x != !y;
86: return old;
87: }
88: };
89: template <typename Type>
90: struct BAND {
91: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
92: {
93: Type old = x;
94: x = x & y;
95: return old;
96: }
97: };
98: template <typename Type>
99: struct BOR {
100: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
101: {
102: Type old = x;
103: x = x | y;
104: return old;
105: }
106: };
107: template <typename Type>
108: struct BXOR {
109: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const
110: {
111: Type old = x;
112: x = x ^ y;
113: return old;
114: }
115: };
116: template <typename PairType>
117: struct Minloc {
118: KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
119: {
120: PairType old = x;
121: if (y.first < x.first) x = y;
122: else if (y.first == x.first) x.second = PetscMin(x.second, y.second);
123: return old;
124: }
125: };
126: template <typename PairType>
127: struct Maxloc {
128: KOKKOS_INLINE_FUNCTION PairType operator()(PairType &x, PairType y) const
129: {
130: PairType old = x;
131: if (y.first > x.first) x = y;
132: else if (y.first == x.first) x.second = PetscMin(x.second, y.second); /* See MPI MAXLOC */
133: return old;
134: }
135: };
137: /*====================================================================================*/
138: /* Atomic operations */
139: /*====================================================================================*/
140: template <typename Type>
141: struct AtomicInsert {
142: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_store(&x, y); }
143: };
144: template <typename Type>
145: struct AtomicAdd {
146: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_add(&x, y); }
147: };
148: template <typename Type>
149: struct AtomicBAND {
150: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_and(&x, y); }
151: };
152: template <typename Type>
153: struct AtomicBOR {
154: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_or(&x, y); }
155: };
156: template <typename Type>
157: struct AtomicBXOR {
158: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_xor(&x, y); }
159: };
160: template <typename Type>
161: struct AtomicLAND {
162: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
163: {
164: const Type zero = 0, one = ~0;
165: Kokkos::atomic_and(&x, y ? one : zero);
166: }
167: };
168: template <typename Type>
169: struct AtomicLOR {
170: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const
171: {
172: const Type zero = 0, one = 1;
173: Kokkos::atomic_or(&x, y ? one : zero);
174: }
175: };
176: template <typename Type>
177: struct AtomicMult {
178: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_mul(&x, y); }
179: };
180: template <typename Type>
181: struct AtomicMin {
182: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_min(&x, y); }
183: };
184: template <typename Type>
185: struct AtomicMax {
186: KOKKOS_INLINE_FUNCTION void operator()(Type &x, Type y) const { Kokkos::atomic_fetch_max(&x, y); }
187: };
188: /* TODO: struct AtomicLXOR */
189: template <typename Type>
190: struct AtomicFetchAdd {
191: KOKKOS_INLINE_FUNCTION Type operator()(Type &x, Type y) const { return Kokkos::atomic_fetch_add(&x, y); }
192: };
194: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
195: static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt, PetscInt tid)
196: {
197: PetscInt i, j, k, m, n, r;
198: const PetscInt *offset, *start, *dx, *dy, *X, *Y;
200: n = opt[0];
201: offset = opt + 1;
202: start = opt + n + 2;
203: dx = opt + 2 * n + 2;
204: dy = opt + 3 * n + 2;
205: X = opt + 5 * n + 2;
206: Y = opt + 6 * n + 2;
207: for (r = 0; r < n; r++) {
208: if (tid < offset[r + 1]) break;
209: }
210: m = (tid - offset[r]);
211: k = m / (dx[r] * dy[r]);
212: j = (m - k * dx[r] * dy[r]) / dx[r];
213: i = m - k * dx[r] * dy[r] - j * dx[r];
215: return start[r] + k * X[r] * Y[r] + j * X[r] + i;
216: }
218: /*====================================================================================*/
219: /* Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link' */
220: /*====================================================================================*/
222: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
223: <Type> is PetscReal, which is the primitive type we operate on.
224: <bs> is 16, which says <unit> contains 16 primitive types.
225: <BS> is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
226: <EQ> is 0, which is (bs == BS ? 1 : 0)
228: If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
229: For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
230: */
231: template <typename Type, PetscInt BS, PetscInt EQ>
232: static PetscErrorCode Pack(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, const void *data_, void *buf_)
233: {
234: const PetscInt *iopt = opt ? opt->array : NULL;
235: const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS; /* If EQ, then MBS will be a compile-time const */
236: const Type *data = static_cast<const Type *>(data_);
237: Type *buf = static_cast<Type *>(buf_);
238: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
240: PetscFunctionBegin;
241: Kokkos::parallel_for(
242: Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
243: /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
244: iopt == NULL && idx == NULL ==> the indices are contiguous;
245: */
246: PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
247: PetscInt s = tid * MBS;
248: for (int i = 0; i < MBS; i++) buf[s + i] = data[t + i];
249: });
250: PetscFunctionReturn(PETSC_SUCCESS);
251: }
253: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
254: static PetscErrorCode UnpackAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data_, const void *buf_)
255: {
256: Op op;
257: const PetscInt *iopt = opt ? opt->array : NULL;
258: const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS;
259: Type *data = static_cast<Type *>(data_);
260: const Type *buf = static_cast<const Type *>(buf_);
261: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
263: PetscFunctionBegin;
264: Kokkos::parallel_for(
265: Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
266: PetscInt t = (iopt ? MapTidToIndex(iopt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
267: PetscInt s = tid * MBS;
268: for (int i = 0; i < MBS; i++) op(data[t + i], buf[s + i]);
269: });
270: PetscFunctionReturn(PETSC_SUCCESS);
271: }
273: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
274: static PetscErrorCode FetchAndOp(PetscSFLink link, PetscInt count, PetscInt start, PetscSFPackOpt opt, const PetscInt *idx, void *data, void *buf)
275: {
276: Op op;
277: const PetscInt *ropt = opt ? opt->array : NULL;
278: const PetscInt M = EQ ? 1 : link->bs / BS, MBS = M * BS;
279: Type *rootdata = static_cast<Type *>(data), *leafbuf = static_cast<Type *>(buf);
280: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
282: PetscFunctionBegin;
283: Kokkos::parallel_for(
284: Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
285: PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (idx ? idx[tid] : start + tid)) * MBS;
286: PetscInt l = tid * MBS;
287: for (int i = 0; i < MBS; i++) leafbuf[l + i] = op(rootdata[r + i], leafbuf[l + i]);
288: });
289: PetscFunctionReturn(PETSC_SUCCESS);
290: }
292: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
293: static PetscErrorCode ScatterAndOp(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
294: {
295: PetscInt srcx = 0, srcy = 0, srcX = 0, srcY = 0, dstx = 0, dsty = 0, dstX = 0, dstY = 0;
296: const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
297: const Type *src = static_cast<const Type *>(src_);
298: Type *dst = static_cast<Type *>(dst_);
299: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
301: PetscFunctionBegin;
302: /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
303: if (srcOpt) {
304: srcx = srcOpt->dx[0];
305: srcy = srcOpt->dy[0];
306: srcX = srcOpt->X[0];
307: srcY = srcOpt->Y[0];
308: srcStart = srcOpt->start[0];
309: srcIdx = NULL;
310: } else if (!srcIdx) {
311: srcx = srcX = count;
312: srcy = srcY = 1;
313: }
315: if (dstOpt) {
316: dstx = dstOpt->dx[0];
317: dsty = dstOpt->dy[0];
318: dstX = dstOpt->X[0];
319: dstY = dstOpt->Y[0];
320: dstStart = dstOpt->start[0];
321: dstIdx = NULL;
322: } else if (!dstIdx) {
323: dstx = dstX = count;
324: dsty = dstY = 1;
325: }
327: Kokkos::parallel_for(
328: Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
329: PetscInt i, j, k, s, t;
330: Op op;
331: if (!srcIdx) { /* src is in 3D */
332: k = tid / (srcx * srcy);
333: j = (tid - k * srcx * srcy) / srcx;
334: i = tid - k * srcx * srcy - j * srcx;
335: s = srcStart + k * srcX * srcY + j * srcX + i;
336: } else { /* src is contiguous */
337: s = srcIdx[tid];
338: }
340: if (!dstIdx) { /* 3D */
341: k = tid / (dstx * dsty);
342: j = (tid - k * dstx * dsty) / dstx;
343: i = tid - k * dstx * dsty - j * dstx;
344: t = dstStart + k * dstX * dstY + j * dstX + i;
345: } else { /* contiguous */
346: t = dstIdx[tid];
347: }
349: s *= MBS;
350: t *= MBS;
351: for (i = 0; i < MBS; i++) op(dst[t + i], src[s + i]);
352: });
353: PetscFunctionReturn(PETSC_SUCCESS);
354: }
356: /* Specialization for Insert since we may use memcpy */
357: template <typename Type, PetscInt BS, PetscInt EQ>
358: static PetscErrorCode ScatterAndInsert(PetscSFLink link, PetscInt count, PetscInt srcStart, PetscSFPackOpt srcOpt, const PetscInt *srcIdx, const void *src_, PetscInt dstStart, PetscSFPackOpt dstOpt, const PetscInt *dstIdx, void *dst_)
359: {
360: const Type *src = static_cast<const Type *>(src_);
361: Type *dst = static_cast<Type *>(dst_);
362: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
364: PetscFunctionBegin;
365: if (!count) PetscFunctionReturn(PETSC_SUCCESS);
366: /*src and dst are contiguous */
367: if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
368: size_t sz = count * link->unitbytes;
369: deviceBuffer_t dbuf(reinterpret_cast<char *>(dst + dstStart * link->bs), sz);
370: deviceConstBuffer_t sbuf(reinterpret_cast<const char *>(src + srcStart * link->bs), sz);
371: Kokkos::deep_copy(exec, dbuf, sbuf);
372: } else {
373: PetscCall(ScatterAndOp<Type, Insert<Type>, BS, EQ>(link, count, srcStart, srcOpt, srcIdx, src, dstStart, dstOpt, dstIdx, dst));
374: }
375: PetscFunctionReturn(PETSC_SUCCESS);
376: }
378: template <typename Type, class Op, PetscInt BS, PetscInt EQ>
379: static PetscErrorCode FetchAndOpLocal(PetscSFLink link, PetscInt count, PetscInt rootstart, PetscSFPackOpt rootopt, const PetscInt *rootidx, void *rootdata_, PetscInt leafstart, PetscSFPackOpt leafopt, const PetscInt *leafidx, const void *leafdata_, void *leafupdate_)
380: {
381: Op op;
382: const PetscInt M = (EQ) ? 1 : link->bs / BS, MBS = M * BS;
383: const PetscInt *ropt = rootopt ? rootopt->array : NULL;
384: const PetscInt *lopt = leafopt ? leafopt->array : NULL;
385: Type *rootdata = static_cast<Type *>(rootdata_), *leafupdate = static_cast<Type *>(leafupdate_);
386: const Type *leafdata = static_cast<const Type *>(leafdata_);
387: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
389: PetscFunctionBegin;
390: Kokkos::parallel_for(
391: Kokkos::RangePolicy<DeviceExecutionSpace>(exec, 0, count), KOKKOS_LAMBDA(PetscInt tid) {
392: PetscInt r = (ropt ? MapTidToIndex(ropt, tid) : (rootidx ? rootidx[tid] : rootstart + tid)) * MBS;
393: PetscInt l = (lopt ? MapTidToIndex(lopt, tid) : (leafidx ? leafidx[tid] : leafstart + tid)) * MBS;
394: for (int i = 0; i < MBS; i++) leafupdate[l + i] = op(rootdata[r + i], leafdata[l + i]);
395: });
396: PetscFunctionReturn(PETSC_SUCCESS);
397: }
399: /*====================================================================================*/
400: /* Init various types and instantiate pack/unpack function pointers */
401: /*====================================================================================*/
402: template <typename Type, PetscInt BS, PetscInt EQ>
403: static void PackInit_RealType(PetscSFLink link)
404: {
405: /* Pack/unpack for remote communication */
406: link->d_Pack = Pack<Type, BS, EQ>;
407: link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
408: link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
409: link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
410: link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>;
411: link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>;
412: link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
413: /* Scatter for local communication */
414: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>; /* Has special optimizations */
415: link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
416: link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
417: link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>;
418: link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>;
419: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
420: /* Atomic versions when there are data-race possibilities */
421: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
422: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
423: link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
424: link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
425: link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
426: link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
428: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
429: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
430: link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
431: link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
432: link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
433: link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
434: }
436: template <typename Type, PetscInt BS, PetscInt EQ>
437: static void PackInit_IntegerType(PetscSFLink link)
438: {
439: link->d_Pack = Pack<Type, BS, EQ>;
440: link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
441: link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
442: link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
443: link->d_UnpackAndMin = UnpackAndOp<Type, Min<Type>, BS, EQ>;
444: link->d_UnpackAndMax = UnpackAndOp<Type, Max<Type>, BS, EQ>;
445: link->d_UnpackAndLAND = UnpackAndOp<Type, LAND<Type>, BS, EQ>;
446: link->d_UnpackAndLOR = UnpackAndOp<Type, LOR<Type>, BS, EQ>;
447: link->d_UnpackAndLXOR = UnpackAndOp<Type, LXOR<Type>, BS, EQ>;
448: link->d_UnpackAndBAND = UnpackAndOp<Type, BAND<Type>, BS, EQ>;
449: link->d_UnpackAndBOR = UnpackAndOp<Type, BOR<Type>, BS, EQ>;
450: link->d_UnpackAndBXOR = UnpackAndOp<Type, BXOR<Type>, BS, EQ>;
451: link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
453: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
454: link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
455: link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
456: link->d_ScatterAndMin = ScatterAndOp<Type, Min<Type>, BS, EQ>;
457: link->d_ScatterAndMax = ScatterAndOp<Type, Max<Type>, BS, EQ>;
458: link->d_ScatterAndLAND = ScatterAndOp<Type, LAND<Type>, BS, EQ>;
459: link->d_ScatterAndLOR = ScatterAndOp<Type, LOR<Type>, BS, EQ>;
460: link->d_ScatterAndLXOR = ScatterAndOp<Type, LXOR<Type>, BS, EQ>;
461: link->d_ScatterAndBAND = ScatterAndOp<Type, BAND<Type>, BS, EQ>;
462: link->d_ScatterAndBOR = ScatterAndOp<Type, BOR<Type>, BS, EQ>;
463: link->d_ScatterAndBXOR = ScatterAndOp<Type, BXOR<Type>, BS, EQ>;
464: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
466: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
467: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
468: link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
469: link->da_UnpackAndMin = UnpackAndOp<Type, AtomicMin<Type>, BS, EQ>;
470: link->da_UnpackAndMax = UnpackAndOp<Type, AtomicMax<Type>, BS, EQ>;
471: link->da_UnpackAndLAND = UnpackAndOp<Type, AtomicLAND<Type>, BS, EQ>;
472: link->da_UnpackAndLOR = UnpackAndOp<Type, AtomicLOR<Type>, BS, EQ>;
473: link->da_UnpackAndBAND = UnpackAndOp<Type, AtomicBAND<Type>, BS, EQ>;
474: link->da_UnpackAndBOR = UnpackAndOp<Type, AtomicBOR<Type>, BS, EQ>;
475: link->da_UnpackAndBXOR = UnpackAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
476: link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
478: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
479: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
480: link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
481: link->da_ScatterAndMin = ScatterAndOp<Type, AtomicMin<Type>, BS, EQ>;
482: link->da_ScatterAndMax = ScatterAndOp<Type, AtomicMax<Type>, BS, EQ>;
483: link->da_ScatterAndLAND = ScatterAndOp<Type, AtomicLAND<Type>, BS, EQ>;
484: link->da_ScatterAndLOR = ScatterAndOp<Type, AtomicLOR<Type>, BS, EQ>;
485: link->da_ScatterAndBAND = ScatterAndOp<Type, AtomicBAND<Type>, BS, EQ>;
486: link->da_ScatterAndBOR = ScatterAndOp<Type, AtomicBOR<Type>, BS, EQ>;
487: link->da_ScatterAndBXOR = ScatterAndOp<Type, AtomicBXOR<Type>, BS, EQ>;
488: link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
489: }
491: #if defined(PETSC_HAVE_COMPLEX)
492: template <typename Type, PetscInt BS, PetscInt EQ>
493: static void PackInit_ComplexType(PetscSFLink link)
494: {
495: link->d_Pack = Pack<Type, BS, EQ>;
496: link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
497: link->d_UnpackAndAdd = UnpackAndOp<Type, Add<Type>, BS, EQ>;
498: link->d_UnpackAndMult = UnpackAndOp<Type, Mult<Type>, BS, EQ>;
499: link->d_FetchAndAdd = FetchAndOp<Type, Add<Type>, BS, EQ>;
501: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
502: link->d_ScatterAndAdd = ScatterAndOp<Type, Add<Type>, BS, EQ>;
503: link->d_ScatterAndMult = ScatterAndOp<Type, Mult<Type>, BS, EQ>;
504: link->d_FetchAndAddLocal = FetchAndOpLocal<Type, Add<Type>, BS, EQ>;
506: link->da_UnpackAndInsert = UnpackAndOp<Type, AtomicInsert<Type>, BS, EQ>;
507: link->da_UnpackAndAdd = UnpackAndOp<Type, AtomicAdd<Type>, BS, EQ>;
508: link->da_UnpackAndMult = UnpackAndOp<Type, AtomicMult<Type>, BS, EQ>;
509: link->da_FetchAndAdd = FetchAndOp<Type, AtomicFetchAdd<Type>, BS, EQ>;
511: link->da_ScatterAndInsert = ScatterAndOp<Type, AtomicInsert<Type>, BS, EQ>;
512: link->da_ScatterAndAdd = ScatterAndOp<Type, AtomicAdd<Type>, BS, EQ>;
513: link->da_ScatterAndMult = ScatterAndOp<Type, AtomicMult<Type>, BS, EQ>;
514: link->da_FetchAndAddLocal = FetchAndOpLocal<Type, AtomicFetchAdd<Type>, BS, EQ>;
515: }
516: #endif
518: template <typename Type>
519: static void PackInit_PairType(PetscSFLink link)
520: {
521: link->d_Pack = Pack<Type, 1, 1>;
522: link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, 1, 1>;
523: link->d_UnpackAndMaxloc = UnpackAndOp<Type, Maxloc<Type>, 1, 1>;
524: link->d_UnpackAndMinloc = UnpackAndOp<Type, Minloc<Type>, 1, 1>;
526: link->d_ScatterAndInsert = ScatterAndOp<Type, Insert<Type>, 1, 1>;
527: link->d_ScatterAndMaxloc = ScatterAndOp<Type, Maxloc<Type>, 1, 1>;
528: link->d_ScatterAndMinloc = ScatterAndOp<Type, Minloc<Type>, 1, 1>;
529: /* Atomics for pair types are not implemented yet */
530: }
532: template <typename Type, PetscInt BS, PetscInt EQ>
533: static void PackInit_DumbType(PetscSFLink link)
534: {
535: link->d_Pack = Pack<Type, BS, EQ>;
536: link->d_UnpackAndInsert = UnpackAndOp<Type, Insert<Type>, BS, EQ>;
537: link->d_ScatterAndInsert = ScatterAndInsert<Type, BS, EQ>;
538: /* Atomics for dumb types are not implemented yet */
539: }
541: /*
542: Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
543: that one is not able to repeatedly create and destroy the object. SF's original design was each
544: SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
545: destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
546: memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
547: does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
548: object in Kokkos.
549: */
550: /*
551: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
552: {
553: PetscFunctionBegin;
554: PetscFunctionReturn(PETSC_SUCCESS);
555: }
556: */
558: /* Some device-specific utilities */
559: static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
560: {
561: PetscFunctionBegin;
562: Kokkos::fence();
563: PetscFunctionReturn(PETSC_SUCCESS);
564: }
566: static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
567: {
568: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
570: PetscFunctionBegin;
571: exec.fence();
572: PetscFunctionReturn(PETSC_SUCCESS);
573: }
575: static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link, PetscMemType dstmtype, void *dst, PetscMemType srcmtype, const void *src, size_t n)
576: {
577: DeviceExecutionSpace exec = PetscGetKokkosExecutionSpace();
579: PetscFunctionBegin;
580: if (!n) PetscFunctionReturn(PETSC_SUCCESS);
581: if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2H
582: PetscCallCXX(exec.fence()); // make sure async kernels on src are finished, in case of unified memory as on AMD MI300A.
583: PetscCall(PetscMemcpy(dst, src, n));
584: } else {
585: if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) { // H2D
586: deviceBuffer_t dbuf(static_cast<char *>(dst), n);
587: HostConstBuffer_t sbuf(static_cast<const char *>(src), n);
588: PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
589: PetscCall(PetscLogCpuToGpu(n));
590: } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2H
591: HostBuffer_t dbuf(static_cast<char *>(dst), n);
592: deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
593: PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
594: PetscCallCXX(exec.fence()); // make sure dbuf is ready for use immediately on host
595: PetscCall(PetscLogGpuToCpu(n));
596: } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) { // D2D
597: deviceBuffer_t dbuf(static_cast<char *>(dst), n);
598: deviceConstBuffer_t sbuf(static_cast<const char *>(src), n);
599: PetscCallCXX(Kokkos::deep_copy(exec, dbuf, sbuf));
600: }
601: }
602: PetscFunctionReturn(PETSC_SUCCESS);
603: }
605: PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype, size_t size, void **ptr)
606: {
607: PetscFunctionBegin;
608: if (PetscMemTypeHost(mtype)) PetscCall(PetscMalloc(size, ptr));
609: else if (PetscMemTypeDevice(mtype)) {
610: if (!PetscKokkosInitialized) PetscCall(PetscKokkosInitializeCheck());
611: PetscCallCXX(*ptr = Kokkos::kokkos_malloc<DefaultMemorySpace>(size));
612: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
613: PetscFunctionReturn(PETSC_SUCCESS);
614: }
616: PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype, void *ptr)
617: {
618: PetscFunctionBegin;
619: if (PetscMemTypeHost(mtype)) PetscCall(PetscFree(ptr));
620: else if (PetscMemTypeDevice(mtype)) {
621: PetscCallCXX(Kokkos::kokkos_free<DefaultMemorySpace>(ptr));
622: } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Wrong PetscMemType %d", (int)mtype);
623: PetscFunctionReturn(PETSC_SUCCESS);
624: }
626: /* Destructor when the link uses MPI for communication */
627: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf, PetscSFLink link)
628: {
629: PetscFunctionBegin;
630: for (int i = PETSCSF_LOCAL; i <= PETSCSF_REMOTE; i++) {
631: PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
632: PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_DEVICE, link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]));
633: }
634: PetscFunctionReturn(PETSC_SUCCESS);
635: }
637: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
638: PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf, PetscSFLink link, MPI_Datatype unit)
639: {
640: PetscInt nSignedChar = 0, nUnsignedChar = 0, nInt = 0, nPetscInt = 0, nPetscReal = 0;
641: PetscBool is2Int, is2PetscInt;
642: #if defined(PETSC_HAVE_COMPLEX)
643: PetscInt nPetscComplex = 0;
644: #endif
646: PetscFunctionBegin;
647: if (link->deviceinited) PetscFunctionReturn(PETSC_SUCCESS);
648: PetscCall(PetscKokkosInitializeCheck());
649: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_SIGNED_CHAR, &nSignedChar));
650: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_UNSIGNED_CHAR, &nUnsignedChar));
651: /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
652: PetscCall(MPIPetsc_Type_compare_contig(unit, MPI_INT, &nInt));
653: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_INT, &nPetscInt));
654: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_REAL, &nPetscReal));
655: #if defined(PETSC_HAVE_COMPLEX)
656: PetscCall(MPIPetsc_Type_compare_contig(unit, MPIU_COMPLEX, &nPetscComplex));
657: #endif
658: PetscCall(MPIPetsc_Type_compare(unit, MPI_2INT, &is2Int));
659: PetscCall(MPIPetsc_Type_compare(unit, MPIU_2INT, &is2PetscInt));
661: if (is2Int) {
662: PackInit_PairType<Kokkos::pair<int, int>>(link);
663: } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
664: PackInit_PairType<Kokkos::pair<PetscInt, PetscInt>>(link);
665: } else if (nPetscReal) {
666: #if !defined(PETSC_HAVE_DEVICE) /* Skip the unimportant stuff to speed up SF device compilation time */
667: if (nPetscReal == 8) PackInit_RealType<PetscReal, 8, 1>(link);
668: else if (nPetscReal % 8 == 0) PackInit_RealType<PetscReal, 8, 0>(link);
669: else if (nPetscReal == 4) PackInit_RealType<PetscReal, 4, 1>(link);
670: else if (nPetscReal % 4 == 0) PackInit_RealType<PetscReal, 4, 0>(link);
671: else if (nPetscReal == 2) PackInit_RealType<PetscReal, 2, 1>(link);
672: else if (nPetscReal % 2 == 0) PackInit_RealType<PetscReal, 2, 0>(link);
673: else if (nPetscReal == 1) PackInit_RealType<PetscReal, 1, 1>(link);
674: else if (nPetscReal % 1 == 0)
675: #endif
676: PackInit_RealType<PetscReal, 1, 0>(link);
677: } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
678: #if !defined(PETSC_HAVE_DEVICE)
679: if (nPetscInt == 8) PackInit_IntegerType<llint, 8, 1>(link);
680: else if (nPetscInt % 8 == 0) PackInit_IntegerType<llint, 8, 0>(link);
681: else if (nPetscInt == 4) PackInit_IntegerType<llint, 4, 1>(link);
682: else if (nPetscInt % 4 == 0) PackInit_IntegerType<llint, 4, 0>(link);
683: else if (nPetscInt == 2) PackInit_IntegerType<llint, 2, 1>(link);
684: else if (nPetscInt % 2 == 0) PackInit_IntegerType<llint, 2, 0>(link);
685: else if (nPetscInt == 1) PackInit_IntegerType<llint, 1, 1>(link);
686: else if (nPetscInt % 1 == 0)
687: #endif
688: PackInit_IntegerType<llint, 1, 0>(link);
689: } else if (nInt) {
690: #if !defined(PETSC_HAVE_DEVICE)
691: if (nInt == 8) PackInit_IntegerType<int, 8, 1>(link);
692: else if (nInt % 8 == 0) PackInit_IntegerType<int, 8, 0>(link);
693: else if (nInt == 4) PackInit_IntegerType<int, 4, 1>(link);
694: else if (nInt % 4 == 0) PackInit_IntegerType<int, 4, 0>(link);
695: else if (nInt == 2) PackInit_IntegerType<int, 2, 1>(link);
696: else if (nInt % 2 == 0) PackInit_IntegerType<int, 2, 0>(link);
697: else if (nInt == 1) PackInit_IntegerType<int, 1, 1>(link);
698: else if (nInt % 1 == 0)
699: #endif
700: PackInit_IntegerType<int, 1, 0>(link);
701: } else if (nSignedChar) {
702: #if !defined(PETSC_HAVE_DEVICE)
703: if (nSignedChar == 8) PackInit_IntegerType<char, 8, 1>(link);
704: else if (nSignedChar % 8 == 0) PackInit_IntegerType<char, 8, 0>(link);
705: else if (nSignedChar == 4) PackInit_IntegerType<char, 4, 1>(link);
706: else if (nSignedChar % 4 == 0) PackInit_IntegerType<char, 4, 0>(link);
707: else if (nSignedChar == 2) PackInit_IntegerType<char, 2, 1>(link);
708: else if (nSignedChar % 2 == 0) PackInit_IntegerType<char, 2, 0>(link);
709: else if (nSignedChar == 1) PackInit_IntegerType<char, 1, 1>(link);
710: else if (nSignedChar % 1 == 0)
711: #endif
712: PackInit_IntegerType<char, 1, 0>(link);
713: } else if (nUnsignedChar) {
714: #if !defined(PETSC_HAVE_DEVICE)
715: if (nUnsignedChar == 8) PackInit_IntegerType<unsigned char, 8, 1>(link);
716: else if (nUnsignedChar % 8 == 0) PackInit_IntegerType<unsigned char, 8, 0>(link);
717: else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char, 4, 1>(link);
718: else if (nUnsignedChar % 4 == 0) PackInit_IntegerType<unsigned char, 4, 0>(link);
719: else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char, 2, 1>(link);
720: else if (nUnsignedChar % 2 == 0) PackInit_IntegerType<unsigned char, 2, 0>(link);
721: else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char, 1, 1>(link);
722: else if (nUnsignedChar % 1 == 0)
723: #endif
724: PackInit_IntegerType<unsigned char, 1, 0>(link);
725: #if defined(PETSC_HAVE_COMPLEX)
726: } else if (nPetscComplex) {
727: #if !defined(PETSC_HAVE_DEVICE)
728: if (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 1>(link);
729: else if (nPetscComplex % 8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 8, 0>(link);
730: else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 1>(link);
731: else if (nPetscComplex % 4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 4, 0>(link);
732: else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 1>(link);
733: else if (nPetscComplex % 2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>, 2, 0>(link);
734: else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 1>(link);
735: else if (nPetscComplex % 1 == 0)
736: #endif
737: PackInit_ComplexType<Kokkos::complex<PetscReal>, 1, 0>(link);
738: #endif
739: } else {
740: MPI_Aint nbyte;
742: PetscCall(PetscSFGetDatatypeSize_Internal(PETSC_COMM_SELF, unit, &nbyte));
743: if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
744: #if !defined(PETSC_HAVE_DEVICE)
745: if (nbyte == 4) PackInit_DumbType<char, 4, 1>(link);
746: else if (nbyte % 4 == 0) PackInit_DumbType<char, 4, 0>(link);
747: else if (nbyte == 2) PackInit_DumbType<char, 2, 1>(link);
748: else if (nbyte % 2 == 0) PackInit_DumbType<char, 2, 0>(link);
749: else if (nbyte == 1) PackInit_DumbType<char, 1, 1>(link);
750: else if (nbyte % 1 == 0)
751: #endif
752: PackInit_DumbType<char, 1, 0>(link);
753: } else {
754: PetscCall(PetscIntCast(nbyte / sizeof(int), &nInt));
755: #if !defined(PETSC_HAVE_DEVICE)
756: if (nInt == 8) PackInit_DumbType<int, 8, 1>(link);
757: else if (nInt % 8 == 0) PackInit_DumbType<int, 8, 0>(link);
758: else if (nInt == 4) PackInit_DumbType<int, 4, 1>(link);
759: else if (nInt % 4 == 0) PackInit_DumbType<int, 4, 0>(link);
760: else if (nInt == 2) PackInit_DumbType<int, 2, 1>(link);
761: else if (nInt % 2 == 0) PackInit_DumbType<int, 2, 0>(link);
762: else if (nInt == 1) PackInit_DumbType<int, 1, 1>(link);
763: else if (nInt % 1 == 0)
764: #endif
765: PackInit_DumbType<int, 1, 0>(link);
766: }
767: }
769: link->SyncDevice = PetscSFLinkSyncDevice_Kokkos;
770: link->SyncStream = PetscSFLinkSyncStream_Kokkos;
771: link->Memcpy = PetscSFLinkMemcpy_Kokkos;
772: link->Destroy = PetscSFLinkDestroy_Kokkos;
773: link->deviceinited = PETSC_TRUE;
774: PetscFunctionReturn(PETSC_SUCCESS);
775: }