Using Union-Find to Trace Connections in C++

May 23, 2021

·

DS&A Adventures

·

The union-find, sometimes referred to as the disjoint-set, is a classic algorithm that is used to identify connections in undirected graphs and to detect cycles in various applications. In this post, we discuss the merits and uses of the union-find algorithm and demonstrate its implementation in C++.

The Problem

Remember those connect the dots worksheets that you used to do as a kid to help you learn how to count? Those worksheets are a good starting point for our discussion.

cnct-dots

Think of each point on the worksheet above as a node in the ether that may or may not be connected to other nodes. Now connect a few dots at random, and then a few more. It should be obvious which nodes are connected to each other, just by looking at the picture. After all, there are only fourteen nodes to keep track of.

Now try the same thing with the picture below.

cnct-dots-complex

Having fun yet? (It's supposed to be a picture of sharks, by the way.) Once we have connected a few hundred dots at random, it becomes impractical to manually determine which points are connected to each other. This example, though seemingly trivial, hints at larger real-world problems.

Consider a social network where each node represents an individual member. Each member is (hopefully) connected to some number of friends, and those friends are connected to other friends, in turn. So how does a social network determine if two members are connected to each other? Note that a similar question could be asked of a public utility like electricity - how does one determine if a given home is connected to power? You have seen how difficult it can be to answer this question with a few thousand nodes. Now imagine how that difficulty increases when our node count reaches the billions.

Fortunately, we can employ the union-find algorithm to answer this question.

Defining the Application

Our data set (source) consists of one million nodes and million of connections. The first number in the data set indicates the number of nodes. The following pairs of numbers represent connections between the specified nodes.

Our goal is to marshal this data into a structure that can tell us whether any two pairs are connected. The client application will look like this:

#include <iostream>

int main(int argc, const char* argv[])
{
  int node_count;
  std::cin >> node_count;

  UnionFind uf(node_count);
  while (!std::cin.eof())
  {
    int a, b;
    std::cin >> a;
    std::cin >> b;

    if (uf.connected(a, b))
    {
      continue;
    }

    uf.unify(a, b);
  }

  std::cout << "Data set contains " << uf.set_count() << " sets." << std::endl;

  std::cout << "Enter a pair of nodes between 0 and " << node_count - 1 << " to see if they are connected." << std::endl;

  int a = std::atoi(argv[1]);
  int b = std::atoi(argv[2]);

  std::cout << "The provided nodes " << a << " and " << b << " are " << (uf.connected(a, b) ? "connected." : "not connected.") << std::endl;
}

The test application is simple enough. We start by reading in the node_count and initializing a UnionFind instance. We then iterate over pairs from the data set, processing connections as we go. If a pair is already connected, move on to the next one (this avoids duplicates).

Once all of the data has been processed, output the set_count. A set in this context means a group of all nodes that are connected in some way. If there are no connections, every node will be in its own set. And if every node is connected to every other, there will be only one set. Two nodes are connected if they are in the same set.

Finally, we take two nodes provided on the command line by the user and tell them whether those nodes are connected or not.

Implementing UnionFind

From the application described above, we can determine a rough interface for our UnionFind class:

  • Constructor that accepts a node_count
  • A connected function that returns true if two nodes are connected
  • A unify function that connects two nodes
  • A set_count function that returns the number of sets in the data

Construction and Storage

Let us begin with the constructor. The constructor should accept the node_count and it needs backing storage to represent our various nodes and their connections. Our storage in this case will be an array of integers that has size node_count, one for each node. The stored integer for one node represents the index of another node that the first is connected to.

class UnionFind
{
  int* nodes_;
  int set_count_;

  UnionFind(int node_count)
  {
    set_count_ = node_count;
    nodes_ = new int[node_count];
    for (int i = 0; i < node_count; ++i)
    {
      nodes_[i] = i;
    }
  }

  ~UnionFind()
  {
    delete[] nodes_;
  }
}

We create a set_count_ variable and initialize it to the node count, because no connection data has been analyzed yet and so every node is in its own set. Then we dynamically allocate the nodes_ array and ensure that we clean it up in our destructor. Finally, we initialize every node to connect to itself.

Find Node Set

A set in this algorithm does not have its own unique identifier. Instead, a set is identified by its root node. The root node is a node which points to itself. With this understanding in mind, it should now make sense why set_count_ is equal to node_count when UnionFind is first created - every node points to itself.

Given an arbitrary node id, we can determine the set it belongs to by following the node to its root. This informs the implementation of our find operation, a utility that unify and connected will rely on.

int find(int node) const
{
  while (node != nodes_[node])
  {
    node = nodes_[node];
  }

  return node;
}

We loop indefinitely, following one node to the next until we find a node that refers to itself (the root node). We return the root node id, which identifies the set to which the parameter node belongs.

Determine Connection Status

Our connected function is quite simple. It just checks whether the set ids for the queried nodes, obtained using find, are the same. If both nodes are in the same set, meaning they have the same root node, they are connected.

bool connected(int node_a, int node_b) const
{
  return find(node_a) == find(node_b);
}

Unify Nodes

At this point, we are able to find root nodes/sets and determine if any two nodes are connected. But none of them are connected by default. Thus, we provide this functionality through our unify operation. Let us see the implementation first, and then walk through it.

void unify(const int node_a, const int node_b)
{
  const int root_a = find(node_a);
  const int root_b = find(node_b);

  nodes_[root_a] = root_b;

  --set_count_;
}

At the most basic level, this function is creating trees of nodes. We find the root node of each tree and then connect those roots to one another, combining the trees. Each tree is a set, and the number of sets is equal to the number of trees that exist. Accordingly, we decrement the set_count_ every time a connection is made because two trees have been merged into one.

Optimizing Unify

The implementation of unify above is useful, but it can potentially result in very deep trees. In the worst case, this can result in linear cost growth as the number of input nodes increases. However, a few small changes can significantly increase performance.

In the original implementation, the first tree is connected to the second tree, regardless of how big each tree is. Instead, we want to make it so that the smaller tree connects to the bigger tree in every case. To do this, we need to start tracking the size of each tree.

We start by adding a sizes_ member to our class and initializing every entry in it to 1. (As you will recall from a moment ago, every node is in its own set to begin with.)

UnionFind()
{
  ...
  sizes_ = new int[node_count];
  for (int i = 0; i < node_count; ++i)
  {
    sizes_[i] = 1;
  }
}

~UnionFind()
{
  ...
  delete[] sizes_;
}

Then in unify, we rework the tree combination to take into account which tree is bigger.

void unify(const int node_a, const int node_b)
{
  const int root_a = find(node_a);
  const int root_b = find(node_b);

  const int size_a = sizes_[root_a];
  const int size_b = sizes_[root_b];

  if (size_a > size_b)
  {
    nodes_[root_b] = root_a;
    sizes_[root_a] += size_b;
  }

  else
  {
    nodes_[root_a] = root_b;
    sizes_[root_b] += size_a;
  }

  --set_count_;
}

If tree a is bigger than tree b, we connect b to a and add the size of b to a. The process is swapped if tree b is bigger than a. In this way, our tree sizes are kept up to date and the larger tree always absorbs the smaller tree.

The beauty of this approach is that it results in the lowest average depth across all nodes. This translates to a very low average find times, because it takes far fewer iterations to reach the root for any given node, and achieves logarithmic growth in the worst case. This improvement turns massive computing problems like the social network one into solvable ones.

Get Set Count

We have covered almost everything. All that is left to do is create a getter for the set count.

int set_count()
{
  return set_count_;
}

Expected Output

If we build and run our application now, with the data set contained in largeUF.txt and the following command line input

union-find 0 999999 < largeUF.txt

we should see output that tells us that the data set contains six sets and that the nodes 0 and 999999 are connected to one another.

Size: 1000000
Data set contains 6 sets.
The provided nodes 0 and 999999 are connected.

Closing Thoughts

If you want to optimize this algorithm even further, look into path compression, which results in even flatter trees. If you need something faster than that, you're on your own.



© 2021 Mustafa Moiz.