Alexandria 2.31.0
SDC-CH common library for the Euclid project
Loading...
Searching...
No Matches
NdSampler.icpp
Go to the documentation of this file.
1#ifdef NDSAMPLER_IMPL
2
3#include "ElementsKernel/Real.h"
4#include "MathUtils/function/function_tools.h"
5#include "MathUtils/helpers/InverseCumulative.h"
6#include "MathUtils/interpolation/interpolation.h"
7#include "NdArray/Operations.h"
8#include <cmath>
9#include <random>
10
11namespace Euclid {
12namespace MathUtils {
13
14/**
15 * Integrate the marginal
16 */
17template <typename T, typename Enabled = void>
18struct TrapzHelper {};
19
20/**
21 * Specialization for continuous
22 */
23template <typename T>
24struct TrapzHelper<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
25 static NdArray::NdArray<double> trapz(const NdArray::NdArray<double>& grid, const std::vector<T>& knots, size_t axis) {
26 if (knots.size() == 1) {
27 return NdArray::sum(grid, axis);
28 }
29 return NdArray::trapz(grid, knots.begin(), knots.end(), axis);
30 }
31};
32
33/**
34 * Specialization for discrete
35 */
36template <typename T>
37struct TrapzHelper<T, typename std::enable_if<!std::is_floating_point<T>::value>::type> {
38 static NdArray::NdArray<double> trapz(const NdArray::NdArray<double>& grid, const std::vector<T>&, size_t axis) {
39 return NdArray::sum(grid, axis);
40 }
41};
42
43/**
44 * Specialization for a 1D distribution
45 * Having the knots values and their probabilities, we compute the un-normalized cumulative distribution,
46 * pick a value at uniform between the minimum and maximum cumulative values, and return as a sample the
47 * corresponding value on the X axis (linearly interpolated)
48 */
49template <typename TKnot>
50class NdSampler<TKnot> {
51public:
52 virtual ~NdSampler() = default;
53
54 NdSampler(std::vector<TKnot> knots, const NdArray::NdArray<double>& grid)
55 : m_inv_cumulative(std::move(knots), std::vector<double>(grid.begin(), grid.end())) {
56 if (grid.shape().size() != 1) {
57 throw Elements::Exception() << "Grid with " << grid.shape().size() << " axes passed to a 1D sampler";
58 }
59 }
60
61 NdSampler(std::tuple<std::vector<TKnot>>&& knots, NdArray::NdArray<double>&& grid)
62 : NdSampler(std::move(std::get<0>(knots)), grid) {}
63
64 template <typename Generator, typename... OKnots>
65 void draw(std::size_t ndraws, Generator& rng, std::vector<std::tuple<OKnots...>>& output) const {
66 constexpr std::size_t this_n = sizeof...(OKnots) - 1;
67 static_assert(sizeof...(OKnots) >= 1, "The output tuple must have at least one element");
68
69 if (output.size() != ndraws) {
70 throw Elements::Exception() << "Output area does not match the required shape: expected at least " << ndraws << ", got "
71 << output.size();
72 }
73 // The std::nextafter is required so the interval is closed
74 std::uniform_real_distribution<> uniform(0, std::nextafter(1, std::numeric_limits<double>::max()));
75 // Draw samples
76
77 for (auto& row : output) {
78 auto p = uniform(rng);
79 std::get<this_n>(row) = m_inv_cumulative(p);
80 }
81 }
82
83 template <typename Generator>
84 std::vector<std::tuple<TKnot>> draw(std::size_t ndraws, Generator& rng) const {
85 std::vector<std::tuple<TKnot>> samples(ndraws);
86 draw(ndraws, rng, samples);
87 return samples;
88 }
89
90private:
91 const InverseCumulative<TKnot> m_inv_cumulative;
92};
93
94/**
95 * Helper class to call the interpolation function
96 */
97template <std::size_t Start, typename Seq>
98struct _CallHelper {};
99
100/**
101 * Unwrap the access to the parameters
102 */
103template <std::size_t Start, std::size_t... Is>
104struct _CallHelper<Start, _index_sequence<Is...>> {
105
106 /**
107 * @tparam Func
108 * Type of the interpolated function
109 * @param func
110 * Interpolated function
111 * @param x0
112 * Value for the 0th dimension
113 * @param xs
114 * Rest (fixed) values for the 1,...,N-1 dimensions
115 */
116 template <typename Func, typename TKnot0, typename... TKnotN>
117 static double call(Func& func, const TKnot0 x0, const std::tuple<TKnotN...>& xs) {
118 return func(x0, std::get<Start + 1 + Is>(xs)...);
119 }
120};
121
122/**
123 * General case
124 */
125template <typename TKnot0, typename... TKnotN>
126class NdSampler<TKnot0, TKnotN...> {
127public:
128 /**
129 * Constructor
130 * @param knots
131 * @param values
132 */
133 NdSampler(std::tuple<std::vector<TKnot0>, std::vector<TKnotN>...> knots, const NdArray::NdArray<double>& values) {
134 constexpr std::size_t N = sizeof...(TKnotN) + 1;
135
136 if (values.shape().size() != N) {
137 throw Elements::Exception() << "Grid with " << values.shape().size() << " axes passed to a " << N << "D sampler";
138 }
139
140 // Interpolate this dimension
141 m_interpolation = Euclid::make_unique<InterpN<TKnot0, TKnotN...>>(knots, values, false);
142
143 // This axis knots
144 m_knots0 = std::move(std::get<0>(knots));
145
146 // Compute the marginal of the first nested dimension
147 // i.e. for N=2, compute the marginal of N=1
148 auto marginal = TrapzHelper<TKnot0>::trapz(values, m_knots0, -1);
149
150 // Nested sampler
151 auto rest_knots = Tuple::Tail(std::move(knots));
152 m_subsampler = Euclid::make_unique<NdSampler<TKnotN...>>(std::move(rest_knots), std::move(marginal));
153 }
154
155 /**
156 * Get a sample
157 * @tparam Generator
158 * @param ndraws
159 * @param rng
160 * @return
161 */
162 template <typename Generator>
163 std::vector<std::tuple<TKnot0, TKnotN...>> draw(std::size_t ndraws, Generator& rng) const {
164 std::vector<std::tuple<TKnot0, TKnotN...>> output(ndraws);
165 draw(ndraws, rng, output);
166 return output;
167 }
168
169 /**
170 * Get a sample into an user-provided output area
171 * @tparam Generator
172 * @param ndraws
173 * @param rng
174 * @param output
175 */
176 template <typename Generator, typename... OutputKnots>
177 void draw(std::size_t ndraws, Generator& rng, std::vector<std::tuple<OutputKnots...>>& output) const {
178 constexpr std::size_t this_n = sizeof...(OutputKnots) - sizeof...(TKnotN) - 1;
179
180 if (output.size() != ndraws) {
181 throw Elements::Exception() << "Output area does not match the required shape: expected at least " << ndraws << ", got "
182 << output.size();
183 }
184
185 // The std::nextafter is required so the interval is closed and there is a chance of getting 1.
186 std::uniform_real_distribution<> uniform(0, std::nextafter(1, std::numeric_limits<double>::max()));
187
188 // Sample from x1..x_{N-1}
189 m_subsampler->draw(ndraws, rng, output);
190
191 // For each sample
192 for (std::size_t draw_i = 0; draw_i < ndraws; ++draw_i) {
193 std::vector<double> pdf(m_knots0.size());
194 auto& subsample = output[draw_i];
195 // Evaluate the PDF
196 for (std::size_t i = 0; i < m_knots0.size(); ++i) {
197 pdf[i] = _CallHelper<this_n, _make_index_sequence<sizeof...(TKnotN)>>::call(*m_interpolation, m_knots0[i], subsample);
198 }
199
200 auto p = uniform(rng);
201 InverseCumulative<TKnot0> inv_cumulative(m_knots0, std::move(pdf));
202 std::get<this_n>(subsample) = inv_cumulative(p);
203 }
204 }
205
206private:
207 std::unique_ptr<InterpN<TKnot0, TKnotN...>> m_interpolation;
208 std::vector<TKnot0> m_knots0;
209 std::unique_ptr<NdSampler<TKnotN...>> m_subsampler;
210};
211
212} // namespace MathUtils
213} // namespace Euclid
214
215#endif