品酒大会问题
解题思路
本题对于每个 $r$,都要求出 $r$ 相似的对数,并且求出所有 $r$ 相似的酒能得到的最大美味值。
可以发现本题的 $r$ 相似,其实就是某两个后缀的最长公共前缀 $\geq r$,因此这启发我们可以用后缀数组来做本题,我们知道跑一遍后缀数组之后,我们就能将所有后缀从小到大排序,并且还有一个 $height$ 数组能知道任意两个排名相邻的后缀的最大公共前缀。从而还能得到任意两个后缀的最大公共前缀,就是这两个后缀之间所有排名相邻的后缀的最大公共前缀的最小值。
然后考虑每个问题应该怎么求,对于一个固定的 $r$,假设某一个 $height_i < r$,首先意味着排名第 $i$ 的后缀和排名第 $i-1$ 的后缀的最长公共前缀 $< r$,它们俩一定不可能 $r$ 相似。同时,通过 $height$ 数组的一些性质,我们知道不可能存在一对 $r$ 相似的酒 $x, y$,使得 $x$ 的排名在 $i$ 前面,$y$ 的排名在 $i$ 后面,因为这样两个后缀的最长公共前缀一定是中间所有 $height$ 的最小值,而此时已经存在 $height_i < r$,所以 $x$ 和 $y$ 的最长公共前缀一定不可能 $\geq r$。只有 $x$ 和 $y$ 同时在 $i$ 的一边,他们的最长公共前缀才有可能 $\geq r$,才会 $r$ 相似。因此我们先找出所有的 $height_i < r$ 的位置,将所有后缀从这些位置断开,则所有后缀会被分成若干个区间,此时每个区间内部的 $height$ 一定 $\geq r$。所以每个区间内部任取两个后缀,他们的最长公共前缀都一定 $\geq r$。假设某一个区间的长度为 $s$,那么这个区间内的对数($r$ 相似的对数)就应该是 $C_s^2$,对于所有区间总计一下对数的总和,就是总共的 $r$ 相似的对数。
然后还要求所有 $r$ 相似的数对中,权值乘积的最大值。由于乘积有正有负,乘积一共有三种可能,一种是正数乘正数,此时要想让乘积最大,肯定是取一个最大值,取一个次大值。一种是负数乘负数,此时要想乘积最大,肯定是取一个最小值,取一个次小值。最后一种是正数乘负数,只有当区间内只有两个数时才会存在这种情况,因为如果区间中有三个数,则一定有正数乘正数或负数乘负数,这两种情况的结果都是正数,都比正数乘负数的结果要更优。而只有两个数时,我们也可以把这种情况归类到取一个最大值,取一个次大值。
因此我们要想统计乘积的最大值,只需要去维护每个区间内的最大值、次大值、最小值、次小值,然后取一个乘积的最大值即可,最后再从所有区间的乘积最大值中统计一个总的最大值,就是所有 $r$ 相似的数对的乘积最大值。
可以发现对于每一个固定的 $r$,本题的两个问题都比较好求,但是本题还需要枚举所有 $r$,都去求一遍这两个问题。这里还要考虑如何去枚举,如果我们从小到大枚举所有的 $r$,那么所有后缀会从一开始的一整个区间到后面被逐渐的分成若干个小块,将一整段拆分成若干小段的话,最大值、次大值、最小值、次小值并不能很好的去进行维护。因此我们可以反过来枚举,从大到小的去枚举 $r$,那么最开始 $r=n$ 时就意味着所有后缀都是自己一个区间,然后随着 $r$ 逐渐变小,会开始出现 $height_i \geq r$,此时就意味着 $i$ 和 $i-1$ 之间的分界线消失了,我们就需要将 $i$ 和 $i-1$ 进行合并,如果是这样一个逐渐合并的过程,那么最大值、次大值、最小值、次小值就比较好去维护了,首先对于每个区间要维护一个长度 $cnt$,但需要维护最大值、次大值、最小值、次小值 $mx1, mx2, mn1, mn2$,这样每两个区间合并时后的新区间的信息也比较好去维护。
然后我们需要用一个方法来动态的维护所有区间之间的合并,这里可以用并查集来维护。
后缀数组是预处理的,因此每一个 $r$ 中都需要进行一个并查集的合并操作,这里不写按秩合并的话并查集的时间复杂度是 $O(logn)$ 的,因此整个的时间复杂度就是 $O(nlogn)$
C++ 代码
#include <iostream>
#include <cstring>
#include <vector>
using namespace std;
typedef long long LL;
typedef pair<LL, LL> PLL;
const int N = 300010;
const LL INF = 2e18;
int n, m;
char s[N]; //字符串
int w[N]; //权值
int sa[N], x[N], y[N], c[N], rk[N], height[N]; //后缀数组
int p[N]; //并查集
int cnt[N]; //每个集合的大小
LL mx1[N], mx2[N], mn1[N], mn2[N]; //最大值、次大值、最小值、次小值
vector<int> hs[N]; //hs[i] 存储所有 height 为 i 的后缀
PLL res[N]; //记录答案
void get_sa() //预处理 sa
{
//以首字母为关键字排序
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[i]]--] = i;
//以第 1 ~ k 个字母为第一关键字,以第 k + 1 ~ 2 * k 个字母为第二关键字排序
for(int k = 1; k <= n; k <<= 1)
{
//排序第二关键字
int num = 0;
for(int i = n - k + 1; i <= n; i++) y[++num] = i;
for(int i = 1; i <= n; i++)
if(sa[i] > k)
y[++num] = sa[i] - k;
//排序第一关键字
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i]]++;
for(int i = 2; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i], y[i] = 0;
//离散化前 2 * k 字母作为下一轮的第一关键字
swap(x, y);
x[sa[1]] = 1, num = 1;
for(int i = 2; i <= n; i++)
x[sa[i]] = (y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) ? num : ++num;
if(num == n) break;
m = num;
}
}
void get_height() //预处理 height
{
for(int i = 1; i <= n; i++) rk[sa[i]] = i;
for(int i = 1, k = 0; i <= n; i++)
{
if(rk[i] == 1) continue;
if(k) k--;
int j = sa[rk[i] - 1];
while(i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
height[rk[i]] = k;
}
}
int find(int x) //返回 x 所在的集合
{
if(p[x] != x) p[x] = find(p[x]);
return p[x];
}
LL get(int x) //C(x, 2)
{
return x * (x - 1ll) / 2;
}
PLL calc(int r)
{
static LL sz = 0, maxv = -INF; //记录当前的区间个数、权值最大值
for(int i = 0; i < hs[r].size(); i++) //枚举所有会消失的分界点
{
int j = hs[r][i]; //接下来需要合并 j 和 j - 1
int a = find(j - 1), b = find(j);
sz -= get(cnt[a]) + get(cnt[b]); //减去两个区间内部的所有数对
p[a] = b; //a 和 b 合并
cnt[b] += cnt[a];
sz += get(cnt[b]); //加上新区见内部的所有数对
if(mx1[a] >= mx1[b]) //更新最大值、次大值
{
mx2[b] = max(mx1[b], mx2[a]);
mx1[b] = mx1[a];
}
else if(mx1[a] > mx2[b]) mx2[b] = mx1[a];
if(mn1[a] <= mn1[b]) //更新最小值、次小值
{
mn2[b] = min(mn1[b], mn2[a]);
mn1[b] = mn1[a];
}
else if(mn1[a] < mn2[b]) mn2[b] = mn1[a];
maxv = max(maxv, max(mx1[b] * mx2[b], mn1[b] * mn2[b]));
}
if(maxv == -INF) return {sz, 0}; //maxv == -INF 说明还没有合并过
return {sz, maxv};
}
int main()
{
scanf("%d", &n), m = 122;
scanf("%s", s + 1);
for(int i = 1; i <= n; i++) scanf("%d", &w[i]);
get_sa(); //预处理 sa
get_height(); //预处理 height
for(int i = 2; i <= n; i++) hs[height[i]].push_back(i); //预处理 hs
//初始化并查集
for(int i = 1; i <= n; i++)
{
p[i] = i, cnt[i] = 1;
mx1[i] = mn1[i] = w[sa[i]];
mx2[i] = -INF, mn2[i] = INF;
}
for(int i = n - 1; i >= 0; i--) res[i] = calc(i); //计算 r = i 时的答案
for(int i = 0; i < n; i++) printf("%lld %lld\n", res[i].first, res[i].second);
return 0;
}