OpenTTD Source 20241224-master-gf74b0cf984
kdtree.hpp
Go to the documentation of this file.
1/*
2 * This file is part of OpenTTD.
3 * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4 * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5 * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
6 */
7
10#ifndef KDTREE_HPP
11#define KDTREE_HPP
12
13#include "../stdafx.h"
14
34template <typename T, typename TxyFunc, typename CoordT, typename DistT>
35class Kdtree {
37 struct node {
39 size_t left;
40 size_t right;
41
43 };
44
45 static const size_t INVALID_NODE = SIZE_MAX;
46 static const size_t MIN_REBALANCE_THRESHOLD = 8;
47
48 std::vector<node> nodes;
49 std::vector<size_t> free_list;
50 size_t root;
51 size_t unbalanced;
52
54 size_t AddNode(const T &element)
55 {
56 if (this->free_list.empty()) {
57 this->nodes.emplace_back(element);
58 return this->nodes.size() - 1;
59 } else {
60 size_t newidx = this->free_list.back();
61 this->free_list.pop_back();
62 this->nodes[newidx] = node{ element };
63 return newidx;
64 }
65 }
66
68 template <typename It>
69 CoordT SelectSplitCoord(It begin, It end, int level)
70 {
71 It mid = begin + (end - begin) / 2;
72 std::nth_element(begin, mid, end, [&](T a, T b) { return TxyFunc()(a, level % 2) < TxyFunc()(b, level % 2); });
73 return TxyFunc()(*mid, level % 2);
74 }
75
77 template <typename It>
78 size_t BuildSubtree(It begin, It end, int level)
79 {
80 ptrdiff_t count = end - begin;
81
82 if (count == 0) {
83 return INVALID_NODE;
84 } else if (count == 1) {
85 return this->AddNode(*begin);
86 } else if (count > 1) {
87 CoordT split_coord = this->SelectSplitCoord(begin, end, level);
88 It split = std::partition(begin, end, [&](T v) { return TxyFunc()(v, level % 2) < split_coord; });
89 size_t newidx = this->AddNode(*split);
90 this->nodes[newidx].left = this->BuildSubtree(begin, split, level + 1);
91 this->nodes[newidx].right = this->BuildSubtree(split + 1, end, level + 1);
92 return newidx;
93 } else {
94 NOT_REACHED();
95 }
96 }
97
99 bool Rebuild(const T *include_element, const T *exclude_element)
100 {
101 size_t initial_count = this->Count();
102 if (initial_count < MIN_REBALANCE_THRESHOLD) return false;
103
104 T root_element = this->nodes[this->root].element;
105 std::vector<T> elements = this->FreeSubtree(this->root);
106 elements.push_back(root_element);
107
108 if (include_element != nullptr) {
109 elements.push_back(*include_element);
110 initial_count++;
111 }
112 if (exclude_element != nullptr) {
113 typename std::vector<T>::iterator removed = std::remove(elements.begin(), elements.end(), *exclude_element);
114 elements.erase(removed, elements.end());
115 initial_count--;
116 }
117
118 this->Build(elements.begin(), elements.end());
119 assert(initial_count == this->Count());
120 return true;
121 }
122
124 void InsertRecursive(const T &element, size_t node_idx, int level)
125 {
126 /* Dimension index of current level */
127 int dim = level % 2;
128 /* Node reference */
129 node &n = this->nodes[node_idx];
130
131 /* Coordinate of element splitting at this node */
132 CoordT nc = TxyFunc()(n.element, dim);
133 /* Coordinate of the new element */
134 CoordT ec = TxyFunc()(element, dim);
135 /* Which side to insert on */
136 size_t &next = (ec < nc) ? n.left : n.right;
137
138 if (next == INVALID_NODE) {
139 /* New leaf */
140 size_t newidx = this->AddNode(element);
141 /* Vector may have been reallocated at this point, n and next are invalid */
142 node &nn = this->nodes[node_idx];
143 if (ec < nc) nn.left = newidx; else nn.right = newidx;
144 } else {
145 this->InsertRecursive(element, next, level + 1);
146 }
147 }
148
153 std::vector<T> FreeSubtree(size_t node_idx)
154 {
155 std::vector<T> subtree_elements;
156 node &n = this->nodes[node_idx];
157
158 /* We'll be appending items to the free_list, get index of our first item */
159 size_t first_free = this->free_list.size();
160 /* Prepare the descent with our children */
161 if (n.left != INVALID_NODE) this->free_list.push_back(n.left);
162 if (n.right != INVALID_NODE) this->free_list.push_back(n.right);
163 n.left = n.right = INVALID_NODE;
164
165 /* Recursively free the nodes being collected */
166 for (size_t i = first_free; i < this->free_list.size(); i++) {
167 node &fn = this->nodes[this->free_list[i]];
168 subtree_elements.push_back(fn.element);
169 if (fn.left != INVALID_NODE) this->free_list.push_back(fn.left);
170 if (fn.right != INVALID_NODE) this->free_list.push_back(fn.right);
171 fn.left = fn.right = INVALID_NODE;
172 }
173
174 return subtree_elements;
175 }
176
184 size_t RemoveRecursive(const T &element, size_t node_idx, int level)
185 {
186 /* Node reference */
187 node &n = this->nodes[node_idx];
188
189 if (n.element == element) {
190 /* Remove this one */
191 this->free_list.push_back(node_idx);
192 if (n.left == INVALID_NODE && n.right == INVALID_NODE) {
193 /* Simple case, leaf, new child node for parent is "none" */
194 return INVALID_NODE;
195 } else {
196 /* Complex case, rebuild the sub-tree */
197 std::vector<T> subtree_elements = this->FreeSubtree(node_idx);
198 return this->BuildSubtree(subtree_elements.begin(), subtree_elements.end(), level);;
199 }
200 } else {
201 /* Search in a sub-tree */
202 /* Dimension index of current level */
203 int dim = level % 2;
204 /* Coordinate of element splitting at this node */
205 CoordT nc = TxyFunc()(n.element, dim);
206 /* Coordinate of the element being removed */
207 CoordT ec = TxyFunc()(element, dim);
208 /* Which side to remove from */
209 size_t next = (ec < nc) ? n.left : n.right;
210 assert(next != INVALID_NODE); // node must exist somewhere and must be found before a leaf is reached
211 /* Descend */
212 size_t new_branch = this->RemoveRecursive(element, next, level + 1);
213 if (new_branch != next) {
214 /* Vector may have been reallocated at this point, n and next are invalid */
215 node &nn = this->nodes[node_idx];
216 if (ec < nc) nn.left = new_branch; else nn.right = new_branch;
217 }
218 return node_idx;
219 }
220 }
221
222
223 DistT ManhattanDistance(const T &element, CoordT x, CoordT y) const
224 {
225 return abs((DistT)TxyFunc()(element, 0) - (DistT)x) + abs((DistT)TxyFunc()(element, 1) - (DistT)y);
226 }
227
229 using node_distance = std::pair<T, DistT>;
232 {
233 if (a.second < b.second) return a;
234 if (b.second < a.second) return b;
235 if (a.first < b.first) return a;
236 if (b.first < a.first) return b;
237 NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
238 }
240 node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit = std::numeric_limits<DistT>::max()) const
241 {
242 /* Dimension index of current level */
243 int dim = level % 2;
244 /* Node reference */
245 const node &n = this->nodes[node_idx];
246
247 /* Coordinate of element splitting at this node */
248 CoordT c = TxyFunc()(n.element, dim);
249 /* This node's distance to target */
250 DistT thisdist = this->ManhattanDistance(n.element, xy[0], xy[1]);
251 /* Assume this node is the best choice for now */
252 node_distance best = std::make_pair(n.element, thisdist);
253
254 /* Next node to visit */
255 size_t next = (xy[dim] < c) ? n.left : n.right;
256 if (next != INVALID_NODE) {
257 /* Check if there is a better node down the tree */
258 best = SelectNearestNodeDistance(best, this->FindNearestRecursive(xy, next, level + 1));
259 }
260
261 limit = std::min(best.second, limit);
262
263 /* Check if the distance from current best is worse than distance from target to splitting line,
264 * if it is we also need to check the other side of the split. */
265 size_t opposite = (xy[dim] >= c) ? n.left : n.right; // reverse of above
266 if (opposite != INVALID_NODE && limit >= abs((int)xy[dim] - (int)c)) {
267 node_distance other_candidate = this->FindNearestRecursive(xy, opposite, level + 1, limit);
268 best = SelectNearestNodeDistance(best, other_candidate);
269 }
270
271 return best;
272 }
273
274 template <typename Outputter>
275 void FindContainedRecursive(CoordT p1[2], CoordT p2[2], size_t node_idx, int level, const Outputter &outputter) const
276 {
277 /* Dimension index of current level */
278 int dim = level % 2;
279 /* Node reference */
280 const node &n = this->nodes[node_idx];
281
282 /* Coordinate of element splitting at this node */
283 CoordT ec = TxyFunc()(n.element, dim);
284 /* Opposite coordinate of element */
285 CoordT oc = TxyFunc()(n.element, 1 - dim);
286
287 /* Test if this element is within rectangle */
288 if (ec >= p1[dim] && ec < p2[dim] && oc >= p1[1 - dim] && oc < p2[1 - dim]) outputter(n.element);
289
290 /* Recurse left if part of rectangle is left of split */
291 if (p1[dim] < ec && n.left != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.left, level + 1, outputter);
292
293 /* Recurse right if part of rectangle is right of split */
294 if (p2[dim] > ec && n.right != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.right, level + 1, outputter);
295 }
296
298 size_t CountValue(const T &element, size_t node_idx) const
299 {
300 if (node_idx == INVALID_NODE) return 0;
301 const node &n = this->nodes[node_idx];
302 return this->CountValue(element, n.left) + this->CountValue(element, n.right) + ((n.element == element) ? 1 : 0);
303 }
304
305 void IncrementUnbalanced(size_t amount = 1)
306 {
307 this->unbalanced += amount;
308 }
309
311 bool IsUnbalanced() const
312 {
313 size_t count = this->Count();
314 if (count < MIN_REBALANCE_THRESHOLD) return false;
315 return this->unbalanced > count / 4;
316 }
317
319 void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y) const
320 {
321 if (node_idx == INVALID_NODE) return;
322
323 const node &n = this->nodes[node_idx];
324 CoordT cx = TxyFunc()(n.element, 0);
325 CoordT cy = TxyFunc()(n.element, 1);
326
327 assert(cx >= min_x);
328 assert(cx < max_x);
329 assert(cy >= min_y);
330 assert(cy < max_y);
331
332 if (level % 2 == 0) {
333 // split in dimension 0 = x
334 this->CheckInvariant(n.left, level + 1, min_x, cx, min_y, max_y);
335 this->CheckInvariant(n.right, level + 1, cx, max_x, min_y, max_y);
336 } else {
337 // split in dimension 1 = y
338 this->CheckInvariant(n.left, level + 1, min_x, max_x, min_y, cy);
339 this->CheckInvariant(n.right, level + 1, min_x, max_x, cy, max_y);
340 }
341 }
342
344 void CheckInvariant() const
345 {
346#ifdef KDTREE_DEBUG
347 this->CheckInvariant(this->root, 0, std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max(), std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max());
348#endif
349 }
350
351public:
354
361 template <typename It>
362 void Build(It begin, It end)
363 {
364 this->nodes.clear();
365 this->free_list.clear();
366 this->unbalanced = 0;
367 if (begin == end) return;
368 this->nodes.reserve(end - begin);
369
370 this->root = this->BuildSubtree(begin, end, 0);
371 this->CheckInvariant();
372 }
373
377 void Clear()
378 {
379 this->nodes.clear();
380 this->free_list.clear();
381 this->unbalanced = 0;
382 return;
383 }
384
388 void Rebuild()
389 {
390 this->Rebuild(nullptr, nullptr);
391 }
392
398 void Insert(const T &element)
399 {
400 if (this->Count() == 0) {
401 this->root = this->AddNode(element);
402 } else {
403 if (!this->IsUnbalanced() || !this->Rebuild(&element, nullptr)) {
404 this->InsertRecursive(element, this->root, 0);
405 this->IncrementUnbalanced();
406 }
407 this->CheckInvariant();
408 }
409 }
410
417 void Remove(const T &element)
418 {
419 size_t count = this->Count();
420 if (count == 0) return;
421 if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element)) {
422 /* If the removed element is the root node, this modifies this->root */
423 this->root = this->RemoveRecursive(element, this->root, 0);
424 this->IncrementUnbalanced();
425 }
426 this->CheckInvariant();
427 }
428
430 size_t Count() const
431 {
432 assert(this->free_list.size() <= this->nodes.size());
433 return this->nodes.size() - this->free_list.size();
434 }
435
441 T FindNearest(CoordT x, CoordT y) const
442 {
443 assert(this->Count() > 0);
444
445 CoordT xy[2] = { x, y };
446 return this->FindNearestRecursive(xy, this->root, 0).first;
447 }
448
458 template <typename Outputter>
459 void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, const Outputter &outputter) const
460 {
461 assert(x1 < x2);
462 assert(y1 < y2);
463
464 if (this->Count() == 0) return;
465
466 CoordT p1[2] = { x1, y1 };
467 CoordT p2[2] = { x2, y2 };
468 this->FindContainedRecursive(p1, p2, this->root, 0, outputter);
469 }
470
475 std::vector<T> FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
476 {
477 std::vector<T> result;
478 this->FindContained(x1, y1, x2, y2, [&result](T e) {result.push_back(e); });
479 return result;
480 }
481};
482
483#endif
K-dimensional tree, specialised for 2-dimensional space.
Definition kdtree.hpp:35
static const size_t MIN_REBALANCE_THRESHOLD
Arbitrary value for "not worth rebalancing".
Definition kdtree.hpp:46
void Build(It begin, It end)
Clear and rebuild the tree from a new sequence of elements,.
Definition kdtree.hpp:362
size_t Count() const
Get number of elements stored in tree.
Definition kdtree.hpp:430
bool IsUnbalanced() const
Check if the entire tree is in need of rebuilding.
Definition kdtree.hpp:311
void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, const Outputter &outputter) const
Find all items contained within the given rectangle.
Definition kdtree.hpp:459
bool Rebuild(const T *include_element, const T *exclude_element)
Rebuild the tree with all existing elements, optionally adding or removing one more.
Definition kdtree.hpp:99
std::vector< size_t > free_list
List of dead indices in the nodes vector.
Definition kdtree.hpp:49
node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit=std::numeric_limits< DistT >::max()) const
Search a sub-tree for the element nearest to a given point.
Definition kdtree.hpp:240
void InsertRecursive(const T &element, size_t node_idx, int level)
Insert one element in the tree somewhere below node_idx.
Definition kdtree.hpp:124
void Rebuild()
Reconstruct the tree with the same elements, letting it be fully balanced.
Definition kdtree.hpp:388
std::pair< T, DistT > node_distance
A data element and its distance to a searched-for point.
Definition kdtree.hpp:229
void Insert(const T &element)
Insert a single element in the tree.
Definition kdtree.hpp:398
Kdtree()
Construct a new Kdtree with the given xyfunc.
Definition kdtree.hpp:353
std::vector< node > nodes
Pool of all nodes in the tree.
Definition kdtree.hpp:48
std::vector< T > FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
Find all items contained within the given rectangle.
Definition kdtree.hpp:475
static node_distance SelectNearestNodeDistance(const node_distance &a, const node_distance &b)
Ordering function for node_distance objects, elements with equal distance are ordered by less-than co...
Definition kdtree.hpp:231
void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y) const
Verify that the invariant is true for a sub-tree, assert if not.
Definition kdtree.hpp:319
void Remove(const T &element)
Remove a single element from the tree, if it exists.
Definition kdtree.hpp:417
static const size_t INVALID_NODE
Index value indicating no-such-node.
Definition kdtree.hpp:45
size_t root
Index of root node.
Definition kdtree.hpp:50
size_t RemoveRecursive(const T &element, size_t node_idx, int level)
Find and remove one element from the tree.
Definition kdtree.hpp:184
size_t BuildSubtree(It begin, It end, int level)
Construct a subtree from elements between begin and end iterators, return index of root.
Definition kdtree.hpp:78
size_t AddNode(const T &element)
Create one new node in the tree, return its index in the pool.
Definition kdtree.hpp:54
void CheckInvariant() const
Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined.
Definition kdtree.hpp:344
T FindNearest(CoordT x, CoordT y) const
Find the element closest to given coordinate, in Manhattan distance.
Definition kdtree.hpp:441
CoordT SelectSplitCoord(It begin, It end, int level)
Find a coordinate value to split a range of elements at.
Definition kdtree.hpp:69
std::vector< T > FreeSubtree(size_t node_idx)
Free all children of the given node.
Definition kdtree.hpp:153
void Clear()
Clear the tree.
Definition kdtree.hpp:377
size_t CountValue(const T &element, size_t node_idx) const
Debugging function, counts number of occurrences of an element regardless of its correct position in ...
Definition kdtree.hpp:298
size_t unbalanced
Number approximating how unbalanced the tree might be.
Definition kdtree.hpp:51
constexpr T abs(const T a)
Returns the absolute value of (scalar) variable.
Definition math_func.hpp:23
Type of a node in the tree.
Definition kdtree.hpp:37
T element
Element stored at node.
Definition kdtree.hpp:38
size_t left
Index of node to the left, INVALID_NODE if none.
Definition kdtree.hpp:39
size_t right
Index of node to the right, INVALID_NODE if none.
Definition kdtree.hpp:40