在这里总结一下自己的模板,hah~
这里总结的是自己容易理解的、写起来顺手的模板,不一定是大众模板。其实,不管是什么板子,只要看到一道题就能确定它属于哪个板子,然后能快速写出来即可。
快速排序
这个模板需要注意的是:
- 在取 $x$ 的时候,我这里使用了 $x = a[l + ((r - l) >> 1)]$ 的方式,这样可以避免因 $l + r$ 造成的溢出现象;
- 此外,$quick\_sort$ 函数中最外层循环,需要判断的是 $i$ 与 $j$ 的大小,而不是 $l$ 和 $r$ 的大小;
- 该模板属于
二路快排
,即将小于等于 $x$ 的放在数组的左边,将大于 $x$ 放在数组的右边; - 而对于
三路快排
来说,它实现的结果是将小于 $x$ 的放在数组的左边,将大于 $x$ 放在数组的右边,同时将等于 $x$ 的放在数组的中间。
#include <iostream>
using namespace std;
const int N = 1000010;
int a[N];
int n;
void quick_sort(int a[], int l, int r) {
if (l >= r) return;
int x = a[l + ((r - l) >> 1)], i = l - 1, j = r + 1;
while (i < j) {
do i++; while(a[i] < x);
do j--; while(a[j] > x);
if (i < j) swap(a[i], a[j]);
}
quick_sort(a, l, j);
quick_sort(a, j + 1, r);
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
quick_sort(a, 0, n - 1);
for (int i = 0; i < n; ++i) printf("%d ", a[i]);
return 0;
}
归并排序
该模板需要注意的是:
- 由于归并排序的特性,需要开一个新的数组,用来存放归并操作后的元素;
- $merge\_sort$ 中的 $l$ 和 $r$ 表示当前区间的左右边界,左右全闭;
- 主要难点在于归并的过程,格外注意 $i$ 与 $j$ 的初始值。
#include <iostream>
using namespace std;
const int N = 1000010;
int n;
int a[N], temp[N];
void merge_sort(int a[], int l, int r) {
if (l >= r) return;
int mid = l + ((r - l) >> 1);
merge_sort(a, l, mid);
merge_sort(a, mid + 1, r);
int k = 0, i = l, j = mid + 1;
while (i <= mid && j <= r) {
if (a[i] <= a[j]) temp[k++] = a[i++];
else temp[k++] = a[j++];
}
while (i <= mid) temp[k++] = a[i++];
while (j <= r) temp[k++] = a[j++];
for (i = l, j = 0; i <= r; i++, j++) a[i] = temp[j];
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
merge_sort(a, 0, n - 1);
for (int i = 0; i < n; ++i) printf("%d ", a[i]);
return 0;
}
堆排序
堆排序分为入堆 $heapInsert$ 以及堆化处理 $heapify$ 两部分,下面是自己总结的模板。
对于代码的解释,在我的 $Github$ 中的 $repo$ 里有。
class Solution {
public int[] sortArray(int[] nums) {
if (nums == null || nums.length == 0) return new int[0];
for (int i = 0; i < nums.length; i++) {
heapInsert(nums, i);
}
int heapSize = nums.length;
swap(nums, 0, --heapSize);
while (heapSize > 0) {
heapify(nums, 0, heapSize);
swap(nums, 0, --heapSize);
}
return nums;
}
private void heapify(int[] nums, int i, int heapSize) {
int left = 2 * i + 1;
while (left < heapSize) {
int largest = left + 1 < heapSize && nums[left + 1] > nums[left] ? left + 1 : left;
largest = nums[i] > nums[largest] ? i : largest;
if (i == largest) break;
swap(nums, i, largest);
i = largest;
left = 2 * i + 1;
}
}
private void heapInsert(int[] nums, int i) {
while (nums[i] > nums[(i - 1) / 2]) {
swap(nums, i, (i - 1) / 2);
i = (i - 1) / 2;
}
}
private void swap(int[] nums, int i, int j) {
int t = nums[i];
nums[i] = nums[j];
nums[j] = t;
}
}
二分
整数二分
该模板有两种情况,但是都基于一个前提:如果给定的序列有序(有单调性),那么一定可以二分;而可以进行二分的题目不一定有序。
- 首先需要根据 $mid$ 的值,来判断具体属于哪种情况;
- $mid = (l + r +1) / 2$,然后去 $check(mid)$:
$\qquad$2.1 如果返回 $true$,那么说明 $mid$ 在 $[mid, r]$ 范围内。因此,需要将 $l$ 更新到 $mid$ 的位置,即 $l = mid$;
$\qquad$2.2 如果返回 $false$,那么说明 $mid$ 在 $[l, mid - 1]$ 范围内。因此,需要将 $r$ 更新到 $mid - 1$ 的位置,即 $r = mid - 1$。 - 计算 $mid = (l + r) / 2$,然后去 $check(mid)$:
$\qquad$3.1 如果返回 $true$,那么说明 $mid$ 在 $[l, mid]$ 范围内。因此,将 $r$ 更新成 $mid$;
$\qquad$3.2 如果返回 $false$,那么说明 $mid$ 在 $[mid + 1, r]$ 范围内。因此,将 $l$ 更新成 $mid +1$。 - 下面通过一个例子,来解释以上的用法。
#include <iostream>
using namespace std;
const int N = 100010;
int n, q;
int a[N];
int main() {
scanf("%d %d", &n, &q);
for (int i = 0; i < n; ++i) scanf("%d", &a[i]);
while (q--) {
int k;
scanf("%d", &k);
// 查找 k 的起始位置
int l = 0, r = n - 1;
while (l < r) {
// + 的优先级大于 >> 的优先级
int mid = l + r >> 1;
if (a[mid] >= k) r = mid;
else l = mid + 1;
}
// 经过 while 操作后,如果 a[] 中没有 k,那么最终 l 或者 r 会来到从左往右第一个大于等于 k 的位置
// 这里使用 a[l] 或 a[r] 都是一样的
if (a[l] != k) cout << "-1 -1" << endl;
else {
cout << l << " ";
// 否则,开始查找 k 的终止位置
int l = 0, r = n - 1;
while (l < r) {
int mid = l + r + 1 >> 1;
if (a[mid] <= k) l = mid;
else r = mid - 1;
}
cout << l << endl;
}
}
return 0;
}
浮点数二分
对于浮点数二分,其实和整数二分是一样的,只是在 $if…else$ 中没有加 $1$ 或减 $1$ 的操作。
但需要注意的是,当我们求得右边界 $r$ 与左边界 $l$ 的差小于一个很小的数的时候,其实就可以停止了。
下面以求解 $\sqrt{x}$ 为例:
#include <iostream>
using namespace std;
int main() {
double x;
scanf("%lf", &x);
double l = 0, r = x;
while (r - l > 1e-8) {
double mid = (l + r) / 2;
if (mid * mid >= x) r = mid;
else l = mid;
}
printf("%lf", l);
return 0;
}
高精度问题
有以下几种情况:
- $A + B$,$A$ 或 $B$ 的位数在 $10^6$ 左右;
- $A - B$,$A$ 或 $B$ 的位数在 $10^6$ 左右;
- $A * a$,$A$ 的位数在 $10^6$ 左右,$a$ 通常是一个小于等于 $10000$ 的数;
- $A \div a$,同上;
- 其它情况使用的比较少。
大整数如何存储?存在数组中,索引为 $0$ 的位置存储大整数的个位,索引为 $1$ 的位置存储大整数的十位,以此类推。这样做的好处是:如果存在进位的话,可以直接在数组中的最后一个位置添加一位即可。
高精度加法
在 $if(carry)$ 中,为什么要向 $C$ 中 $push\_back(1)$ 而不是 $push\_back(carry)$ 呢?这是因为两个位数等于 $1$ 的两个数,它们相加之后得到的进位最多就是 $1$。其实,这里直接 $push\_back(carry)$ 也可以。
#include <iostream>
#include <vector>
using namespace std;
const int N = 1e5 + 10;
// C = A + B
vector<int> add(vector<int>& A, vector<int>& B) {
vector<int> C;
int carry = 0;
for (int i = 0; i < A.size() || i < B.size(); ++i) {
if (i < A.size()) carry += A[i];
if (i < B.size()) carry += B[i];
// 此时,carry 就等于 A 和 B 中的分别的单独一位再加上 carry
C.push_back(carry % 10);
// 进位
carry /= 10;
}
if (carry) C.push_back(1);
return C;
}
int main() {
// a = "123456789"
string a, b;
vector<int> A, B;
cin >> a >> b;
// A = [9, 8, 7, 6, 5, 4, 3, 2, 1]
for (int i = a.size() - 1; i >= 0; --i) A.push_back(a[i] - '0'); // 这里将字符转换成数字
for (int i = b.size() - 1; i >= 0; --i) B.push_back(b[i] - '0');
auto C = add(A, B);
for (int i = C.size() - 1; i >= 0; --i) printf("%d", C[i]);
return 0;
}
高精度加法
思路是:
- 用 $A$ 和 $B$ 中较大的数减去较小的数,否则将结果添加一个负号;
- 也就是在计算 $A - B$ 时,如果 $A>=B$,那么就计算 $A - B$;
- 如果如果 $A < B$,那么就计算 $-(B - A)$。
#include <iostream>
#include <vector>
using namespace std;
bool cmp(vector<int>& A, vector<int>& B) {
if (A.size() != B.size()) return A.size() > B.size();
for (int i = A.size() - 1; i >= 0; --i) {
if (A[i] != B[i]) return A[i] > B[i];
}
return true;
}
// 已经确保 A 大于等于 B
vector<int> sub(vector<int>& A, vector<int>& B) {
vector<int> C;
int t = 0;
for (int i = 0; i < A.size(); ++i) {
t = A[i] - t;
if (i < B.size()) t -= B[i];
C.push_back((t + 10) % 10);
if (t < 0) t = 1;
else t = 0;
}
// 去掉前导 0,如果 A-B=0,那么单独的这一个 0 就不需要去掉
while (C.size() > 1 && C.back() == 0) C.pop_back();
return C;
}
int main() {
string a, b;
cin >> a >> b;
vector<int> A, B;
for (int i = a.size() - 1; i >= 0; --i) A.push_back(a[i] - '0');
for (int i = b.size() - 1; i >= 0; --i) B.push_back(b[i] - '0');
if (cmp(A, B)) {
auto C = sub(A, B);
for (int i = C.size() - 1; i >= 0; --i) printf("%d", C[i]);
} else {
auto C = sub(B, A);
printf("-");
for (int i = C.size() - 1; i >= 0; --i) printf("%d", C[i]);
}
return 0;
}
前缀和
一维前缀和
需要注意的是,原数组在读入的时候,是从下标 $1$ 开始的,那么前缀和数组也是从 $1$ 开始的。
#include <iostream>
using namespace std;
const int N = 100010;
int n, m;
int a[N], s[N];
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
while (m--) {
int l, r;
scanf("%d %d", &l, &r);
printf("%d\n", s[r] - s[l - 1]);
}
return 0;
}
二维前缀和
计算方式如下:
- https://i.loli.net/2020/09/10/B2gJikVIhG5aTQ8.png
- 当前 $(i,j)$ 点的前缀和等于 $S_{i, j}=S_{i-1, j}+S_{i, j-1}-S_{i-1, j-1}+a_{i, j}$;
- 求指定范围内的前缀和,例如,求解 $(x_{1}, y_{1})$ 到 $(x_{2}, y_{2})$ 范围内的前缀和,即 $S_{x_{2}, y_{2}}-S_{x_{2}, y_{1} - 1}-S_{x_{1} - 1, y_{2}}+S_{x_{1} - 1, y_{1} - 1}$。
#include <iostream>
using namespace std;
const int N = 1010;
int n, m, q;
int a[N][N], s[N][N];
int main() {
scanf("%d %d %d", &n, &m, &q);
for(int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j)
scanf("%d", &a[i][j]);
}
for (int i = 1; i <= n; ++i) {
for (int j = 1; j <= m; ++j)
s[i][j] = s[i - 1][j] + s[i][j - 1] - s[i - 1][j - 1] + a[i][j];
}
while (q--) {
int x1, y1, x2, y2;
scanf("%d %d %d %d", &x1, &y1, &x2, &y2);
printf("%d\n", s[x2][y2] - s[x2][y1 - 1] - s[x1 - 1][y2] + s[x1 - 1][y1 - 1]);
}
return 0;
}
差分
一维差分
思路:
- 给定数组 $a_{1}、a_{2}、a_{3}、…、a_{n}$,目的是构造 $b_{1}、b_{2}、b_{3}、…、b_{n}$;
- 满足 $a_{i} = b_{1}+b_{2}+b_{3}+…+b_{i}$;
- 构造方式:$b_{1}=a_{1}$、$b_{2}=a_{2}-a_{1}$、…、$b_{n}=a_{n}-a_{n-1}$;
- 作用:在数组 $A$ 中,如果想让区间 $[l, r]$ 内每个元素都加上一个值 $c$,那么就可以使用差分来做;
- 即返回 $b_{l} + c$ 与 $b_{r+1} - c$ 即可,因为 $b_{l} + c$ 表示将 $l$ 及其后面所有的数都加上 $c$,$b_{r+1} - c$ 表示将 $r+1$ 及其后面的所有数都减去 $c$;
- 所以,从 $r+1$ 开始,后面的计算都抵消了。
#include <iostream>
using namespace std;
const int N = 100010;
int n, m;
int a[N], b[N];
void insert(int l, int r, int c) {
b[l] += c;
b[r + 1] -= c;
}
int main() {
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
for (int i = 1; i <= n; ++i) insert(i, i, a[i]);
while (m--) {
int l, r, c;
scanf("%d %d %d", &l, &r, &c);
insert(l, r, c);
}
// 求前缀和
for (int i = 1; i <= n; ++i) b[i] += b[i - 1];
for (int i = 1; i <= n; ++i) printf("%d ", b[i]);
return 0;
}
/*
// 在最后的时候,也可以写成下面的代码
for (int i = 1; i <= n; ++i) a[i] = a[i - 1] + b[i];;
for (int i = 1; i <= n; ++i) printf("%d ", a[i]);
*/
二维差分
TODO
双指针
如果想要实现一个功能:在一句话中,将英文单词给分隔出来,那么就可以使用双指针的方式。
#include <iostream>
#include <string.h>
using namespace std;
int main() {
char str[1000];
gets(str);
int n = strlen(str);
for (int i = 0; i < n; ++i) {
int j = i;
while (j < n && str[j] != ' ') j++;
for (int k = i; k < j; ++k) cout << str[k];
cout << endl;
i = j;
}
return 0;
}
AcWing 799 例题,这道题最好手动模拟一下,便于理解。
#include <iostream>
using namespace std;
const int N = 100010;
int a[N],s[N];
int main() {
int n;
cin >> n;
for (int i = 0; i < n; ++i) cin >> a[i];
int ans = 0;
for (int i = 0, j = 0; i < n; ++i) {
s[a[i]]++;
while (s[a[i]] > 1) {
s[a[j]]--;
j++;
}
ans = max(ans, i - j + 1);
}
cout << ans;
return 0;
}
位运算
如下:
- 求 $n$ 的二进制表示中,第 $k$ 位的数字是多少:$n >> k \& 1$。也就是先将 $n$ 往右移动 $k$ 位,然后再与 $1$ 进行位运算;
- 返回 $n$ 的最后一位 $1$:$lowbit(n) = n \& -n$。例如,给定 $n = 100010010$,那么经过 $lowbit(n)$ 运算后,得到的结果是 $10$。
数组
在声明数组的时候,如果将数组声明为全局变量,那么数组内的值会被初始化为 $0$;而如果将数组声明为局部变量,那么数组内的值会被初始化为随机值。
链表和邻接表
在描述链表的时候,可以使用结构体声明链表,还可以使用数组来模拟链表,具体的实现方式为:
- 在使用数组模拟单链表时,$e[i]$ 表示当前节点的值,而 $ne[i]$ 表示当前节点所指向的下一个节点的索引值;
- 对于最后一个 $null$ 节点,用 $-1$ 表示,$head$ 表示头节点的索引。
单向链表
下面包括初始化单链表、在头节点后插入、在指定位置节点后插入、删除节点等几个方法。
#include <iostream>
using namespace std;
const int N = 100010;
int head, e[N], ne[N], idx;
void init() {
head = -1;
idx = 0;
}
void add_to_head(int x) {
e[idx] = x;
ne[idx] = head;
head = idx++;
}
// 将新节点 x 插入到索引为 k 节点的后面
void add(int k, int x) {
e[idx] = x;
ne[idx] = ne[k];
ne[k] = idx++;
}
// 删除索引为 k 节点的后一个节点
void remove(int k) {
ne[k] = ne[ne[k]];
}
int main() {
int m;
cin >> m;
init();
while (m--) {
int k, x;
char op;
cin >> op;
if (op == 'H') {
cin >> x;
add_to_head(x);
} else if (op == 'D') {
cin >> k;
// 如果 k 为 0 的话,表示删除头节点的下一个节点
if (!k) head = ne[head];
else remove(k - 1);
} else {
cin >> k >> x;
add(k - 1, x);
}
}
for (int i = head; i != -1; i = ne[i]) cout << e[i] << ' ';
cout << endl;
return 0;
}
双向链表
这里在实现双向链表的时候,使用 $0$ 表示双向链表中最开始的节点的索引,使用 $1$ 表示双向链表中结束的节点的索引。
#include <iostream>
using namespace std;
const int N = 100010;
int e[N], l[N], r[N], idx;
void init() {
r[0] = 1;
l[1] = 0;
idx = 2;
}
// 在 k 位置的右侧插入一个节点 x,格外注意步骤
// 如果要在 k 位置的左侧插入,则直接 add(l[k], x) 即可
void insert(int k, int x) {
e[idx] = x;
r[idx] = r[k];
l[idx] = k;
// 注意:下面这两步不能交换
l[r[k]] = idx;
r[k] = idx++;
}
// 删除第 k 个节点
void remove(int k) {
r[l[k]] = r[k];
l[r[k]] = l[k];
}
int main() {
int m;
cin >> m;
init();
while (m--) {
int k, x;
string op;
cin >> op;
if (op == "L") {
cin >> x;
insert(0, x);
} else if (op == "R") {
cin >> x;
insert(l[1], x);
} else if (op == "D") {
cin >> k;
remove(k + 1);
} else if (op == "IL") {
cin >> k >> x;
insert(l[k + 1], x);
} else {
cin >> k >> x;
insert(k + 1, x);
}
}
for (int i = r[0]; i != 1; i = r[i]) cout << e[i] << ' ';
cout << endl;
return 0;
}
栈
下面介绍了使用数组来模拟栈的过程:
- 定义栈以及栈顶:$stk[N]、 tt$;
- 入栈:$stk[++tt] = x$;
- 出栈:$tt-\-$;
- 判断栈是否为空:如果 $tt > 0$,则栈不为空,否则为空;
- 获取栈顶元素:$stk[tt]$。
队列
使用数组来模拟队列,在队尾入队,在队头出队:
- 定义队列、队头、队尾:$q[N]、hh、tt = -1$;
- 入队:$q[\+\+tt] = x$;
- 出队:$hh++$;
- 判断队列是否为空:如果 $hh <= tt$,则不为空,否则为空;
- 获取队头元素:$q[hh]$;
- 获取队尾元素:$q[tt]$。
单调栈
以下是单调栈的思路与应用。
- 这里使用数组来模拟单调栈;
- 当栈不为空并且栈顶元素大于当前遍历的元素时,则栈顶元素出栈;
- 如果栈不为空,那么此时就可以收集答案了;
- 而如果栈为空,那么输出 $-1$;
- 最后记得将当前元素加入到栈中。
#include <iostream>
using namespace std;
const int N = 100010;
// stk 表示单调栈,tt 表示指向单调栈的栈顶元素
int stk[N], tt;
int n;
int main() {
cin >> n;
for (int i = 0; i < n; i++) {
int x;
cin >> x;
// 单调栈中有元素,并且栈顶元素大于等于 x,则栈顶元素出栈
while (tt && stk[tt] >= x) tt--;
// 如果单调栈中有元素的话,那么说明栈顶元素就是左边第一个小于 x 的元素
if (tt) cout << stk[tt] << ' ';
else cout << -1 << ' ';
stk[++tt] = x;
}
return 0;
}
单调队列
典型的应用就是:求解滑动窗口里面的最大值或者最小值。当然,关于该题的详细解释,可以看之前我在 $LeetCode$ 上分享的一个题解。
#include <iostream>
using namespace std;
const int N = 1000010;
int n, k;
int a[N], q[N], tt;
int main() {
scanf("%d%d", &n, &k);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
// 使用单调队列
int hh = 0, tt = -1;
for (int i = 0; i < n; i++) {
// 判断队头元素是否已经滑出窗口
if (hh <= tt && i - k + 1 > q[hh]) hh++;
while (hh <= tt && a[q[tt]] >= a[i]) tt--;
q[++tt] = i;
// 判断滑动窗口是否形成
if (i >= k - 1) printf("%d ", a[q[hh]]);
}
puts("");
hh = 0, tt = -1;
for (int i = 0; i < n; i++) {
// 判断队头元素是否已经滑出窗口
if (hh <= tt && i - k + 1 > q[hh]) hh++;
while (hh <= tt && a[q[tt]] <= a[i]) tt--;
q[++tt] = i;
// 判断滑动窗口是否形成
if (i >= k - 1) printf("%d ", a[q[hh]]);
}
return 0;
}
KMP
整体分为两部分,一部分是两个串进行匹配的过程,另一部分是求解 $next$ 数组的过程。
#include <iostream>
using namespace std;
const int N = 100010, M = 1000010;
int n, m;
// next 数组要的长度要与模板串的长度相同
int ne[N];
// 字符串为 s,模板串为 p
char s[M], p[N];
int main() {
cin >> n >> p + 1 >> m >> s + 1;
// 根据模板串 p 来初始化 next 数组
for (int i = 2, j = 0; i <= n; i++) {
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
// 匹配的过程,i 遍历字符串所有的字符
for (int i = 1, j = 0; i <= m; i++) {
// 如果字符串 s 在 i 位置上的字符与模板串 p 在 j+1 上的字符不相等,
// 那么 j 就需要来到 ne[j] 的位置
while (j && s[i] != p[j + 1]) j = ne[j];
// 如果 i 所对应字符与 j+1 所对应的字符相等,那么 j 就来到下一个位置
if (s[i] == p[j + 1]) j++;
// j 来到模式串的最后,也就是 n,则说明匹配成功
if (j == n) {
// 输出模式串在字符串中的起始位置
printf("%d ", i - n);
j = ne[j];
}
}
return 0;
}
Trie
前缀树(字典树)可以高效的存储或查找字符串。
- 这里涉及到两个操作,一个操作是将一个字符串插入到 $Trie$ 中,另一个操作是查询一个字符串在 $Trie$ 中出现的次数;
- 注意到,以上两个操作在代码实现层面是非常相似的;
- 注意,此处的 $idx$ 的作用与单链表中的 $idx$ 的作用是相同的。
#include <iostream>
using namespace std;
const int N = 100010;
// son 表示每个节点最多可以向外扩展 26 个小写英文字母,也就是每个节点的所有孩子
// cnt 表示以当前节点结尾的单词的个数,
// 对于下标为 0 的节点来说,它即使根节点,又是空节点,也就是不存储数据
int son[N][26], cnt[N], idx;
char str[N];
// 插入操作
void insert(char str[]) {
int p = 0;
// str[] 是以 \0 结尾的,所以可以直接用 str[i] 来作为判断条件
for (int i = 0; str[i]; i++) {
// 将 'a'~'z' 映射为 0~25
int u = str[i] - 'a';
// 在插入的时候,如果 u 不存在的话,那么就直接创建出来
if (!son[p][u]) son[p][u] = ++idx;
// 如果已存在,则走到下一个节点
// son[p][u] 表示节点 p 的第 u 个孩子
p = son[p][u];
}
// for 循环结束的时候,p 就来到了 str 这个单词的末尾,然后让 cnt 加 1,
// 表示以 p 点结尾的单词的数量多了一个
cnt[p]++;
}
// 查询操作,返回字符串 str 一共出现了多少次
int query(char str[]) {
int p = 0;
for (int i = 0; str[i]; i++) {
int u = str[i] - 'a';
// 如果树中不存在这个节点,那么直接返回 0
if (!son[p][u]) return 0;
// 否则,接着往下走
p = son[p][u];
}
return cnt[p];
}
int main() {
int n;
scanf("%d", &n);
while (n--) {
char op[2];
scanf("%s%s", op, str);
if (op[0] == 'I') insert(str);
else printf("%d\n", query(str));
}
return 0;
}
并查集
并查集的两个作用:一个是将给定的两个集合进行合并;另一个是询问两个元素是否在同一个集合中。
实现的原理:
- 每个集合用树来表示,树根的编号就是整个集合的编号;
- 对于集合中的每个节点 $x$ 而言,它存储的是当前节点的父节点,即 $p[x]$ 表示节点 $x$ 的父节点。
如何判断树根?如果 p[x] == x
,则说明当前是树根。
如何求 $x$ 所在集合的编号?其实,查找 $x$ 所在集合编号的过程,就是找 $x$ 所在树根的过程,即当 while(p[x] != x)
时,执行 x = p[x]
,让 $x$ 一直往上走,等到 $while$ 执行完以后,$x$ 就来到了树根的位置。根据这种方式,我们就可以判断两个节点是否在同一集合中,即分别求两个节点所在的集合编号,然后判断编号是否相等即可。
如何合并两个集合?直接将其中一个集合整体地接在另一个集合的树根下面即可。例如 $p[x]$ 是 $x$ 的集合编号,$p[y]$ 是 $y$ 的集合编号,合并操作可以是 p[x] = y
。也就是直接把 $x$ 所在集合直接插到 $y$ 所在集合中。
可以优化的地方:在求 $x$ 所在集合编号的时候,每次都是从 $x$ 开始一直遍历到根节点。可以看到,时间复杂度与树的高度有关。因此,这里所优化的点在于:路径压缩,即只要 $x$ 找到了根节点,那么当前节点 $x$ 以及沿途的所有节点都直接修改为与根节点相连。
#include <iostream>
using namespace std;
const int N = 100010;
// p[x] 表示当前节点 x 的父节点,
// 初始化时,每个节点单独成为一个集合,这个节点的树根就是自己
int p[N];
int n, m;
// 返回 x 所在集合的编号,同时进行路径压缩
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main() {
scanf("%d%d", &n, &m);
// 初始化
for (int i = 1; i <= n; i++) p[i] = i;
while (m--) {
char op[2];
int a, b;
// scanf 在读入字符串的时候会自动忽略空格和回车
// 而在读入字符的时候,会读入空格和回车
scanf("%s%d%d", op, &a, &b);
// 对于合并操作,让 a 所在集合的的父节点指向 b 所在集合的编号
if (op[0] == 'M') p[find(a)] = find(b);
else {
if (find(a) == find(b)) printf("%s\n", "Yes");
else printf("%s\n", "No");
}
}
return 0;
}
以上是最基本的并查集的使用,此外,还可以额外记录一些信息,例如统计某个集合中节点的数量。但需要注意的是:仅保证每个集合中根节点的 $size$ 是有意义的。因此,在合并两个集合时,直接 size[b] += size[a]
即可。
之所以使用 $sizee$,是因为如果使用 $size$ 的话,会编译出错,产生歧义。
#include <iostream>
using namespace std;
const int N = 100010;
// size 表示当前节点所在集合的数量
int p[N], sizee[N];
int n, m;
// 返回 x 所在集合的编号,同时进行路径压缩
int find(int x) {
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
p[i] = i;
// 初始化时,每个集合中的节点数量就是它自己,即为 1
sizee[i] = 1;
}
while (m--) {
char op[5];
int a, b;
scanf("%s", op);
// 合并两个集合,在合并的过程中,更新 size
if (op[0] == 'C') {
scanf("%d%d", &a, &b);
// 如果 a 和 b 已经在同一集合中了,则什么都不做
if (find(a) == find(b)) continue;
sizee[find(b)] += sizee[find(a)];
p[find(a)] = find(b);
// Q1 操作,判断 a 和 b 是否在同一集合中
} else if (op[1] == '1') {
scanf("%d%d", &a, &b);
if (find(a) == find(b)) printf("%s\n", "Yes");
else printf("%s\n", "No");
// Q2 操作,输出 a 所在集合中的节点数量
} else {
scanf("%d", &a);
printf("%d\n", sizee[find(a)]);
}
}
return 0;
}
哈希表
手写一个哈希表,通常定义哈希表的存储结构有两种方式:拉链法
和 开放寻址法
。
格外说明一下:对于删除操作而言,使用到的情况比较少。如果真的要支持删除操作的话,一般情况下不会将当前的值删除,而是使用布尔数组做一个标记。
拉链法
对于拉链法来说,如果当前的槽出现了冲突,那么就在当前槽的位置引出一个链表,用于存储发生冲突的 $value$。此外,在设计哈希函数进行取模时,即 $H(key)=key \\% mod$,一般将 $mod$ 设置为质数
,这样可以减少冲突的发生。
需要注意的地方如下:
- 在初始化数组槽的时候,将它们都初始化为 $-1$,其中 $memset$ 在 $cstring$ 头文件中;
- 在处理冲突时,使用到了链表的头插法;
- 在求某个关键字的 $value$ 的时候,由于我们的目的是为了将负数关键字也映射到正数范围内,因此,采用了模 $N$ 与加 $N$ 的操作。
// 万能头文件
// #include <bits/stdc++.h>
#include <iostream>
#include <cstring>
using namespace std;
// 大于 100000 的第一个质数是 100003,因此用该数对关键字取模
const int N = 100003;
// 定义槽数组以及链表结构
int h[N], e[N], ne[N], idx;
void insert(int x) {
// 在 C++ 中,对于 x % N 来说,
// 如果 x 是负数,则取模之后也是负数,
// 如果 x 是正数,则取模之后也是正数,
// 再加上一个 N 之后,那么它的结果就一定是一个正数,
// 加 N 模 N 的目的就是让 x 最终都变成一个正数
int k = (x % N + N) % N;
// 单链表的头插法
e[idx] = x, ne[idx] = h[k], h[k] = idx++;
}
bool find(int x) {
int k = (x % N + N) % N;
for (int i = h[k]; i != -1; i = ne[i])
if (e[i] == x)
return true;
return false;
}
int main() {
int n;
scanf("%d", &n);
// 将哈希槽初始化为 -1
memset(h, -1, sizeof h);
while (n--) {
char op[2];
int x;
scanf("%s%d", op, &x);
if (op[0] == 'I') insert(x);
else {
if (find(x)) printf("%s\n", "Yes");
else printf("%s\n", "No");
}
}
return 0;
}
开放寻址法
对于开放寻址法,有如下需要注意的地方:
- 对于新开的数组,一般情况下要比输入的数据范围大 $2~3$ 倍,这样的话,冲突的概率会比较低;
- 注意 $find$ 函数的作用;
- 注意,在声明无穷大时,一般使用 $0x3f3f3f3f$ 这个比较神奇的数字;
- 在 $memset$ 中使用到了 $0x3f$,这是因为 $memset$ 是按照字节来进行填充的,对于 $int$ 类型的数组,共含有 $4$ 个字节,每一个字节都是 $0x3f$,因此对于每个数的话,就是 $0x3f3f3f3f$。
#include <iostream>
#include <cstring>
using namespace std;
// 数组的大小开成 2 倍或 3 倍的给定数据范围
// 大于 200000 的第一个质数是 200003,因此用该数对关键字取模
// 定义 null 代表不在 -10^9~10^9 之间的数,也就是表示无穷大
const int N = 200003, null = 0x3f3f3f3f;
int h[N];
// 该 find 函数的作用是:如果 x 存在的话,那么就返回 x 所在的位置,
// 如果 x 不存在的话,那么就返回 x 应该存放的位置
int find(int x) {
int k = (x % N + N) % N;
// 如果 h[k] 这个位置已经有元素了,并且此时该元素还不等于 x
// 那么就一直往后走
while (h[k] != null && h[k] != x) {
k++;
if (k == N) k = 0;
}
return k;
}
int main() {
int n;
scanf("%d", &n);
memset(h, 0x3f, sizeof h);
while (n--) {
char op[2];
int x;
scanf("%s%d", op, &x);
int k = find(x);
if (op[0] == 'I') h[k] = x;
else {
if (h[k] != null) printf("%s\n", "Yes");
else printf("%s\n", "No");
}
}
return 0;
}
字符串的哈希
这里的字符串哈希,其实是字符串前缀哈希
,也就是说,对于当前的 $h[i]$,它存储的是字符串从 $1$ 到 $i$ 前缀的哈希值。
例如给定字符串 $str=“ABCDABCAAA”$,那么 $h[1]$ 记录的是 $A$ 的哈希值,$h[2]$ 记录的是 $AB$ 的哈希值,以此类推。
如何求字符串的哈希值呢?这里把字符串看成 $p$ 进制的数。例如给定字符串 $str=“ABCD”$,那么我们将 $A$ 对应的值映射为 $1$,将 $B$ 对应的值映射为 $2$,将 $C$ 对应的值映射为 $3$,依次类推。所以 $“ABCD”$ 就变成了 $(ABCD)_{p}$,那么在计算的的时候,将其转换成 $10$ 进制,即 $(1\*p^{3}+2\*p^{2}+3\*p^{1}+4\*p^{0}) \% Q$。
由于将其转换成 $10$ 进制的数可能会很大,因此将其结果取模,即映射到 $0$ 到 $Q-1$ 之间。
需要注意的是:不能将字符映射成 $0$。例如,$A$ 映射为 $0$,那么 $AA$ 也可以映射为 $0$......这样的话,就会将不同的字符串映射成同一个数字了,这样就出现了错误。
一般情况下根据经验来说,假定不会发生冲突的话,可以将 $p$ 置为 $131$ 或 $13331$,将 $Q$ 置为 $2^{64}$。
该哈希方式的核心在于:如何求从 $L$ 到 $R$ 区间内字符串的哈希值?求解方式为:$h[R]-h[L]\*p^{R-L+1}$。
这种字符串哈希方法适用于求解:在一个字符串中,判断给定的两个区间内的字符串是否相同。
#include <iostream>
using namespace std;
typedef unsigned long long ULL;
const int N = 100010, P = 131;
int n, m;
char str[N];
// p 是用来处理多少次方的
ULL h[N], p[N];
ULL get(int l, int r) {
// 公式:h[r]-h[l-1]*p^(r-l+1)
return h[r] - h[l - 1] * p[r - l + 1];
}
int main() {
// 将字符串的第一个字符放在索引为 1 的位置,以此类推
scanf("%d%d%s", &n, &m, str + 1);
p[0] = 1;
for (int i = 1; i <= n; i++) {
p[i] = p[i - 1] * P;
h[i] = h[i - 1] * P + str[i];
}
while (m--) {
int l1, r1, l2, r2;
scanf("%d%d%d%d", &l1, &r1, &l2, &r2);
if (get(l1, r1) == get(l2, r2)) printf("%s\n", "Yes");
else printf("%s\n", "No");
}
return 0;
}
STL
TODO
DFS 和 BFS
这里总结一下它们之间的区别:
- 数据结构:$DFS$ 使用 $stack$,而 $BFS$ 使用 $queue$;
- 空间:$DFS$ 使用的空间为 $O(h)$,也就是与树的高度有关,而 $BFS$ 使用的空间与当前层的节点个数有关,呈指数级别,即 $O(2^h)$;
- 此外,由于 $BFS$ 每次搜索的时候,搜索的都是最近的节点,因此它有一个“最短路”的概念,而 $DFS$ 不具备;
DFS
在使用 $DFS$ 时,有几点需要注意的地方:回溯、剪枝以及搜素的顺序。
#include <iostream>
using namespace std;
const int N = 10;
int n;
int path[N];
// 如果某个位置为 true 的话,那么说明当前这个数字已经被用过了
bool st[N];
void dfs(int u) {
// 当 u 等于 n 的时候,就相当于来到了叶节点,也就是最后一层
if (u == n) {
for (int i = 0; i < n; i++) printf("%d ", path[i]);
printf("\n");
return;
}
// 当 u 小于 n 的时候,我们就枚举一下当前位置可以填哪些数字
for (int i = 1; i <= n; i++) {
// 如果当前这个数字没有被用过,那么当前这个数字接下来才可以使用
if (!st[i]) {
path[u] = i;
st[i] = true;
// 来到下一层
dfs(u + 1);
// 回溯
st[i] = false;
}
}
}
int main() {
cin >> n;
// 从第 0 个空位置开始填数字
dfs(0);
return 0;
}
使用全排列的方式解决 $n$ 皇后问题:
#include <iostream>
using namespace std;
// 由于需要考虑到正反对角线,所以这里开 2 倍的 N
const int N = 20;
int n;
char g[N][N];
// dg 表示正对角线(diagonal),udg 表示反对角线
bool col[N], dg[N], udg[N];
void dfs(int u) {
if (u == n) {
for (int i = 0; i < n; i++) printf("%s\n", g[i]);
printf("\n");
return;
}
for (int i = 0; i < n; i++) {
// 如果这一列之前没有放过,并且正对角线和反对角线都没有被放过
if (!col[i] && !dg[u + i] && !udg[n - u + i]) {
g[u][i] = 'Q';
col[i] = dg[u + i] = udg[n - u + i] = true;
dfs(u + 1);
col[i] = dg[u + i] = udg[n - u + i] = false;
g[u][i] = '.';
}
}
}
int main() {
cin >> n;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
g[i][j] = '.';
dfs(0);
return 0;
}
另外一种搜索方式是:枚举每一个格子,直到枚举完第 $n^2$ 个格子之后,即可求得最终的结果。
#include <iostream>
using namespace std;
const int N = 20;
int n;
char g[N][N];
bool row[N], col[N], dg[N], udg[N];
void dfs(int x, int y, int s) {
// 如果 y 越界了,那么就来到下一行的起始位置
if (y == n) y = 0, x++;
// 已经枚举完最后一行了,此时需要检查皇后的数量
if (x == n) {
// 找到了一组解
if (s == n) {
for (int i = 0; i < n; i++) printf("%s\n", g[i]);
printf("\n");
}
return;
}
// 枚举当前格子的两种选择
// 不放置皇后
dfs(x, y + 1, s);
// 放置皇后
if (!row[x] && !col[y] && !dg[x + y] && !udg[x - y + n]) {
g[x][y] = 'Q';
row[x] = col[y] = dg[x + y] = udg[x - y + n] = true;
dfs(x, y + 1, s + 1);
row[x] = col[y] = dg[x + y] = udg[x - y + n] = false;
g[x][y] = '.';
}
}
int main() {
cin >> n;
for (int i = 0; i < n; i++)
for (int j = 0; j < n; j++)
g[i][j] = '.';
// 从 (0, 0) 点开始搜索,并记录当前有几个皇后
dfs(0, 0, 0);
return 0;
}
BFS
BFS 适用于只有所有的边的权重为 $1$ 的树中进行搜索,例如 走迷宫 这道题。当然,对于 $BFS$ 来说,这里有一个模板框架,即:
- 将初始状态入队;
- 在 $while$ 条件中指明队列不为空;
- 出队一个节点,然后扩展这个节点。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef pair<int, int> PII;
const int N = 110;
int n, m;
// 存储题干给定的图
int g[N][N];
// 存储每个点到起始点的距离
int d[N][N];
// 模拟一个队列
PII q[N * N];
// 记录走过的点
PII Prev[N][N];
int bfs() {
// 队头和队尾
int hh = 0, tt = 0;
q[0] = {0, 0};
memset(d, -1, sizeof d);
d[0][0] = 0;
// 定义上下左右四个方向
int dx[4] = {-1, 0, 1, 0}, dy[4] = {0, 1, 0, -1};
while (hh <= tt) {
// 取出队头元素
auto t = q[hh++];
// 遍历四个方向
for (int i = 0; i < 4; i++) {
// 计算新的方向
int x = t.first + dx[i], y = t.second + dy[i];
// d[x][y] == -1 表示下一个点还没有走过的话
if (x >= 0 && x < n && y >= 0 && y < m && g[x][y] == 0 && d[x][y] == -1) {
d[x][y] = d[t.first][t.second] + 1;
Prev[x][y] = t;
// 加入到队列中
q[++tt] = {x, y};
}
}
}
// 如果想要记录路径,则可以从终点一直往起点找
int x = n - 1, y = m - 1;
while (x || y) {
cout << '(' << x << ", " << y << ')'<< endl;
auto t = Prev[x][y];
x = t.first, y = t.second;
}
return d[n - 1][m - 1];
}
int main() {
cin >> n >> m;
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
cin >> g[i][j];
cout << bfs() << endl;
return 0;
}
树与图
这里讨论的是有向图
和无向图
在如何存储问题上的区别。对于有向图
来说,仅存储一条 $a→b$ 的边即可;而对于无向图
来说,相当于存储 $a→b$ 以及 $b→a$ 两条边。
拿有向图
来说,可以使用邻接矩阵
或邻接表
来存储。邻接矩阵
适用于存储稠密图
,也就是边的数量接近顶点数量的平方,而邻接表
适用于存储稀疏图
,也就是边的数量小于顶点数量的平方。
使用邻接矩阵
存储 $a→b$ 这条边:$g[a][b]$;使用邻接表
存储如下:
#include <cstring>
#include <iostream>
#include <algorithm>
const int N = 100010, M = N * 2;
// h[N] 表示邻接表的 N 个头节点,剩下的数组所代表的含义与链表结构相同
int h[N], e[M], ne[M], idx;
// 插入 a→b 这条边
void insert(int a, int b) {
// 这里的插入方式,与链表的头插法是一样的
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int main() {
// 初始化
idx = 0;
memset(h, -1, sizeof h);
}
树的深度优先遍历:
#include <cstring>
#include <iostream>
#include <algorithm>
const int N = 100010, M = N * 2;
int h[N], e[M], ne[M], idx;
bool st[N];
void dfs(int u) {
st[u] = true;
// 遍历节点 u 的所有出度
for (int i = h[u]; i != -1; i = ne[i]) {
// j 表示节点 i 在图里面的编号
int j = e[i];
// 如果节点 j 没有被搜过,那么就 dfs(j)
if (!st[j]) dfs(j);
}
}
int main() {
memset(h, -1, sizeof h);
dfs(1);
}
树的重心:这个一开始可能会比较难理解,但是画图之后就容易分析了。对于一个树来说,如果删除掉某个节点 $node$,剩下的节点可以组成多个连通分量,那么在这些连通分量中,选出节点最多的节点数量 $num$,此时的这个 $num$ 就是节点 $node$ 的重心。
对于 树的重心 这道题,假如删除节点 $4$,那么如何求得除了 $4$ 这个节点以外的连通分量的节点数量呢?对于 $4$ 的孩子节点来说,可以通过 $DFS$ 进行信息的返回,而对于 $4$ 的父节点来说,可以使用 $n-size_4$ 来计算。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, M = N * 2;
int n;
int h[N], e[M], ne[M], idx;
bool st[N];
int ans = N;
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 返回以 u 为根的子树中的点的数量
int dfs(int u) {
st[u] = true;
int sum = 1, res = 0;
// 遍历节点 u 的所有出度
for (int i = h[u]; i != -1; i = ne[i]) {
// j 表示节点 i 在图里面的编号
int j = e[i];
// 如果节点 j 没有被搜过,那么就 dfs(j)
if (!st[j]) {
int s = dfs(j);
res = max(res, s);
sum += s;
}
}
res = max(res, n - sum);
ans = min(ans, res);
return sum;
}
int main() {
cin >> n;
memset(h, -1, sizeof h);
for (int i = 0; i < n; i++) {
int a, b;
cin >> a >> b;
// 由于是无向图,所以需要声明两条边
add(a, b), add(b, a);
}
dfs(1);
cout << ans << endl;
return 0;
}
树的广度优先遍历:
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010;
int n, m;
// 使用邻接表存储结构
int h[N], e[N], ne[N], idx;
// d 表示距离,q 表示队列
int d[N], q[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
int bfs() {
// 队头和队尾
int hh = 0, tt = 0;
// 第一个起点是节点 1
q[0] = 1;
// -1 表示没有被遍历过
memset(d, -1, sizeof d);
// 第一个节点的距离是 0
d[1] = 0;
while (hh <= tt) {
// 取队头的顶点
int t = q[hh++];
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
// 如果 j 这个点没有被扩展过的话,则去扩展
if (d[j] == -1) {
d[j] = d[t] + 1;
q[++tt] = j;
}
}
}
return d[n];
}
int main() {
cin >> n >> m;
memset(h, -1, sizeof h);
for (int i = 0; i < m; i++) {
int a, b;
cin >> a >> b;
add(a, b);
}
cout << bfs() << endl;
return 0;
}
图的广度优先遍历的应用:求 有向图的拓扑序列。
需要注意的是:拓扑序列只出现在有向无环图中。
整体的步骤:首先,将所有入度为 $0$ 的点入队,当队列不为空的时候,弹出队头的点 $t$,然后枚举 $t$ 所有的出边,例如 $t→j$,再将这条边删除,同时将点 $j$ 的入度数量减 $1$,如果点 $j$ 的入度数量为 $0$ 了,那么就让点 $j$ 入队。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010;
int n, m;
// 定义邻接表
int h[N], e[N], ne[N], idx;
// 定义队列和入度
int q[N], d[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
bool topsort() {
// 定义队头和队尾
int hh = 0, tt = -1;
// 将所有入度为 0 的点,插入到队列中
for (int i = 1; i <= n; i++)
if (d[i] == 0) q[++tt] = i;
while (hh <= tt) {
// 取出队头元素
int t = q[hh++];
// 枚举所有的出边
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
// 让点 j 的入度减 1
d[j]--;
if (d[j] == 0) q[++tt] = j;
}
}
// 判断是否所有的点都已经进入到队列了
return tt == n - 1;
}
int main() {
cin >> n >> m;
memset(h, -1, sizeof h);
// 共有 m 条边
for (int i = 0; i < m; i++) {
int a, b;
cin >> a >> b;
add(a, b);
// 插入一条边之后,b 的入度要加 1
d[b]++;
}
if (topsort()) {
// 队列中的次序就是拓扑序
for (int i = 0; i < n; i++) printf("%d ", q[i]);
printf("\n");
} else printf("%d", -1);
return 0;
}
最短路
对于最短路径问题,之前在我的 博客 中总结过。
约定:$n$ 表示顶点,$m$ 表示边,则有:
$$ 最短路 \begin{cases} 单源最短路& \begin{cases} 所有边权都是正数& \begin{cases} 朴素 Dijkstra:O(n^2+m)& \\\\ 堆优化版的 Dijkstra:O(mlogn)& \end{cases} \\\\ 存在负权边& \begin{cases} Bellman-Ford:O(n*m)& \\\\ SPFA:一般:O(m),最坏:O(n * m) & \end{cases}\\\\ \end{cases} \\\\ 多源最短路:Folyd:O(n^3)& \end{cases} $$
由于朴素 $Dijkstra$ 算法与边无关,所以适用于稠密图
;而堆优化版的 $Dijkstra$ 算法适用于稀疏图
。
Dijkstra 算法
算法步骤如下:
- 初始化距离:$dist[1]=0, dist[i] = +\infty$,即第一个点的距离置为 $0$,其余点的距离置为正无穷大;
- $i$ 从 $1$ 到 $n$ 开始遍历,使用集合 $s$ 表示当前已经确定最短距离的点;
- 在遍历过程中,找到不在集合 $s$ 中的距离 $s$ 最近的点 $t$,然后将 $t$ 加入到集合 $s$ 中,最后用 $t$ 来更新其它所有点的距离;
- 更新的方式:例如有两条路径:$1→x$ 以及 $1→t→x$,那么我们就判断 $dist[t]+w<dist[x]$ 是否成立($w$ 表示 $t→x$的权重),如果成立的话,则更新 $x$ 的距离。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 510, INF = 0x3f3f3f3f;
int n, m;
// 稠密图,使用邻接矩阵存储
int g[N][N];
int dist[N];
// 判断当前点的最短路是否已经最终确定了
bool st[N];
int dijkstra() {
memset(dist, 0x3f, sizeof dist);
dist[1] = 0;
for (int i = 0; i < n; i++) {
int t = -1;
for (int j = 1; j <= n; j++)
// 这里的 if 实现的功能是:在所有没被确定最短距离的点中,找到 dist 最小的点
// 如果当前的点 j 还没有确定最短路的话,或者 t 不是最短的
if (!st[j] && (t == -1 || dist[t] > dist[j]))
t = j;
st[t] = true;
// 更新距离
for (int j = 1; j <= n; j++)
dist[j] = min(dist[j], dist[t] + g[t][j]);
}
// 如果点 1 和点 n 不连通,那么返回 -1
if (dist[n] == INF) return -1;
return dist[n];
}
int main() {
scanf("%d%d", &n, &m);
// 初始化邻接矩阵
//memset(g, 0x3f, sizeof g);
// 初始化邻接矩阵
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
if (i == j) g[i][j] = 1;
else g[i][j] = INF;
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
// 由于可能存在重边和自环,对于重边来说,仅保留权重最小的边即可
g[a][b] = min(g[a][b], c);
}
int t = dijkstra();
printf("%d\n", t);
return 0;
}
堆优化版的 Dijkstra 算法
朴素的 $Dijkstra$ 算法最耗时的地方是:在一堆数中找最小的那个数。基于这样的一个问题,我们就可以使用 堆
来优化。这里的堆可以手写一个堆,也可以直接使用 $C++$ 提供的优先队列。
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 160010, INF = 0x3f3f3f3f;
// 维护节点及其编号,所以用 pair
typedef pair<int, int> PII;
int n, m;
//稀疏图,使用邻接表存储,w 表示边的权重
int h[N], w[N], e[N], ne[N], idx;
int dist[N];
// 如果为 true,则说明这个点的最短路径已经确定
bool st[N];
void add(int a, int b, int c) {
// 需要存储权重
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
int dijkstra() {
memset(dist, 0x3f, sizeof dist);
dist[1] = 0;
// 定义小根堆
priority_queue<PII, vector<PII>, greater<PII>> heap;
// 先将 1 号点放进堆中,距离是 0,编号是 1
// 顺序不能颠倒,pair 在进行排序时,首先是根据 first 排序,然后再根据 second 排序,
// 这里显然需要根据距离进行排序
heap.push({0, 1});
// 当堆不为空时进行循环
while (heap.size()) {
// 找到不在集合 s 中距离最短的点
PII t = heap.top();
heap.pop();
// 拿到距离和编号
int distance = t.first, ver = t.second;
// 如果当前是一个冗余的点,则不需要再进行计算了
if (st[ver]) continue;
st[ver] = true;
// 用当前的点去更新其它的点,也就是遍历当前点的所有临边
for (int i = h[ver]; i != -1; i = ne[i]) {
int j = e[i];
// 如果从源点到 j 点的距离大于从 t 点到 j 点的距离,那么就需要更新
if (dist[j] > distance + w[i]) {
dist[j] = distance + w[i];
// 再将点 j 放到优先队列中
heap.push({dist[j], j});
}
}
}
// 如果点 1 和点 n 不连通,那么返回 -1
if (dist[n] == INF) return -1;
return dist[n];
}
int main() {
scanf("%d%d", &n, &m);
// 初始化邻接表的表头
memset(h, -1, sizeof h);
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
// 用邻接表的话,对于重边来说,就无所谓了
add(a, b, c);
}
int t = dijkstra();
printf("%d\n", t);
return 0;
}
Bellman-Ford 算法
算法思路:
- 该算法可以处理含有负权边的图;
- 使用 $for$ 循环迭代 $n$ 次,然后使用一个数组将 $dist$ 数组进行备份,被分成 $backup$ 数组,也就是使用上一次存储的结果,防止出现串联,然后每一次的再使用一个 $for$ 来循环所有的边,这里的边用 $(a, b, w)$ 来表示;
- 遍历的时候进行更新,即 $dist[b] = min(dist[b], backup[a] + w)$。也就是
松弛
操作; - 对于
松弛
操作,也就是判断 $1→a→b$ 所经过的路径和是否比 $1→b$ 所经过的路径短,如果短的话就更新; - 经过两个 $for$ 循环之后,对于任意给定的边 $(a, b, w)$,都满足 $dist[b] <= dist[a] + w$ 这一
三角不等式
成立; - 如果图中含有负权回路的话,则最终的结果会产生负无穷。因此,在一般情况下,在使用该算法时,图中不会存在负权回路。
例题:有边数限制的最短路,该题是就具有实际意义的。假定以下场景:某人想乘飞机从 $A$ 地到达 $B$ 地,中间可以经过其它的机场进行中转,这个人有一个兴奋值 $k$,求出从 $A$ 到 $B$ 的不超过 $k$ 的情况下的最少花费的中转金额。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 510, M = 10010;
int n, m, k;
int dist[N], backup[N];
struct Edge{
int a, b, w;
}edge[M];
int bellman_ford() {
memset(dist, 0x3f, sizeof dist);
dist[1] = 0;
// 根据题干要求,不超过 k 次
for (int i = 0; i < k; i++) {
// 每次在迭代之前,需要将 dist 数组被分到 backup 中
memcpy(backup, dist, sizeof dist);
for (int j = 0; j < m; j++) {
int a = edge[j].a, b = edge[j].b, w = edge[j].w;
dist[b] = min(dist[b], backup[a] + w);
}
}
// 这里为什么不适用 dist[n] == 0x3f3f3f3f 进行判断?
// 考虑这样一种情况:5 → n,其中权重为 -2,
// 假如一开始的时候 5 为 正无穷,n 也为正无穷,那么就有可能出现 5 的正无穷减去 2,然后将 n 的正无穷进行更新了,
// 所以最终的结果:n 可能是正无穷并减去一个很小的数,因此就不能直接使用 dist[n] 是否等于正无穷来判断
if (dist[n] > 0x3f3f3f3f / 2) return -1;
return dist[n];
}
int main() {
scanf("%d%d%d", &n, &m, &k);
// 读入 m 条边,并存储
for (int i = 0; i < m; i++) {
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
edge[i] = {a, b, w};
}
int t = bellman_ford();
// 最短路不存在
if (t == -1) printf("%s\n", "impossible");
else printf("%d\n", t);
return 0;
}
SPFA 算法
该算法的实现细节如下:
- $SPFA$ 算法是对 $Bellman-Ford$ 算法的优化,优化的点在哪儿呢?
- 对于 $Bellman-Ford$ 算法中的第二个 $for$ 循环,它循环的是每条边。值得注意的是:并不是所有的边都会执行
松弛
操作。因此,$SPFA$ 算法优化的点就在这儿。 - $SPFA$ 利用宽度优先搜索来做优化,队列中存储的就是变小的 $dist[a]$,因此,只要 $dist[a]$ 变小了,那么就可以去更新 $dist[b]$ 了;
- 具体实现方式为:首先将点 $1$ 加入队列中,当队列不为空的时候就一直循环;取出队头的点 $t$,并将队头的点弹出;然后更新 $t$ 的所有出边,例如 $(t, b, w)$,如果更新成功的话,则将 $b$ 加入到队列中。当然,如果队列中已经存在 $b$ 了的话,则无需重复加入。
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 100010, INF = 0x3f3f3f3f;
typedef pair<int, int> PII;
int n, m;
int h[N], w[N], e[N], ne[N], idx;
int dist[N];
// 判断当前的点是否已经存在于队列中,防止队列中存储重复的点
bool st[N];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
int spfa() {
memset(dist, 0x3f, sizeof dist);
dist[1] = 0;
queue<int> q;
q.push(1);
st[1] = true;
while (q.size()) {
int t = q.front();
q.pop();
// 点 t 已经从队列中出来了,则置为 false
st[t] = false;
// 遍历点 t 所有的临边
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
// 判断该点是否能更新
if (dist[j] > dist[t] + w[i]) {
dist[j] = dist[t] + w[i];
// 如果 j 不在队列中,我们则将 j 加入到队列中
if (!st[j]) {
q.push(j);
st[j] = true;
}
}
}
}
if (dist[n] == INF) return -1;
return dist[n];
}
int main() {
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
}
int t = spfa();
if (t != -1) printf("%d\n", t);
else printf("%s\n", "impossible");
return 0;
}
SPFA 判断是否存在负环
这里需要一个额外的数组 $cnt[N]$,它用于存储从源点到当前点所经历过的边数。我们在更新 $dist[x]= dist[t] + w[i]$ 的同时,也需要更新 $cnt[x] = cnt[t] + 1$。
例如,有 $1→t→x$ 和 $1→x$,那么如果前者路径短,则需要更新 $cnt[x]$,也就是从源点到 $t$ 的边数,在加上 $t→x$ 这条边。
最后,如果 $cnt[j] >= n$,说明从点 $1$ 到点 $j$,所经历的路径至少包含 $n$ 条边,也就至少存在 $n+1$ 个点。但是,由于只有 $1$ 到 $n$ 共 $n$ 个点,那么根据 抽屉原理
,则说明有 $2$ 个点是相同的,因此也就出现了负环。
#include <cstring>
#include <iostream>
#include <algorithm>
#include <queue>
using namespace std;
const int N = 100010, INF = 0x3f3f3f3f;
typedef pair<int, int> PII;
int n, m;
int h[N], w[N], e[N], ne[N], idx;
// cnt 用于记录从源点到当前点所经历过的边数
int dist[N], cnt[N];
bool st[N];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
bool spfa() {
// 无需初始化
// 这里需要将每个点放到初始的集合中,因为,如果只将某个点 1 放集合中的话,
// 那么有可能负环本身就不经过该点 1,从而就可能找不出负环
queue<int> q;
for (int i = 1; i <= n; i++) {
st[i] = true;
q.push(i);
}
while (q.size()) {
int t = q.front();
q.pop();
// 点 t 已经从队列中出来了,则置为 false
st[t] = false;
// 遍历点 t 所有的临边
for (int i = h[t]; i != -1; i = ne[i]) {
int j = e[i];
// 判断该点是否能更新
if (dist[j] > dist[t] + w[i]) {
dist[j] = dist[t] + w[i];
// 同时,更新 cnt
cnt[j] = cnt[t] + 1;
if (cnt[j] >= n) return true;
// 如果 j 不在队列中,我们则将 j 加入到队列中
if (!st[j]) {
q.push(j);
st[j] = true;
}
}
}
}
return false;
}
int main() {
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c);
}
bool t = spfa();
if (t) printf("%s\n", "Yes");
else printf("%s\n", "No");
return 0;
}
Floyd 算法
该算法实现的细节如下:
- 使用邻接矩阵来存储图,即 $d[i][j]$;
- 使用三重循环,第一层:$k$ 从 $1$ 到 $n$,第二层:$i$ 从 $1$ 到 $n$,第三层:$j$ 从 $1$ 到 $n$,然后更新 $d[i][j] = min(d[i][j], d[i][k] + d[k][j])$ 即可。
- 执行完之后,$d[i][j]$ 存储的就是从 $i$ 到 $j$ 的最短路径长度;
- 状态转移方程为:$d[k][i][j]=d[k-1][i][k]+d[k-1][k][j]$。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 210, INF = 1e9;
// k 表示询问次数
int n, m, k;
// 定义邻接矩阵
int d[N][N];
void floyd() {
for (int k = 1; k <= n; k++)
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
d[i][j] = min(d[i][j], d[i][k] + d[k][j]);
}
// 对于重边,则保留最短的边,
// 对于自环,则直接删除即可
int main() {
scanf("%d%d%d", &n, &m, &k);
// 初始化邻接矩阵
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++)
if (i == j) d[i][j] = 0;
else d[i][j] = INF;
while (m--) {
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
// 如果有多条边,则保留最小的边
d[a][b] = min(d[a][b], w);
}
floyd();
while (k--) {
int a, b;
scanf("%d%d", &a, &b);
if (d[a][b] > INF / 2) printf("%s\n", "impossible");
else printf("%d\n", d[a][b]);
}
return 0;
}
最小生成树
在求最小生成树问题时,一般情况下对应的都是 无向图
,假设图中顶点的数量为 $n$,边的数量为 $m$,则有以下两个算法:
$$ 最小生成树 \begin{cases} 普利姆(Prim)算法& \begin{cases} 朴素 Prim,O(n^2+m),稠密图& \\\\ 堆优化版 Prim,O(mlogn), 稀疏图& \end{cases} \\\\ 克鲁斯卡尔(Kruskal)算法,O(mlogm,稀疏图)& \end{cases} $$
$Kruskal$ 算法所花费的时间在于对所有的边进行排序。
朴素 Prim
该算法实现的步骤如下所示:
- 初始化所有点的距离为正无穷 $dist[i]=+\infty$;
- 迭代 $n$ 次,然后找到不在当前集合中的距离最小的点,然后赋值给 $t$;
- 用 $t$ 去更新其它点到当前集合的距离;
- 最后再将 $t$ 加入到集合中。
需要注意的是:
迪杰斯特拉
算法是用 $t$ 去更新其它点到源点
的距离,而普利姆
算法是用 $t$ 去更新其它点到当前集合
的距离。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 510, INF = 0x3f3f3f3f;
int n, m;
// 由于顶点少,边数多,所以是稠密图,对于稠密图,这里使用邻接矩阵
int g[N][N];
int dist[N];
bool st[N];
int prim() {
// 将所有顶点的距离初始化为正无穷
memset(dist, 0x3f, sizeof dist);
int ans = 0;
for (int i = 0; i < n; i++) {
int t = -1;
for (int j = 1; j <= n; j++)
// 顶点 j 位于集合外
if (!st[j] && (t == -1 || dist[t] > dist[j]))
t = j;
// 如果不是第一个顶点,说明当前是不连通的
if (i && dist[t] == INF) return INF;
// 先累加,然后再进行 for 循环的更新,需要考虑到自环的 case
if (i) ans += dist[t];
// 用 t 去更新其它点到集合的距离
for (int j = 1; j <= n; j++)
dist[j] = min(dist[j], g[t][j]);
// 将顶点加入到集合中
st[t] = true;
}
return ans;
}
int main() {
scanf("%d%d", &n, &m);
memset(g, 0x3f, sizeof g);
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
// 如果有重边,则保留长度最小的边
// 无向图是一种特殊的有向图,所以需要处理两个方向
g[a][b] = g[b][a] = min(g[a][b], c);
}
int t = prim();
if (t == INF) printf("%s\n", "impossible");
else printf("%d\n", t);
return 0;
}
Kruskal 算法
该算法步骤如下所示:
- 将所有的边按照权重,从小到大排序;
- 从小到大枚举每条边 $(a, b)$,如果 $(a, b)$ 不连通的话(也就是这两个顶点不在同一集合中),则将 $(a, b)$ 加入到集合中。
其中,在枚举每条边以及判断是否连通时,可以使用之前的 并查集
的方法。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 200010;
int n, m;
// 并查集的 p 数组
int p[N];
struct Edge {
int a, b, w;
// 重载小于号,方便按照权重进行排序
bool operator< (const Edge &W) const {
return w < W.w;
}
}edges[N];
int find(int x) {
// 如果 x 不是祖宗节点,则一直朝上找
if (p[x] != x) p[x] = find(p[x]);
return p[x];
}
int main() {
scanf("%d%d", &n, &m);
// 读所有的边
for (int i = 0; i < m; i++) {
int a, b, w;
scanf("%d%d%d", &a, &b, &w);
edges[i] = {a, b, w};
}
// 将所有的边进行排序
sort(edges, edges + m);
// 初始化并查集
for (int i = 1; i <= n; i++) p[i] = i;
int ans = 0, cnt = 0;
// 从小到大枚举所有边
for (int i = 0; i < m; i++) {
int a = edges[i].a, b = edges[i].b, w = edges[i].w;
a = find(a), b = find(b);
// 如果两个顶点不连通的话,则将 (a, b) 这条边加入到集合中
if (a != b) {
// 加入到同一集合中
p[a] = b;
// ans 表示最小生成树的边的权重之和
ans += w;
// cnt 表示当前已经加入了多少条边
cnt++;
}
}
// 如果不连通
if (cnt < n - 1) printf("%s\n", "impossible");
else printf("%d\n", ans);
return 0;
}
二分图
对于二分图的概念来说,首先有两个包含许多顶点的集合,集合内部的顶点互不相连,而两个不同集合之间的顶点是相连的,并且不会出现奇数环,这样的图称作二分图。
性质:如果一个图是二分图,那么当且仅当这个图可以被染色,并且,当且仅当图中不含奇数环。
奇数环:对于图中的一个环,如果环中的边数为奇数,那么则称该环为奇数环。
而判断一个图是否是二分图
,有以下两种算法:
$$ 二分图 \begin{cases} 染色法,O(n+m)& \\\\ 匈牙利算法,O(n*m)& \end{cases} $$
染色法
如果在染色的过程中出现了矛盾(奇数环),那么说明改图不是二分图。算法流程:首先从 $1$ 到 $n$ 遍历每个顶点,如果当前的顶点 $i$ 没有被染色,那么就 $dfs(i, 1)$,也就是使用深度优先遍历的方式,将顶点 $i$ 染成 $1$ 号色。此外,另一种颜色我们声明为 $2$ 号色。当然,在实际的编码过程中,使用 $-1$、$0$ 以及 $1$ 分别代表未染色、白色、黑色即可。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
// 由于是无向图,因此在存储边的时候,需要多存储一条边
const int N = 100010, M = 200010;
int n, m;
// 使用邻接表进行存储
int h[N], e[M], ne[M], idx;
// color 表示每个顶点的颜色,-1 表示未染色,0 表示白色,1 表示黑色
int color[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
bool dfs(int u, int c) {
// 首先记录一下当前顶点 u 的颜色是 c
color[u] = c;
// 然后遍历当前顶点 u 的所有相邻的点
for (int i = h[u]; i != -1; i = ne[i]) {
int j = e[i];
// 如果当前这个顶点没有染过颜色,则将其进行染成另外一种颜色
if (color[j] == -1) {
if (!dfs(j, !c)) return false;
// 如果顶点 j 已经染过颜色的话,就判断一下有没有矛盾,
// 也就是说,如果一条边的两个顶点的颜色都是相同的话,则出现了矛盾,也就是出现了奇数环
} else if (color[j] == c) return false;
}
return true;
}
int main() {
scanf("%d%d", &n, &m);
// 邻接表的初始化操作
memset(h, -1, sizeof h);
// color 表示每个顶点的颜色,-1 表示未染色,0 表示白色,1 表示黑色
memset(color, -1, sizeof color);
while (m--) {
int a, b;
scanf("%d%d", &a, &b);
// 对于无向图,两个顶点之间需要添加两条边
add(a, b), add(b, a);
}
// 在染色的过程中,是否有矛盾发生,也就是记录是否出现了奇数环
bool flag = true;
for (int i = 1; i <= n; i++) {
// 如果当前这个顶点没有染过颜色,则将其进行染成另外一种颜色
if (color[i] == -1) {
// 如果 dfs 返回的是 false,那么就说明有矛盾发生
if (!dfs(i, 0)) {
flag = false;
break;
}
}
}
if (flag) printf("%s\n", "Yes");
else printf("%s\n", "No");
return 0;
}
匈牙利算法
该算法用于求二分图中的最大匹配数量。
算法思路:
- 想象一个场景,一组男生和一组女生,他们之间不同的人之间可能会存在多条线进行连接起来,也就是男生会连向女生,女生也会连向男生,每条线代表好感度。而匈牙利算法就是找到匹配最多的、可以组成情侣对儿的数量。
- 也就是说,当前顶点 $A$ 去匹配 $B$ 的时候,如果顶点 $B$ 已经和别的顶点 $C$ 进行匹配了,那么就再判断顶点 $C$ 能不能去匹配其它点;
- 如果顶点 $C$ 能去匹配其它顶点的话,那么此时顶点 $A$ 就可以与顶点 $B$ 进行匹配了。
- 需要注意的是:如果枚举左边部分的点集,那么此时存储的就是从左侧顶点到右侧顶点的一个有向边。
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 510, M = 100010;
int n1, n2, m;
// 定义临界表
int h[N], e[M], ne[M], idx;
// match 表示右半部分的顶点与左侧的哪些点已经相连
int match[N];
// 避免重复搜索某个点
bool st[N];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
bool find(int x) {
// 首先,枚举当前左半部分顶点 x 已经与右半部分连接的顶点
// 也就是枚举当前男生已经看上的妹子
for (int i = h[x]; i != -1; i = ne[i]) {
int j = e[i];
// 如果当前的妹子没有考虑过别的男生,那么这个妹子就考虑 x 这个男生
if (!st[j]) {
st[j] = true;
// 如果这个妹子还没有匹配任何男生,或者这个妹子已经匹配了其它男生,
// 但是我们可以为那个男生找一个备胎...
if (match[j] == 0 || find(match[j])) {
match[j] = x;
return true;
}
}
}
return false;
}
int main() {
scanf("%d%d%d", &n1, &n2, &m);
memset(h, -1, sizeof h);
while (m--) {
int a, b;
scanf("%d%d", &a, &b);
add(a, b);
}
int ans = 0;
// 枚举左侧的点
for (int i = 1; i <= n1; i++) {
// 先将右半部分的顶点清空,这是为了保证对于每个右半部分的顶点,我只考虑一次
memset(st, false, sizeof st);
// 如果左半部分的顶点 i 找到右半部分的顶点,则 ans++
if (find(i)) ans++;
}
printf("%d\n", ans);
return 0;
}
数论
质数
定义:在大于 $1$ 的整数中,如果只包含 $1$ 和它本身这两个约数,则称为质数(素数)。
质数的判断通常有以下几种方式:
1.试除法,相当于暴力,$O(n)$。
bool is_prime(int n) {
if (n < 2) return false;
for (int i = 2; i < n; i++)
if (n % i == 0) return false;
return true;
}
2.可以对试除法进行优化,优化的地方就是 $for$ 循环中的判断条件。如果一个数 $d$ ($d$ 为 $n$ 的约数)能整除 $n$,那么 $n/d$ 也是能整除 $n$ 的。例如 $n=12, d=3$ 时,由于 $3$ 能整除 $12$,所以 $12/3=4$ 也能整除 $12$。
因此可以看到,$d$ 和 $n/d$ 都是成对儿出现的。所以,我们在每次枚举的时候,都枚举较小的那个数。即:
$$ \qquad d <= \frac{n}{d} $$
$$ \Rightarrow d^2<=n $$
$$ \Rightarrow d<= \sqrt{n} $$
这里的 $d$ 和 $i$ 都是一样的,都指的是同一个数。
至此,时间复杂度从 $O(n)$ 降低到 $O(\sqrt{n})$。因此,可以将 $for$ 循环的判断语句中,修改为 $i <= n / i$。需要注意的是:这里不使用 $i * i <= n$ 的原因是 $i$ 在很接近整数的最大值的时候,$i * i$ 操作有可能会溢出。同时,不使用 $i <= sort(n)$ 的原因是 $sort$ 函数在计算时很慢。
#include <iostream>
using namespace std;
int n;
bool is_prime(int x) {
if (x < 2) return false;
for (int i = 2; i <= x / i; i++)
if (x % i == 0) return false;
return true;
}
int main() {
scanf("%d", &n);
while (n--) {
int x;
scanf("%d", &x);
if (is_prime(x)) printf("%s\n", "Yes");
else printf("%s\n", "No");
}
return 0;
}
质因数
在进行分解质因数的时候,如果想要输出这个数的每个质因数的底数和指数,那么一种方式是使用暴力的方式,如下:
void divide(int x) {
for (int i = 2; i <= x; i++) {
// i 一定是质数
if (x % i == 0) {
int s = 0;
while (x % i == 0) {
n /= i;
s++;
}
printf("%d %d\n", i, s);
}
}
}
当然,$for$ 循环中的判断条件也是可以优化的。那么如何进行优化呢?
有一个很重要的性质:$x$ 中最多只包含一个大于 $\sqrt{x}$ 的质数因子。因此,时间复杂度也可以降低到 $O(\sqrt{n})$。
void divide(int x) {
for (int i = 2; i <= x / i; i++) {
// i 一定是质数
if (x % i == 0) {
int s = 0;
while (x % i == 0) {
x /= i;
s++;
}
printf("%d %d\n", i, s);
}
}
if (x > 1) printf("%d %d\n", x, 1);
printf("\n");
}
筛质数
有一种方法是这样的:假如想要得到从 $1$ 到 $n$ 范围内所有的质数,那么我们可以先将 $2、3、4、5、6、…、n$ 整理出来,然后删掉倍数为 $2$ 的数字,再删掉倍数为 $3$ 的数字,再删掉倍数为 $4$ 数字…最终剩下的数字,就是从 $1$ 到 $n$ 范围内所有的质数了。
显然可以看出,上述方法就是从 $2$ 开始依次枚举,因此需要用到数组。代码如下所示:
int primes[N], cnt;
// 表示 st[i] 是否已经被筛掉了
bool st[N];
void get_primes(int x) {
for (int i = 2; i <= x; i++) {
// 如果当前这个数没有被删(筛)过的话,则说明该数是一个质数
if (!st[i]) primes[cnt++] = i;
// 开始筛质数,也就是依次删除掉所有 2 的倍数、所有 3 的倍数、所有 4 的倍数......
// 注意:j 的初始值是 i + i
for (int j = i + i; j <= x; j += i)
st[j] = true;
}
}
上述方法的时间复杂度为 $O(n\log{n})$。因为对于 $for$ 循环来说,有:
$$ \frac{n}{2}+\frac{n}{3}+\frac{n}{4}+…+\frac{n}{n} = n(\frac{1}{2}+\frac{1}{3}+\frac{1}{4}+…+\frac{1}{n}) $$
对于 $\frac{1}{2}+\frac{1}{3}+\frac{1}{4}+…+\frac{1}{n}$ 来说,当 $n$ 趋于 $\infty$ 时,有:
$$ \lim_{x \rightarrow \infty} (\frac{1}{2}+\frac{1}{3}+\frac{1}{4}+…+\frac{1}{n})=\ln{n}+C $$
而 $\ln{n} = \log_{e}{n}$,因此:
$$ \log_{e}{n} < \log_{2}{n} $$
$$ \Rightarrow n\log_{e}{n} < n\log_{2}{n} $$
所以,从整体上来说,上述这种朴素筛法的时间复杂度为 $O(n\log{n})$。
如何优化?可以发现,第二个 $for$ 循环的作用是将每个数的倍数删掉,但其实可以不用这么做,我们可以将所有质数的倍数删掉即可,也就是将第二个 $for$ 循环放在 $if$ 内部即可。也就是说,当一个数不是质数的时候,我们就不需要筛掉它所有的倍数了。如下所示:
void get_primes(int x) {
for (int i = 2; i <= x; i++) {
if (!st[i]) {
primes[cnt++] = i;
for (int j = i + i; j <= x; j += i)
st[j] = true;
}
}
}
当然,也可以写成如下,效果都是一样的:
void get_primes(int x) {
for (int i = 2; i <= x; i++) {
if (st[i]) continue;
primes[cnt++] = i;
for (int j = i + i; j <= x; j += i)
st[j] = true;
}
}
此外,还有一个质数定理:在 $1$ 到 $n$ 中,一共有 $\frac{n}{\ln{n}}$ 个质数。因此,上述这种方法的时间复杂度是 $O(nloglogn)$,大约是 $O(n)$。
还有一种方法是线性筛法
,当给定的 $n$ 为 $10^7$ 及其以上的时候,该方法比上述方法大约快一倍。其核心思想是:对于每一个数 $x$,只会被它的最小质因子筛掉。
TODO
约数
在使用试除法求 $n$ 的约数的时候,其实与前面介绍的求质数的思路是一样的,也有一个优化的过程。即我们仅枚举较小的那个数即可,也就是 $d <= \sqrt{n}$。需要注意的是:在每次向 $vector$ 中添加约数的过程中,添加的顺序是 $i、n / i$。
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
int n;
vector<int> divisor(int x) {
vector<int> ans;
for (int i = 1; i <= x / i; i++) {
if (x % i == 0) {
ans.push_back(i);
// x/i 也是 x 的一个约数,有可能 x 是 i 的平方,
// 也就是说,如果 i 和 x/i 相同的话,那么我们只添加一个数即可
if (i != x / i) ans.push_back(x / i);
}
}
// 由于在存储的时候,是按照 i、x/i 的顺序存到 vector 的,
// 因此,需要进行排序
sort(ans.begin(), ans.end());
return ans;
}
int main() {
scanf("%d", &n);
while (n--) {
int x;
scanf("%d", &x);
auto ans = divisor(x);
for (auto t : ans) printf("%d ", t);
printf("\n");
}
return 0;
}
约数的个数与约数的和
对于如何求一个数 $n$ 含有约数的个数问题,可以使用如下公式,如果一个数 $n$ 在分解完质因数之后,可以写成如下形式:
$$ n = p_{1}^{\alpha_{1}}·p_{2}^{\alpha_{2}}·…·p_{k}^{\alpha_{k}} $$
那么,整数 $n$ 的约数个数为 $(\alpha_{1} + 1)·(\alpha_{2} + 1)·…·(\alpha_{k} + 1)$。
而对于约数之和来说,可以使用如下公式:
$$ (p_{1}^{0} + p_{1}^{1} + p_{1}^{2} +…+ p_{k}^{\alpha_{1}})·(p_{2}^{0} + p_{2}^{1} + p_{2}^{2} +…+ p_{k}^{\alpha_{2}})·…·(p_{k}^{0} + p_{k}^{1} + p_{k}^{2} +…+ p_{k}^{\alpha_{k}}) $$
在实际的编码过程中,例如 约数个数 这道题,对于给定的 $n$ 个数,我们分别将每个数进行质因数分解之后,再将各自的指数相加即可。当然,这里使用 $map$ 来存储底数和指数。
#include <iostream>
#include <algorithm>
#include <unordered_map>
using namespace std;
typedef long long LL;
const int MOD = 1e9 + 7;
int main() {
int n;
scanf("%d", &n);
unordered_map<int, int> primes;
while (n--) {
int x;
scanf("%d", &x);
// 分解质因数
for (int i = 2; i <= x / i; i++) {
while (x % i == 0) {
x /= i;
primes[i]++;
}
}
// 如果 x 是一个比较大的质因数,那么需要再加上
if (x > 1) primes[x]++;
}
LL ans = 1;
for (auto prime : primes)
ans = ans * (prime.second + 1) % MOD;
cout << ans << endl;
return 0;
}
求约数之和如下所示:
#include <iostream>
#include <algorithm>
#include <unordered_map>
using namespace std;
typedef long long LL;
const int MOD = 1e9 + 7;
int main() {
int n;
scanf("%d", &n);
unordered_map<int, int> primes;
while (n--) {
int x;
scanf("%d", &x);
for (int i = 2; i <= x / i; i++) {
while (x % i == 0) {
x /= i;
primes[i]++;
}
}
if (x > 1) primes[x]++;
}
LL ans = 1;
for (auto prime : primes) {
// 底数
int p = prime.first;
// 指数
int a = prime.second;
LL t = 1;
// 这里要计算 p^0 + p^1 + ... + p^a,
// 有很多方法实现上述公式,我们一开始让 t=p*t+1,此时 t=p+1,
// 然后 t=t*p+1=(p+1)*p+1=p^2+p+1,依次类推
while (a--) t = (t * p + 1) % MOD;
ans = ans * t % MOD;
}
cout << ans << endl;
return 0;
}
最大公约数
欧几里得算法(辗转相除法),在此之前,我们知道一些性质:如果 $d$ 能整除 $a$,斌且 $d$ 能整除 $b$,那么 $d$ 就能整除 $a+b$ 或者 $ax+by$。那么 $(a,, b)$ 的最大公约数就等于 $(b, a mod b)$ 的最大公约数。
由于 $a mod b = a - \frac{a}{b} * b = a - c * b$,那么 $(a, b)$ 的最大公约数等于 $(a, a - c * b)$ 的最大公约数。然后,根据一开始提到的性质,就能够证明它俩的公约数是相等的。因此,最终可以得到 $(a, b)$ 的最大公约数等于 $(b, a mod b)$ 的最大公约数。时间复杂度为 $O(\log{n})$。
#include <iostream>
#include <algorithm>
using namespace std;
int gcb(int a, int b) {
// 当 b 不为 0 的时候,返回 gcb(b, a % b),
// 当 b 为 0 的时候,相当于 gcb(a, 0),此时最大公约数就是 a
return b ? gcb(b, a % b) : a;
}
int main() {
int n;
cin >> n;
while (n--) {
int a, b;
cin >> a >> b;
int t = gcb(a, b);
printf("%d\n", t);
}
return 0;
}
欧拉函数
欧拉函数 $\varphi(n)$ 指的是小于或等于 $n$ 的正整数中与 $n$ 互质的数的个数。例如 $\varphi(8)=4$,即 $1、3、5、7$ 分别与 $8$ 互质。
对于如何求解 $n$ 的欧拉函数,可以使用如下方法:即分解质因数,然后再求解。
$$ n = p_{1}^{\alpha_{1}}·p_{2}^{\alpha_{2}}·…·p_{k}^{\alpha_{k}} $$
$$ \varphi(n) = n(1-\frac{1}{p_{1}})(1-\frac{1}{p_{2}})…(1-\frac{1}{p_{k}}) $$
思路就是使用容斥原理
,步骤如下:
首先,从 $1$ 到 $n$ 中去掉 $p_{1}、p_{1}、…、p_{k}$ 所有的倍数,而数字 $i$ 的所有的倍数一共有多少个呢?其实一共有 $\frac{n}{p_{i}}$ 个。那么,这一步的操作,就等同于:
$$ n - \frac{n}{p_{1}} - \frac{n}{p_{2}} - … - \frac{n}{p_{k}} $$
然后,通过上面的公式,就会多去掉一些数,例如去掉的这个数有可能是 $p_{1}$ 的倍数,也有可能是 $p_{2}$ 的倍数,那么本该去掉一次的,然而实际却去掉了两次。现在就需要将多去掉的次数加回来。即再加上所有 $p_{i} \* p_{j}$ 的倍数,然后再减去所有 $p_{i} \* p_{j} * p_{k}$ 的倍数。如下所示:
$$\varphi(n) = n - \frac{n}{p_{1}} - \frac{n}{p_{2}} - … - \frac{n}{p_{k}} $$
$$ \+ \frac{n}{p_{1}p_{2}} + \frac{n}{p_{1}p_{3}} + … $$
$$ \- \frac{n}{p_{1}p_{2}p_{3}} - \frac{n}{p_{1}p_{2}p_{4}} - … $$
$$ \+ \frac{n}{p_{1}p_{2}p_{3}p_{4}} + … $$
$$ \- \frac{n}{p_{1}p_{2}p_{3}p_{4}p_{5}} - … $$
$$ … $$
上述属于容斥原理
,化简后就等于下列式子:
$$ \varphi(n) = n(1-\frac{1}{p_{1}})(1-\frac{1}{p_{2}})…(1-\frac{1}{p_{k}}) $$
#include <iostream>
#include <algorithm>
using namespace std;
int main() {
int n;
scanf("%d", &n);
while (n--) {
int x;
scanf("%d", &x);
int ans = x;
// 分解质因数
for (int i = 2; i <= x / i; i++) {
if (x % i == 0) {
// ans = ans * (1 - 1 / i);
// 将上式化简
ans = ans / i * (i - 1);
while (x % i == 0) x /= i;
}
}
if (x > 1) ans = ans / x * (x - 1);
printf("%d\n", ans);
}
return 0;
}
上述方法是从定义出发来求解欧拉函数,而如果想要求 $1$ 到 $n$ 每一个数的欧拉函数,可以使用筛选法求欧拉函数
,也就是前面提到的线性筛法
求质数的方法。
TODO
快速幂
该方法可以在 $O(\log{k})$ 的时间内计算 $a^{k} \% p$ 的结果。整体思路是:首先预处理 $\log{k}$ 项结果,即 $a^{2^{0}} \% p$、$a^{2^{1}} \% p$、$a^{2^{2}} \% p$、…、$a^{2^{\log{k}}} \% p$,然后 $a^{k}$ 就可以转化成:
$$ a^{k} = a^{2^{x_{1}}}·a^{2^{x_{2}}}·a^{2^{x_{3}}}·…·a^{2^{x_{t}}} $$
$$ =a^{2^{x_{1}}+2^{x_{2}}+2^{x_{3}}+…+2^{x_{t}}} $$
也就是将 $k$ 转换成已预处理结果的和。因此,我们可以将 $k$ 表示成二进制,例如 $(k)_{10} = (110110)_{2}$,可得 $k = 2^{1} + 2^{2} + 2^{4} + 2^{5}$。而在预处理时,每一个数都是上一个数的平方再模 $p$。
举个例子:求 $4^{5} \% 10$ 的结果。
首先进行预处理:$4^{2^{0}} \% 10= 4$、$4^{2^{1}} \% 10= 6$、$4^{2^{2}} \% 10= 6$,然后:
$$ 4^{5} = 4^{101} = 4^{2^{0}+2^{2}} = 4^{2^{0}} · 4^{2^{2}} $$
通过查表,可知:$4^{2^{0}} · 4^{2^{2}} = 4 × 6 = 24$,当然最后还需要对 $10$ 取模。因此,最终的结果为 $4$。
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
int qmi(int a, int k, int p) {
int ans = 1;
while (k) {
// 如果指数 k 的二进制表示形式中,最后一位是 1 的话,
// 则要先乘以一个底数
if (k & 1) ans = (LL)ans * a % p;
a = (LL)a * a % p;
// 由于 k 是使用二进制表示的,因此上面的个位处理完之后,
// 需要来到二进制的下一位(十位),也就是将二进制的个位删掉,
// 然后再进行计算
k >>= 1;
}
return ans;
}
int main() {
int n;
scanf("%d", &n);
while (n--) {
int a, k, p;
scanf("%d%d%d", &a, &k, &p);
printf("%d\n", qmi(a, k, p));
}
return 0;
}
使用快速幂求解逆元。
TODO
扩展的欧几里得算法
在此之前,需要知道 裴蜀定理,其内容为:设 $a$、$b$ 是不全为零的整数,则存在整数 $x$、$y$,使得 $ax+by=gcd(a, b)$。
#include <iostream>
#include <algorithm>
using namespace std;
int exgcd(int a, int b, int &x, int &y) {
// 如果 b 为 0,则返回 a,
// 那么 gcd(a, 0) = a,此时 a*x + 0*y = a,那么 x = 1、y = 0 就是其中一组解
if (!b) {
x = 1, y = 0;
return a;
}
// 得到最大公约数
// b*y + (a%b)*x = d = gcd(a, b)
// 而 a%b = a - ⌊a/b ⌋*b
// 联立两式,可得 ax = b(y-⌊a/b ⌋ x)= d
int d = exgcd(b, a % b, y, x);
y -= a / b * x;
return d;
}
int main() {
int n;
scanf("%d", &n);
while (n--) {
int a, b, x, y;
scanf("%d%d", &a, &b);
// x 和 y 通过引用传递
exgcd(a, b, x, y);
printf("%d %d\n", x, y);
}
return 0;
}
扩展欧几里得算法的应用:线性同余方程。
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
int exgcd(int a, int m, int &x, int &y) {
if (!m) {
x = 1, y = 0;
return a;
}
int d = exgcd(m, a % m, y, x);
y -= a / m * x;
return d;
}
int main() {
int n;
scanf("%d", &n);
while (n--) {
int a, b, m;
scanf("%d%d%d", &a, &b, &m);
int x, y;
int d = exgcd(a, m, x, y);
// 如果 b 不是 d 的倍数,则无解
if (b % d) printf("%s\n", "impossible");
else printf("%d\n", (LL)x * (b / d) % m);
}
return 0;
}
中国剩余定理
中国剩余定理用于求解如下形式的一元线性同余方程组(其中 $n_{1}、n_{2}、…n_{k}$ 两两互质):
$$ \begin{cases} x \equiv a_{1} (mod \; n_{1})& \\\\ x \equiv a_{2} (mod \; n_{2}) && \\\\ \quad \qquad …& \\\\ x \equiv a_{k} (mod \; n_{k}) && \end{cases} $$
算法的流程为:首先计算所有模数的乘积,即 $n = n_{1}·n_{2}·…·n_{k}$;其次,对于第 $i$ 个方程,计算 $m_{i} = \frac{n}{n_{i}}$,再计算 $m_{i}$ 在 $n_{i}$ 意义下的逆元 $m_{i}^{-1}$;然后,计算 $c_{i} = m_{i}m_{i}^{-1}$;最后,方程组的唯一解为:
$$ a = \sum_{i = 1}^k a_{i}c_{i} (mod \; n) $$
这里在求逆元时,可以使用扩展欧几里得算法
,详见 OI Wiki。
高斯消元
高斯消元可以在约 $n^3$ 的时间内求解一个包含 $m$ 个方程和 $n$ 个未知数的多元线性方程组,如下所示:
$$ \begin{cases} a_{11}x_{1} + a_{12}x_{2} + … + a_{1n}x_{n} = b_{1},& \\\\ a_{21}x_{1} + a_{22}x_{2} + … + a_{2n}x_{n} = b_{2},& \\\\ …&\\\\ a_{m1}x_{1} + a_{m2}x_{2} + … + a_{mn}x_{n} = b_{m} \end{cases} $$
而对于上述方程的解,则有且仅有三种情况:存在唯一解、无数解、无解。
具体流程如下(详见 OI Wiki):
- 将增广矩阵行初等行变换为行最简阶梯形;
- 还原线程方程组;
- 求解第一个变量;
- 补充自由未知量;
- 列表示方程组通解。
在编码时,可以按照以下的方式进行:
- 枚举每一列 $col$;
- 找到绝对值最大的一行;
- 将该行换到最上面;
- 将该行第 $1$ 个数变成 $1$;
- 将下面所有行的第 $col$ 列消成 $0$。
#include <iostream>
#include <algorithm>
// 这里使用到了 fabs 函数,即 float abs
#include <cmath>
using namespace std;
const int N = 110;
// 由于浮点数的原因,可能在判断一个数是否是 0 的时候会存在误差;
// 因此,当某个数的绝对值小于 10 的 -6 次方的时候,则认为这个数就是 0
const double eps = 1e-6;
int n;
double a[N][N];
// 0 表示有唯一解
// 1 表示无穷多组解
// 2 表示无解
int gauss() {
// 定义行和列
int r, c;
// 从第 0 列开始枚举
for (r = 0, c = 0; c < n; c++) {
// 从这一行开始,找到绝对值最大的那一行
int t = r;
for (int i = r; i < n; i++)
// 如果当前遍历到的值大于之前备选答案中的值,则更新这个值
if (fabs(a[i][c]) > fabs(a[t][c]))
t = i;
// 如果当前这个数是 0 的话,则 continue
if (fabs(a[t][c]) < eps) continue;
// 将绝对值最大的一行换到最上面
for (int i = c; i < n + 1; i++) swap(a[t][i], a[r][i]);
// 将改行第一个数字变成 1,也就是将这一行所有的数都除以第一个数;
// 注意:这里需要最后再更新第一个数;
// 如果先将第一个数进行更新的话,那么后面的数在除的时候就相当于除以 1 了,就不对了
for (int i = n; i >= c; i--) a[r][i] /= a[r][c];
// 将下面所有行的第 c 列消成 0
for (int i = r + 1; i < n; i++)
// 如果当前遍历到的数已经是 0 了,那么就不用操作了;
// 反之,如果当前的数不是 0 的话,再去操作
if (fabs(a[i][c]) > eps)
for (int j = n; j >= c; j--)
a[i][j] -= a[r][j] * a[i][c];
r++;
}
// 不是唯一解,此时需要判断是无解还是无穷多解
if (r < n) {
for (int i = r; i < n; i++)
// 出现了 0 等于非 0 的情况
if (fabs(a[i][n]) > eps)
return 2;
return 1;
}
// 有唯一解
return 0;
}
int main() {
cin >> n;
for (int i = 0; i < n; i++)
for (int j = 0; j < n + 1; j++)
cin >> a[i][j];
int t = gauss();
if (t == 0) {
for (int i = 0; i < n; i++) printf("%.2lf\n", a[i][n]);
}
else if (t == 1) printf("%s\n", "Infinite group solutions");
else printf("%s\n", "No solution");
return 0;
}
组合数
从 $n$ 个不同的元素中,任取 $m(m<=n)$ 个元素的所有组合个数,叫做从 $n$ 个不同元素中取出 $m$ 个有元素的组合数,用 $C_{n}^{m}$ 表示。
对于排列来说,它的计算公式为:$A_{n}^{m} = n(n-1)(n-2)(n-3)···(n-m+1) = \frac{n!}{(n-m)!}$
所以,组合数的计算公式为:
$$ C_{n}^{m} = \frac{A_{n}^{m}}{m!} = \frac{n!}{m!(n-m)!} = \binom{n}{m} $$
在进行实现时,可以预处理所有的 $C_{n}^{m}$ 的值,即 $C_{n}^{m} = C_{n-1}^{m} + C_{n-1}^{m-1}$,这样就可以在进行循环时,直接查表得到结果了。
注意:以上的递推方式
在实现时,需要格外注意数据的范围。我们应根据不同的数据范围,使用不同的方法来求解组合数问题。适用于以上方法的数据范围是:$1<=n<10000,1<=b<=a<=2000$。其中,$n$ 表示询问次数,$a、b$ 表示给定的整数。
#include <iostream>
#include <algorithm>
const int N = 2010, MOD = 1e9 + 7;
int c[N][N];
void init() {
for (int i = 0; i < N; i++)
for (int j = 0; j <= i; j++)
// 如果 j 为 0
if (!j) c[i][j] = 1;
// dp 的思想
else c[i][j] = (c[i - 1][j] + c[i - 1][j - 1]) % MOD;
}
int main() {
init();
int n;
scanf("%d", &n);
while (n--) {
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", c[a][b]);
}
return 0;
}
而对于给定数据在 $1<=b<=a<=10^5$ 范围内的情况来说,可以使用预处理的方式。这里的预处理,指的是对某个数阶乘的结果
以及某个数阶乘结果的逆元
进行预处理。如下所示:
求逆元时,可以使用快速幂,即费马小定理。
$$ fact[i] = i! \; \% \; MOD $$
$$ infact[i] = (i!)^{-1} \; \% \; MOD $$
因此,式子 $C_{a}^{b} = \frac{a!}{(b-a)! \; · \; b!}$ 就可以转化成:
$$ C_{a}^{b} = fact[a] × infact[a-b] × infact[b] $$
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010, MOD = 1e9 + 7;
// fact 表示阶乘模上 p 的值
// infact 表示阶乘的逆元模上 p 的值
int fact[N], infact[N];
// 使用快速幂求逆元
int qmi(int a, int k, int p) {
int ans = 1;
while (k) {
if (k & 1) ans = (LL) ans * a % p;
a = (LL) a * a % p;
k >>= 1;
}
return ans;
}
int main() {
fact[0] = infact[0] = 1;
for (int i = 1; i < N; i++) {
fact[i] = (LL) fact[i - 1] * i % MOD;
infact[i] = (LL) infact[i - 1] * qmi(i, MOD - 2, MOD) % MOD;
}
int n;
scanf("%d", &n);
while (n--) {
int a, b;
scanf("%d%d", &a, &b);
printf("%d\n", (LL) fact[a] * infact[b] % MOD * infact[a - b] % MOD);
}
return 0;
}
对于给定数据范围在 $1<=b<=a<=10^{18}$ 来说,就需要使用到卢卡斯($Lucas$)定理。$Lucas$ 定理用于求解大组合数取模的问题,其中 $p$ 必须为质数(素数)。正常情况下组合数的运算可以通过上述递推公式的方法进行求解,但当问题规模很大而模数不是一个很大的质数的时候,就需要用到 $Lucas$ 定理进行求解。例如 求组合数 III。
公式如下:
$$ C_\{a}^{b} \equiv C_\{a \; \% \; p}^{b \; \% \; p} · C_\{a / p}^{b / p} (\% \; p) $$
$$ \Rightarrow \binom{a}{b} mod \; p = \binom{a / p}{b / p} · \binom{a \; mod \; p}{b \; mod \; p} \; mod \; p $$
其中,这里的 $mod$ 表示取模 %
的意思。对于 $a \; mod \; p$ 和 $b \; mod \; p$ 一定是小于 $p$ 的数,因此可以直接求解。而前面的 $\binom{a / p}{b / p}$ 可以继续使用该定理求解。
如何推导?
首先,可以将 $a$ 和 $b$ 写成以下形式:
$$ a = a_{k}·p^k + a_{k-1}·p^{k-1} + … + a_{0}·p^0 $$
$$ b = b_{k}·p^k + b_{k-1}·p^{k-1} + … + b_{0}·p^0 $$
然后,使用生成函数方法,可以得到:
$$ (1 + x) ^ p = C_{p}^{0}·x^0 + C_{p}^{1}·x^1 + C_{p}^{2}·x^2 + …+ C_{p}^{p}·x^p $$
由于 $p$ 是质数,所以 $p$ 中是不包含任何质因子的,因此中间的项可以消掉。所以,可得:
$$ (1 + x) ^ p \equiv 1 + x^p (mod \; p) $$
因此,有:
$$ (1 + x) ^ a = (1 + x^{p_{0}})^{a_{0}}(1+x^{p_{1}})^{a_{1}}(1+x^{p_{2}})^{a_{2}}···(1+x^{p_{k}})^{a_{k}} $$
$$ C_{a}^{b} \equiv C_{a_{k}}^{b_{k}} · C_{a_{k-1}}^{b_{k-1}}· C_{a_{k-2}}^{b_{k-2}} ··· C_{a_{0}}^{b_{0}} (mod \; p) $$
需要注意的是:当 $b_{i} > a_{i}$ 的时候,$C_{a}^{b} = 0$。
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
int p;
// 快速幂
int qmi(int a, int k) {
int ans = 1;
while (k) {
if (k & 1) ans = (LL)ans * a % p;
a = (LL)a * a % p;
k >>= 1;
}
return ans;
}
// 通过定理求组合数 C(a, b)
int C(LL a, LL b) {
int ans = 1;
for (int i = 1, j = a; i <= b; i++, j--) {
ans = (LL)ans * j % p;
ans = (LL)ans * qmi(i, p - 2) % p;
}
return ans;
}
int lucas(LL a, LL b) {
if (a < p && b < p) return C(a, b);
return (LL) C(a % p, b % p) * lucas(a / p, b / p) % p;
}
int main() {
int n;
cin >> n;
while (n--) {
LL a, b;
cin >> a >> b >> p;
cout << lucas(a, b) << endl;
}
return 0;
}
此外,还有使用分解质因数求解组合数的方法,该方法适用于使用高精度的方式将结果计算出来的场景。
TODO
卡特兰数
推荐阅读 OI Wiki
对于 满足条件的01序列 这道题来说,我们可以将其转化成在二维平面内走格子的场景。例如,可以看成:从 $(0, 0)$ 点走到 $(6, 6)$ 点,总共的走法有哪些?
如果规定往右走一步用 $0$ 表示,往上走一步用 $1$ 表示,那么由于限制因素为任意前缀序列中 $0$ 的个数都不能少于 $1$ 的个数 这一限制条件,则对应到二维平面上,就相当于需要满足 $x>=y$。因此,在走的过程中,只能走在从 $(0, 0)$ 到 $(6, 6)$ 的连接线及其以下的位置。
可以使用总的方案数
减去经过了连接线以上的方案数
,最后剩下的就是所求的方案数
。总的方案数
为:$C_{12}^{6}$,而经过了连接线以上的方案数
在求解的时候,需要做一个关于红线的轴对称。也就是说,将一开始进入到红线的点以及之后的所有的线,都做关于该红线的轴对称(具体参见数学知识(三)02:10:10)。所以,减去的不合法的方案数就是 $C_{12}^{5}$。因此,最终的结果就是:$C_{12}^{6} - C_{12}^{5}$,即:
$$ C_{2n}^{n} - C_{2n}^{n-1} = \frac{1}{n + 1} C_{2n}^{n}= \frac{1}{n + 1} \binom{2n}{n} $$
$C_{12}^{6}$ 表示从 $12$ 条边中选择 $6$ 条往上走的边。
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int MOD = 1e9 + 7;
int qmi(int a, int k, int p) {
int ans = 1;
while (k) {
if (k & 1) ans = (LL)ans * a % p;
a = (LL)a * a % p;
k >>= 1;
}
return ans;
}
int main() {
int n;
cin >> n;
int a = 2 * n, b = n;
int ans = 1;
for (int i = a; i > a - b; i--) ans = (LL)ans * i % MOD;
for (int i = 1; i <= b; i++) ans = (LL)ans * qmi(i, MOD - 2, MOD) % MOD;
ans = (LL)ans * qmi(n + 1, MOD - 2, MOD) % MOD;
cout << ans << endl;
return 0;
}
容斥原理
这里涉及到集合中交集和并集的概念,如何区分交集和并集?
交集指的是两个集合相交的部分,也就是两个集合中都含有的部分,相当于把多余的元素倒掉,所以交集的开口向下,即 $\cap$。
并集指的是两个集合中所有的元素,相当于把两个集合的元素全部放在一起,那就需要一个大的容器来装,所以并集的符号开口向上,即 $\cup$。
一个简单的容斥原理的例子是:求三个相交在一起的圆的面积。有以下等式:
$$ |S_1 \cup S_2 \cup S_3| = |S_1| + |S_2| + |S_3| - |S_1 \cap S_2| - |S_1 \cap S_3| - |S_2 \cap S_3| + |S_1 \cap S_2 \cap S_3| $$
时间复杂度为:
$$ C_{n}^{1} +C_{n}^{2} + … + C_{n}^{n} = 2^n - C_{n}^{0} = 2^n - 1 $$
其中,组合恒等式为:
$$ C_{k}^{1} - C_{k}^{2} + … + (-1)^{k-1} C_{k}^{k} = 1 $$
这里涉及到的例题为 能被整除的数。那么对于给定的样例,我们可以使用容斥原理来做。
能被 $2$ 整除的数有 $S_2 : \{2、4、6、8、10\}$,能被 $3$ 整除的数有 $S_3 : \{3、6、9 \}$。因此,有:
$$ |S_2 \cup S_3| = |S_2| + |S_3| - |S_2 \cap S_3| = 5 + 3 - 1 = 7 $$
其中,$|S_{p}|$ 表示 $1$ 到 $n$ 中 $p$ 的倍数的个数,即 $\lfloor \frac{n}{p} \rfloor$。扩展一下, $|S_{p_1} \cap S_{p_2} \cap S_{p_3} \cap … \cap S_{p_k}|$ 的计算方式为 $\lfloor \frac{n}{p_1 p_2 p_3···p_k} \rfloor$。
这里为什么是下取整?
因为我们可以将这 $n$ 个数分成两部分,即能被 $p$ 整除的部分和不能被 $p$ 整除的部分。而能被 $p$ 整除的部分为 $\frac{n}{p}$,不能被 $p$ 整除的部分可以为 $p、2p、3p…kp$,由于 $k$ 小于 $n$,因此,该部分为 $\lfloor \frac{n}{p} \rfloor$。
当然在求 $|S_{p}|$ 的时候,可以使用位运算的方式。将每个 $i$ 看作是 $n$ 位的二进制数,对于每一位上的数,$1$ 表示选,$0$ 表示不选,每一个二进制数都对应一种选法。
#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 20;
int n, m;
int p[N];
int main() {
scanf("%d%d", &n, &m);
// 读入 m 个质数
for (int i = 0; i < m; i++) cin >> p[i];
int ans = 0;
for (int i = 1; i < 1 << m; i++) {
// t 表示当前所有质数的乘积
// cnt 表示当前 i 中包含 1 的个数,因为这里需要将 i 看作是二进制数
int t = 1, cnt = 0;
for (int j = 0; j < m; j++) {
// 如果当前位是 1
if (i >> j & 1) {
cnt++;
if ((LL)t * p[j] > n) {
t = -1;
break;
}
t *= p[j];
}
}
if (t != -1) {
// 如果有奇数个集合
if (cnt % 2) ans += n / t;
else ans -= n / t;
}
}
cout << ans << endl;
return 0;
}
博弈论
对于 $Nim$ 游戏来说,假如给定两堆石子,第一堆石子共 $2$ 个,第二堆石子共 $3$ 个,那么先手是必赢的。
因为,先手可以选择在第二堆石子中拿 $1$ 个,那么以后,后手不管拿几个,先手直接镜像拿相应的个数即可。这样可以保证先手每次能拿到,而后手最终会走到石子数量为 $0$ 的状态,也就没有石子拿了。
对于先手来说,他有两种状态:
- 先手必胜状态:可以走到某个让对手必败的状态;
- 先手必败状态:走不到任何一个让对手必败的状态。
给定 $n$ 堆石子,第 $i$ 堆有 $A_i$ 个石子,将所有的石子进行异或操作,那么先手必胜的条件是:$ A_1 \land A_2 \land A_3 \land … \land A_n != 0$,其中,符号 $\land$ 表示异或操作。
数学知识(四)01:18:19