ACM_JLINE.

set维护树上问题

字数统计: 3.3k阅读时长: 19 min
2019/10/13 Share

9月提高组模拟赛题三

越写越好玩,不得不佩服出题人!!!官方题解写得特别用心!!!

set维护树上问题

题目:

给定一棵初始有 n 个点的有根树,点编号为 1 … n,其中 1 是树根,每个点有权值 ai

对于每个点 u,设其他既不是 u 的父节点,也不是 u 的子节点构成的集合为 E,集合中的点按权值从小到大排序,求相邻两项的绝对值的最大值。如果集合 E 的点小于等于 2,直接输出 -1.

输入格式:

第一行一个正整数 n,表示树上节点的数量。

第二行 n 个正整数,第 i 个正整数表示 ai

接下来的 n - 1 行,每行两个整数 uv,表示编号为 u 的点和编号为 v 的点之间有一条边。

数据范围:

对于前 20% 的数据,n ≤ 1e3

另有 16% 的数据,n ≤ 2e4,树高不超过 25。

另有 16% 的数据,存在两个不同的整数 xy,使得对于任意的1 ≤ inai = xai = y 成立。

对于 100% 的数据,n ≤ 5e4, |ai| ≤ 1e9 。

解题思路:

20% 直接暴力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdio>
#define ll long long
using namespace std;
const int N = 1e5 + 5;

int head[N], ver[N], Next[N], tot, cnt;
int vis[N], dfn[N], sz[N];
int a[N], ans[N];
int n;

inline int read(){
int s=0, w = 1;
char ch=getchar();
while(ch<'0'||ch>'9'){w = -1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s * w;
}

void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

void dfs(int x, int pre){
dfn[x] = ++cnt; sz[x] = 1;
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs(y, x);
sz[x] += sz[y];
}
}

void dfs2(int x, int pre){
vector <int> V;
V.clear();
vis[x] = 1;
for(int i = 1; i <= n; i++){
if(vis[i] == 1 || (dfn[i] >= dfn[x] && dfn[i] < dfn[x] + sz[x])){
continue;
}
V.push_back(a[i]);
}
sort(V.begin(), V.end());
V.erase(unique(V.begin(), V.end()), V.end());
if(V.size() <= 1){
ans[x] = -1;
}
else{
ans[x] = -2e9;
for(int i = 0; i + 1 < V.size(); i++){
ans[x] = max(ans[x], V[i + 1] - V[i]);
}
}
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs2(y, x);
}
vis[x] = 0;
}

int main()
{
n = read();
for(int i = 1; i <= n; i++){
int x = read();
a[i] = x;
}
for(int i = 1; i < n; i++){
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 0);
dfs2(1, 0);

for(int i = 1; i <= n; i++){
printf("%d\n", ans[i]);
}
return 0;
}

另外16%

用一个 set 维护集合 E, 另外一个 set T 维护相邻两个数的绝对值。首先把所有点都加入集合 E,同时维护好 T。

关键在于点进出集合不会超过 50n 次。

定义 d[i] 表示点i到1号点路劲上的点数,sz[i]表示点i子树的点的个数。

因为 $ \sum_1^n\ $di = $ \sum_1^n\ $szi 。$ \sum_1^n\ $di ≤ 25n, $ \sum_1^n\ $ szi + di ≤ 50n。时间复杂度 O(nhlogn)

  1. 两个集合版本
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdio>
#include <set>
#include <queue>
#pragma GCC optimize(2)
#define ll long long
using namespace std;
const int N = 1e5 + 5;
const int Max = 2e9 + 10;

int head[N], ver[N], Next[N], tot, cnt;
int vis[N], dfn[N], rdfn[N], sz[N], fa[N];
int a[N], ans[N], id[N], num[N];
int n;
vector <int> V;
set <int> s, t;

inline int read(){
int s=0, w = 1;
char ch=getchar();
while(ch<'0'||ch>'9'){w = -1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s * w;
}

void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

void dfs(int x, int pre){
dfn[x] = ++cnt; sz[x] = 1; rdfn[cnt] = x; fa[x] = pre;
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs(y, x);
sz[x] += sz[y];
}
}

int getPre(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.begin()) return -Max;
return *(--iter);
}

int getNxt(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.end()) return Max;
return *(iter);
}

void addT(int x){
t.insert(x);
}

void subT(int x){
t.erase(x);
}

int solve(){
if(t.size() == 0) return -1;
return *(--t.end());
}

void addE(int u){
int x = V[u - 1];
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) subT(r - l);
if(l != -Max) addT(x - l);
if(r != Max) addT(r - x);
s.insert(x);
}

void subE(int u){
int x = V[u - 1];
s.erase(x);
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) addT(r - l);
if(l != -Max) subT(x - l);
if(r != Max) subT(r - x);
}

int main()
{
n = read();
for(int i = 1; i <= n; i++){
int x = read();
a[i] = x;
V.push_back(x);
}
sort(V.begin(), V.end());
for(int i = 1; i <= n; i++){
id[i] = lower_bound(V.begin(), V.end(), a[i]) - V.begin() + 1;
}

for(int i = 1; i < n; i++){
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 0);

for(int i = 1; i <= n; i++){
if(num[id[i]] == 0){
addE(id[i]);
}
num[id[i]]++;
}

for(int i = 1; i <= n; i++){
int x = i;
while(fa[x]){
x = fa[x];
--num[id[x]];
if(num[id[x]] == 0){
subE(id[x]);
}
}
for(int j = dfn[i]; j < dfn[i] + sz[i]; j++){
int x = rdfn[j];
--num[id[x]];
if(num[id[x]] == 0){
subE(id[x]);
}
}

printf("%d\n", solve());
x = i;

while(fa[x]){
x = fa[x];
if(num[id[x]] == 0){
addE(id[x]);
}
++num[id[x]];
}
for(int j = dfn[i]; j < dfn[i] + sz[i]; j++){
int x = rdfn[j];
if(num[id[x]] == 0){
addE(id[x]);
}
++num[id[x]];
}

}
return 0;
}
  1. 集合 T 用优先队列懒删除,常数低很多,亲测!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdio>
#include <set>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
const int Max = 2e9 + 10;

int head[N], ver[N], Next[N], tot, cnt;
int vis[N], dfn[N], rdfn[N], sz[N], fa[N];
int a[N], ans[N], id[N], num[N];
int n;
vector <int> V;
set <int> s;
priority_queue <int> q1, q2;

inline int read(){
int s=0, w = 1;
char ch=getchar();
while(ch<'0'||ch>'9'){w = -1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s * w;
}

void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

void dfs(int x, int pre){
dfn[x] = ++cnt; sz[x] = 1; rdfn[cnt] = x; fa[x] = pre;
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs(y, x);
sz[x] += sz[y];
}
}

int getPre(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.begin()) return -Max;
return *(--iter);
}

int getNxt(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.end()) return Max;
return *(iter);
}

void addT(int x){
q1.push(x);
}

void subT(int x){
q2.push(x);
}

int solve(){
while(q2.size() && q1.top() == q2.top()){
q1.pop(), q2.pop();
}
if(q1.size() == 0) return -1;
return q1.top();
}

void addE(int u){
int x = V[u - 1];
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) subT(r - l);
if(l != -Max) addT(x - l);
if(r != Max) addT(r - x);
s.insert(x);
}

void subE(int u){
int x = V[u - 1];
s.erase(x);
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) addT(r - l);
if(l != -Max) subT(x - l);
if(r != Max) subT(r - x);
}

int main()
{
n = read();
for(int i = 1; i <= n; i++){
int x = read();
a[i] = x;
V.push_back(x);
}
sort(V.begin(), V.end());
for(int i = 1; i <= n; i++){
id[i] = lower_bound(V.begin(), V.end(), a[i]) - V.begin() + 1;
}

for(int i = 1; i < n; i++){
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 0);

for(int i = 1; i <= n; i++){
if(num[id[i]] == 0){
addE(id[i]);
}
num[id[i]]++;
}

for(int i = 1; i <= n; i++){
int x = i;
while(fa[x]){
x = fa[x];
--num[id[x]];
if(num[id[x]] == 0){
subE(id[x]);
}
}
for(int j = dfn[i]; j < dfn[i] + sz[i]; j++){
int x = rdfn[j];
--num[id[x]];
if(num[id[x]] == 0){
subE(id[x]);
}
}

printf("%d\n", solve());
x = i;

while(fa[x]){
x = fa[x];
if(num[id[x]] == 0){
addE(id[x]);
}
++num[id[x]];
}
for(int j = dfn[i]; j < dfn[i] + sz[i]; j++){
int x = rdfn[j];
if(num[id[x]] == 0){
addE(id[x]);
}
++num[id[x]];
}

}
return 0;
}

另外16%

这个比较好实现,记录x, y的个数。dfs解决。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdio>
#include <set>
#include <queue>
#define ll long long
using namespace std;
const int N = 1e5 + 5;
const int Max = 2e9 + 10;

int head[N], ver[N], Next[N], tot, cnt;
int vis[N], dfn[N], rdfn[N], sz1[N], sz2[N], fa[N];
int a[N], ans[N], id[N], num[N];
int n, exx;
int num1, num2;
int tmp1, tmp2;

inline int read(){
int s=0, w = 1;
char ch=getchar();
while(ch<'0'||ch>'9'){w = -1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s * w;
}

void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

void dfs(int x, int pre){
dfn[x] = ++cnt;
if(a[x] == a[1]) sz1[x] = 1;
else sz2[x] = 1;
rdfn[cnt] = x; fa[x] = pre;
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs(y, x);
sz1[x] += sz1[y];
sz2[x] += sz2[y];
}
}

void dfs1(int x, int pre){
if(a[x] == a[1]) tmp1++;
else tmp2++;

int l1 = num1 - sz1[x] - tmp1, l2 = num2 - sz2[x] - tmp2;
if(a[x] == a[1]) l1++; else l2++;
//cout << "x : " << x << " " << l1 << " " << l2 << endl;
if(l1 + l2 <= 1){
ans[x] = -1;
}
else if(l1 == 0 || l2 == 0){
ans[x] = 0;
}
else ans[x] = abs(a[1] - exx);

for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs1(y, x);
}
if(a[x] == a[1]) tmp1--;
else tmp2--;
}


int main()
{
n = read();
for(int i = 1; i <= n; i++){
int x = read();
a[i] = x;
if(a[i] != a[1]) exx = a[i];
}

for(int i = 1; i <= n; i++){
if(a[i] == a[1]){
num1++;
}
else num2++;
}

for(int i = 1; i < n; i++){
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 0);
dfs1(1, 0);
for(int i = 1; i <= n; i++){
printf("%d\n", ans[i]);
}
return 0;
}

100%

在方法二上改进,当然树链剖分,维护一段段区间。再加扫描线的操作。妙呀

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#include <iostream>
#include <cmath>
#include <algorithm>
#include <vector>
#include <cstring>
#include <cstdio>
#include <set>
#include <queue>
#define ll long long
#define pii pair<int, int>
using namespace std;
const int N = 1e5 + 5;
const int Max = 2e9 + 10;

int head[N], ver[N], Next[N], tot, cnt;
int vis[N], dfn[N], rdfn[N], sz[N], fa[N], son[N], top[N];
int a[N], ans[N], id[N], num[N];
int n;
vector <int> V, ans1[N], ans2[N];
set <int> s;
priority_queue <int> q1, q2;

inline int read(){
int s=0, w = 1;
char ch=getchar();
while(ch<'0'||ch>'9'){w = -1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s * w;
}

void add(int x, int y){
ver[++tot] = y;
Next[tot] = head[x];
head[x] = tot;
}

void dfs(int x, int pre){
sz[x] = 1; fa[x] = pre;
for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == pre) continue;
dfs(y, x);
sz[x] += sz[y];
if(sz[y] > sz[son[x]]) son[x] = y;
}
}

void dfs2(int x, int t){
top[x] = t; dfn[x] = ++cnt; rdfn[cnt] = x;
if(son[x]) dfs2(son[x], t);

for(int i = head[x]; i; i = Next[i]){
int y = ver[i];
if(y == fa[x]) continue;
if(y != son[x]) dfs2(y, y);
}
}

int getPre(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.begin()) return -Max;
return *(--iter);
}

int getNxt(int x){
set <int> :: iterator iter = s.lower_bound(x);
if(iter == s.end()) return Max;
return *(iter);
}

void addT(int x){
q1.push(x);
}

void subT(int x){
q2.push(x);
}

int solve(){
while(q2.size() && q1.top() == q2.top()){
q1.pop(), q2.pop();
}
if(q1.size() == 0) return -1;
return q1.top();
}

void addE(int u){
int x = V[u - 1];
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) subT(r - l);
if(l != -Max) addT(x - l);
if(r != Max) addT(r - x);
s.insert(x);
}

void subE(int u){
int x = V[u - 1];
s.erase(x);
int l = getPre(x), r = getNxt(x);
if(l != -Max && r != Max) addT(r - l);
if(l != -Max) subT(x - l);
if(r != Max) subT(r - x);

}

void addVector(int x, vector<pii> &tmp){
while(x){
tmp.push_back(pii(dfn[top[x]], dfn[x]));
x = fa[top[x]];
}
}

void addSeg(int l, int r, int x){
if(l > r) return;
ans1[l].push_back(x);
ans2[r + 1].push_back(x);
}

int main()
{
n = read();
for(int i = 1; i <= n; i++){
int x = read();
a[i] = x;
V.push_back(x);
}
sort(V.begin(), V.end());
for(int i = 1; i <= n; i++){
id[i] = lower_bound(V.begin(), V.end(), a[i]) - V.begin() + 1;
}


for(int i = 1; i < n; i++){
int x = read(), y = read();
add(x, y); add(y, x);
}
dfs(1, 0);
dfs2(1, 1);

for(int i = 1; i <= n; i++){
vector <pii> tmp;
tmp.clear();
if(fa[i]){
addVector(fa[i], tmp);
}
tmp.push_back(pii(dfn[i], dfn[i] + sz[i] - 1));
sort(tmp.begin(), tmp.end());
addSeg(1, tmp[0].first - 1, id[i]);
for(int j = 0; j + 1 < tmp.size(); j++){
addSeg(tmp[j].second + 1, tmp[j+1].first - 1, id[i]);
}
addSeg(tmp[tmp.size() - 1].second + 1, n, id[i]);
}

for(int i = 1; i <= n; i++){
for(int j = 0; j < ans1[i].size(); j++){
int x = ans1[i][j];
//cout << "i: " << i << "add: " << x << " ";
if(num[x] == 0){
addE(x);
}
num[x] ++;
}
//cout << endl;
for(int j = 0; j < ans2[i].size(); j++){
int x = ans2[i][j];
//cout << "i: " << i << "sub: " << x << " ";
num[x]--;
if(num[x] == 0){
subE(x);
}
}
//cout << endl;
ans[i] = solve();
}

for(int i = 1; i <= n; i++){
printf("%d\n", ans[dfn[i]]);
}

return 0;
}
CATALOG
  1. 1. 9月提高组模拟赛题三
    1. 1.1. set维护树上问题