OpenTTD Source 20260218-master-g2123fca5ea
kdtree.hpp
Go to the documentation of this file.
1/*
2 * This file is part of OpenTTD.
3 * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4 * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5 * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <https://www.gnu.org/licenses/old-licenses/gpl-2.0>.
6 */
7
9
10#ifndef KDTREE_HPP
11#define KDTREE_HPP
12
32template <typename T, typename TxyFunc, typename CoordT, typename DistT>
33class Kdtree {
35 struct node {
37 size_t left;
38 size_t right;
39
41 };
42
43 static const size_t INVALID_NODE = SIZE_MAX;
44 static const size_t MIN_REBALANCE_THRESHOLD = 8;
45
46 std::vector<node> nodes;
47 std::vector<size_t> free_list;
48 size_t root;
49 size_t unbalanced;
50
56 size_t AddNode(const T &element)
57 {
58 if (this->free_list.empty()) {
59 this->nodes.emplace_back(element);
60 return this->nodes.size() - 1;
61 } else {
62 size_t newidx = this->free_list.back();
63 this->free_list.pop_back();
64 this->nodes[newidx] = node{ element };
65 return newidx;
66 }
67 }
68
76 template <typename It>
77 CoordT SelectSplitCoord(It begin, It end, int level)
78 {
79 It mid = begin + (end - begin) / 2;
80 std::nth_element(begin, mid, end, [&](T a, T b) { return TxyFunc()(a, level % 2) < TxyFunc()(b, level % 2); });
81 return TxyFunc()(*mid, level % 2);
82 }
83
91 template <typename It>
92 size_t BuildSubtree(It begin, It end, int level)
93 {
94 ptrdiff_t count = end - begin;
95
96 if (count == 0) {
97 return INVALID_NODE;
98 } else if (count == 1) {
99 return this->AddNode(*begin);
100 } else if (count > 1) {
101 CoordT split_coord = this->SelectSplitCoord(begin, end, level);
102 It split = std::partition(begin, end, [&](T v) { return TxyFunc()(v, level % 2) < split_coord; });
103 size_t newidx = this->AddNode(*split);
104 this->nodes[newidx].left = this->BuildSubtree(begin, split, level + 1);
105 this->nodes[newidx].right = this->BuildSubtree(split + 1, end, level + 1);
106 return newidx;
107 } else {
108 NOT_REACHED();
109 }
110 }
111
118 bool Rebuild(const T *include_element, const T *exclude_element)
119 {
120 size_t initial_count = this->Count();
121 if (initial_count < MIN_REBALANCE_THRESHOLD) return false;
122
123 T root_element = this->nodes[this->root].element;
124 std::vector<T> elements = this->FreeSubtree(this->root);
125 elements.push_back(root_element);
126
127 if (include_element != nullptr) {
128 elements.push_back(*include_element);
129 initial_count++;
130 }
131 if (exclude_element != nullptr) {
132 typename std::vector<T>::iterator removed = std::remove(elements.begin(), elements.end(), *exclude_element);
133 elements.erase(removed, elements.end());
134 initial_count--;
135 }
136
137 this->Build(elements.begin(), elements.end());
138 assert(initial_count == this->Count());
139 return true;
140 }
141
148 void InsertRecursive(const T &element, size_t node_idx, int level)
149 {
150 /* Dimension index of current level */
151 int dim = level % 2;
152 /* Node reference */
153 node &n = this->nodes[node_idx];
154
155 /* Coordinate of element splitting at this node */
156 CoordT nc = TxyFunc()(n.element, dim);
157 /* Coordinate of the new element */
158 CoordT ec = TxyFunc()(element, dim);
159 /* Which side to insert on */
160 size_t &next = (ec < nc) ? n.left : n.right;
161
162 if (next == INVALID_NODE) {
163 /* New leaf */
164 size_t newidx = this->AddNode(element);
165 /* Vector may have been reallocated at this point, n and next are invalid */
166 node &nn = this->nodes[node_idx];
167 if (ec < nc) nn.left = newidx; else nn.right = newidx;
168 } else {
169 this->InsertRecursive(element, next, level + 1);
170 }
171 }
172
178 std::vector<T> FreeSubtree(size_t node_idx)
179 {
180 std::vector<T> subtree_elements;
181 node &n = this->nodes[node_idx];
182
183 /* We'll be appending items to the free_list, get index of our first item */
184 size_t first_free = this->free_list.size();
185 /* Prepare the descent with our children */
186 if (n.left != INVALID_NODE) this->free_list.push_back(n.left);
187 if (n.right != INVALID_NODE) this->free_list.push_back(n.right);
188 n.left = n.right = INVALID_NODE;
189
190 /* Recursively free the nodes being collected */
191 for (size_t i = first_free; i < this->free_list.size(); i++) {
192 node &fn = this->nodes[this->free_list[i]];
193 subtree_elements.push_back(fn.element);
194 if (fn.left != INVALID_NODE) this->free_list.push_back(fn.left);
195 if (fn.right != INVALID_NODE) this->free_list.push_back(fn.right);
196 fn.left = fn.right = INVALID_NODE;
197 }
198
199 return subtree_elements;
200 }
201
209 size_t RemoveRecursive(const T &element, size_t node_idx, int level)
210 {
211 /* Node reference */
212 node &n = this->nodes[node_idx];
213
214 if (n.element == element) {
215 /* Remove this one */
216 this->free_list.push_back(node_idx);
217 if (n.left == INVALID_NODE && n.right == INVALID_NODE) {
218 /* Simple case, leaf, new child node for parent is "none" */
219 return INVALID_NODE;
220 } else {
221 /* Complex case, rebuild the sub-tree */
222 std::vector<T> subtree_elements = this->FreeSubtree(node_idx);
223 return this->BuildSubtree(subtree_elements.begin(), subtree_elements.end(), level);;
224 }
225 } else {
226 /* Search in a sub-tree */
227 /* Dimension index of current level */
228 int dim = level % 2;
229 /* Coordinate of element splitting at this node */
230 CoordT nc = TxyFunc()(n.element, dim);
231 /* Coordinate of the element being removed */
232 CoordT ec = TxyFunc()(element, dim);
233 /* Which side to remove from */
234 size_t next = (ec < nc) ? n.left : n.right;
235 assert(next != INVALID_NODE); // node must exist somewhere and must be found before a leaf is reached
236 /* Descend */
237 size_t new_branch = this->RemoveRecursive(element, next, level + 1);
238 if (new_branch != next) {
239 /* Vector may have been reallocated at this point, n and next are invalid */
240 node &nn = this->nodes[node_idx];
241 if (ec < nc) nn.left = new_branch; else nn.right = new_branch;
242 }
243 return node_idx;
244 }
245 }
246
247
248 DistT ManhattanDistance(const T &element, CoordT x, CoordT y) const
249 {
250 return abs((DistT)TxyFunc()(element, 0) - (DistT)x) + abs((DistT)TxyFunc()(element, 1) - (DistT)y);
251 }
252
254 using node_distance = std::pair<T, DistT>;
262 {
263 if (a.second < b.second) return a;
264 if (b.second < a.second) return b;
265 if (a.first < b.first) return a;
266 if (b.first < a.first) return b;
267 NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
268 }
269
277 node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit = std::numeric_limits<DistT>::max()) const
278 {
279 /* Dimension index of current level */
280 int dim = level % 2;
281 /* Node reference */
282 const node &n = this->nodes[node_idx];
283
284 /* Coordinate of element splitting at this node */
285 CoordT c = TxyFunc()(n.element, dim);
286 /* This node's distance to target */
287 DistT thisdist = this->ManhattanDistance(n.element, xy[0], xy[1]);
288 /* Assume this node is the best choice for now */
289 node_distance best = std::make_pair(n.element, thisdist);
290
291 /* Next node to visit */
292 size_t next = (xy[dim] < c) ? n.left : n.right;
293 if (next != INVALID_NODE) {
294 /* Check if there is a better node down the tree */
295 best = SelectNearestNodeDistance(best, this->FindNearestRecursive(xy, next, level + 1));
296 }
297
298 limit = std::min(best.second, limit);
299
300 /* Check if the distance from current best is worse than distance from target to splitting line,
301 * if it is we also need to check the other side of the split. */
302 size_t opposite = (xy[dim] >= c) ? n.left : n.right; // reverse of above
303 if (opposite != INVALID_NODE && limit >= abs((int)xy[dim] - (int)c)) {
304 node_distance other_candidate = this->FindNearestRecursive(xy, opposite, level + 1, limit);
305 best = SelectNearestNodeDistance(best, other_candidate);
306 }
307
308 return best;
309 }
310
311 template <typename Outputter>
312 void FindContainedRecursive(CoordT p1[2], CoordT p2[2], size_t node_idx, int level, const Outputter &outputter) const
313 {
314 /* Dimension index of current level */
315 int dim = level % 2;
316 /* Node reference */
317 const node &n = this->nodes[node_idx];
318
319 /* Coordinate of element splitting at this node */
320 CoordT ec = TxyFunc()(n.element, dim);
321 /* Opposite coordinate of element */
322 CoordT oc = TxyFunc()(n.element, 1 - dim);
323
324 /* Test if this element is within rectangle */
325 if (ec >= p1[dim] && ec < p2[dim] && oc >= p1[1 - dim] && oc < p2[1 - dim]) outputter(n.element);
326
327 /* Recurse left if part of rectangle is left of split */
328 if (p1[dim] < ec && n.left != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.left, level + 1, outputter);
329
330 /* Recurse right if part of rectangle is right of split */
331 if (p2[dim] > ec && n.right != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.right, level + 1, outputter);
332 }
333
340 size_t CountValue(const T &element, size_t node_idx) const
341 {
342 if (node_idx == INVALID_NODE) return 0;
343 const node &n = this->nodes[node_idx];
344 return this->CountValue(element, n.left) + this->CountValue(element, n.right) + ((n.element == element) ? 1 : 0);
345 }
346
347 void IncrementUnbalanced(size_t amount = 1)
348 {
349 this->unbalanced += amount;
350 }
351
356 bool IsUnbalanced() const
357 {
358 size_t count = this->Count();
359 if (count < MIN_REBALANCE_THRESHOLD) return false;
360 return this->unbalanced > count / 4;
361 }
362
372 void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y) const
373 {
374 if (node_idx == INVALID_NODE) return;
375
376 const node &n = this->nodes[node_idx];
377 CoordT cx = TxyFunc()(n.element, 0);
378 CoordT cy = TxyFunc()(n.element, 1);
379
380 assert(cx >= min_x);
381 assert(cx < max_x);
382 assert(cy >= min_y);
383 assert(cy < max_y);
384
385 if (level % 2 == 0) {
386 /* split in dimension 0 = x */
387 this->CheckInvariant(n.left, level + 1, min_x, cx, min_y, max_y);
388 this->CheckInvariant(n.right, level + 1, cx, max_x, min_y, max_y);
389 } else {
390 /* split in dimension 1 = y */
391 this->CheckInvariant(n.left, level + 1, min_x, max_x, min_y, cy);
392 this->CheckInvariant(n.right, level + 1, min_x, max_x, cy, max_y);
393 }
394 }
395
397 void CheckInvariant() const
398 {
399#ifdef KDTREE_DEBUG
400 this->CheckInvariant(this->root, 0, std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max(), std::numeric_limits<CoordT>::min(), std::numeric_limits<CoordT>::max());
401#endif
402 }
403
404public:
407
414 template <typename It>
415 void Build(It begin, It end)
416 {
417 this->nodes.clear();
418 this->free_list.clear();
419 this->unbalanced = 0;
420 if (begin == end) return;
421 this->nodes.reserve(end - begin);
422
423 this->root = this->BuildSubtree(begin, end, 0);
424 this->CheckInvariant();
425 }
426
430 void Clear()
431 {
432 this->nodes.clear();
433 this->free_list.clear();
434 this->unbalanced = 0;
435 return;
436 }
437
441 void Rebuild()
442 {
443 this->Rebuild(nullptr, nullptr);
444 }
445
452 void Insert(const T &element)
453 {
454 if (this->Count() == 0) {
455 this->root = this->AddNode(element);
456 } else {
457 if (!this->IsUnbalanced() || !this->Rebuild(&element, nullptr)) {
458 this->InsertRecursive(element, this->root, 0);
459 this->IncrementUnbalanced();
460 }
461 this->CheckInvariant();
462 }
463 }
464
472 void Remove(const T &element)
473 {
474 size_t count = this->Count();
475 if (count == 0) return;
476 if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element)) {
477 /* If the removed element is the root node, this modifies this->root */
478 this->root = this->RemoveRecursive(element, this->root, 0);
479 this->IncrementUnbalanced();
480 }
481 this->CheckInvariant();
482 }
483
488 size_t Count() const
489 {
490 assert(this->free_list.size() <= this->nodes.size());
491 return this->nodes.size() - this->free_list.size();
492 }
493
502 T FindNearest(CoordT x, CoordT y) const
503 {
504 assert(this->Count() > 0);
505
506 CoordT xy[2] = { x, y };
507 return this->FindNearestRecursive(xy, this->root, 0).first;
508 }
509
519 template <typename Outputter>
520 void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, const Outputter &outputter) const
521 {
522 assert(x1 < x2);
523 assert(y1 < y2);
524
525 if (this->Count() == 0) return;
526
527 CoordT p1[2] = { x1, y1 };
528 CoordT p2[2] = { x2, y2 };
529 this->FindContainedRecursive(p1, p2, this->root, 0, outputter);
530 }
531
541 std::vector<T> FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
542 {
543 std::vector<T> result;
544 this->FindContained(x1, y1, x2, y2, [&result](T e) {result.push_back(e); });
545 return result;
546 }
547};
548
549#endif
void Build(It begin, It end)
Clear and rebuild the tree from a new sequence of elements,.
Definition kdtree.hpp:415
size_t Count() const
Get number of elements stored in tree.
Definition kdtree.hpp:488
bool IsUnbalanced() const
Check if the entire tree is in need of rebuilding.
Definition kdtree.hpp:356
void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, const Outputter &outputter) const
Find all items contained within the given rectangle.
Definition kdtree.hpp:520
bool Rebuild(const T *include_element, const T *exclude_element)
Rebuild the tree with all existing elements, optionally adding or removing one more.
Definition kdtree.hpp:118
node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit=std::numeric_limits< DistT >::max()) const
Search a sub-tree for the element nearest to a given point.
Definition kdtree.hpp:277
void InsertRecursive(const T &element, size_t node_idx, int level)
Insert one element in the tree somewhere below node_idx.
Definition kdtree.hpp:148
void Rebuild()
Reconstruct the tree with the same elements, letting it be fully balanced.
Definition kdtree.hpp:441
std::pair< T, DistT > node_distance
A data element and its distance to a searched-for point.
Definition kdtree.hpp:254
void Insert(const T &element)
Insert a single element in the tree.
Definition kdtree.hpp:452
Kdtree()
Construct a new Kdtree with the given xyfunc.
Definition kdtree.hpp:406
std::vector< T > FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
Find all items contained within the given rectangle.
Definition kdtree.hpp:541
static node_distance SelectNearestNodeDistance(const node_distance &a, const node_distance &b)
Ordering function for node_distance objects, elements with equal distance are ordered by less-than co...
Definition kdtree.hpp:261
void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y) const
Verify that the invariant is true for a sub-tree, assert if not.
Definition kdtree.hpp:372
void Remove(const T &element)
Remove a single element from the tree, if it exists.
Definition kdtree.hpp:472
static const size_t INVALID_NODE
Index value indicating no-such-node.
Definition kdtree.hpp:43
size_t RemoveRecursive(const T &element, size_t node_idx, int level)
Find and remove one element from the tree.
Definition kdtree.hpp:209
size_t BuildSubtree(It begin, It end, int level)
Construct a subtree from elements between begin and end iterators.
Definition kdtree.hpp:92
size_t AddNode(const T &element)
Create one new node in the tree.
Definition kdtree.hpp:56
void CheckInvariant() const
Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined.
Definition kdtree.hpp:397
T FindNearest(CoordT x, CoordT y) const
Find the element closest to given coordinate, in Manhattan distance.
Definition kdtree.hpp:502
CoordT SelectSplitCoord(It begin, It end, int level)
Find a coordinate value to split a range of elements at.
Definition kdtree.hpp:77
std::vector< T > FreeSubtree(size_t node_idx)
Free all children of the given node.
Definition kdtree.hpp:178
void Clear()
Clear the tree.
Definition kdtree.hpp:430
size_t CountValue(const T &element, size_t node_idx) const
Debugging function, counts number of occurrences of an element regardless of its correct position in ...
Definition kdtree.hpp:340
constexpr T abs(const T a)
Returns the absolute value of (scalar) variable.
Definition math_func.hpp:23
T element
Element stored at node.
Definition kdtree.hpp:36
size_t left
Index of node to the left, INVALID_NODE if none.
Definition kdtree.hpp:37
size_t right
Index of node to the right, INVALID_NODE if none.
Definition kdtree.hpp:38