Alexandria 2.31.0
SDC-CH common library for the Euclid project
Loading...
Searching...
No Matches
KdTree.icpp
Go to the documentation of this file.
1/** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
2 *
3 * This library is free software; you can redistribute it and/or modify it under
4 * the terms of the GNU Lesser General Public License as published by the Free
5 * Software Foundation; either version 3.0 of the License, or (at your option)
6 * any later version.
7 *
8 * This library is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
11 * details.
12 *
13 * You should have received a copy of the GNU Lesser General Public License
14 * along with this library; if not, write to the Free Software Foundation, Inc.,
15 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
16 */
17
18#include <stdexcept>
19
20namespace KdTree {
21
22template <typename T, typename DistanceMethod>
23class KdTree<T, DistanceMethod>::Node {
24public:
25 virtual void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const = 0;
26 virtual std::size_t countPointsWithinRadius(const T& coord, double radius) const = 0;
27 virtual ~Node() = default;
28};
29
30template <typename T, typename DistanceMethod>
31class KdTree<T, DistanceMethod>::Leaf : public KdTree::Node {
32public:
33 explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
34 virtual ~Leaf() = default;
35
36 void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
37 selection.reserve(selection.size() + m_data.size());
38 for (auto& entry : m_data) {
39 if (DistanceMethod::isCloserThan(entry, coord, radius)) {
40 selection.emplace_back(entry);
41 }
42 }
43 }
44
45 std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
46 std::size_t count = 0;
47 for (auto& entry : m_data) {
48 if (DistanceMethod::isCloserThan(entry, coord, radius)) {
49 ++count;
50 }
51 }
52 return count;
53 };
54
55private:
56 const std::vector<T> m_data;
57};
58
59template <typename T, typename DistanceMethod>
60class KdTree<T, DistanceMethod>::Split : public KdTree::Node {
61public:
62 virtual ~Split() = default;
63 explicit Split(std::size_t dimensionality, std::size_t leaf_size, std::vector<T> data, size_t axis) : m_axis(axis) {
64 std::sort(data.begin(), data.end(),
65 [axis](const T& a, const T& b) -> bool { return Traits::getCoord(a, axis) < Traits::getCoord(b, axis); });
66
67 double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
68 double b = Traits::getCoord(data.at(data.size() / 2), axis);
69
70 if (a == b) {
71 // avoid a possible rounding issue
72 m_split_value = a;
73 } else {
74 m_split_value = (a + b) / 2.0;
75 }
76
77 std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
78 std::vector<T> right(data.begin() + data.size() / 2, data.end());
79
80 if (left.size() > leaf_size) {
81 m_left_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(left), (axis + 1) % dimensionality);
82 } else {
83 m_left_child = std::make_shared<Leaf>(std::move(left));
84 }
85 if (right.size() > leaf_size) {
86 m_right_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(right), (axis + 1) % dimensionality);
87 } else {
88 m_right_child = std::make_shared<Leaf>(std::move(right));
89 }
90 }
91
92 void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
93 if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
94 m_left_child->findPointsWithinRadius(coord, radius, selection);
95 } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
96 m_right_child->findPointsWithinRadius(coord, radius, selection);
97 } else {
98 m_left_child->findPointsWithinRadius(coord, radius, selection);
99 m_right_child->findPointsWithinRadius(coord, radius, selection);
100 }
101 }
102
103 std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
104 if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
105 return m_left_child->countPointsWithinRadius(coord, radius);
106 } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
107 return m_right_child->countPointsWithinRadius(coord, radius);
108 } else {
109 return m_left_child->countPointsWithinRadius(coord, radius) +
110 m_right_child->countPointsWithinRadius(coord, radius);
111 }
112 }
113
114private:
115 size_t m_axis;
116 double m_split_value;
117
118 std::shared_ptr<Node> m_left_child;
119 std::shared_ptr<Node> m_right_child;
120};
121
122template <typename T, typename DistanceMethod>
123KdTree<T, DistanceMethod>::KdTree(const std::vector<T>& data, std::size_t leaf_size) {
124 if (!data.empty()) {
125 m_dimensionality = Traits::getDimensions(data.front());
126 } else {
127 m_dimensionality = 0;
128 }
129
130 if (data.size() > leaf_size) {
131 m_root = std::make_shared<Split>(m_dimensionality, leaf_size, data, 0);
132 } else {
133 std::vector<T> data_copy(data);
134 m_root = std::make_shared<Leaf>(std::move(data_copy));
135 }
136}
137
138template <typename T, typename DistanceMethod>
139std::vector<T> KdTree<T, DistanceMethod>::findPointsWithinRadius(const T& coord, double radius) const {
140 std::vector<T> output;
141 m_root->findPointsWithinRadius(coord, radius, output);
142 return output;
143}
144
145template <typename T, typename DistanceMethod>
146std::size_t KdTree<T, DistanceMethod>::countPointsWithinRadius(const T& coord, double radius) const {
147 return m_root->countPointsWithinRadius(coord, radius);
148}
149
150template <typename T>
151bool EuclideanDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
152 using Traits = KdTreeTraits<T>;
153 double square_dist = 0.0;
154 const std::size_t dim = Traits::getDimensions(a);
155 for (std::size_t i = 0; i < dim; i++) {
156 double delta = Traits::getCoord(a, i) - Traits::getCoord(b, i);
157 square_dist += delta * delta;
158 }
159 return square_dist < distance * distance;
160}
161
162template <typename T>
163bool ChebyshevDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
164 using Traits = KdTreeTraits<T>;
165 double max_d = 0.;
166 const std::size_t dim = Traits::getDimensions(a);
167 for (std::size_t i = 0; i < dim; ++i) {
168 double delta = std::abs(Traits::getCoord(a, i) - Traits::getCoord(b, i));
169 if (delta > max_d) {
170 max_d = delta;
171 }
172 }
173 return max_d <= distance;
174}
175
176} // namespace KdTree