题目描述
难度分:1600
输入两个长度均≤5×105的字符串s和字符串t,只包含0
和1
。
重排s中的字符,使得s中有尽量多的子串等于t。
输出重排后的s。如果有多个答案,输出任意一个。
输入样例1
101101
110
输出样例1
110110
输入样例2
10010110
100011
输出样例2
01100011
输入样例3
10
11100
输出样例3
01
算法
Z函数+贪心构造
首先过滤一下无解的情况,如果n<m或者s中的0
数量或1
数量达不到t中对应数字的数量,肯定就无解了。
否则,至少s第一个长度为m的子串可以是t的样子,然后我们滑动窗口来构造后面的部分。假设此时[0,r]已经构造完毕了,如果此时遍历到了i>r,只要前缀[0,r]长度为m−(i−r)的后缀是t的前缀,并且t的后i−r个字符能够被当前剩下的0
和1
覆盖住(0
和1
的数量足够),那么就可以做到s的子串[0,i]也能以t为后缀结尾,按这个策略贪心地构造即可。
那么这里就有个关键问题,如何能够快速判断“前缀[0,r]长度为m−(i−r)的后缀是t的前缀”?其实只要利用Z算法预处理出t的z数组就可以了(z[i]表示t后缀[i,m−1]的所有前缀中,也是t的前缀的最大长度),只要满足m−(i−r)≤z[i−r],说明s的子串[i−m+1,i]也可以是t的样子。
复杂度分析
设s串的长度为n,t串的长度为m。
时间复杂度
对t串用Z函数预处理,时间复杂度为O(m)。然后遍历[0,n)构建新的s串,均摊时间复杂度为O(n)。因此,算法整体的时间复杂度为O(n+m)。
空间复杂度
Z函数得到的z数组空间复杂度为O(m)。而由于构造并不是在s串上进行原地操作,所以空间复杂度为O(n)。因此,算法整体的额外空间复杂度为O(n+m)。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
int scnt[2], tcnt[2];
string s, t;
template<class S>
vector<int> z_algo(const S& s) {
int n = s.size();
vector<int> r(n + 1);
r[0] = 0;
for(int i = 1, j = 0; i <= n; i++) {
int& k = r[i];
k = j + r[j] <= i? 0: min(j + r[j] - i, r[i - j]);
while(i + k < n and s[k] == s[i + k]) k++;
if(j + r[j] < i + r[i]) j = i;
}
r[0] = n;
return r;
}
int main() {
cin >> s >> t;
int n = s.size(), m = t.size();
scnt[0] = scnt[1] = tcnt[0] = tcnt[1] = 0;
for(int i = 0; i < n; i++) {
scnt[s[i] - '0']++;
}
for(int i = 0; i < m; i++) {
tcnt[t[i] - '0']++;
}
bool flag = n >= m;
for(int i = 0; i <= 1; i++) {
if(scnt[i] < tcnt[i]) {
flag = false;
break;
}
}
if(!flag) {
cout << s << endl;
exit(0);
}
vector<int> z = z_algo(t);
string ans = string(n, '*');
for(int i = 0; i < m; i++) {
scnt[t[i] - '0']--;
ans[i] = t[i];
}
int r = m - 1;
for(int i = m; i < n; i++) {
int len = m - (i - r); // 前缀长度
if(len < 0) break;
if(len <= z[m - len]) {
int zero = 0, one = 0;
for(int j = r + 1, index = len; j <= i; j++, index++) {
if(t[index] == '0') {
zero++;
}else {
one++;
}
}
if(zero <= scnt[0] && one <= scnt[1]) {
scnt[0] -= zero;
scnt[1] -= one;
for(int j = r + 1, index = len; j <= i; j++, index++) {
ans[j] = t[index];
}
r = i;
}
}
}
for(int i = 0; i < n; i++) {
if(ans[i] == '*') {
if(scnt[0] > 0) {
ans[i] = '0';
scnt[0]--;
}else {
ans[i] = '1';
scnt[1]--;
}
}
}
cout << ans << endl;
return 0;
}