背景🎣

上周六晚参加的LeetCode双周赛81场第二题2316. Count Unreachable Pairs of Nodes in an Undirected Graph提交了三次才过😓,主要手写查找函数的时候没有使用路径压缩(或者合并函数里面按秩合并也可以过的),看来自己对并查集还不熟悉,用得极少,特地写一篇博客梳理一下此算法。

原理📖

数学上的并查集(Disjoint sets)可以简单理解为若干个不相交的集合。而数据结构上的并查集(Disjoint-set data structure)一般是用森林来表示集合关系,每棵树是一个独立的集合,每个集合的标志就是树的root。底层存储可以用数组,其中每个元素的值指向其parent所在的位置,特别地,root的parent就是它自己

并查集主要实现了查找Find(找出某个元素属于哪个集合,即返回所在树的root)和合并Union(合并两个集合,即两棵树合并成一棵树)这两个基础操作,在此之上可以实现很多功能,如:

  • 判断两个元素e1e2是否属于一个集合

    Find(e1) == Find(e2)

  • 计算所有不相交集合的个数,即树的个数,类似于求总的连通分量的个数

    遍历所有的元素ei,计算不同Find(ei)结果的个数

    也可以动态维护一个变量

  • 计算某一个集合的元素个数,即树的大小,类似于求某个连通分量的大小

    遍历所有的元素ei,计算Find(ei) == target_root的个数

    也可以动态维护一个变量

实现💻

底层数据结构

用数组disjoint_sets来存储元素,i位置的元素ei=disjoint_sets[i]表示其parent的位置。

以上面说到的LeetCode 2316题的示例二做个解释:

LeetCode 2316 Example 2

对应的 disjoint_sets 可以为[0, 1, 0, 3, 0, 0, 1],其中0、2、4、5元素属于以0为root的集合,3元素属于以3为root的集合,1、6属于以1为root的集合。

初始化

新建数组的时候,直接把i位置的值设为i,表示每个元素的parent都是它自己,整个并查集包含n个互不相交的集合,整个森林有n棵只包含root的树。

1
2
3
void Initialize(int n){
for (int i = 0; i < n; i++) disjoint_sets[i] = i;
}

查找

对于一个i位置的元素ei,如果disjoint_sets[i] != i则说明它不是root,那么继续向上找它的parent,也就是disjoint_sets[i]位置的元素,直到找到为止。这里有递归和迭代两种写法。

递归写法

1
2
3
4
int Find(int i){
if (disjoint_sets[i] != i) return Find(disjoint_sets[i]);
else return i;
}

迭代写法

1
2
3
4
5
6
int Find(int i){
while (disjoint_sets[i] != i){
i = disjoint_sets[i];
}
return i;
}

合并

对于两个不属于同一集合的元素eiej,要将它们合并成一个集合。以把ej所属集合并进ei所属集合为例,只需要找到它们的root,再把ej的root指向ei的root即可。

1
2
3
void Union(int i, int j){
disjoint_sets[Find(j)] = Find(i);
}

路径压缩

上面的合并会产生一个问题,那就是在合并的过程中,整个树会越来越高,导致的后果是查找的时候效率会降低。如何解决这个问题呢?我们想要的效果是尽快找到一个元素所属的集合,也就是说树的高度越低越好。最理想的情况就是每个元素都指向其root,也就是说树的高度为2,其中root指向自己而其他元素直接指向root。

路径压缩就是在查找的过程中减小树高,尽量让元素都指向root。将查找中的两个函数简单修改一下即可实现路径压缩,代码如下:

递归写法

1
2
3
4
int Find(int i){
if (disjoint_sets[i] != i) disjoint_sets[i] = Find(disjoint_sets[i]);
return disjoint_sets[i];
}

迭代写法

1
2
3
4
5
6
7
int Find(int i){
while (disjoint_sets[i] != i){
disjoint_sets[i] = disjoint_sets[disjoint_sets[i]];
i = disjoint_sets[i];
}
return i;
}

这里有个点需要注意一下,递归写法保证了查找路径上所有的点的parent都被指向了root,但是迭代写法并没有,迭代写法是把当前节点的parent指向其parent的parent,也就是指向了祖父节点,这又被称为隔代压缩。

可见迭代写法的压缩程度是比不上递归写法的,但是考虑到递归的写法涉及到调用栈故性能开销是比迭代大的。当然,迭代写法也可以实现彻底压缩,那就是在找到root后重新遍历一遍把路径上所有的点都指向root。

按秩合并

上面的路径压缩在查找过程中压缩了树的高度,我们想一想在合并的过程中能不能也减少树高,假设元素e1属于树t1、元素e2属于树t2,现在要合并这两棵树,理想的做法是把高度较低的树合并到高度较高的树上,这样得到的新树高度较低。具体的做法就是先得到树t1t2的树高r1r2,假设r1<r2的话,那直接将树t1的root指向t2的root即可。除了把秩定义为树高外也有别的方法,比如说集合中点的个数。

到这里还我们缺少了一个计算树高的功能,一般可以开一个数组来存储每个元素所属树的高度从而实现动态更新,没有必要每次计算一遍,假设我们可以通过ranks[i]访问到树高(具体看下面的模板代码),下面先给出按秩合并的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
void Union(int i, int j){
int root_i = Find(i), root_j = Find(j);
if (root_i == root_j) return; // 特殊情况直接返回,否则下面ranks自增和count自减会有bug
if (ranks[root_i] > ranks[root_j]){
disjoint_sets[root_j] = root_i;
} else if (ranks[root_i] < ranks[root_j]){
disjoint_sets[root_i] = root_j;
} else{ // 相等的情况下怎么合并都可以,但是注意ranks存储的树高要加1
disjoint_sets[root_i] = root_j;
ranks[root_j]++; // 这里很容易写错,注意下是哪个树的树高加1
}
count--;
}

模板🧱

路径压缩版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Union Find Algorithm, [path compression] version
// Date: 2022.06.29
// Author: L. Bao
// Email: baoliay2008 [AT] gmail.com
// Website: lbao.site
class DisjointSets {
private:
vector<int> disjoint_sets;
int count;
public:
explicit DisjointSets(int n) {
count = n;
Initialize(n);
}

void Initialize(int n) {
disjoint_sets = vector<int>(n);
for (int i = 0; i < n; i++) disjoint_sets[i] = i;
}

int Find(int i) {
if (disjoint_sets[i] != i) disjoint_sets[i] = Find(disjoint_sets[i]);
return disjoint_sets[i];
}

void Union(int i, int j) {
int root_i = Find(i), root_j = Find(j);
if (root_i == root_j) return; // 特殊情况直接返回,否则下面count自减的时候会有bug
disjoint_sets[root_i] = root_j;
count--;
}

bool isConnected(int i, int j) {
return Find(i) == Find(j);
}

int getCount() const {
return count;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DisjointSets:
"""
Union Find Algorithm, [path compression] version
Date: 2022.06.29
Author: L. Bao
Email: baoliay2008 [AT] gmail.com
Website: lbao.site
"""
def __init__(self, n: int) -> None:
self._disjoint_sets = list(range(n))
self._count = n

def find(self, pos: int) -> int:
if (parent := self._disjoint_sets[pos]) != pos:
self._disjoint_sets[pos] = self.find(parent)
return self._disjoint_sets[pos]

def union(self, i: int, j: int) -> None:
if (root_i := self.find(i)) == (root_j := self.find(j)):
return
self._disjoint_sets[root_i] = root_j
self._count -= 1

def is_connected(self, i: int, j: int) -> bool:
return self.find(i) == self.find(j)

def get_count(self) -> int:
return self._count

1
待补充

按秩合并版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Union Find Algorithm, [union by rank] version
// Date: 2022.06.29
// Author: L. Bao
// Email: baoliay2008 [AT] gmail.com
// Website: lbao.site
class DisjointSetsWithRanks {
private:
vector<int> disjoint_sets, ranks;
int count;
public:
explicit DisjointSetsWithRanks(int n) {
count = n;
Initialize(n);
}

void Initialize(int n) {
disjoint_sets = vector<int>(n);
ranks = vector<int>(n, 1);
for (int i = 0; i < n; i++) disjoint_sets[i] = i;
}

int Find(int i) {
// 按秩合并版本Find函数可以不路径压缩,压缩过程中树高不好维护。
// 非要压缩也可以,那注意ranks秩的含义变了,不再是准确的树高,而是树高的上界。
while (disjoint_sets[i] != i) {
i = disjoint_sets[i];
}
return i;
}

void Union(int i, int j) {
int root_i = Find(i), root_j = Find(j);
if (root_i == root_j) return; // 特殊情况直接返回,否则下面ranks自增和count自减会有bug
if (ranks[root_i] > ranks[root_j]) {
disjoint_sets[root_j] = root_i;
} else if (ranks[root_i] < ranks[root_j]) {
disjoint_sets[root_i] = root_j;
} else { // 相等的情况下怎么合并都可以,但是注意ranks存储的树高要加1
disjoint_sets[root_i] = root_j;
ranks[root_j]++; // 这里很容易写错,注意下是哪个树的树高加1
}
count--;
}

bool isConnected(int i, int j) {
return Find(i) == Find(j);
}

int getCount() const {
return count;
}
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class DisjointSetsWithRanks:
"""
Union Find Algorithm, [union by rank] version
Date: 2022.06.29
Author: L. Bao
Email: baoliay2008 [AT] gmail.com
Website: lbao.site
"""
def __init__(self, n: int) -> None:
self._disjoint_sets = list(range(n))
self._rank = [1] * n
self._count = n

def find(self, pos: int) -> int:
while (parent := self._disjoint_sets[pos]) != pos:
pos = parent
return parent

def union(self, i: int, j: int) -> None:
if (root_i := self.find(i)) == (root_j := self.find(j)):
return
if self._rank[root_i] > self._rank[root_j]:
self._disjoint_sets[root_j] = root_i
elif self._rank[root_i] < self._rank[root_j]:
self._disjoint_sets[root_i] = root_j
else:
self._disjoint_sets[root_i] = root_j
self._rank[root_j] += 1
self._count -= 1

def is_connected(self, i: int, j: int) -> bool:
return self.find(i) == self.find(j)

def get_count(self) -> int:
return self._count

1
待补充

按大小合并版本

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// Union Find Algorithm, [union by size] version
// Date: 2022.06.29
// Author: L. Bao
// Email: baoliay2008 [AT] gmail.com
// Website: lbao.site
class DisjointSetsWithSizes {
private:
vector<int> disjoint_sets, sizes;
int count;
public:
explicit DisjointSetsWithSizes(int n) {
count = n;
Initialize(n);
}

void Initialize(int n) {
disjoint_sets = vector<int>(n);
sizes = vector<int>(n, 1);
for (int i = 0; i < n; i++) disjoint_sets[i] = i;
}

int Find(int i) {
// 按秩合并版本Find函数可以不路径压缩
while (disjoint_sets[i] != i) {
i = disjoint_sets[i];
}
return i;
}

void Union(int i, int j) {
int root_i = Find(i), root_j = Find(j);
if (root_i == root_j) return; // 特殊情况直接返回
if (sizes[root_i] > sizes[root_j]) {
disjoint_sets[root_j] = root_i;
sizes[root_i] += sizes[root_j];
} else{ // 这里包含了相等的情况,怎么合并都可以
disjoint_sets[root_i] = root_j;
sizes[root_j] += sizes[root_i];
}
count--;
}

bool isConnected(int i, int j) {
return Find(i) == Find(j);
}

int getCount() const {
return count;
}

int getSize(int i) {
return sizes[Find(i)];
}
};
1
待补充
1
待补充

例题🌰

并查集的题太多了,官方题集请点击🔗LeetCode并查集,我这里就先以背景里面说的上周赛第二题为例子:

2316. Count Unreachable Pairs of Nodes in an Undirected Graph

题目问的是无法互相到达的点对数目,可以先求每个集合的大小,每个集合内部的点对都是可以相互到达的,总对数减去互达对数就是最后答案。所这个题转化成了求每棵树的大小,用一个哈希表遍历一下对不同root统计数量即可,很简单啦,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 省略并查集模板代码

class Solution {
public:
long long countPairs(int n, vector<vector<int>>& edges) {
// 初始化并查集,用按秩合并的模板也是可以的,这里换成 DisjointSetsWithRanks ds(n); 就行
DisjointSets ds(n);
// 哈希表计数
unordered_map<int, int> mp;
// 所有可能的点对数目
long long ans = (long long)n * (n - 1) >> 1;
// 合并集合
for (auto & edge: edges) ds.Union(edge[0], edge[1]);
// 每个集合的大小计数
for (int i = 0; i < n; i++) mp[ds.Find(i)]++;
// 减去每个集合内部可以互相到达的点对数目
for (auto &[k, v] : mp) ans -= (long long)v * (v - 1) >> 1 ;
return ans;
}
};
1
2
3
4
5
6
7
8
9
10
11
class Solution:
def countPairs(self, n: int, edges: List[List[int]]) -> int:
total = n * (n - 1) // 2
# 也可以用按秩合并版本的模板 ds = DisjointSetsWithRanks(n)
ds = DisjointSets(n)
for x, y in edges:
ds.union(x, y)
cnt = Counter(ds.find(i) for i in range(n))
for k1, v1 in cnt.items():
total -= v1 * (v1 - 1) // 2
return total
1
待补充

注意到,遇到求并查集大小的需求时,用按大小合并的方法更方便一点,因为sizes已经求出来了,下面用第三个模板实现本题:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
// 省略并查集模板代码

class Solution {
public:
long long countPairs(int n, vector<vector<int>>& edges) {
// 初始化并查集
DisjointSetsWithSizes ds(n);
// 所有可能的点对数目
long long ans = (long long)n * (n - 1) >> 1;
// 合并集合
for (auto & edge: edges) ds.Union(edge[0], edge[1]);
// 直接取根元素的size
for (int i = 0; i < n; i++)
if (ds.Find(i) == i) {
int v = ds.getSize(i);
ans -= (long long)v * (v - 1) >> 1 ;
}
return ans;
}
};
1
待补充
1
待补充