[SPOJ 10628] Count on a tree

题目大意

给定一棵 nn 个节点的树,每个点有一个权值,对于 mm 个询问 (u,  v,  k)(u, \; v, \; k),你需要回答 u  xor  lastansu \; xor \; lastansvv 这两个节点间第 kk 小的点权。其中 lastanslastans 是上一个询问的答案,初始为 00

1n,  m100,0001\leqslant n, \; m \leqslant 100,000

题目链接

Count on a tree - Luogu 2633

SPOJ 10628 - COT

题解

主席树。对于每一个询问,求出 p=lca(u,  v)p = lca(u, \; v),然后在线段树 u+vpp.fau + v - p - p.fa 上二分即可。

求 lca 用倍增。

代码

有烦人的 PE。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#include <cstdio>
#include <climits>
#include <vector>
#include <queue>
#include <algorithm>
const int MAXN = 100005;
const int MAXLOGN = 17;
struct PSegT *null;
struct PSegT {
PSegT *lc, *rc;
int cnt;
PSegT(PSegT *lc, PSegT *rc) : lc(lc), rc(rc), cnt(lc->cnt + rc->cnt) {}
PSegT(PSegT *lc, PSegT *rc, int cnt) : lc(lc), rc(rc), cnt(cnt) {}
PSegT *insert(int l, int r, int x) {
if (l == r) return new PSegT(null, null, cnt + 1);
else {
int mid = l + (r - l) / 2;
if (x <= mid) return new PSegT(lc->insert(l, mid, x), rc);
else return new PSegT(lc, rc->insert(mid + 1, r, x));
}
}
};
struct Node {
std::vector<Node *> adj;
Node *fa;
int dep, w;
bool vis;
PSegT *seg;
} N[MAXN];
void addEdge(int u, int v) {
N[u].adj.push_back(&N[v]);
N[v].adj.push_back(&N[u]);
}
void init() {
null = new PSegT(NULL, NULL, 0);
null->lc = null->rc = null;
}
int n, f[MAXN][MAXLOGN], logn;
void build() {
N[0].vis = true;
N[0].seg = null;
std::queue<Node *> q;
q.push(&N[1]);
N[1].vis = true;
N[1].dep = 1;
N[1].fa = &N[0];
while (!q.empty()) {
Node *u = q.front();
q.pop();
u->seg = u->fa->seg->insert(0, INT_MAX, u->w);
for (Node **p = &u->adj.front(), *v = *p; p <= &u->adj.back(); v = *++p) {
if (!v->vis) {
v->vis = true;
v->dep = u->dep + 1;
v->fa = u;
q.push(v);
}
}
}
while ((1 << (logn + 1)) <= n) logn++;
f[1][0] = 1;
for (int i = 2; i <= n; i++) f[i][0] = N[i].fa - N;
for (int j = 1; j <= logn; j++) {
for (int i = 1; i <= n; i++) {
f[i][j] = f[f[i][j - 1]][j - 1];
}
}
}
int lca(int u, int v) {
if (N[u].dep < N[v].dep) std::swap(u, v);
if (N[u].dep > N[v].dep) {
for (int i = logn; i >= 0; i--) {
if (N[f[u][i]].dep >= N[v].dep) u = f[u][i];
}
}
if (u != v) {
for (int i = logn; i >= 0; i--) {
if (f[u][i] != f[v][i]) {
u = f[u][i];
v = f[v][i];
}
}
return f[u][0];
}
return u;
}
int query(int u, int v, int k) {
int p = lca(u, v);
PSegT *su = N[u].seg, *sv = N[v].seg, *sp = N[p].seg, *sf = N[p].fa->seg;
int l = 0, r = INT_MAX;
while (l < r) {
int mid = l + (r - l) / 2;
int s = su->lc->cnt + sv->lc->cnt - sp->lc->cnt - sf->lc->cnt;
if (k > s) {
k -= s;
l = mid + 1;
su = su->rc;
sv = sv->rc;
sp = sp->rc;
sf = sf->rc;
} else {
r = mid;
su = su->lc;
sv = sv->lc;
sp = sp->lc;
sf = sf->lc;
}
}
return l;
}
int main() {
int m;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++) scanf("%d", &N[i].w);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d %d", &u, &v);
addEdge(u, v);
}
init();
build();
int lastAns = 0;
while (m--) {
int u, v, k;
scanf("%d %d %d", &u, &v, &k);
u ^= lastAns;
printf(m ? "%d\n" : "%d", lastAns = query(u, v, k));
}
return 0;
}