diff --git a/cpl/inc/tree.h b/cpl/inc/tree.h index 0394d95..dc29a0f 100644 --- a/cpl/inc/tree.h +++ b/cpl/inc/tree.h @@ -426,39 +426,40 @@ OutIt dijkstra(const std::size_t n, InIt first, InIt last, OutIt dest, const std return dest; } -template -OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pr1 preferred = Pr1{}, Pr2 tie_break = Pr2{}) { +template +OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pred pred = Pred{}) { + using edge_ptr = decltype(&*first); + DisjointSet ds(n); - std::vector cheapest(n, std::numeric_limits::max()); - std::vector cheapest_edge(n, std::numeric_limits::max()); + std::vector cheapest(n); - std::size_t mst_size = 0; - while (mst_size < n - 1) { - for (std::size_t i = 0; i < n; ++i) { - cheapest[i] = std::numeric_limits::max(); - cheapest_edge[i] = std::numeric_limits::max(); - } + for (std::size_t components = n; components > 1;) { + std::fill(cheapest.begin(), cheapest.end(), nullptr); for (auto it = first; it != last; ++it) { auto set1 = ds.find(it->from); auto set2 = ds.find(it->to); - - if (set1 == set2) { - continue; - } - - if (preferred(*it, cheapest[set1]) || (tie_break(*it, cheapest[set1]) && it->weight == cheapest[set1])) { - cheapest[set1] = it->weight; - cheapest_edge[set1] = it->to; + if (set1 != set2) { + if (!cheapest[set1] || pred(*it, *cheapest[set1])) { + cheapest[set1] = &*it; + } + + if (!cheapest[set2] || pred(*it, *cheapest[set2])) { + cheapest[set2] = &*it; + } } } for (std::size_t i = 0; i < n; ++i) { - if (cheapest[i] != std::numeric_limits::max()) { - ds.union_rank(cheapest_edge[i], i); - *dest = Edge{cheapest_edge[i], i, cheapest[i]}; - ++mst_size; - ++dest; + if (!cheapest[i]) { + continue; + } + + auto set1 = ds.find(cheapest[i]->from), set2 = ds.find(cheapest[i]->to); + if (set1 != set2) { + ds.union_rank(set1, set2); + *dest++ = *cheapest[i]; + --components; } } } diff --git a/tests/cpl/boruvka_mst/test.cpp b/tests/cpl/boruvka_mst/test.cpp new file mode 100644 index 0000000..b64b197 --- /dev/null +++ b/tests/cpl/boruvka_mst/test.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Brandon Pacewic +// SPDX-License-Identifier: MIT + +#include +#include + +#include "minimum_spanning_tree_test_cases.hpp" +#include "tree.h" + +template +auto total_weight(const EdgeContainer& edges) { + using weight_type = decltype(edges[0].weight); + weight_type sum = 0; + for (const auto& e : edges) { + sum += e.weight; + } + return sum; +} + +int main() { + using namespace std; + using namespace cpl; + + { + auto [input, expected] = small_test_case(); + vector> mst; + boruvka(9, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 8); + + auto expected_weight = total_weight(expected); + auto actual_weight = total_weight(mst); + assert(actual_weight == expected_weight); + } + { + auto [input, expected] = single_edge_test_case(); + vector> mst; + boruvka(2, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 1); + assert(total_weight(mst) == total_weight(expected)); + } + { + vector> input = { + {0, 1, 1}, + {1, 2, 2}, + {0, 2, 3}, + }; + vector> mst; + boruvka(3, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 2); + assert(total_weight(mst) == 3); + } + { + auto [input, expected] = same_weight_test_case(); + vector> mst; + boruvka(4, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 3); + assert(total_weight(mst) == total_weight(expected)); + } + { + auto [input, expected] = large_test_case(); + vector> mst; + boruvka(100, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 99); + assert(total_weight(mst) == total_weight(expected)); + } + { + auto [input, expected] = large_sparse_test_case(); + vector> mst; + boruvka(1000, input.begin(), input.end(), back_inserter(mst)); + + assert(mst.size() == 999); + assert(total_weight(mst) == total_weight(expected)); + } + + return 0; +}