算法1
(暴力枚举) 主席树
C++ 代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N = 100010;
struct Node{
int lc, rc;
int sum;
}tr[N * 4 + N * 17];
int h[N * 2], e[N * 2], ne[N * 2], tdx;
int n, m;
int root[N], idx;
int fa[N][17], depth[N];
int q[N], hh, tt;
int T[N];
void add(int a, int b)
{
e[tdx] = b, ne[tdx] = h[a], h[a] = tdx ++;
}
void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
hh = 0, tt = -1;
q[++ tt] = 1;
while(hh <= tt)
{
int t = q[hh ++];
for(int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if(depth[j] > depth[t] + 1)
{
fa[j][0] = t;
depth[j] = depth[t] + 1;
for(int k = 1; k <= 16; ++ k)
fa[j][k] = fa[fa[j][k - 1]][k - 1];
q[++ tt] = j;
}
}
}
}
int lca(int a, int b)
{
if(depth[a] < depth[b]) swap(a, b);
for(int i = 16; i >= 0; -- i)
if(depth[fa[a][i]] >= depth[b])
a = fa[a][i];
if(a == b) return a;
for(int i = 16; i >= 0; -- i)
if(fa[a][i] != fa[b][i])
{
a = fa[a][i];
b = fa[b][i];
}
return fa[a][0];
}
int build(int l, int r)
{
int p = ++ idx;
if(l == r) return p;
int mid = (l + r) / 2;
tr[p].lc = build(l, mid);
tr[p].rc = build(mid + 1, r);
return p;
}
void pushup(int u)
{
tr[u].sum = tr[tr[u].lc].sum + tr[tr[u].rc].sum;
}
int modify(int p, int l, int r, int k)
{
int q = ++ idx;
tr[q] = tr[p];
if(l == r)
{
tr[q].sum ++;
return q;
}
int mid = (l + r) / 2;
if(k <= mid) tr[q].lc = modify(tr[p].lc, l, mid, k);
else tr[q].rc = modify(tr[p].rc, mid + 1, r, k);
pushup(q);
return q;
}
void dfs(int u, int p)
{
root[u] = modify(root[p], 1, n, T[u]);
for(int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if(j == p) continue;
dfs(j, u);
}
}
int query(int p, int l, int r, int lq, int rq)
{
if(lq <= l && rq >= r)
{
return tr[p].sum;
}
int res = 0;
int mid = (l + r) / 2;
if(lq <= mid) res += query(tr[p].lc, l, mid, lq, rq);
if(rq > mid) res += query(tr[p].rc, mid + 1, r, lq, rq);
return res;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; ++ i) scanf("%d", &T[i]);
memset(h, -1, sizeof h);
for(int i = 1; i <= n - 1; ++ i)
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
add(b, a);
}
bfs();
root[0] = build(1, n);
dfs(1, 0);
while(m --)
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
int p = lca(a, b);
int na = query(root[a], 1, n, c, c);
int nb = query(root[b], 1, n, c, c);
int np = p == 1 ? 0 : query(root[fa[p][0]], 1, n, c, c);
if(na + nb - np * 2 - (T[p] == c) > 0) printf("%d", 1);
else printf("%d", 0);
}
return 0;
}