1#ifndef GRIDINTERPOLATION_IMPL
2#error Please, include "MathUtils/interpolation/GridInterpolation.h"
5#include "AlexandriaKernel/Tuples.h"
6#include "MathUtils/interpolation/interpolation.h"
11template <typename T, typename Enable = void>
12struct InterpolationImpl;
15 * Trait for continuous types
18struct InterpolationImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
19 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
21 return simple_interpolation(x, knots, values, extrapolate);
24 template <typename... Rest>
25 static double interpolate(const T x, const std::vector<T>& knots,
26 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool extrapolate,
28 // If no extrapolation, and the value if out-of-bounds, just clip at 0
29 if ((x < knots.front() || x > knots.back()) && !extrapolate) {
33 if (knots.size() == 1) {
34 return (*interpolators[0])(rest...);
37 std::size_t x2i = std::lower_bound(knots.begin(), knots.end(), x) - knots.begin();
40 } else if (x2i == knots.size()) {
43 std::size_t x1i = x2i - 1;
45 double y1 = (*interpolators[x1i])(rest...);
46 double y2 = (*interpolators[x2i])(rest...);
48 return simple_interpolation(x, knots[x1i], knots[x2i], y1, y2, extrapolate);
51 static void checkOrder(const std::vector<T>& knots) {
52 if (!std::is_sorted(knots.begin(), knots.end())) {
53 throw InterpolationException("coordinates must be sorted");
59struct InterpolationImpl<T, typename std::enable_if<std::is_integral<T>::value>::type> {
60 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
61 bool /*extrapolate*/) {
62 if (x < knots.front() || x > knots.back())
67 template <typename... Rest>
68 static double interpolate(const T x, const std::vector<T>& knots,
69 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
71 if (x < knots.front() || x > knots.back())
73 return (*interpolators[x])(rest...);
76 static void checkOrder(const std::vector<T>& knots) {
77 if (knots.front() != 0) {
78 throw InterpolationException("int axis must start at 0");
80 for (auto b = knots.begin() + 1; b != knots.end(); ++b) {
81 if (*b - *(b - 1) != 1) {
82 throw InterpolationException("int values must be contiguous");
89 * Trait for discrete types
92struct InterpolationImpl<T, typename std::enable_if<!std::is_arithmetic<T>::value>::type> {
93 static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
94 bool /*extrapolate*/) {
95 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
96 if (i >= knots.size() || knots[i] != x)
101 template <typename... Rest>
102 static double interpolate(const T x, const std::vector<T>& knots,
103 const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
104 const Rest... rest) {
105 std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
106 if (i >= knots.size() || knots[i] != x)
108 return (*interpolators[i])(rest...);
111 static void checkOrder(const std::vector<T>&) {
112 // Discrete axes do not need to be in order
117 * Specialization (and end of the recursion) for a 1-dimensional interpolation.
125 * A 1-dimensional grid
130 InterpN(const std::tuple<std::vector<T>>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
131 : m_knots(std::get<0>(grid)), m_values(values.begin(), values.end()), m_extrapolate(extrapolate) {
132 if (values.shape().size() != 1) {
133 throw InterpolationException() << "values and coordinates dimensionalities must match: " << values.shape().size()
136 if (m_knots.size() != values.size()) {
137 throw InterpolationException() << "The size of the grid and the size of the values do not match: "
138 << m_knots.size() << " != " << m_values.size();
149 double operator()(const T x) const {
150 return InterpolationImpl<T>::interpolate(x, m_knots, m_values, m_extrapolate);
154 InterpN(const InterpN&) = default;
157 InterpN(InterpN&&) = default;
160 std::vector<T> m_knots;
161 std::vector<double> m_values;
166 * Recursive specialization of an N-Dimensional interpolator
167 * @tparam N Dimensionality (N > 1)
168 * @tparam F The first element of the index sequence
169 * @tparam Rest The rest of the elements from the index sequence
171template <typename T, typename... Rest>
172class InterpN<T, Rest...> {
181 InterpN(const std::tuple<std::vector<T>, std::vector<Rest>...>& grid, const NdArray::NdArray<double>& values,
183 : m_extrapolate(extrapolate) {
184 constexpr std::size_t N = sizeof...(Rest) + 1;
186 if (values.shape().size() != N) {
187 throw InterpolationException() << "values and coordinates dimensionality must match: " << values.shape().size()
190 m_knots = std::get<0>(grid);
191 InterpolationImpl<T>::checkOrder(m_knots);
192 if (m_knots.size() != values.shape().back()) {
193 throw InterpolationException("coordinates and value sizes must match");
195 // Build nested interpolators
196 auto subgrid = Tuple::Tail(std::move(grid));
197 m_interpolators.resize(m_knots.size());
198 for (size_t i = 0; i < m_knots.size(); ++i) {
199 auto subvalues = values.rslice(i);
200 m_interpolators[i].reset(new InterpN<Rest...>(subgrid, subvalues, extrapolate));
206 * @param x Value for the axis for the first dimension
207 * @param rest Values for the next set of axes
208 * @return The interpolated value
210 * Doubles<Rest>... is used to expand into (N-1) doubles
211 * x is used to find the interpolators for x1 and x2 s.t. x1 <= x <=x2
212 * Those two interpolators are used to compute y1 for x1, and y2 for x2 (based on the rest of the parameters)
213 * A final linear interpolator is used to get the value of y at the position x
215 double operator()(T x, Rest... rest) const {
216 return InterpolationImpl<T>::interpolate(x, m_knots, m_interpolators, m_extrapolate, rest...);
220 InterpN(const InterpN& other) : m_knots(other.m_knots), m_extrapolate(other.m_extrapolate) {
221 m_interpolators.resize(m_knots.size());
222 for (size_t i = 0; i < m_interpolators.size(); ++i) {
223 m_interpolators[i].reset(new InterpN<Rest...>(*other.m_interpolators[i]));
228 std::vector<T> m_knots;
229 std::vector<std::unique_ptr<InterpN<Rest...>>> m_interpolators;
233} // namespace MathUtils