Actual source code: curand2.cu
1: #include <petsc/private/randomimpl.h>
2: #include <thrust/transform.h>
3: #include <thrust/device_ptr.h>
4: #include <thrust/iterator/counting_iterator.h>
6: #if defined(PETSC_USE_COMPLEX)
7: struct complexscalelw
8: #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0)
9: :
10: public thrust::unary_function<thrust::tuple<PetscReal, size_t>, PetscReal>
11: #endif
12: {
13: PetscReal rl, rw;
14: PetscReal il, iw;
16: complexscalelw(PetscScalar low, PetscScalar width)
17: {
18: rl = PetscRealPart(low);
19: il = PetscImaginaryPart(low);
20: rw = PetscRealPart(width);
21: iw = PetscImaginaryPart(width);
22: }
24: __host__ __device__ PetscReal operator()(thrust::tuple<PetscReal, size_t> x) { return thrust::get<1>(x) % 2 ? thrust::get<0>(x) * iw + il : thrust::get<0>(x) * rw + rl; }
25: };
26: #endif
28: struct realscalelw
29: #if PETSC_PKG_CUDA_VERSION_LT(12, 8, 0) // To suppress the warning "thrust::THRUST_200700_860_NS::unary_function is deprecated"
30: :
31: public thrust::unary_function<PetscReal, PetscReal>
32: #endif
33: {
34: PetscReal l, w;
36: realscalelw(PetscReal low, PetscReal width) : l(low), w(width) { }
38: __host__ __device__ PetscReal operator()(PetscReal x) { return x * w + l; }
39: };
41: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom r, size_t n, PetscReal *val, PetscBool isneg)
42: {
43: PetscFunctionBegin;
44: if (!r->iset) PetscFunctionReturn(PETSC_SUCCESS);
45: if (isneg) { /* complex case, need to scale differently */
46: #if defined(PETSC_USE_COMPLEX)
47: thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
48: auto zibit = thrust::make_zip_iterator(thrust::make_tuple(pval, thrust::counting_iterator<size_t>(0)));
49: thrust::transform(zibit, zibit + n, pval, complexscalelw(r->low, r->width));
50: #else
51: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Negative array size %" PetscInt_FMT, (PetscInt)n);
52: #endif
53: } else {
54: PetscReal rl = PetscRealPart(r->low);
55: PetscReal rw = PetscRealPart(r->width);
56: thrust::device_ptr<PetscReal> pval = thrust::device_pointer_cast(val);
57: thrust::transform(pval, pval + n, pval, realscalelw(rl, rw));
58: }
59: PetscFunctionReturn(PETSC_SUCCESS);
60: }