题目描述
给你一个字符串 num
。如果一个数字字符串的奇数位下标的数字之和与偶数位下标的数字之和相等,那么我们称这个数字字符串是 平衡的。
请你返回 num
不同排列 中,平衡 字符串的数目。
由于答案可能很大,请你将答案对 10^9 + 7
取余 后返回。
一个字符串的 排列 指的是将字符串中的字符打乱顺序后连接得到的字符串。
样例
输入:num = "123"
输出:2
解释:
num 的不同排列包括:"123","132","213","231","312" 和 "321"。
它们之中,"132" 和 "231" 是平衡的。所以答案为 2。
输入:num = "112"
输出:1
解释:
num 的不同排列包括:"112" ,"121" 和 "211"。
只有 "121" 是平衡的。所以答案为 1。
输入:num = "12345"
输出:0
解释:
num 的所有排列都是不平衡的。所以答案为 0。
限制
2 <= num.length <= 80
num
中的字符只包含数字'0'
到'9'
。
算法
(动态规划,组合数学) $O(n^2 U + \log mod)$
- 设状态 $f(i, j, k)$ 表示使用了 $i$ 个数字,其中放在偶数位置的字符个数为 $j$ 个,且偶数位置的数字之和为 $k$ 的方案数。
- 初始时,$f(0, 0, 0) = 1$,其余为 $0$ 待定。
- 转移时,对于第 $i$ 个数字 $x$,可以选择放在偶数位置,即转移 $f(i, j, k) = f(i, j, k) + f(i - 1, j - 1, k - x) * j$;也可以选择放在奇数位置上,即转移 $f(i, j, k) = f(i, j, k) + f(i - 1, j, k) * (i - j)$。
- 最终答案 $f(n, (n + 1) / 2, u / 2)$,其中 $u$ 是数字总和。
- 但这里仍存在重复的结果,即相同数字不同的排列出现了重复,需要通过除以相同数字个数的阶乘去重。
- 可以通过倒序枚举后面两维,省略掉第一维的状态存储。
时间复杂度
- 预处理阶乘逆元的时间复杂度为 $O(n + \log mod)$。
- 动态规划的状态数为 $O(n^2 U)$,转移时间为常数。其中 $U$ 为所有数字的总和。
- 故总时间复杂度为 $O(n^2 U + \log mod)$
空间复杂度
- 需要 $O(nU)$ 的额外空间存储预处理阶乘逆元和动态规划的状态。
C++ 代码
#define LL long long
const int M = 41, U = 361;
const int mod = 1000000007;
int inv[2 * M];
int power(int x, int y) {
int res = 1, p = x;
for (; y; y >>= 1) {
if (y & 1)
res = (LL)(res) * p % mod;
p = (LL)(p) * p % mod;
}
return res;
}
auto init = []{
int fac = 1;
for (int i = 1; i < 2 * M; i++)
fac = (LL)(fac) * i % mod;
inv[2 * M - 1] = power(fac, mod - 2);
for (int i = 2 * M - 2; i >= 0; i--)
inv[i] = (LL)(inv[i + 1]) * (i + 1) % mod;
return 0;
}();
class Solution {
private:
int f[M][U];
public:
int countBalancedPermutations(string num) {
const int n = num.size();
int u = 0;
vector<int> cnt(10, 0);
for (char c : num) {
++cnt[c - '0'];
u += c - '0';
}
if (u & 1)
return 0;
u >>= 1;
const int m = (n + 1) >> 1;
memset(f, 0, sizeof(f));
f[0][0] = 1;
for (int i = 1; i <= n; i++) {
int x = num[i - 1] - '0';
for (int j = min(i, m); j >= 0; j--)
for (int k = u; k >= 0; k--) {
f[j][k] = (LL)(f[j][k]) * (i - j) % mod;
if (j >= 1 && k >= x)
f[j][k] = (f[j][k] + (LL)(f[j - 1][k - x]) * j) % mod;
}
}
int ans = f[m][u];
for (int i = 0; i < 10; i++)
ans = (LL)(ans) * inv[cnt[i]] % mod;
return ans;
}
};