Alexandria 2.31.0
SDC-CH common library for the Euclid project
Loading...
Searching...
No Matches
GridInterpolation.icpp
Go to the documentation of this file.
1#ifndef GRIDINTERPOLATION_IMPL
2#error Please, include "MathUtils/interpolation/GridInterpolation.h"
3#endif
4
5#include "AlexandriaKernel/Tuples.h"
6#include "MathUtils/interpolation/interpolation.h"
7
8namespace Euclid {
9namespace MathUtils {
10
11template <typename T, typename Enable = void>
12struct InterpolationImpl;
13
14/**
15 * Trait for continuous types
16 */
17template <typename T>
18struct InterpolationImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
19 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
20 bool extrapolate) {
21 return simple_interpolation(x, knots, values, extrapolate);
22 }
23
24 template <typename... Rest>
25 static double interpolate(const T x, const std::vector<T>& knots,
26 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool extrapolate,
27 const Rest... rest) {
28 // If no extrapolation, and the value if out-of-bounds, just clip at 0
29 if ((x < knots.front() || x > knots.back()) && !extrapolate) {
30 return 0.;
31 }
32
33 if (knots.size() == 1) {
34 return (*interpolators[0])(rest...);
35 }
36
37 std::size_t x2i = std::lower_bound(knots.begin(), knots.end(), x) - knots.begin();
38 if (x2i == 0) {
39 ++x2i;
40 } else if (x2i == knots.size()) {
41 --x2i;
42 }
43 std::size_t x1i = x2i - 1;
44
45 double y1 = (*interpolators[x1i])(rest...);
46 double y2 = (*interpolators[x2i])(rest...);
47
48 return simple_interpolation(x, knots[x1i], knots[x2i], y1, y2, extrapolate);
49 }
50
51 static void checkOrder(const std::vector<T>& knots) {
52 if (!std::is_sorted(knots.begin(), knots.end())) {
53 throw InterpolationException("coordinates must be sorted");
54 }
55 }
56};
57
58template <typename T>
59struct InterpolationImpl<T, typename std::enable_if<std::is_integral<T>::value>::type> {
60 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
61 bool /*extrapolate*/) {
62 if (x < knots.front() || x > knots.back())
63 return 0.;
64 return values[x];
65 }
66
67 template <typename... Rest>
68 static double interpolate(const T x, const std::vector<T>& knots,
69 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
70 const Rest... rest) {
71 if (x < knots.front() || x > knots.back())
72 return 0.;
73 return (*interpolators[x])(rest...);
74 }
75
76 static void checkOrder(const std::vector<T>& knots) {
77 if (knots.front() != 0) {
78 throw InterpolationException("int axis must start at 0");
79 }
80 for (auto b = knots.begin() + 1; b != knots.end(); ++b) {
81 if (*b - *(b - 1) != 1) {
82 throw InterpolationException("int values must be contiguous");
83 }
84 }
85 }
86};
87
88/**
89 * Trait for discrete types
90 */
91template <typename T>
92struct InterpolationImpl<T, typename std::enable_if<!std::is_arithmetic<T>::value>::type> {
93 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
94 bool /*extrapolate*/) {
95 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
96 if (i >= knots.size() || knots[i] != x)
97 return 0.;
98 return values[i];
99 }
100
101 template <typename... Rest>
102 static double interpolate(const T x, const std::vector<T>& knots,
103 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
104 const Rest... rest) {
105 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
106 if (i >= knots.size() || knots[i] != x)
107 return 0.;
108 return (*interpolators[i])(rest...);
109 }
110
111 static void checkOrder(const std::vector<T>&) {
112 // Discrete axes do not need to be in order
113 }
114};
115
116/**
117 * Specialization (and end of the recursion) for a 1-dimensional interpolation.
118 */
119template <typename T>
120class InterpN<T> {
121public:
122 /**
123 * Constructor
124 * @param grid
125 * A 1-dimensional grid
126 * @param values
127 * @param type
128 * @param extrapolate
129 */
130 InterpN(const std::tuple<std::vector<T>>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
131 : m_knots(std::get<0>(grid)), m_values(values.begin(), values.end()), m_extrapolate(extrapolate) {
132 if (values.shape().size() != 1) {
133 throw InterpolationException() << "values and coordinates dimensionalities must match: " << values.shape().size()
134 << " != 1";
135 }
136 if (m_knots.size() != values.size()) {
137 throw InterpolationException() << "The size of the grid and the size of the values do not match: "
138 << m_knots.size() << " != " << m_values.size();
139 }
140 }
141
142 /**
143 * Call as a function
144 * @param x
145 * Coordinate value
146 * @return
147 * Interpolated value
148 */
149 double operator()(const T x) const {
150 return InterpolationImpl<T>::interpolate(x, m_knots, m_values, m_extrapolate);
151 }
152
153 /// Copy constructor
154 InterpN(const InterpN&) = default;
155
156 /// Move constructor
157 InterpN(InterpN&&) = default;
158
159private:
160 std::vector<T> m_knots;
161 std::vector<double> m_values;
162 bool m_extrapolate;
163};
164
165/**
166 * Recursive specialization of an N-Dimensional interpolator
167 * @tparam N Dimensionality (N > 1)
168 * @tparam F The first element of the index sequence
169 * @tparam Rest The rest of the elements from the index sequence
170 */
171template <typename T, typename... Rest>
172class InterpN<T, Rest...> {
173public:
174 /**
175 * Constructor
176 * @param grid
177 * @param values
178 * @param type
179 * @param extrapolate
180 */
181 InterpN(const std::tuple<std::vector<T>, std::vector<Rest>...>& grid, const NdArray::NdArray<double>& values,
182 bool extrapolate)
183 : m_extrapolate(extrapolate) {
184 constexpr std::size_t N = sizeof...(Rest) + 1;
185
186 if (values.shape().size() != N) {
187 throw InterpolationException() << "values and coordinates dimensionality must match: " << values.shape().size()
188 << " != " << N;
189 }
190 m_knots = std::get<0>(grid);
191 InterpolationImpl<T>::checkOrder(m_knots);
192 if (m_knots.size() != values.shape().back()) {
193 throw InterpolationException("coordinates and value sizes must match");
194 }
195 // Build nested interpolators
196 auto subgrid = Tuple::Tail(std::move(grid));
197 m_interpolators.resize(m_knots.size());
198 for (size_t i = 0; i < m_knots.size(); ++i) {
199 auto subvalues = values.rslice(i);
200 m_interpolators[i].reset(new InterpN<Rest...>(subgrid, subvalues, extrapolate));
201 }
202 }
203
204 /**
205 * Call as a function
206 * @param x Value for the axis for the first dimension
207 * @param rest Values for the next set of axes
208 * @return The interpolated value
209 * @details
210 * Doubles<Rest>... is used to expand into (N-1) doubles
211 * x is used to find the interpolators for x1 and x2 s.t. x1 <= x <=x2
212 * Those two interpolators are used to compute y1 for x1, and y2 for x2 (based on the rest of the parameters)
213 * A final linear interpolator is used to get the value of y at the position x
214 */
215 double operator()(T x, Rest... rest) const {
216 return InterpolationImpl<T>::interpolate(x, m_knots, m_interpolators, m_extrapolate, rest...);
217 }
218
219 /// Copy constructor
220 InterpN(const InterpN& other) : m_knots(other.m_knots), m_extrapolate(other.m_extrapolate) {
221 m_interpolators.resize(m_knots.size());
222 for (size_t i = 0; i < m_interpolators.size(); ++i) {
223 m_interpolators[i].reset(new InterpN<Rest...>(*other.m_interpolators[i]));
224 }
225 }
226
227private:
228 std::vector<T> m_knots;
229 std::vector<std::unique_ptr<InterpN<Rest...>>> m_interpolators;
230 bool m_extrapolate;
231};
232
233} // namespace MathUtils
234} // namespace Euclid