题目描述
难度分:1982
输入n(1≤n≤2×105)和长为n的数组a(1≤a[i]≤109)。
你可以修改a中恰好一个数(修改成任意整数)。输出a的最长严格递增子序列(LIS)的长度。
注:子序列不一定连续。
输入样例1
4
3 2 2 4
输出样例1
3
输入样例2
5
4 5 3 6 7
输出样例2
4
算法
一周灵茶AK时刻!
前后缀分解+线段树优化DP
这个题自己要写的代码非常短,大部分的代码都是经典问题和数据结构,可以直接复制粘贴模板。比较直观的一个思路就是要枚举这个要操作的元素a[i],f[x]是以a[x]结尾的最长上升子序列长度,g[y]是以g[y]开头的最长上升子序列长度。这两个数组可以用O(nlog2n)求LIS的方法求出来,这是个经典的算法,可以用同一个模板(用模板求出f,再把a逆序并求相反数,求出g,把g反转过来即可,注意求完g后还原a数组)。
如果x<i<y,只要a[x]+1<a[y],就可以把a[i]修改为[a[x]+2,a[y]−1]里面的某个数,从而把以a[x]结尾的最长上升子序列和以a[y]开头的最长上升子序列接起来。构建一个线段树,索引是a[i](注意a数组的值域比较大,需要对其进行离散化),索引上对应的值是f[i],初始情况下线段树上所有索引位置上都是0。
然后枚举这y,这时候需要为i腾一个位置出来,这个位置至少是i=y−1,所以x的位置应该满足x≤y−2,当遍历到y的时候就要把f[y−2]插入到线段树的a[y−2]位置。接下来求线段树在值域[1,a[y]−1)上的最大值就可以了,这就是我们需要的最大f[x],维护1+f[x]+g[y]的最大值。
注意i=1、i=n这两个边界情况,分别没有x和y,需要特判一下。
复杂度分析
时间复杂度
预处理出f和g的时间复杂度是O(nlog2n)的;遍历y求解是O(n)的,但对于每个y,需要对线段树进行单点修改和区间查询操作,时间复杂度为O(log2n)。因此,整个算法的时间复杂度为O(nlog2n)。
空间复杂度
线段树的空间复杂度为O(n);对a数组中的元素进行离散化空间消耗为O(n);f和g两个数组的时间复杂度为O(n);对线段树进行递归的时候递归深度是O(log2n)规模的,空间复杂度为O(log2n)。因此,整个算法的额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
int n, a[N], f[N], g[N];
class SegmentTree {
public:
struct Info {
int l, r, v;
Info() {}
Info(int left, int right, int val): l(left), r(right), v(val) {}
} seg[N<<2];
explicit SegmentTree() {}
void build(int u, int l, int r) {
if(l == r) {
seg[u] = Info(l, r, 0);
}else {
int mid = l + r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void modify(int pos, int val) {
modify(1, pos, val);
}
Info query(int l, int r) {
if(l > r) return Info(0, 0, 0);
return query(1, l, r);
}
private:
void modify(int u, int pos, int val) {
if(seg[u].l == pos && seg[u].r == pos) {
seg[u] = Info(pos, pos, val);
}else {
int mid = seg[u].l + seg[u].r >> 1;
if(pos <= mid) {
modify(u<<1, pos, val);
}else {
modify(u<<1|1, pos, val);
}
pushup(u);
}
}
Info query(int u, int l, int r) {
if(l <= seg[u].l && seg[u].r <= r) {
return seg[u];
}
int mid = seg[u].l + seg[u].r >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
void pushup(int u) {
seg[u] = merge(seg[u<<1], seg[u<<1|1]);
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l;
info.r = rchild.r;
info.v = max(lchild.v, rchild.v);
return info;
}
};
void get(int dp[]) {
vector<int> ends;
for(int i = 1; i <= n; i++) {
if(ends.empty()) {
ends.push_back(a[i]);
dp[i] = 1;
}else{
if(a[i] > ends.back()){
ends.push_back(a[i]);
dp[i] = ends.size();
}else {
int p = lower_bound(ends.begin(), ends.end(), a[i]) - ends.begin();
dp[i] = p + 1;
ends[p] = a[i];
}
}
}
}
int main() {
scanf("%d", &n);
vector<int> vals;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
vals.push_back(a[i]);
}
// 离散化a的元素值
sort(vals.begin(), vals.end());
vals.erase(unique(vals.begin(), vals.end()), vals.end());
// 预处理出f和g
get(f);
reverse(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) {
a[i] *= -1;
}
get(g);
reverse(a + 1, a + n + 1);
for(int i = 1; i <= n; i++) {
a[i] *= -1;
}
reverse(g + 1, g + n + 1);
int sz = vals.size();
int ans = f[n]; // 初始化答案为原始的LIS长度
SegmentTree seg;
seg.build(1, 1, sz);
for(int i = 3; i <= n; i++) {
int index = lower_bound(vals.begin(), vals.end(), a[i - 2]) - vals.begin() + 1;
seg.modify(index, f[i - 2]);
index = lower_bound(vals.begin(), vals.end(), a[i] - 1) - vals.begin() + 1;
ans = max(ans, 1 + g[i] + seg.query(1, index - 1).v); // 小于a[i]-1的f最大值
}
if(f[n - 1] == f[n]) {
ans = max(ans, f[n] + 1);
}
if(g[1] == g[2]) {
ans = max(ans, f[n] + 1);
}
printf("%d\n", ans);
return 0;
}