洛谷 P3369 【模板】普通平衡树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#include <iostream>
#include <cstdlib>
#include <ctime>
using namespace std;
const int N = 100001, INF = 1e8;
int n, op, x, root, idx;
struct Node {
int l, r;
int key, fix;
int cnt, size;
} tr[N];
void pushup(int p) {
tr[p].size = tr[tr[p].l].size + tr[tr[p].r].size + tr[p].cnt;
}
int get_node(int key) {
tr[++idx].key = key;
tr[idx].fix = rand();
tr[idx].cnt = tr[idx].size = 1;
return idx;
}
void build() {
get_node(-INF), get_node(INF);
root = 1, tr[1].r = 2;
pushup(root);
}
void zig(int &p) {
int q = tr[p].l;
tr[p].l = tr[q].r, tr[q].r = p, p = q;
pushup(tr[p].r), pushup(p);
}
void zag(int &p) {
int q = tr[p].r;
tr[p].r = tr[q].l, tr[q].l = p, p = q;
pushup(tr[p].l), pushup(p);
}
void insert(int &p, int key) {
if (!p) p = get_node(key);
else if (tr[p].key == key) ++tr[p].cnt;
else if (tr[p].key > key) {
insert(tr[p].l, key);
if (tr[tr[p].l].fix > tr[p].fix) zig(p);
} else {
insert(tr[p].r, key);
if (tr[tr[p].r].fix > tr[p].fix) zag(p);
}
pushup(p);
}
void remove(int &p, int key) {
if (!p) return;
if (tr[p].key == key) {
if (tr[p].cnt > 1) --tr[p].cnt;
else if (tr[p].l || tr[p].r) {
if (!tr[p].r || tr[tr[p].l].fix > tr[tr[p].r].fix) {
zig(p);
remove(tr[p].r, key);
} else {
zag(p);
remove(tr[p].l, key);
}
} else p = 0;
} else if (tr[p].key > key) remove(tr[p].l, key);
else remove(tr[p].r, key);
pushup(p);
}
int get_rank_by_key(int p, int key) {
if (!p) return 0;
if (tr[p].key == key) return tr[tr[p].l].size + 1;
if (tr[p].key > key) return get_rank_by_key(tr[p].l, key);
return tr[tr[p].l].size + tr[p].cnt + get_rank_by_key(tr[p].r, key);
}
int get_key_by_rank(int p, int rank) {
if (!p) return INF;
if (tr[tr[p].l].size >= rank) return get_key_by_rank(tr[p].l, rank);
if (tr[tr[p].l].size + tr[p].cnt >= rank) return tr[p].key;
return get_key_by_rank(tr[p].r, rank - tr[tr[p].l].size - tr[p].cnt);
}
int get_prev(int p, int key) {
if (!p) return -INF;
if (tr[p].key >= key) return get_prev(tr[p].l, key);
return max(tr[p].key, get_prev(tr[p].r, key));
}
int get_next(int p, int key) {
if (!p) return INF;
if (tr[p].key <= key) return get_next(tr[p].r, key);
return min(tr[p].key, get_next(tr[p].l, key));
}
int main() {
ios::sync_with_stdio(false);
srand(time(nullptr));
build();
cin >> n;
while (n--) {
cin >> op >> x;
if (op == 1) insert(root, x);
else if (op == 2) remove(root, x);
else if (op == 3) cout << get_rank_by_key(root, x) - 1 << endl;
else if (op == 4) cout << get_key_by_rank(root, x + 1) << endl;
else if (op == 5) cout << get_prev(root, x) << endl;
else cout << get_next(root, x) << endl;
}
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#include <iostream>
using namespace std;
const int N = 100001;
int n, op, x;
int rt, tot, fa[N], ch[N][2], val[N], cnt[N], sz[N];
struct Splay {
void maintain(int x) { sz[x] = sz[ch[x][0]] + sz[ch[x][1]] + cnt[x]; }
bool get(int x) { return x == ch[fa[x]][1]; }
void clear(int x) { ch[x][0] = ch[x][1] = fa[x] = val[x] = sz[x] = cnt[x] = 0; }
void rotate(int x) {
int y = fa[x], z = fa[y], chk = get(x);
ch[y][chk] = ch[x][chk ^ 1];
if (ch[x][chk ^ 1]) fa[ch[x][chk ^ 1]] = y;
ch[x][chk ^ 1] = y;
fa[y] = x;
fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
maintain(x);
maintain(y);
}
void splay(int x) {
for (int f = fa[x]; f = fa[x], f; rotate(x))
if (fa[f]) rotate(get(x) == get(f) ? f : x);
rt = x;
}
void ins(int k) {
if (!rt) {
val[++tot] = k;
cnt[tot]++;
rt = tot;
maintain(rt);
return;
}
int cur = rt, f = 0;
while (1) {
if (val[cur] == k) {
cnt[cur]++;
maintain(cur);
maintain(f);
splay(cur);
break;
}
f = cur;
cur = ch[cur][val[cur] < k];
if (!cur) {
val[++tot] = k;
cnt[tot]++;
fa[tot] = f;
ch[f][val[f] < k] = tot;
maintain(tot);
maintain(f);
splay(tot);
break;
}
}
}
int rk(int k) {
int res = 0, cur = rt;
while (1) {
if (k < val[cur]) cur = ch[cur][0];
else {
res += sz[ch[cur][0]];
if (k == val[cur]) {
splay(cur);
return res + 1;
}
res += cnt[cur];
cur = ch[cur][1];
}
}
}
int kth(int k) {
int cur = rt;
while (1) {
if (ch[cur][0] && k <= sz[ch[cur][0]]) {
cur = ch[cur][0];
} else {
k -= cnt[cur] + sz[ch[cur][0]];
if (k <= 0) {
splay(cur);
return val[cur];
}
cur = ch[cur][1];
}
}
}
int pre() {
int cur = ch[rt][0];
if (!cur) return cur;
while (ch[cur][1]) cur = ch[cur][1];
splay(cur);
return cur;
}
int nxt() {
int cur = ch[rt][1];
if (!cur) return cur;
while (ch[cur][0]) cur = ch[cur][0];
splay(cur);
return cur;
}
void del(int k) {
rk(k);
if (cnt[rt] > 1) {
cnt[rt]--;
maintain(rt);
return;
}
if (!ch[rt][0] && !ch[rt][1]) {
clear(rt);
rt = 0;
return;
}
if (!ch[rt][0]) {
int cur = rt;
rt = ch[rt][1];
fa[rt] = 0;
clear(cur);
return;
}
if (!ch[rt][1]) {
int cur = rt;
rt = ch[rt][0];
fa[rt] = 0;
clear(cur);
return;
}
int cur = rt;
int x = pre();
fa[ch[cur][1]] = x;
ch[x][1] = ch[cur][1];
clear(cur);
maintain(rt);
}
} tree;
int main() {
ios::sync_with_stdio(false);
cin >> n;
while (n--) {
cin >> op >> x;
if (op == 1) tree.ins(x);
else if (op == 2) tree.del(x);
else if (op == 3) cout << tree.rk(x) << endl;
else if (op == 4) cout << tree.kth(x) << endl;
else if (op == 5) tree.ins(x), cout << val[tree.pre()] << endl, tree.del(x);
else tree.ins(x), cout << val[tree.nxt()] << endl, tree.del(x);
}
return 0;
}