UOJ 261

题解

如果直接考虑每条路径,显然有 $O(nm)$ 的做法。
这样的复杂度太高,考虑每个点得到的贡献。
记 $r_i = lca(s_i, t_i)$,路径 $i$ 对其上升部分(即链 $s_i \rightarrow r_i$)内的点 $u$ 贡献的答案为
$$[depth(s_i) - depth(u) = w_u]$$
而当 $u$ 在路径 $i$ 的下降部分(即链 $r_i \rightarrow t_i$)上时,该贡献为
$$[depth(u) - depth(r_i) + (depth(s_i) - depth[r_i]) = w_u]$$
考虑到产生贡献时 $w_u$ 恰为路径 $s_i \rightarrow u$ 上的边数,上两式是显然的。
注意到前一个式子移项可得
$$[depth(u) + w_u = depth(s_i)]$$
此式左边是关于点 $u$ 的函数,所以容易通过计算点 $u$ 的该函数值,判断任何路径(只考虑上升部分时)对该点 $u$ 是否有贡献。暴力计算仍然是 $O(nm)$ 的,但可以注意到对于路径 $i$,只需要考虑路径 $s_i \rightarrow t_i$ 上的点是否得到贡献。则问题容易转化为 $m$ 次在路径 $s_i \rightarrow t_i$ 上的所有点处插入一个数 $depth(s_i)$ 的操作,最后对每个 $u$ 点查询该点处 $depth(u) + w_u$ 出现的次数。对于路径 $i$ 的下降部分上的点,有类似的做法,下文只讨论上升部分点的贡献。
这个问题可以通过树上差分解决。有一个树上差分常用的做法:
考虑一种树上操作:对于链 $u \rightarrow v$(设 $u$ 是 $v$ 的祖先)上的每个点 $i$,令 $a[i] += x$ (初始时所有 $a[i] = 0$)。可以维护一个差分数组 $c[i]$ (初始全为 $0$),每次操作时令 $c[v] += x, c[fa(u)] -= x$($u$ 为树根时只需进行前者),则任意次操作后,以 $i$ 为根的子树内差分数组值的和即表示 $a[i]$ 的值,即
$$a[i] = \sum_{u \in subtree(i)} c[u]$$
对于这个问题,考虑对每条路径在树上维护差分标记。路径 $i$ 可以分为上升部分 $[s_i, r_i]$ 与 下降部分 $(r_i, t_i]$,对于上升部分,可以在 $s_i$ 处添加标记 $(depth(s_i), 1)$(表示加入 $1$ 个 $depth(s_i)$),在 $fa(r_i)$ 处添加标记 $(depth(s_i), -1)$(若 $r_i$ 为树根则不添加),下降部分的标记类似。差分后对每个点计算一次答案,总的复杂度为 $O(n^2)$。这个过程中有大量的重复计算,考虑用树的 DFS 序列优化计算。
每个子树都在树的 DFS 序列里对应一段连续的区间,记以 $u$ 为根的子树对于 DFS 序列里的区间 $[begin(u), end(u)]$ 。扫描一遍 DFS 序列并处理每个位置处的标记。维护数组 $cnt[]$,对于标记 $(v, d)$ 令 $cnt[v] += d$,即维护目前扫描到的部分的差分标记的前缀和。受空间限制,不能存储每个位置的前缀和。因此可以对每个位置 $i$ 分别记录令 $begin(u) = i$ 和令 $end(u) = i$ 的子树 $u$,用扫描到 $begin(u)$ 和扫描到 $end(u)$ 时刻的 $cnt(depth(u) + w_u)$ 作差即可得到子树内 $depth(u) + w_u$ 差分标记的和,即 $depth(u) + w_u$ 的出现次数。下降部分做法类似。扫描时注意应先记录 $begin(u)$ 时刻的 $cnt$ 值再用标记更新 $cnt$,从而保证 $begin(u)$ 位置的标记值被算入区间和。如果先更新再记录,也可以用 $begin(u) - 1$ 时刻的 $cnt$ 来作差。
使用 Tarjan 算法求 $r_i$ 并进行 DFS 的复杂度是 $O(n + m)$,扫描时最多处理 $O(m)$ 个标记,总的复杂度是 $O(n + m)$。本题也可以树链剖分并用类似的一维差分解决。  

代码

式子符号推错,调了一晚上 + 一早上。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include <cstdio>
#include <vector>
const int MAXV = 3e5 + 10, MAXE = 6e5 + 10, MAXM = 3e5 + 5;
int n, adj[MAXV], w[MAXV], to[MAXE], next[MAXE];
inline void add_edge(int u, int v, int e) {
to[e] = v;
next[e] = adj[u];
adj[u] = e;
}
int pa[MAXV];
inline int find(int x) {
return pa[x] == x ? x : pa[x] = find(pa[x]);
}
int fa[MAXV], depth[MAXV], begin[MAXV], stamp = 0, ans1[MAXV], ans2[MAXV];
std::vector<int> as_begin_of_node[MAXV], as_end_of_node[MAXV];
struct Mark {
int v, d;
};
std::vector<Mark> mark1[MAXV], mark2[MAXV];
struct Query {
int v, id;
};
int s[MAXM], t[MAXM], lca[MAXM];
std::vector<Query> query[MAXV];
inline void dfs(int u, int deep) {
as_begin_of_node[++stamp].push_back(u);
begin[u] = stamp;
depth[u] = deep;
pa[u] = u;
for (Query *q = &query[u].front(); q && q <= &query[u].back(); ++q)
if (fa[q -> v] && !lca[q -> id])
lca[q -> id] = find(q -> v);
for (int e = adj[u], v; e; e = next[e]) {
v = to[e];
if (!fa[v]) {
fa[v] = u;
dfs(v, deep + 1);
pa[v] = u;
}
}
as_end_of_node[stamp].push_back(u);
// end[u] = stamp;
}
char ch;
inline void read(int &res) {
res = 0;
ch = 0;
while (ch < '0' || ch > '9')
ch = getchar();
while (ch >= '0' && ch <= '9')
res = res * 10 + ch - '0', ch = getchar();
}
int cnt1[MAXV * 4], cnt2[MAXV * 4];
int main() {
int i, m, u, v;
read(n), read(m);
for (i = 1; i < n; ++i) {
read(u), read(v);
add_edge(u, v, i);
add_edge(v, u, n + i - 1);
}
for (i = 1; i <= n; ++i)
read(w[i]);
for (i = 1; i <= m; ++i) {
read(s[i]), read(t[i]);
query[s[i]].push_back(Query{t[i], i});
query[t[i]].push_back(Query{s[i], i});
}
fa[1] = -1;
dfs(1, 0);
for (i = 1; i <= m; ++i) {
mark1[begin[s[i]]].push_back((Mark){depth[s[i]], 1});
if (fa[lca[i]] != -1)
mark1[begin[fa[lca[i]]]].push_back((Mark){depth[s[i]], -1});
if (t[i] != lca[i]) {
mark2[begin[t[i]]].push_back((Mark){2 * depth[lca[i]] - depth[s[i]], 1});
mark2[begin[lca[i]]].push_back((Mark){2 * depth[lca[i]] - depth[s[i]], -1});
}
}
int *pu;
Mark *pm;
for (i = 1; i <= n; ++i) {
for (pu = &as_begin_of_node[i].front(); pu && pu <= &as_begin_of_node[i].back(); ++pu) {
u = *pu;
ans1[u] = cnt1[depth[u] + w[u] + n];
ans2[u] = cnt2[depth[u] - w[u] + n];
// w[j] = depth[s_i] - depth[j] (1)
// w[j] - (depth[s_i] - depth[lca_i]) = depth[j] - depth[lca_i] (2)
}
for (pm = &mark1[i].front(); pm && pm <= &mark1[i].back(); ++pm) {
cnt1[pm -> v + n] += pm -> d;
}
for (pm = &mark2[i].front(); pm && pm <= &mark2[i].back(); ++pm) {
cnt2[pm -> v + n] += pm -> d;
}
for (pu = &as_end_of_node[i].front(); pu && pu <= &as_end_of_node[i].back(); ++pu) {
u = *pu;
ans1[u] = cnt1[depth[u] + w[u] + n] - ans1[u];
ans2[u] = cnt2[depth[u] - w[u] + n] - ans2[u];
}
}
for (i = 1; i <= n; ++i)
printf("%d ", ans1[i] + ans2[i]);
printf("\n");
return 0;
}