学习笔记 DSU on tree(树上启发式合并)

树上启发式合并(DSU on Tree)用于解决一类离线的、询问子树信息的问题,它拥有 $\mathcal{O} (n \log n)$ 的优秀时间复杂度,并且常数小,实现难度低。

本文将结合例题介绍 DSU on tree 的基本实现方法。

CF600E Lomsat gelral

Description

题意:给定一棵大小为 $n(n \le 10 ^ 5)$ 的树,每个节点有一个颜色,求每个子树中出现次数最多的颜色编号之和。

Solution

如果 $n \le 10 ^ 3$,那我们怎么做?当然是暴力啦

考虑暴力,我们对于每个点,遍历其子树并统计颜色,这样的复杂度是 $\mathcal{O} (n ^ 2)$ 的,瓶颈在于每次进入不同子树时都要重新统计,导致很多信息重复计算。当然我们可以使用线段树合并将其优化到 $\mathcal{O} (n \log n)$,不过线段树写起来比较麻烦,且空间复杂度也比较大。而 DSU on tree 可以在同样的时间复杂度下以 $O(n)$ 的空间复杂度解决这个问题。

前面说了,暴力的瓶颈在于信息的重复统计,如果我们能将子节点的一部分信息带到父节点那就很好了。不难发现,对于一个点最后一个进入的子节点,其信息我们可以保留;换句话说,这最后一个子节点重新统计的时间我们可以省下。我们要做的当然是取统计起来最耗时间的子节点作为最后一个点统计。这是什么?重儿子,也就是 $\text{size}$ 最大的儿子。

这样做的时间复杂度是多少?可以证明是 $\mathcal {O} (n \log n)$,因为没有看过详细严谨的证明,这里只提出一个口胡证法放在文末。

有了复杂度的保证后,我们便可以用这样一个 “优雅的暴力” 通过此题,这就是 DSU on tree。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e5;
int n, c[N + 5], f[N + 5], siz[N + 5], son[N + 5], cnt[N + 5];
long long maxcnt, sum, ans[N + 5];
bool vis[N + 5];
std::vector <int> E[N + 5];

void Dfs1(int u, int fa) {
f[u] = fa, siz[u] = 1;
for (const int &v : E[u]) {
if (v == fa) continue;
Dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}

void Update(int u, int val) {
cnt[c[u]] += val;
if (val > 0 && cnt[c[u]] >= maxcnt) {
if (cnt[c[u]] > maxcnt) maxcnt = cnt[c[u]], sum = c[u];
else sum += c[u];
}
for (const int &v : E[u])
if (v != f[u] && !vis[v])
Update(v, val);
}

void Dfs2(int u, bool heavy) {
for (const int &v : E[u])
if (v != f[u] && v != son[u])
Dfs2(v, false);
if (son[u] != 0) {
Dfs2(son[u], true);
vis[son[u]] = true;
}
Update(u, 1);
ans[u] = sum;
if (son[u] != 0) vis[son[u]] = false;
if (heavy == false) Update(u, -1), maxcnt = sum = 0;
}

int main() {
scanf("%d", &n);
for (int i = 1; i <= n; ++i)
scanf("%d", &c[i]);
for (int i = 1, u, v; i < n; ++i) {
scanf("%d%d", &u, &v);
E[u].push_back(v);
E[v].push_back(u);
}
Dfs1(1, 0);
Dfs2(1, false);
for (int i = 1; i <= n; ++i) printf("%lld%c", ans[i], " \n"[i == n]);
return 0;
}

CF208E Blood Cousins

Description

题意:给你一片森林,每次询问一个点与多少个点拥有共同的 K 级祖先

Solution

这题没有强制在线,我们可以将询问离线,用类似上一题的办法解决。由于一个点的 $K$ 级祖先是一定的,我们将原问题转化为:询问一个点的子树中有多少个点的深度为 $d + K$,其中 $d$ 为 $K$ 级祖先的深度。

这不就是裸题了吗?对于每个深度开一个桶统计即可。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>

const int N = 1e5;
int n, m, f[N + 5][22], son[N + 5], sz[N + 5], ans[N + 5], cnt[N + 5], dep[N + 5];
bool vis[N + 5];
std::vector<int> rt, E[N + 5];
std::vector<std::pair<int, int>> Q[N + 5];

void Dfs1(int u) {
sz[u] = 1;
dep[u] = dep[f[u][0]] + 1;
for (int i = 1; i <= 20; ++i)
f[u][i] = f[f[u][i - 1]][i - 1];
for (const int &v : E[u]) {
Dfs1(v);
sz[u] += sz[v];
if (sz[v] > sz[son[u]]) son[u] = v;
}
}

void Update(int u, int val) {
cnt[dep[u]] += val;
for (const int &v : E[u])
if (!vis[v])
Update(v, val);
}

void Dfs2(int u, bool heavy) {
for (const int &v : E[u])
if (v != son[u])
Dfs2(v, false);
if (son[u] != 0) {
Dfs2(son[u], true);
vis[son[u]] = true;
}
Update(u, 1);
for (const auto &p : Q[u])
ans[p.second] = cnt[dep[u] + p.first] - 1;
vis[son[u]] = false;
if (!heavy)
Update(u, -1);
}

int main() {
scanf("%d", &n, &m);
for (int i = 1; i <= n; ++i) {
scanf("%d", &f[i][0]);
if (!f[i][0]) rt.push_back(i);
else E[f[i][0]].push_back(i);
}
for (const int &u : rt)
Dfs1(u);
scanf("%d", &m);
for (int i = 1, u, k; i <= m; ++i) {
scanf("%d%d", &u, &k);
for (int j = 20; j >= 0; --j)
if ((k >> j) & 1)
u = f[u][j];
Q[u].push_back(std::make_pair(k, i));
}
for (const int &u : rt)
Dfs2(u, false);
for (int i = 1; i <= m; ++i)
printf("%d%c", ans[i], " \n"[i == m]);
return 0;
}

CF741D Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths

Description

一棵根为1 的树,每条边上有一个字符(a-v共22种)。 一条简单路径被称为 syk 的当且仅当路径上的字符经过重排后可以变成一个回文串。 求每个子树中最长的 syk 路径的长度。

Solution

不难发现,一个字符串能在重排后变成回文串,当且仅当其中出现奇数次的字符个数不大于 $1$,若我们用二进制整数表示字符集中每个字符出现次数的奇偶性(奇数为 $1$,偶数为 $0$),那么一个状态是合法的吗,当且仅当其为 $0$ 或 $2$ 的自然数次幂。

考虑 $\mathcal{O} (n ^ 2)$ 暴力:对于每个点 $u$。我们钦定路径一定经过 $u$,那么枚举子节点,对于 $22$ 种可能状态,用合法的路径的最大长度更新 $u$ 的答案。

上述做法使用 DSU on tree 可优化到 $\mathcal{O}(22 \times n \log n)$。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#include <vector>
#include <cstdio>
#include <cstring>
#include <algorithm>

inline void Splay(int &x) {
x = 0; char c = getchar();
while (c < 48 || c > 57) c = getchar();
while (c >= 48 && c <= 57) x = x * 10 + (c ^ 48), c = getchar();
}

const int N = 5e5, S = 22;
int n, f[N + 5], fa[N + 5], dep[N + 5], siz[N + 5], son[N + 5], ans[N + 5];
int Max[1 << S + 1];
std::vector<std::pair<int, int>> E[N + 5];

void Dfs(int u) {
siz[u] = 1;
dep[u] = dep[fa[u]] + 1;
for (const auto &p : E[u]) {
int v = p.first;
f[v] = f[u] ^ (1 << p.second);
Dfs(v);
siz[u] += siz[v];
if (siz[v] > siz[son[u]]) son[u] = v;
}
}

void Update(int u, int val) {
if (val) Max[f[u]] = std::max(Max[f[u]], dep[u]);
else Max[f[u]] = 0;
for (const auto &p : E[u])
Update(p.first, val);
}

int Calc(int u) {
int tmp = Max[f[u]] ? Max[f[u]] + dep[u] : 0;
for (int i = 0; i < 22; ++i)
if (Max[f[u] ^ (1 << i)]) tmp = std::max(tmp, Max[f[u] ^ (1 << i)] + dep[u]);
for (const auto &p : E[u])
tmp = std::max(tmp, Calc(p.first));
return tmp;
}

void Dfs(int u, bool heavy) {
for (const auto &p : E[u]) {
int v = p.first;
if (v == son[u]) continue;
Dfs(v, false);
ans[u] = std::max(ans[u], ans[v]);
}
if (son[u]) {
Dfs(son[u], true);
ans[u] = std::max(ans[u], ans[son[u]]);
}
for (const auto &p : E[u]) {
int v = p.first;
if (v == son[u]) continue;
ans[u] = std::max(ans[u], Calc(v) - 2 * dep[u]);
Update(v, 1);
}
if (Max[f[u]]) ans[u] = std::max(ans[u], Max[f[u]] - dep[u]);
for (int i = 0; i < 22; ++i)
if (Max[f[u] ^ (1 << i)]) ans[u] = std::max(ans[u], Max[f[u] ^ (1 << i)] - dep[u]);
Max[f[u]] = std::max(Max[f[u]], dep[u]);
if (!heavy) Update(u, 0);
}

int main() {
Splay(n);
for (int i = 2; i <= n; ++i) {
Splay(fa[i]);
char ch = getchar();
while (ch < 'a' || ch > 'v') ch = getchar();
E[fa[i]].push_back(std::make_pair(i, (int)ch - 'a'));
}
Dfs(1);
Dfs(1, false);
for (int i = 1; i <= n; ++i)
printf("%d%c", ans[i], " \n"[i == n]);
return 0;
}

Extended:复杂度证明

  • 我们给出轻重边的定义:一条边上深度较大的点若是重儿子,那么这条边是重边,否则为轻边。
  • 有一条重要性质:每个点到其根节点路径上的轻边数不超过 $\log n$ 条。这个的证明不难,因为每个轻儿子的子树大小不超过其父节点子树大小的一半,每条轻边都会使大小增加一倍,那么边数为 $\log n$ 级别。
  • 一个节点的子树在计算自身答案时会被遍历一次,若是轻儿子那么会在其父亲统计答案时遍历一次,因此每个点最多被遍历的次数同样是 $\log n$ 级别的,累加起来可以得出总复杂度为 $\mathcal{O} (n \log n)$。