树上点分治

点分治

树上点分治 其实就是把序列的分治方法移到了树上操作。序列上每个点的后继只有一个,树上可以有很多,我们找一个分支最多的出去,在不考虑常数的的情况下,这种分治方法是非常划算的。

静态分治

关于点分治的思想我不再赘述,很多博客已经讲得很清楚了。
我们实现上面代码,主要为下面几个函数。

寻找重心

在树上找重心的操作其实相当于在序列上找中点。

1
int mid = (l + r) >> 1;

而我们在树上是这样操作的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
void getroot(int u, int f)
{
size[u] = 1;
mxs[u] = 0;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
getroot(v, u);
size[u] += size[v];
mxs[u] = max(mxs[u], size[v]);
}
mxs[u] = max(mxs[u], tsize - size[u]);
if (mxs[u] < mxs[root])
root = u;
}

(其实就是树上跑个dp)

分治

得到重心之后,我们分治重心的每个分支,最后相当于pushup上来。
注意在这个函数里要进行一些找重心的初始化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
void solve(int u, int f)
{
vis[u] = 1;
ans += get(u, 0); //治
for (int it = head[u]; it != -1; it = edge[it].nxt) //分
{
int v = edge[it].v;
if (v == f || vis[v])
continue;
ans -= get(v, edge[it].d); //治
root = 0;
tsize = size[v];
getroot(v, 0);
solve(root, u);
}
}

计算每个分支的答案

这一部分就是我们随意发挥的地方了,我们可以在这里进行一些复杂度约为O(n)计算,达到和序列分治类似甚至更优的效果。

POJ1741

以此题为例,我们来搞一波树上点分治。
在大部分博客中,分治部分代码对小子树大小的计算方法我觉得有问题,于是后来自己在原来我的AC代码上做了一番修改,貌似还是有几十毫秒的改进的。
虽然原来方法复杂度也不是很高,但是在那样的基础上计算会影响我们对树上动态点分治的理解,因为那样每次不一定找到的是重心。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include<cstdio>
#include<algorithm>
#include<iostream>
using namespace std;
typedef long long ll;

const int N = 1e4 + 5;
struct node
{
int v, nxt, d;
}edge[2 * N];

int head[N], tot;
int size[N], root, mxs[N], vis[N], tsize;
int cnt[N], k;

void init(int n)
{
for (int i = 1; i <= n; i++)
head[i] = -1, vis[i] = 0;
root = 0;
tot = 0;
tsize = n;
mxs[0] = 0x3f3f3f3f;
}

void addedge(int u, int v, int d)
{
edge[++tot].v = v;
edge[tot].nxt = head[u];
edge[tot].d = d;
head[u] = tot;
}

void getroot(int u, int f)
{
size[u] = 1;
mxs[u] = 0;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
getroot(v, u);
size[u] += size[v];
mxs[u] = max(mxs[u], size[v]);
}
mxs[u] = max(mxs[u], tsize - size[u]);
if (mxs[u] < mxs[root])
root = u;
}

int cur = 0;
ll ans = 0;

void calc(int u, int f, int d)
{
cnt[++cur] = d;
size[u] = 1;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f || vis[v])
continue;
calc(v, u, d + edge[it].d);
size[u] += size[v];
}
}

int get(int u, int d)
{
cur = 0;
calc(u, 0, d);
sort(cnt + 1, cnt + cur + 1);
int head = 1, tail = cur;
int ret = 0;
while (head < tail)
{
if (cnt[head] + cnt[tail] <= k)
{
ret += tail - head;
head++;
}
else
tail--;
}
return ret;
}

void solve(int u, int f)
{
// dbg(u);
vis[u] = 1;
ans += get(u, 0);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f || vis[v])
continue;
ans -= get(v, edge[it].d);
root = 0;
tsize = size[v];
// puts("getroot");
// dbg(v, tsize);
getroot(v, 0);
solve(root, u);
}
}

template<class T>
void read(T& ret)
{
ret = 0;
char c;
while ((c = getchar()) > '9' || c < '0');
while (c >= '0' && c <= '9')
{
ret = ret * 10 + c - '0';
c = getchar();
}
}

int main()
{
int n;
while (true)
{
read(n);
read(k);
if (n == 0 && k == 0)
break;
init(n);
for (int i = 1; i < n; i++)
{
int u, v, d;
read(u);
read(v);
read(d);
addedge(u, v, d);
addedge(v, u, d);
}
ans = 0;
root = 0;
getroot(1, 0);
solve(root, 0);
printf("%lld\n", ans);
}
return 0;
}

动态点分治

有了前面的基础,我们来了解一下动态点分治。

动态在哪里

我们之所以将这种点分治和之前加以区分,是因为这一类题目会要求可以做一些修改,而显而易见我们不能每次暴力修改然后查询的时候用O(nlog(n))时间去查询,这是十分爆炸的。

解决方案

还是考虑分治的时候,如果我们要求修改某一个点,那么我们可以像线段树那样每次取一半,看看要修改的点在哪个区间,然后向上层push。
我们这样只去修改会受到影响的log(n)级别个区间,查询类似。
为了方便快捷地知道我们后面该往哪里走,先预处理出一棵点分树,在这棵树上是我们寻找重心的顺序,后面将会用到这棵树。

点分树

这棵树是由我们原本的树建出来的,但是树形又和原来不同。
点分树的父子关系,是由原树的分治顺序决定的,也就是说,我们寻找重心的时候的顺序,其实应该是点分树上的遍历顺序。既然这样,那我们的数其实也是比较好建立了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
inline void get_tree(int u, int f)
{
//dbg(u);
vis[u] = 1;
fa[u] = f;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v])
continue;
root = 0;
tsz = size[v];
get_root(v, u, 1);
get_tree(root, u);
}
}

使用点分树

我们要这样一棵树有什么用呢?其实理解树上点分治的话,就会觉得这棵树其实是很有必要的。我们要维护这棵树上的一些父子关系。
分治的过程,我们要知道当前这一部分要往哪个点合并,知道方向我们才好处理。为了不用每次都去get_root,我们预先把所有点父子关系存下来。
就像我们在归并的时候向上返回,总要做些处理,如果当前区间已经覆盖了要修改或查询的点,直接返回,否则分半去修改或查询。
使用方法需要随机应变的。

BZOJ3730 震波

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#ifndef ONLINE_JUDGE
#define dbg(x...) do{cout << "\033[33;1m" << #x << "->" ; err(x);} while (0)
void err(){cout << "\033[39;0m" << endl;}
template<template<typename...> class T, typename t, typename... A>
void err(T<t> a, A... x){for (auto v: a) cout << v << ' '; err(x...);}
template<typename T, typename... A>
void err(T a, A... x){cout << a << ' '; err(x...);}
#else
#define dbg(...)
#endif
#define inf 1ll << 50
#define lowbit(x) ((x)&-(x))
const int N = 1e5 + 5;
struct node
{
int v, nxt;
}edge[2 * N];
int tot, head[N], val[N], vis[N];
int dep[N];
inline void add_edge(int u, int v)
{
edge[++tot].v = v;
edge[tot].nxt = head[u];
head[u] = tot;
}

int root, tsz, mxs[N], size[N], fa[N], dis[N];
vector<ll> sum[2][N];

inline void init(int n)
{
for (int i = 1; i <= n; i++)
{
head[i] = -1;
vis[i] = 0;
sum[0][i].clear();
sum[1][i].clear();
}
root = 0;
mxs[0] = 0x3f3f3f3f;
tsz = n;
tot = 0;
}
/* //树链剖分求lcaT掉了,后来改了ST表
int Dep[N], son[N], Top[N], pa[N];
inline void dfs1(int u, int f, int d)
{
Dep[u] = d;
son[u] = -1;
size[u] = 1;
pa[u] = f;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f)
continue;
dfs1(v, u, d + 1);
size[u] += size[v];
if (son[u] == -1 || size[v] > size[u])
son[u] = v;
}
}

inline void dfs2(int u, int f, int t)
{
Top[u] = t;
if (son[u] != -1)
dfs2(son[u], u, t);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f || v == son[u])
continue;
dfs2(v, u, v);
}
}
*/
int st[2 * N], cur, Dep[N];
int first[N];
inline void dfs(int u, int f)
{
st[++cur] = u;
first[u] = cur;
Dep[u] = Dep[f] + 1;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (v == f)
continue;
dfs(v, u);
st[++cur] = u;
}
}
int dp[N * 2][20];
void rmq_pre()
{
for (int i = 1; i <= cur; i++)
dp[i][0] = st[i];
for (int j = 1; (1 << j) <= cur; j++)
for (int i = 1; i + (1 << j) <= cur + 1; i++)
{
int d1 = Dep[dp[i][j - 1]], d2 = Dep[dp[i + (1 << (j - 1))][j - 1]];
if (d1 < d2)
dp[i][j] = dp[i][j - 1];
else
dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
}
}

inline int rmq(int l, int r)
{
int k = 31 - __builtin_clz(r - l + 1);
int d1 = dp[l][k], d2 = dp[r - (1 << k) + 1][k];
if (d1 < d2)
return dp[l][k];
return dp[r - (1 << k) + 1][k];
}

inline int lca(int u, int v)
{
/*
while (Top[u] != Top[v])
{
if (Dep[Top[u]] < Dep[Top[v]])
swap(u, v);
u = pa[Top[u]];
}
if (Dep[u] < Dep[v])
return u;
else
return v;
*/
if (first[u] > first[v])
swap(u, v);
return rmq(first[u], first[v]);
}

inline int Dis(int u, int v)
{
int ca = lca(u, v);
return Dep[u] + Dep[v] - 2 * Dep[ca];
}

inline void get_root(int u, int f, int d)
{
// dbg(u, d);
size[u] = 1;
mxs[u] = 0;
dep[u] = d;
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
get_root(v, u, d + 1);
size[u] += size[v];
mxs[u] = max(mxs[u], size[v]);
}
mxs[u] = max(mxs[u], tsz - size[u]);
if (mxs[u] < mxs[root])
root = u;
}
/*
inline int lowbit(int x)
{
return x & (-x);
}
*/
inline void update(int bla, int id, int x, int val)
{
while (x < sum[bla][id].size())
{
sum[bla][id][x] += val;
x += lowbit(x);
}
}

inline ll query(int bla, int id, int x)
{
ll ans = 0;
if (x >= sum[bla][id].size())
x = sum[bla][id].size() - 1;
// dbg(id, x);
while (x > 0)
{
ans += sum[bla][id][x];
x -= lowbit(x);
}
return ans;
}

inline pair<int, int> get_dep(int u, int f, int d)
{
pair<int, int> ans = make_pair(d, dep[u]);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
pair<int, int> tmp = get_dep(v, u, d + 1);
ans.first = max(ans.first, tmp.first);
ans.second = max(ans.second, tmp.second);
}
return ans;
}

inline void calc(int u, int f, int top, int d)
{
size[u] = 1;
update(0, top, d, val[u]);
//dbg(top, d, val[u], u);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
calc(v, u, top, d + 1);
size[u] += size[v];
}
}

inline void calc2(int u, int f, int top)
{
update(1, top, dep[u], val[u]);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v] || v == f)
continue;
calc2(v, u, top);
}
}

inline void get_tree(int u, int f)
{
//dbg(u);
vis[u] = 1;
fa[u] = f;
pair<int, int> mxd = get_dep(u, 0, 0);
sum[0][u].resize(mxd.first + 1);
sum[1][u].resize(mxd.second + 1);
calc2(u, f, u);
// dbg(u, ans);
for (int it = head[u]; it != -1; it = edge[it].nxt)
{
int v = edge[it].v;
if (vis[v])
continue;
calc(v, u, u, 1);
//dbg(u, v, ans);
root = 0;
tsz = size[v];
get_root(v, u, 1);
get_tree(root, u);
}
// update(1, u, dep[u], val[u]);
}

inline ll s_que(int u, int k)
{
ll ans = val[u] + query(0, u, k);
//dbg(u, k, query(0, u, k), ans);
int tmp = u;
while (fa[tmp])
{
int d1 = Dis(fa[tmp], u);
if (k - d1 >= 0)
ans += query(0, fa[tmp], k - d1) - query(1, tmp, k - d1) + (k - d1 >= 0) * val[fa[tmp]];
//dbg(tmp, fa[tmp], d1, k - d1, k, query(0, fa[tmp], k - d1), query(1, tmp, k - d1), ans);
tmp = fa[tmp];
}
return ans;
}

inline void s_update(int u, int v)
{
ll del = v - val[u];
int tmp = u;
while (fa[tmp])
{
int dis = Dis(u, fa[tmp]);
update(0, fa[tmp], dis, del);
update(1, tmp, dis, del);
tmp = fa[tmp];
}
val[u] = v;
}

template<class T>
void read(T& ret)
{
ret = 0;
char c;
while ((c = getchar()) > '9' | c < '0');
while (c >= '0' && c <= '9')
{
ret = ret * 10 + c - '0';
c = getchar();
}
}
int main()
{
int n, m;
read(n), read(m);
init(n);
//puts("init over");
for (int i = 1; i <= n; i++)
read(val[i]);
for (int i = 1; i < n; i++)
{
int u, v;
read(u);
read(v);
add_edge(u, v);
add_edge(v, u);
}
/*
dfs1(1, 0, 1);
dfs2(1, 0, 1);
*/
Dep[0] = 0;
dfs(1, 0);
rmq_pre();
get_root(1, 0, 1);
get_tree(root, 0);
ll ans = 0;
while (m--)
{
int type;
read(type);
if (type == 0)
{
int u, k;
read(u);
read(k);
u ^= ans;
k ^= ans;
printf("%lld\n", ans = s_que(u, k));
}
else
{
int u, v;
read(u);
read(v);
u ^= ans;
v ^= ans;
s_update(u, v);
}
}
return 0;
}
0%