May 23, 2021
DS&A Adventures
The union-find, sometimes referred to as the disjoint-set, is a classic algorithm that is used to identify connections in undirected graphs and to detect cycles in various applications. In this post, we discuss the merits and uses of the union-find algorithm and demonstrate its implementation in C++.
Remember those connect the dots worksheets that you used to do as a kid to help you learn how to count? Those worksheets are a good starting point for our discussion.
Think of each point on the worksheet above as a node in the ether that may or may not be connected to other nodes. Now connect a few dots at random, and then a few more. It should be obvious which nodes are connected to each other, just by looking at the picture. After all, there are only fourteen nodes to keep track of.
Now try the same thing with the picture below.
Having fun yet? (It's supposed to be a picture of sharks, by the way.) Once we have connected a few hundred dots at random, it becomes impractical to manually determine which points are connected to each other. This example, though seemingly trivial, hints at larger real-world problems.
Consider a social network where each node represents an individual member. Each member is (hopefully) connected to some number of friends, and those friends are connected to other friends, in turn. So how does a social network determine if two members are connected to each other? Note that a similar question could be asked of a public utility like electricity - how does one determine if a given home is connected to power? You have seen how difficult it can be to answer this question with a few thousand nodes. Now imagine how that difficulty increases when our node count reaches the billions.
Fortunately, we can employ the union-find algorithm to answer this question.
Our data set (source) consists of one million nodes and million of connections. The first number in the data set indicates the number of nodes. The following pairs of numbers represent connections between the specified nodes.
Our goal is to marshal this data into a structure that can tell us whether any two pairs are connected. The client application will look like this:
#include <iostream>
int main(int argc, const char* argv[])
{
int node_count;
std::cin >> node_count;
UnionFind uf(node_count);
while (!std::cin.eof())
{
int a, b;
std::cin >> a;
std::cin >> b;
if (uf.connected(a, b))
{
continue;
}
uf.unify(a, b);
}
std::cout << "Data set contains " << uf.set_count() << " sets." << std::endl;
std::cout << "Enter a pair of nodes between 0 and " << node_count - 1 << " to see if they are connected." << std::endl;
int a = std::atoi(argv[1]);
int b = std::atoi(argv[2]);
std::cout << "The provided nodes " << a << " and " << b << " are " << (uf.connected(a, b) ? "connected." : "not connected.") << std::endl;
}
The test application is simple enough. We start by reading in the node_count
and initializing a UnionFind
instance. We then iterate over pairs from the data set, processing connections as we go. If a pair is already connected, move on to the next one (this avoids duplicates).
Once all of the data has been processed, output the set_count
. A set in this context means a group of all nodes that are connected in some way. If there are no connections, every node will be in its own set. And if every node is connected to every other, there will be only one set. Two nodes are connected if they are in the same set.
Finally, we take two nodes provided on the command line by the user and tell them whether those nodes are connected or not.
From the application described above, we can determine a rough interface for our UnionFind
class:
node_count
connected
function that returns true if two nodes are connectedunify
function that connects two nodesset_count
function that returns the number of sets in the dataLet us begin with the constructor. The constructor should accept the node_count
and it needs backing storage to represent our various nodes and their connections. Our storage in this case will be an array of integers that has size node_count
, one for each node. The stored integer for one node represents the index of another node that the first is connected to.
class UnionFind
{
int* nodes_;
int set_count_;
UnionFind(int node_count)
{
set_count_ = node_count;
nodes_ = new int[node_count];
for (int i = 0; i < node_count; ++i)
{
nodes_[i] = i;
}
}
~UnionFind()
{
delete[] nodes_;
}
}
We create a set_count_
variable and initialize it to the node count, because no connection data has been analyzed yet and so every node is in its own set. Then we dynamically allocate the nodes_
array and ensure that we clean it up in our destructor. Finally, we initialize every node to connect to itself.
A set in this algorithm does not have its own unique identifier. Instead, a set is identified by its root node. The root node is a node which points to itself. With this understanding in mind, it should now make sense why set_count_
is equal to node_count
when UnionFind
is first created - every node points to itself.
Given an arbitrary node id, we can determine the set it belongs to by following the node to its root. This informs the implementation of our find
operation, a utility that unify
and connected
will rely on.
int find(int node) const
{
while (node != nodes_[node])
{
node = nodes_[node];
}
return node;
}
We loop indefinitely, following one node to the next until we find a node that refers to itself (the root node). We return the root node id, which identifies the set to which the parameter node
belongs.
Our connected
function is quite simple. It just checks whether the set ids for the queried nodes, obtained using find
, are the same. If both nodes are in the same set, meaning they have the same root node, they are connected.
bool connected(int node_a, int node_b) const
{
return find(node_a) == find(node_b);
}
At this point, we are able to find root nodes/sets and determine if any two nodes are connected. But none of them are connected by default. Thus, we provide this functionality through our unify
operation. Let us see the implementation first, and then walk through it.
void unify(const int node_a, const int node_b)
{
const int root_a = find(node_a);
const int root_b = find(node_b);
nodes_[root_a] = root_b;
--set_count_;
}
At the most basic level, this function is creating trees of nodes. We find the root node of each tree and then connect those roots to one another, combining the trees. Each tree is a set, and the number of sets is equal to the number of trees that exist. Accordingly, we decrement the set_count_
every time a connection is made because two trees have been merged into one.
The implementation of unify
above is useful, but it can potentially result in very deep trees. In the worst case, this can result in linear cost growth as the number of input nodes increases. However, a few small changes can significantly increase performance.
In the original implementation, the first tree is connected to the second tree, regardless of how big each tree is. Instead, we want to make it so that the smaller tree connects to the bigger tree in every case. To do this, we need to start tracking the size of each tree.
We start by adding a sizes_
member to our class and initializing every entry in it to 1. (As you will recall from a moment ago, every node is in its own set to begin with.)
UnionFind()
{
...
sizes_ = new int[node_count];
for (int i = 0; i < node_count; ++i)
{
sizes_[i] = 1;
}
}
~UnionFind()
{
...
delete[] sizes_;
}
Then in unify
, we rework the tree combination to take into account which tree is bigger.
void unify(const int node_a, const int node_b)
{
const int root_a = find(node_a);
const int root_b = find(node_b);
const int size_a = sizes_[root_a];
const int size_b = sizes_[root_b];
if (size_a > size_b)
{
nodes_[root_b] = root_a;
sizes_[root_a] += size_b;
}
else
{
nodes_[root_a] = root_b;
sizes_[root_b] += size_a;
}
--set_count_;
}
If tree a
is bigger than tree b
, we connect b
to a
and add the size of b
to a
. The process is swapped if tree b
is bigger than a
. In this way, our tree sizes are kept up to date and the larger tree always absorbs the smaller tree.
The beauty of this approach is that it results in the lowest average depth across all nodes. This translates to a very low average find
times, because it takes far fewer iterations to reach the root for any given node, and achieves logarithmic growth in the worst case. This improvement turns massive computing problems like the social network one into solvable ones.
We have covered almost everything. All that is left to do is create a getter for the set count.
int set_count()
{
return set_count_;
}
If we build and run our application now, with the data set contained in largeUF.txt and the following command line input
union-find 0 999999 < largeUF.txt
we should see output that tells us that the data set contains six sets and that the nodes 0 and 999999 are connected to one another.
Size: 1000000
Data set contains 6 sets.
The provided nodes 0 and 999999 are connected.
If you want to optimize this algorithm even further, look into path compression, which results in even flatter trees. If you need something faster than that, you're on your own.