18 August 2013

Union Find

I never understood the point of union find (not after learning it in high school with Dr. Nevard, or even after learning about Kruskal's minimum spanning tree algorithm) until I saw this problem:

A journey to the Moon

The member states of the UN are planning to send two people to the Moon. But there is a problem. In line with their principles of global unity, they want to pair astronauts in such a way, that both are citizens of different countries.

There are N astronauts numbered with identifiers from 0 to N-1. They are qualified and trained to be sent to the moon. But the trouble is that those in charge of the mission haven’t been directly informed about the citizenship of each astronaut. The only information they have is that some particular pairs of astronauts belong to the same country.

Your task is to compute in how many ways they can pick a pair of astronauts satisfying the above criteria, to be sent to the moon. Assume that you are provided enough pairs to let you identify the groups of astronauts even though you might not know their country directly. For instance, if 1,2,3 are astronauts from the same country; it is sufficient to mention that (1,2) and (2,3) are pairs of astronauts from the same country without providing information about a third pair (1,3).

Input Format

The first line contains two integers, N and I separated by a single space. I lines follow. each line contains 2 integers separated by a single space A and B such that

0 ≤ A, B ≤ N-1

and A and B are astronauts from the same country.

Output Format

An integer containing the number of permissible ways in which a pair of astronauts can be sent to the moon.




Sample Input

4 2
0 1
2 3

Sample Output



As persons numbered 0 and 1 belong to same country and 2 and 3 belong to same country. So the UN can choose one person of 0,1 and one out of 2,3. So number of ways of choosing pair is 4.

This pre formatting sucks, but that's beside the point.

Here's my solution:

class DisjointSet(object):
    def __init__(self, vals):
        self.parents = {x: x for x in vals}
    def find(self, x):
        if self.parents[x] == x:
            return x
            return self.find(self.parents[x])
    def union(self, x, y):
        xRoot = self.find(x)
        yRoot = self.find(y)
        self.parents[xRoot] = yRoot
    def sets(self):
        d = {}
        for child in self.parents:
            parent = self.find(child)
            if parent not in d:
                d[parent] = set()
        return d        

N, L = map(int, raw_input().strip().split())
ds = DisjointSet(range(N))
for i in range(1, L+1):
    a, b = map(int, raw_input().strip().split())
    ds.union(a, b)

s = ds.sets()
acc = 0
if len(s) > 1:
    cumsum = [len(x) for x in s.values()]
    for i in range(len(cumsum)-2, 0, -1):
        cumsum[i] += cumsum[i+1]
    for i, x in enumerate(s.values()[:-1]):
        acc += len(x) * cumsum[i+1]
print acc

Maybe someday I will get proper code formatting for my blog.

No comments:

Post a Comment