$树分治$
$问题描述$
给定一棵有 $n$ 个点的树,询问树上距离不超过 $k$ 的点对数量。
$点分治$
点分治用于树上路径问题
对于一棵树,先找出他的重心
所有路径被分成三类:
- 一端为重心
- 两端在同一子树内
- 两端在不同子树内
对于第二类,需要递归计数
对于一、三类,需要$dfs$搜索
例题:树
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
int get_size(int u, int fa) // 求子树大小
{
if (st[u]) return 0;
int res = 1;
for (int i = h[u]; ~i; i = ne[i])
if (e[i] != fa)
res += get_size(e[i], u);
return res;
}
int get_wc(int u, int fa, int tot, int& wc) // 求重心
{
if (st[u]) return 0;
int sum = 1, ms = 0;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
int t = get_wc(j, u, tot, wc);
ms = max(ms, t);
sum += t;
}
ms = max(ms, tot - sum);
if (ms <= tot / 2) wc = u; // 保证复杂度即可
return sum;
}
void get_dist(int u, int fa, int dist, int& qt)
{
if (st[u]) return;
q[qt ++ ] = dist;
for (int i = h[u]; ~i; i = ne[i])
if (e[i] != fa)
get_dist(e[i], u, dist + w[i], qt);
}
int get(int a[], int k) // 计算点对数量
{
sort(a, a + k);
int res = 0;
for (int i = k - 1, j = -1; i >= 0; i -- )
{
while (j + 1 < i && a[j + 1] + a[i] <= m) j ++ ;
j = min(j, i - 1);
res += j + 1;
}
return res;
}
int calc(int u)
{
if (st[u]) return 0;
int res = 0;
get_wc(u, -1, get_size(u, -1), u); // 寻找重心
st[u] = true; // 删除重心
int pt = 0;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i], qt = 0;
get_dist(j, -1, w[i], qt); // 加入所有点
res -= get(q, qt); // 减去同一子树内的
for (int k = 0; k < qt; k ++ )
{
if (q[k] <= m) res ++ ; // 第一类
p[pt ++ ] = q[k];
}
}
res += get(p, pt); // 第三类
for (int i = h[u]; ~i; i = ne[i]) res += calc(e[i]); // 第二类
return res;
}
$点分树$
点分树是通过更改原树形态使树的层数变为稳定 $\log n$ 的一种重构树。
常用于解决与树原形态无关的带修改问题。
将上一轮的重心与下一轮的重心连接起来,就可以形成一颗高度小于$\log n$层的树。
例题:权值
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
const int N = 200010, M = N * 2, S = 1000010, INF = 0x3f3f3f3f;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int f[S], ans = INF;
PII p[N], q[N];
bool st[N];
void add(int a, int b, int c)
{
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx ++ ;
}
int get_size(int u, int fa)
{
if (st[u]) return 0;
int res = 1;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
res += get_size(j, u);
}
return res;
}
int get_wc(int u, int fa, int tot, int& wc)
{
if (st[u]) return 0;
int sum = 1, ms = 0;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
int t = get_wc(j, u, tot, wc);
ms = max(ms, t);
sum += t;
}
ms = max(ms, tot - sum);
if (ms <= tot / 2) wc = u;
return sum;
}
void get_dist(int u, int fa, int dist, int cnt, int& qt)
{
if (st[u] || dist > m) return;
q[qt ++ ] = {dist, cnt};
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa) continue;
get_dist(j, u, dist + w[i], cnt + 1, qt);
}
}
void calc(int u)
{
if (st[u]) return;
get_wc(u, -1, get_size(u, -1), u);
st[u] = true;
int pt = 0;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i], qt = 0;
get_dist(j, u, w[i], 1, qt);
for (int k = 0; k < qt; k ++ )
{
auto& t = q[k];
if (t.x == m) ans = min(ans, t.y);
ans = min(ans, f[m - t.x] + t.y);
p[pt ++ ] = t;
}
for (int k = 0; k < qt; k ++ )
{
auto& t = q[k];
f[t.x] = min(f[t.x], t.y);
}
}
for (int i = 0; i < pt; i ++ )
f[p[i].x] = INF;
for (int i = h[u]; ~i; i = ne[i]) calc(e[i]);
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++ )
{
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
memset(f, 0x3f, sizeof f);
calc(0);
if (ans == INF) ans = -1;
printf("%d\n", ans);
return 0;
}