题目描述
给定一个由大写英文字母组成的字符串 S。
你需要找到 一个 长度 最短 的回文串 P,使得 S 是 P 的前缀。
回文串是指正读和反读都一样的字符串(例如 “ABA”, “RACECAR”)。
样例
输入:
ABC
输出:
ABCBA
说明:ABCBA
是一个回文串,它以 S="ABC"
作为前缀,并且是满足这两个条件的最短字符串。
输入:
Z
输出:
Z
说明:Z
本身就是回文串,且以 S="Z"
为前缀。
输入:
TREE
输出:
TREERT
说明:TREERT
是回文串,以 S="TREE"
为前缀,且是最短的。
算法1 (字符串哈希)
O(N)
思路分析
-
目标: 我们要找最短的回文串
P
,使得S
是P
的前缀。这意味着P
可以写成S + T
的形式,其中T
是某个(可能为空)字符串,并且S + T
整体是一个回文串。我们希望T
尽可能短。 -
回文串的性质: 如果
P = S + T
是回文串,那么P
必须等于它的反转rev(P)
。即S + T = rev(S + T) = rev(T) + rev(S)
。 -
寻找 T: 我们需要找到最短的
T
使得S + T
是回文串。观察S + T = rev(T) + rev(S)
。为了让T
最短,我们需要让S
的尽可能长的 后缀 本身就是一个回文串。- 设
S
的长度为N
。 - 假设
S
的最长回文后缀是S[k...N-1]
(从索引k
到末尾)。这意味着S
的前k
个字符S[0...k-1]
是非回文的部分(相对于整个S
而言)。 - 要构成最短的回文串
P
,我们需要在S
的末尾添加的字符串T
恰好是S
前k
个字符的反转,即T = rev(S[0...k-1])
。 - 那么
P = S + rev(S[0...k-1])
。 - 例如:
S = ABC
。最长回文后缀是C
(k=2)。需要添加rev(S[0...1]) = rev(AB) = BA
。结果P = ABCBA
。S = RACE
。最长回文后缀是E
(k=3)。需要添加rev(S[0...2]) = rev(RAC) = CAR
。结果P = RACECAR
。S = ABACABA
。最长回文后缀是ABACABA
(k=0)。需要添加rev(S[0...-1])
(空串)。结果P = ABACABA
。
- 设
-
等价问题: 找到
S
的最长回文后缀,等价于找到S
的最长 前缀,使得这个前缀等于S
的反转串rev(S)
的对应长度的 前缀。 更正: 找到S
的最长回文后缀S[k...N-1]
,等价于找到S
的 后缀S[k...N-1]
,使得它等于rev(S)
的对应长度N-k
的 前缀rev(S)[0...N-k-1]
。 -
使用哈希查找: 我们可以高效地判断
S
的后缀是否等于rev(S)
的前缀。- 计算
S
的字符串哈希值。 - 计算
S
的反转串S_rev
的字符串哈希值。 - 从最长的可能长度
L=N
开始,向下检查到L=1
:- 比较
S
的长度为L
的后缀S[N-L...N-1]
的哈希值,是否等于S_rev
的长度为L
的前缀S_rev[0...L-1]
的哈希值。 - 为了减少哈希冲突,可以使用多组不同的基数和模数(代码中用了 3 组)。
- 找到的第一个(即最大的)满足哈希值相等的长度
L
,就对应了最长的回文后缀S[N-L...N-1]
。
- 比较
- 令找到的最大长度为
k
(在代码中)。 - 那么需要添加的字符是
S
的前m = N - k
个字符的反转。 - 最终答案是
S + reverse(S.substr(0, m))
。
- 计算
时间复杂度
- 计算
S
和S_rev
的哈希值:O(N)。 - 查找最长匹配长度
k
:循环最多 N 次,每次哈希值比较是 O(D),其中 D 是哈希的维数 (这里 D=3)。总共 O(N)。 - 构造最终答案:字符串截取、反转、拼接,最坏情况 O(N)。
- 总时间复杂度为 O(N)。
C++ 代码
#include <bits/stdc++.h> // 引入所有标准库
using namespace std;
using i64 = long long; // 定义 i64 为 long long 的别名
// 字符串哈希模板
// D: 哈希维数 (使用多少组不同的哈希)
// B: 指向基数数组的指针
// P: 指向模数数组的指针
template <int D, const int *B, const int *P>
struct StringHash {
std::vector<std::array<int, D>> h; // h[i][k] 存储前缀 s[0...i-1] 的第 k 维哈希值
// 构造函数,从字符串或字符数组计算哈希值
template <class T>
StringHash(const T &s) : h(s.size() + 1) { // 哈希数组大小为 N+1
for (auto i = 0U; i < s.size(); i++) { // 遍历字符串
for (int k = 0; k < D; k++) { // 计算每一维哈希
// h[i+1] = (h[i] * base + char_value) % mod
// +11 是为了避免字符本身为0或接近0,增加区分度
h[i + 1][k] = (1LL * h[i][k] * B[k] + s[i] + 11) % P[k];
}
}
}
StringHash(const char *s) : StringHash(std::string(s)) {} // 支持 C 风格字符串
// 查询区间 [l, r) (即 s[l...r-1]) 的哈希值
std::array<int, D> get(int l, int r) const {
// 静态变量存储预计算的基数幂次 B^len % P
// static 只初始化一次,且在函数调用间保持
static std::vector<std::array<int, D>> spow(1);
assert(l < r); // 确保区间有效
// 如果需要的幂次 len = r - l 超出已计算范围,则扩展 spow
if (static_cast<int>(spow.size()) < r - l + 1) {
if (spow[0][0] == 0) { // 首次进入或静态变量重置?初始化 B^0 = 1
spow[0].fill(1);
}
int n = spow.size(); // 当前已计算到的最大幂次+1
spow.resize(r - l + 1); // 扩展到需要的长度
for (int i = n; i < static_cast<int>(spow.size()); i++) {
for (int k = 0; k < D; k++) {
// spow[i] = spow[i-1] * B % P
spow[i][k] = 1LL * spow[i - 1][k] * B[k] % P[k];
}
}
}
std::array<int, D> res = {}; // 存储结果哈希值
int len = r - l; // 区间长度
for (int k = 0; k < D; k++) {
// hash(l, r) = (h[r] - h[l] * B^len) % P
res[k] = h[r][k] - 1LL * h[l][k] * spow[len][k] % P[k];
res[k] += (res[k] < 0 ? P[k] : 0); // 结果取正模
}
return res;
}
};
// 定义 3 组哈希的基数和模数 (选择大的素数)
static const int B3[3] = {1000003, 19260817, 998244353};
static const int P3[3] = {int(1e9) + 9, int(1e9) + 21, int(1e9) + 33};
// 定义 Hash 类型为使用 B3, P3 的 3 维哈希
using Hash = StringHash<3, B3, P3>;
int main() {
// 加速 IO
ios::sync_with_stdio(false);
cin.tie(nullptr);
string S; // 输入字符串 S
cin >> S;
int N = S.size(); // S 的长度
string s_rev = S; // 创建 S 的副本
reverse(s_rev.begin(), s_rev.end()); // 反转副本得到 S_rev
Hash hash_S(S); // 计算 S 的哈希
Hash hash_S_rev(s_rev); // 计算 S_rev 的哈希
// 从长到短查找 S 的后缀 与 S_rev 的前缀 匹配的最长长度 k
for (int k = N; k >= 1; k--) {
// 比较 S[N-k...N-1] 的哈希 与 S_rev[0...k-1] 的哈希
if (hash_S.get(N - k, N) == hash_S_rev.get(0, k)) {
// 找到最长的匹配长度 k (即最长回文后缀的长度)
int m = N - k; // 需要反转并添加的前缀 S[0...m-1] 的长度
string R = S.substr(0, m); // 取出此前缀
reverse(R.begin(), R.end()); // 反转此前缀
cout << S + R << "\n"; // 输出 S 加上反转后的前缀
return 0; // 找到答案,结束程序
}
}
// 理论上 k=0 (空串) 总会匹配,或者至少 k=1 会匹配 (单个字符是回文)
// 但循环从 k=N 到 1,如果 S 本身是回文,k=N 时会匹配
// 如果需要处理空串 S 的情况(虽然题目保证长度 >= 1),需要调整
// 这里因为循环到 k=1 必能找到解 (单个字符后缀匹配单个字符前缀)
// 如果k=N没匹配,则k=N-1, ..., k=1中一定会找到匹配
return 0; // 正常结束 (虽然上面已经 return 了)
}
算法2 (Manacher 算法)
O(N)
思路分析
-
目标回顾: 找到最短回文串
P = S + T
,其中T = rev(S[0...k-1])
,k
是S
最长回文后缀S[k...N-1]
的起始索引。我们需要找到这个最小的k
。 -
Manacher 算法: Manacher 算法可以在 O(N) 时间内计算出以字符串中每个字符(以及字符间的间隙)为中心的最长回文子串的半径。它通常在一个转换后的字符串
t
(例如#A#B#C#
) 上操作,计算半径数组r
。 -
利用 Manacher 找最长回文后缀:
- 我们要求
S
的最长回文后缀S[k...N-1]
。 - 这对应于在转换后的字符串
t
中,找到一个以某个中心i
为对称中心的回文串,该回文串恰好延伸到t
的末尾。 - 令
S
的长度为n
,t
的长度为m = 2n + 1
。 - 考虑
S
的后缀S[l...n-1]
。它对应t
中的子串t[2l+1 ... 2n-1]
。 - 这个子串的对称中心在
t
中的索引是i = (2l+1 + 2n-1) / 2 = l + n
。 S[l...n-1]
是一个回文串,当且仅当以t[i] = t[l+n]
为中心的回文串至少覆盖t[2l+1 ... 2n-1]
。- Manacher 算法计算的半径
r[i]
表示以t[i]
为中心的回文串向两边扩展的长度(包括中心)。该回文串的右边界是i + r[i] - 1
。 - 我们需要这个回文串包含
t
的最后一个有效字符t[m-2]
(对应S[n-1]
) 或者说到达t
的边界m-1
。 - 因此,我们需要
i + r[i] - 1 >= m - 1 = 2n
(或者至少是2n-1
?)。更准确地说,我们需要从中心i = l+n
到右边界m-1
的距离,即(m-1) - i + 1 = m - i = (2n+1) - (l+n) = n - l + 1
,必须小于等于半径r[i]
。 - 所以,
S[l...n-1]
是回文串的条件是r[l+n] >= n - l + 1
。 - 我们要找的是最小的
l
(对应最长的回文后缀),满足r[l+n] >= n - l + 1
。
- 我们要求
-
代码实现:
- 运行 Manacher 算法得到半径数组
r
(在转换后的字符串t
上)。 - 从
l = 0
开始,检查条件r[l+n] >= n - l + 1
是否满足。 - 代码中的循环
while (r[l + n] < n - l + 1)
是在寻找第一个 不 满足r[l+n] < n-l+1
(即满足r[l+n] >= n-l+1
)的l
。循环停止时,l
就是我们要找的最小索引。 - 找到这个
l
后,需要添加的字符串是S
的前l
个字符的反转rev(S[0...l-1])
。 - 最终答案是
S + rev(S.substr(0, l))
。
- 运行 Manacher 算法得到半径数组
时间复杂度
- Manacher 算法:O(N)。
- 查找最小
l
的循环:最多 O(N)。 - 构造最终答案:O(N)。
- 总时间复杂度为 O(N)。
C++ 代码
#include <bits/stdc++.h> // 引入所有标准库
// 类型别名,这里 i128, u128 未使用
using i64 = long long;
using u64 = unsigned long long;
using u32 = unsigned;
using i128 = __int128;
using u128 = unsigned __int128;
// 使用 C++20 ranges 和 views
namespace ranges = std::ranges;
namespace views = std::views;
// Manacher 算法实现
// 输入: 字符串 s
// 输出: 半径数组 r (在转换后的字符串 t 上的半径)
std::vector<int> manacher(std::string s) {
std::string t = "#"; // 转换后的字符串 t
for (auto c : s) { // 在字符间插入 '#'
t += c;
t += '#';
}
int n = t.size(); // t 的长度
std::vector<int> r(n); // 半径数组
// i: 当前中心, j: 当前达到最右边界的回文串的中心
for (int i = 0, j = 0; i < n; i++) {
// 利用已计算的回文信息进行初始化 (优化)
// 如果 i 在当前最右回文串的覆盖范围内 (i < j + r[j])
if (2 * j - i >= 0 && j + r[j] > i) {
// r[i] 的初始值可以取 i 关于 j 的对称点 2*j-i 的半径 r[2*j-i]
// 但不能超过右边界 j + r[j]
r[i] = std::min(r[2 * j - i], j + r[j] - i);
}
// 从初始半径开始向外扩展,检查是否匹配
// r[i] 是当前已知的半径,尝试增加半径 (r[i]++)
// 检查 t[i - r[i]] 和 t[i + r[i]] 是否相等且在边界内
while (i - r[i] >= 0 && i + r[i] < n && t[i - r[i]] == t[i + r[i]]) {
r[i]++; // 半径增加
}
// 如果当前中心 i 扩展到的右边界超过了之前的最右边界 j + r[j]
if (i + r[i] > j + r[j]) {
j = i; // 更新最右回文串的中心为 i
}
}
return r; // 返回半径数组
}
int main() {
// 加速 IO
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::string S; // 输入字符串 S
std::cin >> S;
// 运行 Manacher 算法得到半径数组 r
auto r = manacher(S);
int n = S.size(); // S 的长度
int l = 0; // 寻找最小的 l,使得 S[l...n-1] 是回文
// 循环条件:r[l + n] < n - l + 1
// l+n 是 S[l...n-1] 在 t 中的中心索引
// n-l+1 是使得以 l+n 为中心的回文串能覆盖到 t 末尾所需的最小半径
// 只要半径不够,就增加 l (即缩短后缀)
while (r[l + n] < n - l + 1) {
l++;
}
// 循环结束后,l 是最小的使得 S[l...n-1] 是回文串的起始索引
auto ans = S; // 最终答案以 S 开始
// 将 S 的前 l 个字符 S[0...l-1] 反转后添加到 ans 末尾
// 使用 C++20 ranges 和 views 进行反转和复制
// S.substr(0, l) 获取前缀
// | views::reverse 反转视图
// ranges::copy 将反转后的字符复制到 ans 的末尾 (通过 back_inserter)
ranges::copy(S.substr(0, l) | views::reverse, std::back_inserter(ans));
// 输出最终答案
std::cout << ans << "\n";
return 0; // 程序正常结束
}