2 * Copyright (C) 2012-2021 Euclid Science Ground Segment
4 * This library is free software; you can redistribute it and/or modify it under
5 * the terms of the GNU Lesser General Public License as published by the Free
6 * Software Foundation; either version 3.0 of the License, or (at your option)
9 * This library is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
14 * You should have received a copy of the GNU Lesser General Public License
15 * along with this library; if not, write to the Free Software Foundation, Inc.,
16 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19#ifdef INVERSE_CUMULATIVE_IMPL
21#include "MathUtils/helpers/Solvers.h"
22#include "ElementsKernel/Logging.h"
29static Elements::Logging loggerinverseCumul = Elements::Logging::getLogger("LoggerInverseCumulative");
32 * Specialization for floating point (linear interpolation of the PDF)
34template <typename TKnot>
35class InverseCumulative<TKnot, typename std::enable_if<std::is_floating_point<TKnot>::value>::type> {
37 InverseCumulative(std::vector<TKnot> knots, std::vector<double> pdf)
38 : m_knots(std::move(knots)), m_pdf(std::move(pdf)), m_cdf(m_pdf.size()) {
39 if (m_knots.size() != m_knots.size()) {
40 throw Elements::Exception() << "PDF and knots dimensionality do not match: " << m_knots.size() << " != " << knots.size();
42 if (!std::is_sorted(m_knots.begin(), m_knots.end())) {
43 throw Elements::Exception() << "Knots must be sorted";
47 for (std::size_t i = 1; i < m_cdf.size(); ++i) {
48 m_cdf[i] = (m_knots[i] - m_knots[i - 1]) * (m_pdf[i] + m_pdf[i - 1]) / 2.;
49 m_cdf[i] += m_cdf[i - 1];
52 // Remove trailing knots with no probability
53 while (m_cdf.size() > 1 && m_cdf.back() == *(m_cdf.end() - 2)) {
58 m_min = m_cdf.front();
59 m_range = m_cdf.back() - m_min;
62 double operator()(double p) const {
63 if (p < 0. || p > 1.) {
64 throw Elements::Exception() << "Cumulative::findInterpolatedValue : p parameter must be in the range [0,1]";
67 const double unnormed_p = p * m_range + m_min;
70 if (unnormed_p <= m_cdf.front()) {
71 return m_knots.front();
73 if (unnormed_p >= m_cdf.back()) {
74 return m_knots.back();
77 std::size_t i = std::upper_bound(m_cdf.begin(), m_cdf.end(), unnormed_p) - m_cdf.begin() - 1;
79 const double x0 = m_knots[i], x1 = m_knots[i + 1];
80 const double cdf0 = m_cdf[i], cdf1 = m_cdf[i + 1];
81 const double p0 = m_pdf[i], p1 = m_pdf[i + 1];
83 // If p0 == p1 we are on a uniform area
85 double interval_p = (unnormed_p - cdf0) / (cdf1 - cdf0);
86 return (x1 - x0) * interval_p + x0;
89 // If both cdf are the same, then the interval has 0 probability.
90 // If both x are the same, this is probably a discontinuity.
91 // In those cases, return the lower bound, which might be at least defined (or has a probability of exactly p).
92 if (x1 == x0 || cdf0 == cdf1 || cdf0 >= unnormed_p) {
96 // Since we assume a linear interpolation, we know that
99 // So we can solve for a and b
100 const double a = (p1 - p0) / (x1 - x0);
101 const double b = p0 - a * x0;
103 assert(std::abs(a * x0 + b - p0) < 1e-8);
104 assert(std::abs(a * x1 + b - p1) < 1e-8);
106 // The CDF is the integral, so we also know that
107 // cdf0 = a/2 * x0^2 + b * x0 + c
108 // cdf1 = a/2 * x1^2 + b * x1 + c
109 // We already know a and b, so it is easy to solve for c
110 const double c = cdf0 - (a / 2.) * (x0 * x0) - b * x0;
112 // Double check that the equation passes through the CDF at the limits
113 // assert(std::abs((a / 2.) * (x0 * x0) + b * x0 + c - cdf0) < 1e-4);
114 // assert(std::abs((a / 2.) * (x1 * x1) + b * x1 + c - cdf1) < 1e-4);
116 // We have the equation, so we now need to solve x for p
118 std::tie(s0, s1) = solveSquare(a / 2, b, c, unnormed_p);
120 // Pick the possible result that lies within [x0, x1]
121 if (s0 >= x0 - 1e-8 && s0 <= x1 + 1e-8) {
124 assert(s1 >= x0 - 1e-8 && s1 <= x1 + 1e-8);
126 // Worse case scenario: we got out of the double bound
127 if (! std::isfinite(s1)) {
128 loggerinverseCumul.warn()<<"Computation of the inverse Cumulative is not finite: use the median value.";
137 std::vector<TKnot> m_knots;
138 std::vector<double> m_pdf, m_cdf;
139 double m_min, m_range;
143 * Specialization for discrete types
145template <typename TKnot>
146class InverseCumulative<TKnot, typename std::enable_if<!std::is_floating_point<TKnot>::value>::type> {
148 InverseCumulative(std::vector<TKnot> knots, std::vector<double> pdf) : m_knots(std::move(knots)), m_cdf(std::move(pdf)) {
149 // Compute cumulative and normalize
150 for (std::size_t i = 1; i < m_cdf.size(); ++i) {
151 m_cdf[i] += m_cdf[i - 1];
153 for (auto& v : m_cdf) {
158 TKnot operator()(double p) const {
159 if (p < 0. || p > 1.) {
160 throw Elements::Exception() << "Cumulative::findInterpolatedValue : p parameter must be in the range [0,1]";
162 std::size_t i = std::lower_bound(m_cdf.begin(), m_cdf.end(), p) - m_cdf.begin();
167 std::vector<TKnot> m_knots;
168 std::vector<double> m_cdf;
171} // namespace MathUtils