题目描述
如果一个数组的任意两个相邻元素之和都是 完全平方数,则该数组称为 平方数组。
给定一个整数数组 nums
,返回所有属于 平方数组 的 nums
的排列数量。
如果存在某个索引 i
使得 perm1[i] != perm2[i]
,则认为两个排列 perm1
和 perm2
不同。
样例
输入:nums = [1,17,8]
输出:2
解释:[1,8,17] 和 [17,8,1] 是有效的排列。
输入:nums = [2,2,2]
输出:1
限制
1 <= nums.length <= 12
0 <= nums[i] <= 10^9
算法
(状态压缩动态规划) O(2n⋅n2)
- 首先我们不考虑重复的情况,即假设排列的下标不同就是不同的排列,我们尝试用动态规划求解。
- 状态 f(mask,i) 表示已经用了集合 mask 中的数字,且最后一个数字的下标为 i,合法可平方排列的方案数,这里的集合 mask 用一个二进制数来表示,二进制位 bk 为 0 代表第 k 个数字没用过。
- 初始时 f({i},i)=1,其余为 0。转移时,首先枚举 j 使得 j 不在 mask 中,然后判断 nums[i]+nums[j] 是否为完全平方数,如果符合条件,则 f(mask⋃{j},j)=f(mask⋃{j})+f(mask,i)。
- 最终答案为 ∑ni=0f(2n−1,i)。
- 然后需要处理重复的情况,由排列转组合的经验得知,我们只需要找到数组 nums 中每种数字出现的次数,用答案除以每种数字出现次数的阶乘即可。
时间复杂度
- 状态数为 O(2n⋅n),转移需要 O(n) 的时间,去重需要首先排序,然后扫一遍,故总时间复杂度为 O(2n⋅n2)。
空间复杂度
- 需要 O(2n⋅n) 的额外空间存储状态。
C++ 代码
class Solution {
private:
bool isSquare(int x) {
int t = sqrt(x);
return t * t == x;
}
int fact(int x) {
int t = 1;
for (int i = 2; i <= x; i++)
t *= i;
return t;
}
public:
int numSquarefulPerms(vector<int>& nums) {
int n = nums.size();
sort(nums.begin(), nums.end());
vector<int> s;
int cnt = 1;
for (int i = 1; i < n; i++) {
if (nums[i] != nums[i - 1]) {
s.push_back(cnt);
cnt = 0;
}
cnt++;
}
s.push_back(cnt);
vector<vector<int>> f(1 << n, vector<int>(n, 0));
for (int i = 0; i < n; i++)
f[1 << i][i] = 1;
for (int mask = 1; mask < (1 << n); mask++)
for (int i = 0; i < n; i++)
if (mask & (1 << i))
for (int j = 0; j < n; j++)
if (!(mask & (1 << j)) && isSquare(nums[i] + nums[j]))
f[mask | (1 << j)][j] += f[mask][i];
int ans = 0;
for (int i = 0; i < n; i++)
ans += f[(1 << n) - 1][i];
for (int i = 0; i < s.size(); i++)
ans /= fact(s[i]);
return ans;
}
};
直接递归来做似乎好一些
时间复杂度O(2^n),空间O(n)
class Solution { vector<int> nums; vector<bool> st; vector<int> plan; int cnt = 0; public: int numSquarefulPerms(vector<int>& _nums) { nums = _nums; sort(nums.begin(), nums.end()); st.resize(nums.size(), false); dfs(-1); return cnt; } void dfs(int last) { if (plan.size() == nums.size()) { cnt++; return; } for (int i = 0; i < st.size(); ++i) { if (st[i]) continue; int t = last + nums[i]; if (last == -1 || (int)sqrt(t) * (int)sqrt(t) == t) { st[i] = true; plan.push_back(nums[i]); dfs(nums[i]); plan.pop_back(); st[i] = false; } while (i + 1 < st.size() && nums[i + 1] == nums[i]) i++; } } };
直接递归的理论时间复杂度是 O(n!),但中途有剪枝,不会超时