ACM_JLINE.

树上修改

字数统计: 2.1k阅读时长: 11 min
2019/10/09 Share

树链剖分专题

这该死的OI题,常数卡那么紧,没有O2死活过不了64%的数据点,一直T。不过真的是好题,收藏了!!!

树上修改

题目:

给定一棵初始有 n 个点的有根树,点编号为 1 … n,其中 1 是树根。

定义两点 u, v 间的距离为 u, v之间最短路径上的边数。

m 次操作,每个操作是以下两种之一:

  • 1 f 表示在树上添加一个编号为当前点数 +1 的点,这个点在树上的父亲是 f。保证 f 是在树上存在的点。
  • 2 u l r表示询问以 u 为根的子树中,所有编号在 [l, r]中的点到 u 的距离之和。

输入格式:

第一行包括两个正整数 nm,表示初始点数和操作数。

接下来的 n − 1 行,每行包含两个正整数 u, v, 表示初始时 uv 在树上有一条边。

接下来的 m 行,每行包含若干个整数。第一个整数 op 表示操作种类。如果 op = 1,随后会输入一个正整数 f;如果 op = 2,随后会输入三个正整数 u, l, r

保证 op = 2 的操作中的 uop = 1的操作中的 f 不超过当前已有的点数。

保证 op = 2 的操作中 1 ≤ lrn + m

输出格式:

对每个 op = 2 的操作,输出一行表示答案。

数据范围:

对于前 16% 的数据,n ≤ 1e3,m ≤ 1e3。

对于另外 32% 的数据,保证对于所有 op = 2 的操作, l = 1。

对于前 64% 的数据,n ≤ 2e5 , m ≤ 2e5。

对于 100% 的数据,n ≤ 7e5, m ≤ 7e5。

解题思路:

对于查询操作,对于 u, v两点的距离,直接dv - du。di表示 i 到根节点的距离。因为u, v属于父子关系。

离线操作。前 32% 的数据,给了很好的思路铺垫,对于每个点的加入,r = i 的查询就可以计算答案了。

用树链剖分 + 线段树。每加入一个点v,对于v的每一个父节点u S[u] += d[v],Size[u]+=1。对于查询只要S[u] - d[u] * Size[u]。

前64%的数据也就是 l - r。方法一样,用 [1, r]的答案减去 [1, l -1]就行。复杂度O(n logn logn)

100% 的数据直接用dfs序,树状数组维护维护区间和,单点更新。真的超级有意思。

复杂度O(n * logn)

64%

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#pragma GCC optimize(2)
#define ll long long
#define pii pair<int, int>
using namespace std;
const int N = 1400000 + 100;
const int M = 1400100;
int n, m, x, y;
//vector <int> V[N];
vector <pii> e[N];
int op[N], U[N], L[N], R[N], X[N];
int cnt;
ll ans[N];

int head[M * 2], Next[M * 2], tot, ver[M * 2];

inline int read(){
int s=0;
char ch=getchar();
while(ch<'0'||ch>'9'){ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s;
}

int Size[M], fa[M], deep[M], son[M];
int top[M], dfn[M], cnt1;
void dfs1(int u, int pre){
Size[u] = 1;
son[u] = 0;
fa[u] = pre;
deep[u] = deep[pre] + 1;
for(int i = head[u]; i; i = Next[i]){
int v = ver[i];
if(v == pre) continue;
dfs1(v, u);
Size[u] += Size[v];
if(Size[v] > Size[son[u]]){
son[u] = v;
}
}
}

void dfs2(int u, int t){
top[u] = t;
dfn[u] = ++cnt1;
if(son[u]) dfs2(son[u], t);
for(int i = head[u]; i; i = Next[i]){
int v = ver[i];
if(v != son[u] && v != fa[u]){
dfs2(v, v);
}
}
}

struct node{
int l, r;
ll sum, add;
}tree[M * 4][2];

void pushup(int x, int id){
tree[x][id].sum = tree[x * 2][id].sum + tree[x * 2 + 1][id].sum ;
}

void pushdown(int x, int id){
if(tree[x][id].add){
tree[x * 2][id].sum += tree[x][id].add * (tree[x * 2][id].r - tree[x * 2][id].l + 1);
tree[x * 2 + 1][id].sum += tree[x][id].add * (tree[x * 2 + 1][id].r - tree[x * 2 + 1][id].l + 1);
tree[x * 2][id].add += tree[x][id].add;
tree[x * 2 + 1][id].add += tree[x][id].add;
tree[x][id].add = 0;
}
}

void build(int x, int l, int r, int id){
tree[x][id].l = l, tree[x][id].r = r, tree[x][id].add = 0;
if(l == r) {
tree[x][id].sum = 0;
return;
}
int mid = (l + r) / 2;
build(x * 2, l, mid, id);
build(x * 2 + 1, mid + 1, r, id);
pushup(x, id);
}

void add(int x, int l, int r, ll c, int id){
if(l <= tree[x][id].l && r >= tree[x][id].r){
tree[x][id].sum += c * (tree[x][id].r - tree[x][id].l + 1);
tree[x][id].add += c;
return;
}
pushdown(x, id);
int mid = (tree[x][id].l + tree[x][id].r) / 2;
if(l <= mid) add(x * 2, l, r, c, id);
if(r > mid) add(x * 2 + 1, l, r, c, id);
}

ll ask(int x, int l, int id){
if(tree[x][id].l == tree[x][id].r) return tree[x][id].sum;
pushdown(x, id);
int mid = (tree[x][id].l + tree[x][id].r) / 2;
ll v = 0;
if(l <= mid) v += ask(x * 2, l, id);
else v += ask(x * 2 + 1, l, id);
return v;
}

void change(int x, int y, int c, int id){
while(top[x] != top[y]){
if(deep[top[x]] < deep[top[y]]){
swap(x, y);
}
add(1, dfn[top[x]], dfn[x], c, id);
x = fa[top[x]];

}
if(deep[x] > deep[y]) swap(x, y);
add(1, dfn[x], dfn[y], c, id);
}

ll queryAns(int x, int id){
ll ans = 0;
ans += ask(1, dfn[x], id);
return ans;
}

void solve(int u){
int len = e[u].size();
for(int i = 0; i < len; i++){
int x = e[u][i].first, xi = e[u][i].second;
int u = U[x];
ans[x] += (ll)xi * (queryAns(u, 0) - queryAns(u, 1) * (ll)deep[u]);
}

}

void add_edge(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
swap(x, y);
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

int main()
{
n = read(); m = read();
for(int i = 1; i < n; i++){
x = read(); y = read();
add_edge(x, y);
}
cnt = n;
for(int i = 1; i <= m; i++){
op[i] = read();
if(op[i] == 1){
X[i] = read();
++cnt;
add_edge(X[i], cnt);
}
else{
U[i] = read(); L[i] = read(); R[i] = read();
if(L[i] > cnt || L[i] > R[i]) continue;
R[i] = min(cnt, R[i]);
e[R[i]].push_back(pii(i, 1));
e[L[i] - 1].push_back(pii(i, -1));
}
}
dfs1(1, 1); dfs2(1, 1);
build(1, 1, cnt, 0); build(1, 1, cnt, 1);solve(0);
for(int i = 1; i <= cnt; i++){
change(1, fa[i], deep[i], 0);
change(1, fa[i], 1, 1);
solve(i);
}

for(int i = 1; i <= m; i++){
if(op[i] == 2){
printf("%lld\n", ans[i]);
}
}


return 0;
}

// 5 6
// 1 3
// 2 3
// 3 5
// 4 5
// 2 3 1 5
// 2 1 1 4
// 2 3 1 6
// 1 5
// 2 3 1 6
// 2 5 1 4

100%

注意开 n + m。点最多有1.4 * 107!!!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <vector>
#pragma GCC optimize(2)
#define ll long long
#define pii pair<int, int>
using namespace std;
const int N = 1400000 + 100;
int n, m, x, y;
//vector <int> V[N];
vector <pii> e[N];
int op[N], U[N], L[N], R[N], X[N];
int cnt;
ll ans[N];

int head[N * 2], Next[N * 2], tot, ver[N * 2];

inline int read(){
int s=0;
char ch=getchar();
while(ch<'0'||ch>'9'){ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s;
}

int cnt1, dfn[N], Last[N], deep[N], fa[N];
void dfs(int x, int pre){
fa[x] = pre; deep[x] = deep[pre] + 1; dfn[x] = ++cnt1;
for(int i = head[x]; i; i = Next[i]){
int v = ver[i];
if(v == pre) continue;
dfs(v, x);
}
Last[x] = cnt1;
}

ll c[N], c2[N];

void add1(int x, int y){
while(x < N){
c[x] += y;
x += x & -x;
}
}

ll ask1(int x){
ll ans = 0;
while(x){
ans += c[x];
x -= x & -x;
}
return ans;
}

void add2(int x, int y){
while(x < N){
c2[x] += y;
x += x & -x;
}
}

ll ask2(int x){
ll ans = 0;
while(x){
ans += c2[x];
x -= x & -x;
}
return ans;
}

void solve(int u){
int len = e[u].size();
for(int i = 0; i < len; i++){
int x = e[u][i].first, xi = e[u][i].second;
int u = U[x];
ans[x] += (ll)xi * (ask1(Last[u]) - ask1(dfn[u]) - (ask2(Last[u]) - ask2(dfn[u])) * (ll)deep[u]);
}

}

void add_edge(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
swap(x, y);
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

int main()
{
n = read(); m = read();
for(int i = 1; i < n; i++){
x = read(); y = read();
add_edge(x, y);
}
cnt = n;
for(int i = 1; i <= m; i++){
op[i] = read();
if(op[i] == 1){
X[i] = read();
++cnt;
add_edge(X[i], cnt);
}
else{
U[i] = read(); L[i] = read(); R[i] = read();
if(L[i] > cnt || L[i] > R[i]) continue;
R[i] = min(cnt, R[i]);
e[R[i]].push_back(pii(i, 1));
e[L[i] - 1].push_back(pii(i, -1));
}
}
dfs(1, 0); solve(0);
for(int i = 1; i <= cnt; i++){
add1(dfn[i], deep[i]);
add2(dfn[i], 1);
solve(i);
}

for(int i = 1; i <= m; i++){
if(op[i] == 2){
printf("%lld\n", ans[i]);
}
}
return 0;
}
CATALOG
  1. 1. 树链剖分专题
    1. 1.1. 树上修改