背景🎣 上周六晚参加的LeetCode双周赛81场 第二题2316. Count Unreachable Pairs of Nodes in an Undirected Graph 提交了三次才过😓,主要手写查找函数的时候没有使用路径压缩(或者合并函数里面按秩合并也可以过的),看来自己对并查集还不熟悉,用得极少,特地写一篇博客梳理一下此算法。
原理📖 数学上的并查集(Disjoint sets )可以简单理解为若干个不相交的集合。而数据结构上的并查集(Disjoint-set data structure )一般是用森林来表示集合关系,每棵树是一个独立的集合,每个集合的标志就是树的root。底层存储可以用数组,其中每个元素的值指向其parent所在的位置,特别地,root的parent就是它自己 。
并查集主要实现了查找Find
(找出某个元素属于哪个集合,即返回所在树的root)和合并Union
(合并两个集合,即两棵树合并成一棵树)这两个基础操作,在此之上可以实现很多功能,如:
实现💻 底层数据结构 用数组disjoint_sets
来存储元素,i
位置的元素ei=disjoint_sets[i]
表示其parent的位置。
以上面说到的LeetCode 2316题的示例二做个解释:
对应的 disjoint_sets
可以为[0, 1, 0, 3, 0, 0, 1]
,其中0、2、4、5元素属于以0为root的集合,3元素属于以3为root的集合,1、6属于以1为root的集合。
初始化 新建数组的时候,直接把i位置的值设为i,表示每个元素的parent都是它自己,整个并查集包含n个互不相交的集合,整个森林有n棵只包含root的树。
1 2 3 void Initialize (int n) { for (int i = 0 ; i < n; i++) disjoint_sets[i] = i; }
查找 对于一个i
位置的元素ei
,如果disjoint_sets[i] != i
则说明它不是root,那么继续向上找它的parent,也就是disjoint_sets[i]
位置的元素,直到找到为止。这里有递归和迭代两种写法。
递归写法
1 2 3 4 int Find (int i) { if (disjoint_sets[i] != i) return Find (disjoint_sets[i]); else return i; }
迭代写法
1 2 3 4 5 6 int Find (int i) { while (disjoint_sets[i] != i){ i = disjoint_sets[i]; } return i; }
合并 对于两个不属于同一集合的元素ei
和ej
,要将它们合并成一个集合。以把ej
所属集合并进ei
所属集合为例,只需要找到它们的root,再把ej
的root指向ei
的root即可。
1 2 3 void Union (int i, int j) { disjoint_sets[Find (j)] = Find (i); }
路径压缩 上面的合并会产生一个问题,那就是在合并的过程中,整个树会越来越高,导致的后果是查找的时候效率会降低。如何解决这个问题呢?我们想要的效果是尽快找到一个元素所属的集合,也就是说树的高度越低越好。最理想的情况就是每个元素都指向其root ,也就是说树的高度为2,其中root指向自己而其他元素直接指向root。
路径压缩就是在查找的过程中减小树高,尽量让元素都指向root。将查找中的两个函数简单修改一下即可实现路径压缩,代码如下:
递归写法
1 2 3 4 int Find (int i) { if (disjoint_sets[i] != i) disjoint_sets[i] = Find (disjoint_sets[i]); return disjoint_sets[i]; }
迭代写法
1 2 3 4 5 6 7 int Find (int i) { while (disjoint_sets[i] != i){ disjoint_sets[i] = disjoint_sets[disjoint_sets[i]]; i = disjoint_sets[i]; } return i; }
这里有个点需要注意一下,递归写法保证了查找路径上所有的点的parent都被指向了root,但是迭代写法并没有,迭代写法是把当前节点的parent指向其parent的parent,也就是指向了祖父节点,这又被称为隔代压缩。
可见迭代写法的压缩程度是比不上递归写法的,但是考虑到递归的写法涉及到调用栈故性能开销是比迭代大的。当然,迭代写法也可以实现彻底压缩,那就是在找到root后重新遍历一遍把路径上所有的点都指向root。
按秩合并 上面的路径压缩在查找 过程中压缩了树的高度,我们想一想在合并 的过程中能不能也减少树高,假设元素e1
属于树t1
、元素e2
属于树t2
,现在要合并这两棵树,理想的做法是把高度较低的树合并到高度较高的树上,这样得到的新树高度较低。具体的做法就是先得到树t1
、t2
的树高r1
、r2
,假设r1<r2
的话,那直接将树t1
的root指向t2
的root即可。除了把秩定义为树高外也有别的方法,比如说集合中点的个数。
到这里还我们缺少了一个计算树高的功能,一般可以开一个数组来存储每个元素所属树的高度从而实现动态更新,没有必要每次计算一遍,假设我们可以通过ranks[i]访问到树高(具体看下面的模板代码),下面先给出按秩合并的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 void Union (int i, int j) { int root_i = Find (i), root_j = Find (j); if (root_i == root_j) return ; if (ranks[root_i] > ranks[root_j]){ disjoint_sets[root_j] = root_i; } else if (ranks[root_i] < ranks[root_j]){ disjoint_sets[root_i] = root_j; } else { disjoint_sets[root_i] = root_j; ranks[root_j]++; } count--; }
模板🧱 路径压缩版本 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 class DisjointSets {private : vector<int > disjoint_sets; int count; public : explicit DisjointSets (int n) { count = n; Initialize (n); } void Initialize (int n) { disjoint_sets = vector <int >(n); for (int i = 0 ; i < n; i++) disjoint_sets[i] = i; } int Find (int i) { if (disjoint_sets[i] != i) disjoint_sets[i] = Find (disjoint_sets[i]); return disjoint_sets[i]; } void Union (int i, int j) { int root_i = Find (i), root_j = Find (j); if (root_i == root_j) return ; disjoint_sets[root_i] = root_j; count--; } bool isConnected (int i, int j) { return Find (i) == Find (j); } int getCount () const { return count; } };
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 class DisjointSets : """ Union Find Algorithm, [path compression] version Date: 2022.06.29 Author: L. Bao Email: baoliay2008 [AT] gmail.com Website: lbao.site """ def __init__ (self, n: int ) -> None : self._disjoint_sets = list (range (n)) self._count = n def find (self, pos: int ) -> int : if (parent := self._disjoint_sets[pos]) != pos: self._disjoint_sets[pos] = self.find(parent) return self._disjoint_sets[pos] def union (self, i: int , j: int ) -> None : if (root_i := self.find(i)) == (root_j := self.find(j)): return self._disjoint_sets[root_i] = root_j self._count -= 1 def is_connected (self, i: int , j: int ) -> bool : return self.find(i) == self.find(j) def get_count (self ) -> int : return self._count
按秩合并版本 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 class DisjointSetsWithRanks {private : vector<int > disjoint_sets, ranks; int count; public : explicit DisjointSetsWithRanks (int n) { count = n; Initialize (n); } void Initialize (int n) { disjoint_sets = vector <int >(n); ranks = vector <int >(n, 1 ); for (int i = 0 ; i < n; i++) disjoint_sets[i] = i; } int Find (int i) { while (disjoint_sets[i] != i) { i = disjoint_sets[i]; } return i; } void Union (int i, int j) { int root_i = Find (i), root_j = Find (j); if (root_i == root_j) return ; if (ranks[root_i] > ranks[root_j]) { disjoint_sets[root_j] = root_i; } else if (ranks[root_i] < ranks[root_j]) { disjoint_sets[root_i] = root_j; } else { disjoint_sets[root_i] = root_j; ranks[root_j]++; } count--; } bool isConnected (int i, int j) { return Find (i) == Find (j); } int getCount () const { return count; } };
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 class DisjointSetsWithRanks : """ Union Find Algorithm, [union by rank] version Date: 2022.06.29 Author: L. Bao Email: baoliay2008 [AT] gmail.com Website: lbao.site """ def __init__ (self, n: int ) -> None : self._disjoint_sets = list (range (n)) self._rank = [1 ] * n self._count = n def find (self, pos: int ) -> int : while (parent := self._disjoint_sets[pos]) != pos: pos = parent return parent def union (self, i: int , j: int ) -> None : if (root_i := self.find(i)) == (root_j := self.find(j)): return if self._rank[root_i] > self._rank[root_j]: self._disjoint_sets[root_j] = root_i elif self._rank[root_i] < self._rank[root_j]: self._disjoint_sets[root_i] = root_j else : self._disjoint_sets[root_i] = root_j self._rank[root_j] += 1 self._count -= 1 def is_connected (self, i: int , j: int ) -> bool : return self.find(i) == self.find(j) def get_count (self ) -> int : return self._count
按大小合并版本 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 class DisjointSetsWithSizes {private : vector<int > disjoint_sets, sizes; int count; public : explicit DisjointSetsWithSizes (int n) { count = n; Initialize (n); } void Initialize (int n) { disjoint_sets = vector <int >(n); sizes = vector <int >(n, 1 ); for (int i = 0 ; i < n; i++) disjoint_sets[i] = i; } int Find (int i) { while (disjoint_sets[i] != i) { i = disjoint_sets[i]; } return i; } void Union (int i, int j) { int root_i = Find (i), root_j = Find (j); if (root_i == root_j) return ; if (sizes[root_i] > sizes[root_j]) { disjoint_sets[root_j] = root_i; sizes[root_i] += sizes[root_j]; } else { disjoint_sets[root_i] = root_j; sizes[root_j] += sizes[root_i]; } count--; } bool isConnected (int i, int j) { return Find (i) == Find (j); } int getCount () const { return count; } int getSize (int i) { return sizes[Find (i)]; } };
例题🌰 并查集的题太多了,官方题集请点击🔗LeetCode并查集 ,我这里就先以背景里面说的上周赛第二题为例子:
题目问的是无法互相到达的点对数目,可以先求每个集合的大小,每个集合内部的点对都是可以相互到达的,总对数减去互达对数就是最后答案。所这个题转化成了求每棵树的大小,用一个哈希表遍历一下对不同root统计数量即可,很简单啦,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class Solution {public : long long countPairs (int n, vector<vector<int >>& edges) { DisjointSets ds (n) ; unordered_map<int , int > mp; long long ans = (long long )n * (n - 1 ) >> 1 ; for (auto & edge: edges) ds.Union (edge[0 ], edge[1 ]); for (int i = 0 ; i < n; i++) mp[ds.Find (i)]++; for (auto &[k, v] : mp) ans -= (long long )v * (v - 1 ) >> 1 ; return ans; } };
1 2 3 4 5 6 7 8 9 10 11 class Solution : def countPairs (self, n: int , edges: List [List [int ]] ) -> int : total = n * (n - 1 ) // 2 ds = DisjointSets(n) for x, y in edges: ds.union(x, y) cnt = Counter(ds.find(i) for i in range (n)) for k1, v1 in cnt.items(): total -= v1 * (v1 - 1 ) // 2 return total
注意到,遇到求并查集大小的需求时,用按大小合并的方法更方便一点,因为sizes
已经求出来了,下面用第三个模板实现本题:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 class Solution {public : long long countPairs (int n, vector<vector<int >>& edges) { DisjointSetsWithSizes ds (n) ; long long ans = (long long )n * (n - 1 ) >> 1 ; for (auto & edge: edges) ds.Union (edge[0 ], edge[1 ]); for (int i = 0 ; i < n; i++) if (ds.Find (i) == i) { int v = ds.getSize (i); ans -= (long long )v * (v - 1 ) >> 1 ; } return ans; } };