这不是我们 [NOI2014]购票 的梗吗?下次引用记得标注出处。
题意不赘述,题目中说的商路总和最大其实是个幌子。不难发现本题是一个树上 DP ,当且仅当所有节点的商路价值最大才能达到总和最大的要求。而每一个 $i$ 都是从他子树中的一个点继承。也就是从子树中寻找一个点,然后自己连一个 $a_1a_2$ 这样的边过去,形成自己节点的商路,使其最大。
对于 $i$ 节点,其商路是 $dp_i=\max\limits_{j\in \text{subtree}(i)} dp_j + v_i-(f_i+dis_i-dis_j)^2$ ,其中 $dis_i$ 表示 $i$ 其到 $1$ 号节点的距离,对于所有叶子节点 $v$ 有 $dp_v=0$。
看到平方项,一眼斜率优化。
$$ dp_i=dp_j+v_i-((f_i+dis_i)^2+dis_j^2-2(f_i+dis_i)dis_j) \\= 2dis_j(dis_i+f_i)+(dp_j-dis_j^2)+(v_i-(f_i+dis_i)^2) $$
然后就可以把每个点的直线 $y_j=k_jx+b_j$ 求出来,$k_j=2dis_j,b_j=dp_j-dis_j^2$ 。然后利用李超树存储求 $x=f_i+dis_i $ 处的最大值。横坐标需要离散化。
然后因为是树上,要保证 $j\in \text{subtree}(i)$ ,所以用线段树套李超树即可,外层维护的是整个树的 dfs 序。跑一遍 dfs 即可,同时需要得到每个点的子树规模,然后把每个节点对应的直线放在 dfs 序的位置上即可。
然后注意一点,非叶子节点也可以不连接商路,让答案为 $0$ 即可。所以每次查询结果还要对 $0$ 取 $\max$ 作为实际值。
然后注意多测清空。时间复杂度 $O(n\log ^2 n)$ ,空间复杂度 $O(n\log n)$。官方数据峰值时间不到 1.7s ,空间占用接近 30MB 。
const int N = 100010, LOG_N = 18;
const i64 INF = 1145141919810114514ll;
const i64 mod = 1000000000000000000;
inline i64 add(i64 a, i64 b) { return (a + b >= mod) ? (a + b - mod) : (a + b); }
int n;
i64 fa[N], s[N], v[N], f[N], dis[N], sz[N];
i64 X[N], xcnt;
i64 dp[N];
int dfn[N], dfs_t;
namespace graph
{
int to[N], nxt[N], h[N], ecnt;
inline void add(int u, int v) { to[++ecnt] = v, nxt[ecnt] = h[u], h[u] = ecnt; }
inline void clr() { memset(h, 0, sizeof(h)), ecnt = 0; }
inline void dfs(int u)
{
dfn[++dfs_t] = u, sz[u] = 1;
for (int i = h[u]; i; i = nxt[i])
dfs(to[i]), sz[u] += sz[to[i]];
}
}
struct line
{
i64 k, b;
inline line(i64 _k = 0, i64 _b = -INF) : k(_k), b(_b) {}
inline i64 f(const i64& x) const { return k * x + b; }
} lin[N];
struct node_in
{
int lid, lc, rc;
} tri[N * LOG_N];
int icnt;
struct node_out
{
int rt, lc, rc;
} tro[N << 1];
int ocnt;
inline void clr()
{
for (int i = 1; i <= n; ++i)
lin[i] = line();
memset(tri, 0, (icnt + 1) * sizeof(node_in));
memset(tro, 0, (ocnt + 1) * sizeof(node_out));
icnt = ocnt = 0;
graph::clr(), memset(dfn, 0, (n + 1) << 2), dfs_t = 0;
memset(dp, 0, (n + 1) << 3), memset(X, 0, (n + 1) << 3), xcnt = 0;
memset(dis, 0, (n + 1) << 3), memset(sz, 0, (n + 1) << 3);
}
inline void modify_in(int& u, int l, int r, int lid)
{
if (!u)
return (void)(u = ++icnt, tri[u].lid = lid);
int m = (l + r) >> 1;
if (lin[tri[u].lid].f(X[m]) < lin[lid].f(X[m]))
std::swap(tri[u].lid, lid);
if (lin[tri[u].lid].f(X[l]) < lin[lid].f(X[l]))
modify_in(tri[u].lc, l, m, lid);
else if (lin[tri[u].lid].f(X[r]) < lin[lid].f(X[r]))
modify_in(tri[u].rc, m + 1, r, lid);
}
inline i64 query_in(const int rt, const int& x)
{
i64 ret = lin[tri[rt].lid].f(X[x]);
int l = 1, r = xcnt, m = 0, u = rt;
while (l < r)
{
m = (l + r) >> 1;
(x <= m) ? (u = tri[u].lc, r = m) : (u = tri[u].rc, l = m + 1);
if (!u)
return ret;
ret = std::max(ret, lin[tri[u].lid].f(X[x]));
}
return ret;
}
inline int build_out(int l, int r)
{
int u = ++ocnt, m = (l + r) >> 1;
if (l ^ r)
tro[u].lc = build_out(l, m), tro[u].rc = build_out(m + 1, r);
return u;
}
inline void modify_out(int u, int l, int r, int pos, int lid)
{
modify_in(tro[u].rt, 1, xcnt, lid);
if (l == r)
return;
int m = (l + r) >> 1;
(pos <= m) ? (modify_out(tro[u].lc, l, m, pos, lid)) : (modify_out(tro[u].rc, m + 1, r, pos, lid));
}
inline i64 query_out(int u, int L, int R, int l, int r, int x)
{
if (l > R || r < L)
return -INF;
if (l <= L && R <= r)
return query_in(tro[u].rt, x);
int M = (L + R) >> 1;
return std::max(query_out(tro[u].lc, L, M, l, r, x), query_out(tro[u].rc, M + 1, R, l, r, x));
}
inline void solve()
{
n = rd();
for (int i = 1; i <= n; ++i)
fa[i] = rd(), graph::add(fa[i], i), s[i] = rd(), v[i] = rd(), f[i] = rd(), dis[i] = dis[fa[i]] + s[i], X[i] = f[i] + dis[i];
std::sort(X + 1, X + n + 1), xcnt = std::unique(X + 1, X + n + 1) - X - 1;
graph::dfs(1), build_out(1, n);
for (int i = n; i; --i)
{
int u = dfn[i];
if (sz[u] > 1)
{
int x = std::lower_bound(X + 1, X + xcnt + 1, f[u] + dis[u]) - X;
dp[u] = std::max(query_out(1, 1, n, i + 1, i + sz[u] - 1, x) + v[u] - X[x] * X[x], 0ll);
}
lin[u] = line(2 * dis[u], dp[u] - dis[u] * dis[u]);
modify_out(1, 1, n, i, u);
}
i64 ans = 0;
for (int i = 1; i <= n; ++i)
ans = add(ans, dp[i]);
wr(ans), putchar('\n');
}
int main()
{
int T = rd();
while (T--)
solve(), clr();
}