数学知识之矩阵乘法
$矩阵乘法是一种高效的算法可以把一些一维递推优化到log(n),是一种应用性极强的算法。$
$矩阵,是线性代数中的基本概念之一。$
$矩阵乘法可以解决从一个状态转换到另一个状态的动态规划问题$
$这里例举一个斐波那契数列求和的问题$
$第i个状态数组标识为$
状态序号 | $第i个数$ | $第 i + 1 个数$ | $ 前i个数的和$ |
---|---|---|---|
$i$ | $f[i]$ | $f[i + 1]$ | $S[i]$ |
$i + 1$ | $f[i + 1] $ | $f[i + 2] =f[i] + f[i + 1]$ | $ S[i + 1] = S[i] + f[i + 1] $ |
$现有一个矩阵A[N][N]$
$能使F[i] 变换为 F[i+1]$
$f[i + 1] = f[i] \times 0 + f[i + 1] \times 1 + S[i] \times 0, 第一列[0,1,0]$
$f[i + 2] =f[i] \ times 1 + f[i + 1] \times 1 + S[i] \times 0, 第二列[1,1,0]$
$ S[i + 1] =f[i] \ times 0 + f[i + 1] \times 1 + S[i] \times 1, 第三列[0,1,1]$
$ A= \left\[ \begin{matrix} 0 & 1 & 0 \\\ 1 & 1 & 1 \\\ 0 & 0 & 1 \\\ \end{matrix} \right\] $
序列 乘 矩阵
$按照上列求出的A矩阵$
$依次求出每一个乘出来的值$
$F[i + 1][j] = \sum _{k = 0}^{k < N}F[i][k] \times A[k][j] (0 \leq j < n)$
$Code$
void mul(int c[N], int a[N], int b[N][N])
{
int temp[N] = {0};
for (int i = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ )
temp[i] = (temp[i] + (LL)a[j] * b[j][i]) % m;
memcpy(c, temp, sizeof temp);
}
矩阵乘矩阵
$由于每次把一个序列和矩阵相乘,想要达到目标序列的时间复杂度是O(n)$
$因为每一次乘的矩阵都是常量, 所以可以用快速幂,来优化计算A^n,时间复杂度O(log_2n)$
$那么怎样才能把两个矩阵相乘呢?$
$矩阵相乘,设A \times B = C$
$那么序列F \times A \times B = F \times C$
$对于C中的每一个数C[i][j]表示F \times A \times B 的第j位数是有多少个F[i]组成的$
$在F \times A 中 的 每一位数AF[k] 中包含了 A[i][k]个F[i]$
$对于 每一位 AF[k] 包含的F[i] ,在ABF[j]中都贡献了 B[k][j]倍$
$那么C[i][j] = \sum _{k=0}^{k<N} A[i][k] \times B[k][j]$
$Code$
void mul(int c[N][N], int a[N][N], int b[N][N])
{
int temp[N][N] = {0};
for (int i = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ )
for (int k = 0; k < N; k ++ )
temp[i][j] = (temp[i][j] + (LL)a[i][k] * b[k][j]) % m;
memcpy(c, temp, sizeof temp);
}
$Code$
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = ;//...
int n, m;
void mul(int c[N], int a[N], int b[N][N])
{
int temp[N] = {0};
for (int i = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ )
temp[i] = (temp[i] + (LL)a[j] * b[j][i]) % m;
memcpy(c, temp, sizeof temp);
}
void mul(int c[N][N], int a[N][N], int b[N][N])
{
int temp[N][N] = {0};
for (int i = 0; i < N; i ++ )
for (int j = 0; j < N; j ++ )
for (int k = 0; k < N; k ++ )
temp[i][j] = (temp[i][j] + (LL)a[i][k] * b[k][j]) % m;
memcpy(c, temp, sizeof temp);
}
void print()
{
//...
}
int main()
{
cin >> n >> m;
int f[N] = {};//...
int a[N][N] = {
//...
};
n -- ;
while (n)
{
if (n & 1)mul(f, f, a);
mul(a, a, a);
n >>= 1;
}
print();
return 0;
}