写了两天,共 7k 多字,点个赞不过分,真的。
以下代码纯手打,累死我了,如果有错误一定要告诉我,不然我拿模板去打比赛就 GG 了。
这周的算法学习还不错,按照惯例,先写几道看起来比较常见,但是自己曾经没有很顺利做出来的题目。另外总结一下树状数组的常用方法以及根号分治这种巧妙的思想。
D - Static Sushi
题目链接: D - Static Sushi
题意:
一个圆形寿司转盘,按次序放着寿司,一个人从原点开始可以沿着圆盘走动,给出每个寿司的营养值,这个人每走动一个单位就会消耗一个单位的营养,可以随时停止这个过程,问这个人最多可以得到多少营养。
解题思路:
去年第一次做这道题的时候被搞晕了,因为把问题想得太复杂了,的确,如果把边界情况、细节都要抠得很清楚的话,还是需要一点时间的。今年重做了一遍,几分钟就秒出来了,其实遇到这种问题,考虑能不能把问题简化一下,在一个更大的集合上求答案,保证答案一定在这个集合,而且最优解不会被干扰项破坏即可。
实现过程就是分别在顺时针和逆时针方向上求出每个点到原点的距离 $x$ 和 $rx$ ,然后在每个点上求顺时针最多该点以及逆时针最多走到该点,最多能够得到多少营养。
最后枚举顺时针的终点,用 $mx[i]+rmx[i+1]-x[i]$ 和 $mx[i]+rmx[i+1]-rx[i+1]$ 更新答案即可。
这样做是正确的,就是因为最优解肯定能够在这个过程中被求出来,而且这个集合中出现的不合法干扰项一定不会优于最优解。
代码如下:
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int n;
LL c, v[N];
LL x[N], rx[N]; // 顺时针和逆时针下原点到 i 点的距离
LL s[N], rs[N]; // 顺时针和逆时针下走到 i 点可以得到的营养
LL mx[N], rmx[N]; // 顺时针和逆时针下最多走到 i 点可以得到最大的营养
int main()
{
cin >> n >> c;
for (int i = 1; i <= n; i ++ ) cin >> x[i] >> v[i];
for (int i = 1; i <= n; i ++ ) s[i] = s[i - 1] + v[i];
for (int i = 1; i <= n; i ++ ) mx[i] = max(mx[i - 1], s[i] - x[i]);
for (int i = n; i; i -- ) rx[i] = c - x[i];
for (int i = n; i; i -- ) rs[i] = rs[i + 1] + v[i];
for (int i = n; i; i -- ) rmx[i] = max(rmx[i + 1], rs[i] - rx[i]);
LL res = 0;
for (int i = 0; i <= n; i ++ )
{
res = max(res, mx[i] + rmx[i + 1] - x[i]);
res = max(res, mx[i] + rmx[i + 1] - rx[i + 1]);
}
cout << res << '\n';
return 0;
}
F - Range Set Query
题目链接: F - Range Set Query
题意:
给定一个序列,多次询问某个区间有多少种数。
解题思路:
首先对于每个询问 $[l,r]$ 来说,中间可能会有重复颜色的情况,如果去重呢?可以考虑只考虑最右边的颜色坐标,也就是说对于某种颜色,它最右边的坐标位置为 $1$ 。然后先把所有询问存下来,按照位置次序维护颜色位置,对于当前位置,处理所有以当前位置为右边界的询问,那么对于每次询问,只需要求一下 $[l,r]$ 的区间和就行了,求区间和可以用树状数组来处理。按照这样的顺序处理询问,是不会对右边界在后面的询问产生影响的。
代码如下:
#include <iostream>
#include <algorithm>
using namespace std;
typedef pair<int, int> PII;
const int N = 500010;
int n, q;
int c[N];
int tr[N];
inline int lowbit(int x)
{
return (x & (-x));
}
inline void add(int x, int c)
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
inline int get(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
inline int get(int l, int r)
{
return get(r) - get(l - 1);
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n >> q;
for (int i = 1; i <= n; i ++ ) cin >> c[i];
vector<vector<PII>> query(n + 1);
for (int i = 0; i < q; i ++ )
{
int l, r;
cin >> l >> r;
query[r].emplace_back(l, i);
}
vector<int> res(q), pos(n + 1, -1);
for (int i = 1; i <= n; i ++ )
{
// 如果前面已经出现过 c[i] ,则需要将原来的位置清除
if (pos[c[i]] != -1) add(pos[c[i]], -1);
add(i, 1);
pos[c[i]] = i;
int r = i;
for (auto &[l, id] : query[r]) res[id] = get(l, r);
}
for (auto &u : res) cout << u << '\n';
return 0;
}
C - Multiple Sequences
题目链接: C - Multiple Sequences
题意:
给定 $n$ 和 $m$ ,问有多少个长度为 $n$ 的序列满足:
- 序列所有数都不超过 $m$ ;
- 序列每个数都是它前一个数倍数(假如有的话)。
解题思路:
此题的思路非常精巧,如果是 $O(n^2)$ 的做法,那么非常容易。
我们从倍数这个条件入手,考虑一个合法序列中,最多有多少种数,粗略计算不会超过 $20$ 种,因为 $2^{20}$ 已经超过 $m$ 了。
因此可以求一下以某个数结尾的序列个数,定义 $f[i,j]$ 为以 $i$ 结尾,序列种有 $j$ 种不同数的选数方案,属性是数量。
求出每个 $f[i,j]$ 以后,用组合数来求一下对于每个方案,它一共可以在 $n-1$ 个位置上发生变化,总共需要变化 $j-1$ 次。
代码如下:
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 200010, M = 20, mod = 998244353;
int n, m;
int f[N][M];
int fact[N], infact[N];
int qp(int a, int b)
{
int res = 1;
while (b)
{
if (b & 1) res = (LL)res * a % mod;
a = (LL)a * a % mod;
b >>= 1;
}
return res;
}
void init()
{
fact[0] = infact[0] = 1;
for (int i = 1; i < N; i ++ )
{
fact[i] = (LL)fact[i - 1] * i % mod;
infact[i] = (LL)infact[i - 1] * qp(i, mod - 2) % mod;
}
}
int C(int a, int b)
{
if (a < b) return 0;
return (LL)fact[a] * infact[b] % mod * infact[a - b] % mod;
}
int main()
{
init();
cin >> n >> m;
for (int i = 1; i <= m; i ++ )
{
f[i][1] = 1;
for (int j = 1; j < M - 1; j ++ )
if (f[i][j])
for (int k = i + i; k <= m; k += i)
f[k][j + 1] = (f[k][j + 1] + f[i][j]) % mod;
}
int res = 0;
for (int i = 1; i <= m; i ++ )
for (int j = 1; j < M; j ++ )
res = (res + (LL)f[i][j] * C(n - 1, j - 1) % mod) % mod;
cout << res << '\n';
return 0;
}
另外,在此记录一种初始化阶乘以及阶乘逆元的方法,不用快速幂,而且时间复杂度会优化一些。
void init()
{
fact[0] = infact[0] = 1;
fact[1] = infact[1] = 1;
for (int i = 2; i < N; i ++ )
{
fact[i] = (LL)fact[i - 1] * i % mod;
infact[i] = (LL)(mod - mod / i) * infact[mod % i] % mod;
}
for (int i = 2; i < N; i ++ ) infact[i] = (LL)infact[i] * infact[i - 1] % mod;
}
P1990 覆盖墙壁
题目链接: P1990 覆盖墙壁
题意:
给定两种形状的瓷砖,在 $2\times n$ 大小的地板上铺满,有多少种方案。
解题思路:
这是一道很裸但是模型非常经典的题目,
这类题目一般都可以按照一维表示列,其它维度用二进制表示状态的方式定义 DP 数组。而且比较常见的是 $f[i,j]$ 表示第 $i-1$ 列以及之前都已经放满,第 $i$ 列状态是 $j$ ,第 $i+1$ 列以及后面都为空的所有方案,属性是数量。
这样的定义就使得转移非常简单了,画一下图就能推出来。对于这道题而言,我用 $0$ 表示列为空, $1$ 表示只占用了上面的格子, $2$ 表示只占用了下面的格子, $3$ 表示两个格子都被占用。
写法一代码:
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1000010, mod = 10000;
int n;
int f[N][4];
int main()
{
cin >> n;
f[0][3] = 1;
for (int i = 1; i <= n; i ++ )
{
f[i][0] = (f[i][0] + f[i - 1][3]) % mod;
f[i][1] = (f[i][1] + f[i - 1][0] + f[i - 1][2]) % mod;
f[i][2] = (f[i][2] + f[i - 1][0] + f[i - 1][1]) % mod;
f[i][3] = (f[i][3] + f[i - 1][0] + f[i - 1][1] + f[i - 1][2] + f[i - 1][3]) % mod;
}
cout << f[n][3] << '\n';
return 0;
}
写法二代码:
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1000010, mod = 10000;
int n;
int f[N][4];
int main()
{
cin >> n;
f[0][3] = 1;
for (int i = 0; i < n; i ++ )
{
f[i + 1][0] = (f[i + 1][0] + f[i][3]) % mod;
f[i + 1][1] = (f[i + 1][1] + f[i][0] + f[i][2]) % mod;
f[i + 1][2] = (f[i + 1][2] + f[i][0] + f[i][1]) % mod;
f[i + 1][3] = (f[i + 1][3] + f[i][0] + f[i][1] + f[i][2] + f[i][3]) % mod;
}
cout << f[n][3] << '\n';
return 0;
}
树状数组
这部分是对树状数组的基本、进阶用法以及一些使用技巧做一个总结。树状数组写起来非常简单,但是却常常应用到,可以用来动态维护一个数组的前缀和、求逆序对、求区间最值、求数组第 $k$ 大/小值等。
部分内容参考了 OI WiKi 。
因为树状数组一般都是处理多次询问的问题,输入输出量较大,推荐都通过关闭同步流来加速,必要时把函数都加上 inline
修饰。
一维前缀和:单点修改、区间查询
// tr[] 维护的是 a[] 的前缀和
int n;
LL a[N];
LL tr[N];
#define lowbit(x) (x & -x)
void add(int x, LL c)
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
LL get(int x)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
LL get(int l, int r)
{
return get(r) - get(l - 1);
}
void build(int n)
{
for (int i = 1; i <= n; i ++ ) add(i, a[i]);
}
一维前缀和:区间修改、单点查询
// tr[] 维护的是 a[] 的差分数组的前缀和
// 对 i 点增加 x ,要 add(i, x), add(i+1,-x)
int n;
LL a[N];
LL tr[N];
#define lowbit(x) (x & -x)
void add(int x, LL c)
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
void add(int l, int r, LL c)
{
add(l, c), add(r + 1, -c);
}
LL get(int x)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
// 初始化
void build(int n)
{
for (int i = 1; i <= n; i ++ ) add(i, i, a[i]);
}
一维前缀和:区间修改、区间查询
// tr1[] 维护 a[] 的差分数组 b[] 的前缀和
// tr2[] 维护 b[i] * i 的前缀和
int n;
LL a[N];
LL tr1[N], tr2[N];
#define lowbit(x) (x & -x)
void add(LL tr[], int x, LL c)
{
for (int i = x; i <= n; i += lowbit(i)) tr[i] += c;
}
void add(int l, int r, LL c)
{
add(tr1, l, c), add(tr2, l, l * c);
add(tr1, r + 1, -c), add(tr2, r + 1, (r + 1) * -c);
}
LL get(LL tr[], int x)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i)) res += tr[i];
return res;
}
LL get(int l, int r)
{
return get(tr1, r) * (r + 1) - get(tr2, r) - get(tr1, l - 1) * l + get(tr2, l - 1);
}
// 初始化
void build(int n)
{
for (int i = 1; i <= n; i ++ ) add(i, i, a[i]);
}
二维前缀和:单点修改、区间查询
// 原始数组大小为 n*m
int n, m;
LL a[N][N];
LL tr[N][N];
#define lowbit(x) (x & -x)
void add(int x, int y, LL c)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
tr[i][j] += c;
}
LL get(int x, int y)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
res += tr[i][j];
return res;
}
LL get(int x1, int y1, int x2, int y2)
{
return get(x2, y2) - get(x2, y1 - 1) - get(x1 - 1, y2) + get(x1 - 1, y1 - 1);
}
void build(int n, int m)
{
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= m; j ++ )
add(i, j, a[i][j]);
}
二维前缀和:区间修改、单点查询
// 原始数组大小为 n*m
int n, m;
LL a[N][N];
LL tr[N][N];
#define lowbit(x) (x & -x)
void add(int x, int y, LL c)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
tr[i][j] += c;
}
void add(int x1, int y1, int x2, int y2, LL c)
{
add(x1, y1, c);
add(x1, y2 + 1, -c);
add(x2 + 1, y1, -c);
add(x2 + 1, y2 + 1, c);
}
LL get(int x, int y)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
res += tr[i][j];
return res;
}
void build(int n, int m)
{
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= m; j ++ )
add(i, j, i, j, a[i]);
}
二维前缀和:区间修改、区间查询
// 原始数组大小为 n*m
// 推导的思路和一维相似,但有些复杂
int n, m;
LL a[N][N];
LL tr1[N][N], tr2[N][N], tr3[N][N], tr4[N][N];
#define lowbit(x) (x & -x)
void add(int x, int y, LL c)
{
for (int i = x; i <= n; i += lowbit(i))
for (int j = y; j <= m; j += lowbit(j))
{
tr1[i][j] += c;
tr2[i][j] += c * x;
tr3[i][j] += c * y;
tr4[i][j] += c * x * y;
}
}
void add(int x1, int y1, int x2, int y2, LL c)
{
add(x1, y1, c);
add(x1, y2 + 1, -c);
add(x2 + 1, y1, -c);
add(x2 + 1, y2 + 1, c);
}
LL get(int x, int y)
{
LL res = 0;
for (int i = x; i; i -= lowbit(i))
for (int j = y; j; j -= lowbit(j))
{
res += (x + 1) * (y + 1) * tr1[i][j];
res -= (y + 1) * tr2[i][j];
res -= (x + 1) * tr3[i][j];
res += tr4[i][j];
}
return res;
}
LL get(int x1, int y1, int x2, int y2)
{
return get(x2, y2) - get(x1 - 1, y2) - get(x2, y1 - 1) + get(x1 - 1, y1 - 1);
}
void build(int n, int m)
{
for (int i = 1; i <= n; i ++ )
for (int j = 1; j <= m; j ++ )
add(i, j, i, j, a[i]);
}
$O(n)$ 建树
// 可以在建树阶段稍作优化
void build()
{
for (int i = 1; i <= n; ++i)
{
tr[i] += a[i];
int j = i + lowbit(i);
if (j <= n) tr[j] += tr[i];
}
}
时间戳优化
对于多组测试数据,如果每次都暴力清空树状数组,很可能会超时,这时候可以为树状数组每个位置加入时间戳,这样就不必进行清空操作,只要时间戳和当前不同,就在原来的数组覆盖新的数据,再进行接下来的操作。
int mark[N], tr[N], timestamp;
// 处理每组数据前先更新时间戳
void reset()
{
timestamp ++ ;
}
void add(int x, int c)
{
for (int i = x; i <= n; i += lowbit(i))
{
if (mark[i] != timestamp) tr[i] = 0;
mark[i] = timestamp;
tr[i] += c;
}
}
int get(int x)
{
int res = 0;
for (int i = x; i; i -= lowbit(i))
if (mark[x] == timestamp)
res += tr[x];
return res;
}
一维求区间最值
树状数组维护区间最值用得会少一些,它可以支持 静态区间查询 或者 单点修改+区间查询 ,一般静态可以用 ST 表,动态可以用线段树,但其实也挺好用的,它会比 ST 表更加省空间,每次询问的时间复杂度是 $O(logn)$ ,会比线段树容易写。
这里有几道可以验证代码的模板题:
最大值:
// 最大值模板
// 其中 a[] 是原始数组, n 为原始数组大小
// 且原始数组的下标从 1 开始
// 若不需要动态修改,可以直接 build() + query()
// 若需要动态修改,则用 update() + query()
#define lowbit(x) (x & -x)
void build(int n)
{
for (int i = 1; i <= n; i ++ )
{
tr[i] = a[i];
for (int j = 1; j < lowbit(i); j <<= 1)
tr[i] = max(tr[i], tr[i - j]);
}
}
void update(int x)
{
for (int i = x; i <= n; i += lowbit(i))
{
tr[i] = a[i];
int len = lowbit(i);
for (int j = 1; j < len; j <<= 1)
tr[i] = max(tr[i], tr[i - j]);
}
}
// 求 [l,r] 最大值
int query(int l, int r)
{
int res = -INF;
while (true)
{
res = max(res, a[r]);
if (l == r) break;
for (r -- ; r - lowbit(r) >= l; r -= lowbit(r))
res = max(res, tr[r]);
}
return res;
}
最小值:
// 最小值模板
#define lowbit(x) (x & -x)
void build(int n)
{
for (int i = 1; i <= n; i ++ )
{
tr[i] = a[i];
for (int j = 1; j < lowbit(i); j <<= 1)
tr[i] = min(tr[i], tr[i - j]);
}
}
void update(int x)
{
for (int i = x; i <= n; i += lowbit(i))
{
tr[i] = a[i];
int len = lowbit(i);
for (int j = 1; j < len; j <<= 1)
tr[i] = min(tr[i], tr[i - j]);
}
}
// 求 [l,r] 最小值
int query(int l, int r)
{
int res = INF;
while (true)
{
res = min(res, a[r]);
if (l == r) break;
for (r -- ; r - lowbit(r) >= l; r -= lowbit(r))
res = min(res, tr[r]);
}
return res;
}
求第 $k$ 小值
// 建立权值树状数组,统计每个元素出现的次数
// MAX_VAL 为可能出现的最大值
int find_k_min(int k)
{
int cnt = 0, x = 0;
for (int i = log2(MAX_VAL); i >= 0; i -- )
{
x += (1 << i);
if (x >= MAX_VAL || cnt + tr[x] >= k) x -= (1 << i);
else cnt += tr[x];
}
return x + 1;
}
// 用法(输出第 k 小值)
cout << find_k_min(k) << endl;
求第 $k$ 大值
// 建立权值树状数组,统计每个元素出现的次数
// MAX_VAL 为可能出现的最大值
int find_k_max(int k)
{
// 第 k 大即第 n-k+1 小, n 为序列元素个数
k = n - k + 1;
int cnt = 0, x = 0;
for (int i = log2(MAX_VAL); i >= 0; i -- )
{
x += (1 << i);
if (x >= MAX_VAL || cnt + tr[x] >= k) x -= (1 << i);
else cnt += tr[x];
}
return x + 1;
}
// 用法(输出第 k 大值)
cout << find_k_max(k) << endl;
下标从 $0$ 开始
只需要在处理时先将坐标加 $1$ 。
四元环问题
先看一道题吧,一道思路不难,但实现起来非常麻烦,时间和空间都很紧张,复杂度证明也有一定的难度,因此我认为这道题可以作为根号分治的代表题目之一。
M. Similar Sets
题目链接: M. Similar Sets
题意:
给定 $n$ 个集合,每个集合的元素两两不同,问是否存在两个集合有至少两个元素是相同的。
解题思路:
暴力做肯定不行,好像也想不到什么数据结构能够转化成 $O(nlogn)$ 的复杂度处理,那么就看看 $O(n\sqrt n)$ 是否可行吧。
首先每个元素的范围很大,我们只需要考虑是否有两个元素相同,而不关心它们具体是什么,只要不破坏它们的大小关系以及个数即可,因此先把元素离散化,设离散化之后的最大值为 $mx$ ,元素个数为 $tot$ , $sq=\sqrt {tot}$ 。
然后把所有集合分成两类,一类是大小大于或等于 $sq$ ,一类大小小于 $sq$ ,两个作为最终答案的集合,有可能是 一个大集合与其它任意一个集合 ,也有可能是 两个小集合 ,可以发现这两种情况已经囊括了两个集合形成的所有组合。
首先处理第一种情况,一个大集合与其它任意一个集合是否存在至少两个元素相等,可以发现一共不会超过 $\sqrt n$ 个大集合,对于每个大集合,找其它集合是否和它存在两个元素相等,可以直接暴力 $O(tot)$ 处理,总的复杂度是 $O(tot\sqrt{n})$ ,非常好。
然后看第二种情况,最坏情况下,会有 $\sqrt{tot}$ 个小集合,每个小集合有 $\sqrt{tot}$ 个元素,如果每个集合是有序的,将小集合内的所有元素组成有序对,如果存在两个小集合至少有两个元素相等,意味着会有重复的有序对出现。那么有序对的个数,最多也是 $O(tot\sqrt{tot})$ ,只需要判断里面是否有重复的元素,问题就变得容易起来了。
但是,这道题的时间空间都卡得有点紧,并不适合用排序、哈希表之类的东西来判重(也有可能是我的代码还可以优化,至少这份代码是不支持的),所以还需要用个小技巧来判重。
我们按照有序对的第一个元素来分类,将第一个元素为 $i$ 的有序对的第二个元素加入到 $f[i]$ 中,同时存下这对有序对是来自哪个集合的,用于判重后获取答案。最后枚举每个 $f[i]$ ,每次枚举相当于是在所有第一个元素为 $i$ 的有序对中找,是否存在两个有序对的第二个元素是一样的,具体实现可以参考下面的代码。
代码如下:
#include <map>
#include <cmath>
#include <vector>
#include <numeric>
#include <iostream>
#include <algorithm>
using namespace std;
typedef pair<int, int> PII;
int main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int T;
cin >> T;
while (T -- )
{
int n;
cin >> n;
vector<int> sz(n);
vector<vector<int>> g(n);
// 用 map 做离散化
int mx = 0;
map<int, int> mp;
for (int i = 0; i < n; i ++ )
{
cin >> sz[i];
g[i].resize(sz[i]);
for (auto &u : g[i])
{
cin >> u;
if (!mp.count(u)) mp[u] = mx ++ ;
u = mp[u];
}
sort(g[i].begin(), g[i].end());
}
int tot = accumulate(sz.begin(), sz.end(), 0);
int sq = sqrt(tot);
bool flag = false;
int x = -1, y = -1;
// 大集合与其它集合
for (int i = 0; i < n && !flag; i ++ )
if (sz[i] >= sq)
{
// st[i]=true 表示 i 出现过
vector<bool> st(mx, false);
for (auto &u : g[i]) st[u] = true;
for (int j = 0; j < n && !flag; j ++ )
{
if (i == j) continue;
int cnt = 0;
for (auto &u : g[j])
{
cnt += (st[u] == true);
if (cnt >= 2)
{
x = i + 1, y = j + 1;
flag = true;
break;
}
}
}
}
if (flag)
{
cout << x << ' ' << y << '\n';
continue;
}
// 小集合与小集合
vector<vector<PII>> f(mx);
for (int i = 0; i < n; i ++ )
if (sz[i] < sq)
for (int j = 0; j < sz[i]; j ++ )
for (int k = 0; k < j; k ++ )
f[g[i][j]].emplace_back(g[i][k], i);
// 用于保存来自哪个集合,同时判重
vector<int> pos(mx, -1);
for (int i = 0; i < mx && !flag; i ++ )
{
for (auto &[u, v] : f[i])
// 如果是第一次找到 u ,则存下
if (pos[u] == -1) pos[u] = v;
else
{
flag = true;
x = pos[u], y = v;
break;
}
// 记得这里要清空
for (auto &[u, v] : f[i]) pos[u] = -1;
}
if (flag) cout << x + 1 << ' ' << y + 1 << '\n';
else cout << "-1\n";
}
return 0;
}
根号分治
上面这道题,应该有很大的启发性,利用集合的大小进行分类,将问题分成两个部分来解决,就能把一个 $O(n^2)$ 的问题转化成 $O(n\sqrt n)$ 了,其实很多分治问题都是这个复杂度,它们都可以被称为根号分治,与上面这道题比较有关系的一类问题叫做四元环问题,还有它的简化版,即三元环问题。
参考资料:三、四元环计数
首先看看 三元环计数问题 ,就是说在一个无向图上,统计三个点形成的环有多少个,我们基于原图,重新建图,对于原图中所有的边连接的两个点,度数小的指向度数大的,如果度数相同,则将编号小的指向编号大的。
然后可以发现,如果原图中有环,这个环在新图上一定是这样的: $(a,b),(a,c),(b,c)$ 。可以用反证法来证明,非常简单。然后我们只需要在新图上找这样的环即可,首先枚举 $a$ ,再枚举它的出边,得到 $b$ ,然后枚举 $b$ 的出边,检查这个点是否能够从 $a$ 直接走到,这里可以开一个数组 $from[]$ ,每次枚举到一个 $a$ 时,先把 $a$ 能走到的所有点都设置为 $a$ ,如果 $b$ 能走到的点 $x$ ,是可以来自 $a$ ,那么就找到了这样的一个环了。
分析一下复杂度,枚举 $a$ ,以及它的所有边,复杂度是 $O(n+m)$ 的,如果 $b$ 的出度大于 $\sqrt m$ ,由于它只会连到出度不小于它的点,这样的点最多只有 $\sqrt m$ 个,因此总的复杂度不会超过 $O(m\sqrt m)$ 。
另外还可以顺便统计每个点在多少个三元环中。
三元环计数问题代码如下:
#include <vector>
#include <iostream>
using namespace std;
int main()
{
int n, m;
cin >> n >> m;
vector<vector<int>> g(n);
// 统计度数
vector<int> d(n);
for (int i = 0; i < m; i ++ )
{
int a, b;
cin >> a >> b;
a -- , b -- ;
g[a].emplace_back(b);
g[b].emplace_back(a);
d[a] ++ , d[b] ++ ;
}
vector<vector<int>> tg(n);
// 两个点,度数小的连上度数大的
// 度数相等,则编号小的连上编号大的
for (int i = 0; i < n; i ++ )
for (auto &j : g[i])
if (d[i] < d[j] || d[i] == d[j] && i < j)
tg[i].emplace_back(j);
// 方便编程,交换一下,只用 g[][]
g.swap(tg);
// cnt[i] 表示 i 在多少个三元环里
vector<int> from(n, -1), cnt(n);
int res = 0;
for (int i = 0; i < n; i ++ )
{
for (auto &u : g[i]) from[u] = i;
for (auto &u : g[i])
for (auto &v : g[u])
if (from[v] == i)
{
cnt[i] ++ ;
cnt[u] ++ ;
cnt[v] ++ ;
res ++ ;
}
}
cout << res << '\n';
for (auto &u : cnt) cout << u << ' ';
cout << '\n';
return 0;
}
再来看看四元环问题,问题基本同上,只是要找四个点形成的环。
这里会发现,如果按照三元环的方法建图,成四元环会有两种情况:
- $(a,b),(b,c),(c,d),(a,d)$
- $(a,b),(b,c),(a,d),(d,c)$
那么就要转换一下枚举的方式了,会发现,这两种情况都可以转化为 一条原图的边+一条新图的边形成的环 ,也就是:
- $(a,b),(b,c),(c,d),(a,d)$ 变成 $(a,b)+*a,d)$ 以及 $(b,c)+(c,d)$ ,其中加号左边为原图的无向边,加号右边为新图的有向边;
- $(a,b),(b,c),(a,d),(d,c)$ 变成 $(a,b)+(b,c)$ 以及 $(a,d)+(d,c)$ ,其中加号左边为原图的无向边,加号右边为新图的有向边;
首先用一个数组 $w[]$ 表示从 $a$ 点出发,得到的 $(a,b)+(b,c)$ 的数量,其中 $(a,b)$ 为无向边, $(b,c)$ 为有向边,在 原图 上枚举 $a$ ,枚举所有可以到达的点 $b$ ,再在 新图 从 $b$ 枚举 $c$ ,累加 $cnt[c]$ 即可。
时间复杂度分析和三元环是完全一样的。
也可以像三元环一样,统计每个点所在的四元环个数。
四元环计数问题代码如下:
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
int main()
{
int n, m;
cin >> n >> m;
vector<vector<int>> g(n);
vector<int> d(n);
// 建立原图
for (int i = 0; i < m; i ++ )
{
int a, b;
cin >> a >> b;
a -- , b -- ;
g[a].emplace_back(b);
g[b].emplace_back(a);
d[a] ++ , d[b] ++ ;
}
// 在自定的规则下排序
vector<int> order(n);
for (int i = 0; i < n; i ++ ) order[i] = i;
sort(order.begin(), order.end(), [&](int a, int b)
{
return d[a] < d[b] || d[a] == d[b] && a < b;
});
// 得到每个点的排名
vector<int> rank(n);
for (int i = 0; i < n; i ++ ) rank[order[i]] = i;
// 建立新图
vector<vector<int>> f(n);
for (int i = 0; i < n; i ++ )
for (auto &u : g[i])
if (rank[i] < rank[u])
f[i].emplace_back(u);
int res = 0;
// cnt[i] 表示在多少个四元环里
vector<int> w(n), cnt(n);
for (int i = 0; i < n; i ++ )
{
for (auto &u : g[i])
for (auto &v : f[u])
if (rank[i] < rank[v])
{
cnt[i] += w[v];
cnt[u] += w[v];
cnt[v] += w[v];
res += w[v];
w[v] ++ ;
}
// 要清空 w[]
for (auto &u : g[i])
for (auto &v : f[u])
if (rank[i] < rank[v])
{
w[v] -- ;
cnt[u] += w[v];
}
}
cout << res << '\n';
for (auto &u : cnt)
cout << u << ' ';
cout << '\n';
return 0;
}
用四元环的方法解决一开始的题目
思路是将集合本身也看作一个点,有 $n$ 个集合, $m$ 个元素,那么图中的点就有 $ps=n+m$ 个,简图方式是将集合与它的包含的所有元素连边,如果两个集合中有至少两个元素相同,那么就会形成一个四元环。
其它方面和上面的求四元环差不多,但最后找四元环的时候要注意,四个点哪两个才是集合编号。
代码如下:
#include <map>
#include <vector>
#include <iostream>
#include <algorithm>
using namespace std;
int main()
{
int T;
cin >> T;
while (T -- )
{
int n;
cin >> n;
vector<int> sz(n);
vector<vector<int>> a(n);
// ps 为图中点的个数
int ps = n;
map<int, int> mp;
for (int i = 0; i < n; i ++ )
{
cin >> sz[i];
a[i].resize(sz[i]);
for (auto &u : a[i])
{
cin >> u;
if (!mp.count(u)) mp[u] = ps ++ ;
u = mp[u];
}
}
vector<int> d(ps);
vector<vector<int>> g(ps);
for (int i = 0; i < n; i ++ )
for (auto &u : a[i])
{
g[i].emplace_back(u);
g[u].emplace_back(i);
d[i] ++ ;
d[u] ++ ;
}
vector<int> order(ps);
for (int i = 0; i < ps; i ++ ) order[i] = i;
sort(order.begin(), order.end(), [&](int a, int b)
{
return d[a] < d[b] || d[a] == d[b] && a < b;
});
vector<int> rank(ps);
for (int i = 0; i < ps; i ++ ) rank[order[i]] = i;
vector<vector<int>> f(ps);
for (int i = 0; i < ps; i ++ )
for (auto &u : g[i])
if (rank[i] < rank[u])
f[i].emplace_back(u);
int x = -1, y = -1;
vector<int> w(ps), p(ps, -1);
for (int i = 0; i < ps; i ++ )
{
for (auto &u : g[i])
{
for (auto &v : f[u])
if (rank[i] < rank[v])
{
if (p[v] != -1 && x == -1)
{
vector<int> t{i, p[v], u, v};
for (auto &k : t)
if (k < n)
{
if (x == -1) x = k;
else y = k;
}
}
p[v] = u;
}
}
for (auto &u : g[i])
for (auto &v : f[u])
if (rank[i] < rank[v])
{
w[v] -- ;
p[v] = -1;
}
}
if (x != -1 && y != -1) cout << x + 1 << ' ' << y + 1 << '\n';
else cout << "-1\n";
}
return 0;
}