题目描述
有 $n$ 辆车,分别在 $a_1, a_2, \ldots , a_n$ 位置和 $n$ 个加油站,分别在 $b_1, b_2, \ldots ,b_n$ 位置。
每个加油站只能支持一辆车的加油,所以你要把这些车开到不同的加油站加油。一个车从 $x$ 位置开到 $y$ 位置的代价为 $|x-y|$,问如何安排车辆,使得代价之和最小。
同时你有 $q$ 个操作,每次操作会修改第 $i$ 辆车的位置到 $x$,你要回答每次修改操作之后最优安排方案的总代价。
输入格式
第一行一个正整数 $n$ 。
接下来一行 $n$ 个整数 $a_1, a_2,\ldots,a_n$ 。
接下来一行 $n$ 个整数 $b_1, b_2,\ldots,b_n$ 。
接下来一行一个正整数 $q$ ,表示操作的个数。
接下来 $q$ 行,每行有两个整数 $i$($1\leq i\leq n$)和 $x$,表示将$i$这辆车开到 $x$ 位置的操作。
所有的车和加油站的范围一直在 $0$ 到 $10^9$ 之间。
输出格式
共 $q+1$ 行,第一行表示一开始的最优代价。
接下来 $q$ 行,第 $i$ 行表示操作 $i$ 之后的最优代价。
输入输出样例
输入 #1
2
1 2
3 4
1
1 3
输出 #1
4
2
说明/提示
【样例解释】
对于 $100\%$ 的数据,$1\leq n\leq 5\times 10^4$,$0\leq q\leq 5\times 10^4$ 。
首先显然将数组 $a,b$ 排序,此时 $\sum\limits_{i=1}^n|a_i-b_i|$ 一定是最优解。考虑怎么动态维护这个值。
这里用到了均分纸牌问题的证明思想。
先将数据中出现过的位置离散化,设 $w_i$ 表示位置 $i$ 到位置 $i+1$ 的距离;设 $numa_i,numb_i$ 分别表示数组 $a,b$ 中小于等于 $i$ 的元素数量,则会有 $|numa_i-numb_i|$ 数量的车经过位置 $i$ 和位置 $i+1$ ,即对答案的贡献为 $|numa_i-numb_i|\cdot w_i$ 。 因此答案为 $\sum\limits_{i=1}^n|numa_i-numb_i|\cdot w_i$ 。
对于修改操作,将 $a_i$ 修改为 $x$ 仅对上式中 $numa$ 造成影响:
- 如果 $a_i>x$ ,则 $numa_{x\sim a_i-1}$ 会增加 $1$ 。
- 如果 $a_i<x$ ,则 $numa_{a_i\sim x-1}$ 会减少 $1$ 。
因此需要使用一种数据结构维护 $numa_1$ 的区间加法和 $\sum\limits_{i=1}^n|numa_i-numb_i|\cdot w_i$ 的快速计算。
这里考虑按位置分块,假定块的大小为 $T$ 。
然后维护动态当前 $\sum\limits_{i=1}^n|numa_i-numb_i|\cdot w_i$ 的值 $ans$ 。
每次区间修改:
对于存在于区间中的完整的块 $x$ ,将标记 $add_x$ 加上修改的值。由于事先将块中的 $numa_i-numb_i$ 排序,并按排序后的顺序求出 $(numa_1-numb_i)\cdot w_i$ 和 $w_i$ 的前缀和 $sum_i,prew_i$ ,因此可以先通过二分以 $O(\log T)$ 复杂度快速找出 $suma_i+add_x-sumb_i$ 的正负分界线,从而计算出该块对答案的贡献值,更新 $ans$ 。由于最坏情况需要更新 $O(\frac{n}{T})$ 个块,因此该部分时间复杂度为 $O(\frac{n}{T}\log T)$ 。
对于不完整的块暴力计算更新 $numa_i$ 并更新 $ans$ ,然后在块中按照新的 $numa_i-numb_i$ 排序,从而更新 $sum$ 和 $prew$ 。该部分时间复杂度为 $O(T\log T)$ 。
总时间复杂度为 $O(n\log n+q\log T (\frac{n}{T}+T))$ ,令 $T=\sqrt n$ ,则时间复杂度为 $O(n\log n+q\sqrt n\log \sqrt n)$ 。
#pragma GCC optimize(2)
#include <bits/stdc++.h>
#define get(x) (numa[id[x]]-numb[id[x]])
using namespace std;
typedef long long ll;
const int N = 5e4 + 10, M = 400;
int n, m, q, t, len, now;
int L[M], R[M], add[M], a[N], b[N];
int pos[N * 3], numa[N * 3], numb[N * 3], w[N * 3], id[N * 3];
ll sum[N * 3], prew[N * 3], ans, cur[M];
vector<int> seq;
struct {
int i, x;
} p[N];
inline ll calc(int x) {
if (get(L[x]) + add[x] >= 0)return sum[R[x]] + add[x] * prew[R[x]];
if (get(R[x]) + add[x] < 0) return -(sum[R[x]] + add[x] * prew[R[x]]);
int l = L[x], r = R[x];
while (l < r) {
int mid = (l + r) >> 1;
get(mid) + add[x] >= 0 ? (r = mid) : (l = mid + 1);
}
return sum[R[x]] - (sum[l - 1] << 1) + (prew[R[x]] - (prew[l - 1] << 1)) * add[x];
}
inline void update(int x, int l, int r, int v) {
for (int i = l; i <= r; i++) {
ans -= abs(1ll * (numa[i] - numb[i] + add[x]) * w[i]);
cur[x] -= abs(1ll * (numa[i] - numb[i] + add[x]) * w[i]);
numa[i] += v;
ans += abs(1ll * (numa[i] - numb[i] + add[x]) * w[i]);
cur[x] += abs(1ll * (numa[i] - numb[i] + add[x]) * w[i]);
}
sort(id + L[x], id + R[x] + 1, [&](int i, int j) {
return numa[i] - numb[i] < numa[j] - numb[j];
});
for (int i = L[x]; i <= R[x]; i++) {
sum[i] = (i == L[x] ? 0 : sum[i - 1]) + 1ll * get(i) * w[id[i]];
prew[i] = (i == L[x] ? 0 : prew[i - 1]) + w[id[i]];
}
}
inline void change(int l, int r, int x) {
int bl = pos[l], br = pos[r];
if (bl == br) update(bl, l, r, x);
else {
update(bl, l, R[bl], x), update(br, L[br], r, x);
for (int i = bl + 1; i <= br - 1; i++)
add[i] += x, ans -= cur[i], ans += (cur[i] = calc(i));
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++)scanf("%d", &a[i]), seq.push_back(a[i]);
for (int i = 1; i <= n; i++)scanf("%d", &b[i]), seq.push_back(b[i]);
scanf("%d", &q);
for (int i = 1; i <= q; i++)scanf("%d%d", &p[i].i, &p[i].x), seq.push_back(p[i].x);
sort(seq.begin(), seq.end()), seq.erase(unique(seq.begin(), seq.end()), seq.end());
for (int i = 1; i <= n; i++) {
a[i] = lower_bound(seq.begin(), seq.end(), a[i]) - seq.begin() + 1;
b[i] = lower_bound(seq.begin(), seq.end(), b[i]) - seq.begin() + 1;
}
for (int i = 1; i <= q; i++)
p[i].x = lower_bound(seq.begin(), seq.end(), p[i].x) - seq.begin() + 1;
for (int i = 1; i <= n; i++)numa[a[i]]++, numb[b[i]]++;
m = (int) seq.size(), len = (int) sqrt(m);
for (int i = 1; i <= m; i++) {
if (i != m)w[i] = seq[i] - seq[i - 1];
numa[i] = numa[i - 1] + numa[i];
numb[i] = numb[i - 1] + numb[i];
}
while (now + len <= m)L[++t] = now + 1, R[t] = now + len, now += len;
if (now < m)L[++t] = now + 1, R[t] = m;
iota(id + 1, id + 1 + m, 1);
for (int i = 1; i <= t; i++) {
sort(id + L[i], id + R[i] + 1, [&](int x, int y) {
return numa[x] - numb[x] < numa[y] - numb[y];
});
for (int j = L[i]; j <= R[i]; pos[j++] = i) {
sum[j] = (j == L[i] ? 0 : sum[j - 1]) + 1ll * get(j) * w[id[j]];
prew[j] = (j == L[i] ? 0 : prew[j - 1]) + w[id[j]];
cur[i] += abs(1ll * (numa[j] - numb[j]) * w[j]);
}
ans += cur[i];
}
printf("%lld\n", ans);
for (int i = 1; i <= q; i++) {
if (a[p[i].i] > p[i].x)change(p[i].x, a[p[i].i] - 1, 1);
else change(a[p[i].i], p[i].x - 1, -1);
a[p[i].i] = p[i].x;
printf("%lld\n", ans);
}
return 0;
}