1/** Copyright © 2021 Université de Genève, LMU Munich - Faculty of Physics, IAP-CNRS/Sorbonne Université
3 * This library is free software; you can redistribute it and/or modify it under
4 * the terms of the GNU Lesser General Public License as published by the Free
5 * Software Foundation; either version 3.0 of the License, or (at your option)
8 * This library is distributed in the hope that it will be useful, but WITHOUT
9 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
10 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13 * You should have received a copy of the GNU Lesser General Public License
14 * along with this library; if not, write to the Free Software Foundation, Inc.,
15 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
22template <typename T, typename DistanceMethod>
23class KdTree<T, DistanceMethod>::Node {
25 virtual void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const = 0;
26 virtual std::size_t countPointsWithinRadius(const T& coord, double radius) const = 0;
27 virtual ~Node() = default;
30template <typename T, typename DistanceMethod>
31class KdTree<T, DistanceMethod>::Leaf : public KdTree::Node {
33 explicit Leaf(const std::vector<T>&& data) : m_data(data) {}
34 virtual ~Leaf() = default;
36 void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
37 selection.reserve(selection.size() + m_data.size());
38 for (auto& entry : m_data) {
39 if (DistanceMethod::isCloserThan(entry, coord, radius)) {
40 selection.emplace_back(entry);
45 std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
46 std::size_t count = 0;
47 for (auto& entry : m_data) {
48 if (DistanceMethod::isCloserThan(entry, coord, radius)) {
56 const std::vector<T> m_data;
59template <typename T, typename DistanceMethod>
60class KdTree<T, DistanceMethod>::Split : public KdTree::Node {
62 virtual ~Split() = default;
63 explicit Split(std::size_t dimensionality, std::size_t leaf_size, std::vector<T> data, size_t axis) : m_axis(axis) {
64 std::sort(data.begin(), data.end(),
65 [axis](const T& a, const T& b) -> bool { return Traits::getCoord(a, axis) < Traits::getCoord(b, axis); });
67 double a = Traits::getCoord(data.at(data.size() / 2 - 1), axis);
68 double b = Traits::getCoord(data.at(data.size() / 2), axis);
71 // avoid a possible rounding issue
74 m_split_value = (a + b) / 2.0;
77 std::vector<T> left(data.begin(), data.begin() + data.size() / 2);
78 std::vector<T> right(data.begin() + data.size() / 2, data.end());
80 if (left.size() > leaf_size) {
81 m_left_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(left), (axis + 1) % dimensionality);
83 m_left_child = std::make_shared<Leaf>(std::move(left));
85 if (right.size() > leaf_size) {
86 m_right_child = std::make_shared<Split>(dimensionality, leaf_size, std::move(right), (axis + 1) % dimensionality);
88 m_right_child = std::make_shared<Leaf>(std::move(right));
92 void findPointsWithinRadius(const T& coord, double radius, std::vector<T>& selection) const override {
93 if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
94 m_left_child->findPointsWithinRadius(coord, radius, selection);
95 } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
96 m_right_child->findPointsWithinRadius(coord, radius, selection);
98 m_left_child->findPointsWithinRadius(coord, radius, selection);
99 m_right_child->findPointsWithinRadius(coord, radius, selection);
103 std::size_t countPointsWithinRadius(const T& coord, double radius) const override {
104 if (Traits::getCoord(coord, m_axis) + radius < m_split_value) {
105 return m_left_child->countPointsWithinRadius(coord, radius);
106 } else if (Traits::getCoord(coord, m_axis) - radius > m_split_value) {
107 return m_right_child->countPointsWithinRadius(coord, radius);
109 return m_left_child->countPointsWithinRadius(coord, radius) +
110 m_right_child->countPointsWithinRadius(coord, radius);
116 double m_split_value;
118 std::shared_ptr<Node> m_left_child;
119 std::shared_ptr<Node> m_right_child;
122template <typename T, typename DistanceMethod>
123KdTree<T, DistanceMethod>::KdTree(const std::vector<T>& data, std::size_t leaf_size) {
125 m_dimensionality = Traits::getDimensions(data.front());
127 m_dimensionality = 0;
130 if (data.size() > leaf_size) {
131 m_root = std::make_shared<Split>(m_dimensionality, leaf_size, data, 0);
133 std::vector<T> data_copy(data);
134 m_root = std::make_shared<Leaf>(std::move(data_copy));
138template <typename T, typename DistanceMethod>
139std::vector<T> KdTree<T, DistanceMethod>::findPointsWithinRadius(const T& coord, double radius) const {
140 std::vector<T> output;
141 m_root->findPointsWithinRadius(coord, radius, output);
145template <typename T, typename DistanceMethod>
146std::size_t KdTree<T, DistanceMethod>::countPointsWithinRadius(const T& coord, double radius) const {
147 return m_root->countPointsWithinRadius(coord, radius);
151bool EuclideanDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
152 using Traits = KdTreeTraits<T>;
153 double square_dist = 0.0;
154 const std::size_t dim = Traits::getDimensions(a);
155 for (std::size_t i = 0; i < dim; i++) {
156 double delta = Traits::getCoord(a, i) - Traits::getCoord(b, i);
157 square_dist += delta * delta;
159 return square_dist < distance * distance;
163bool ChebyshevDistance<T>::isCloserThan(const T& a, const T& b, double distance) {
164 using Traits = KdTreeTraits<T>;
166 const std::size_t dim = Traits::getDimensions(a);
167 for (std::size_t i = 0; i < dim; ++i) {
168 double delta = std::abs(Traits::getCoord(a, i) - Traits::getCoord(b, i));
173 return max_d <= distance;