Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 24 additions & 23 deletions cpl/inc/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,39 +426,40 @@ OutIt dijkstra(const std::size_t n, InIt first, InIt last, OutIt dest, const std
return dest;
}

template <class InIt, class OutIt, class Pr1 = EdgeLess, class Pr2 = EdgeLess>
OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pr1 preferred = Pr1{}, Pr2 tie_break = Pr2{}) {
template <class InIt, class OutIt, class Pred = EdgeLess>
OutIt boruvka(const std::size_t n, InIt first, InIt last, OutIt dest, Pred pred = Pred{}) {
using edge_ptr = decltype(&*first);

DisjointSet<std::size_t> ds(n);
std::vector<std::size_t> cheapest(n, std::numeric_limits<std::size_t>::max());
std::vector<std::size_t> cheapest_edge(n, std::numeric_limits<std::size_t>::max());
std::vector<edge_ptr> cheapest(n);

std::size_t mst_size = 0;
while (mst_size < n - 1) {
for (std::size_t i = 0; i < n; ++i) {
cheapest[i] = std::numeric_limits<std::size_t>::max();
cheapest_edge[i] = std::numeric_limits<std::size_t>::max();
}
for (std::size_t components = n; components > 1;) {
std::fill(cheapest.begin(), cheapest.end(), nullptr);

for (auto it = first; it != last; ++it) {
auto set1 = ds.find(it->from);
auto set2 = ds.find(it->to);

if (set1 == set2) {
continue;
}

if (preferred(*it, cheapest[set1]) || (tie_break(*it, cheapest[set1]) && it->weight == cheapest[set1])) {
cheapest[set1] = it->weight;
cheapest_edge[set1] = it->to;
if (set1 != set2) {
if (!cheapest[set1] || pred(*it, *cheapest[set1])) {
cheapest[set1] = &*it;
}

if (!cheapest[set2] || pred(*it, *cheapest[set2])) {
cheapest[set2] = &*it;
}
}
}

for (std::size_t i = 0; i < n; ++i) {
if (cheapest[i] != std::numeric_limits<std::size_t>::max()) {
ds.union_rank(cheapest_edge[i], i);
*dest = Edge{cheapest_edge[i], i, cheapest[i]};
++mst_size;
++dest;
if (!cheapest[i]) {
continue;
}

auto set1 = ds.find(cheapest[i]->from), set2 = ds.find(cheapest[i]->to);
if (set1 != set2) {
ds.union_rank(set1, set2);
*dest++ = *cheapest[i];
--components;
}
}
}
Expand Down
81 changes: 81 additions & 0 deletions tests/cpl/boruvka_mst/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) Brandon Pacewic
// SPDX-License-Identifier: MIT

#include <cassert>
#include <vector>

#include "minimum_spanning_tree_test_cases.hpp"
#include "tree.h"

template <class EdgeContainer>
auto total_weight(const EdgeContainer& edges) {
using weight_type = decltype(edges[0].weight);
weight_type sum = 0;
for (const auto& e : edges) {
sum += e.weight;
}
return sum;
}

int main() {
using namespace std;
using namespace cpl;

{
auto [input, expected] = small_test_case();
vector<Edge<>> mst;
boruvka(9, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 8);

auto expected_weight = total_weight(expected);
auto actual_weight = total_weight(mst);
assert(actual_weight == expected_weight);
}
{
auto [input, expected] = single_edge_test_case();
vector<Edge<>> mst;
boruvka(2, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 1);
assert(total_weight(mst) == total_weight(expected));
}
{
vector<Edge<>> input = {
{0, 1, 1},
{1, 2, 2},
{0, 2, 3},
};
vector<Edge<>> mst;
boruvka(3, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 2);
assert(total_weight(mst) == 3);
}
{
auto [input, expected] = same_weight_test_case();
vector<Edge<>> mst;
boruvka(4, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 3);
assert(total_weight(mst) == total_weight(expected));
}
{
auto [input, expected] = large_test_case();
vector<Edge<>> mst;
boruvka(100, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 99);
assert(total_weight(mst) == total_weight(expected));
}
{
auto [input, expected] = large_sparse_test_case();
vector<Edge<>> mst;
boruvka(1000, input.begin(), input.end(), back_inserter(mst));

assert(mst.size() == 999);
assert(total_weight(mst) == total_weight(expected));
}

return 0;
}