题目描述
难度分:2600
输入n(1≤n≤4×105),m(1≤m≤4×105)和长为n的数组 a(1≤a[i]≤109),数组下标从1开始。
然后输入m个询问,每个询问输入两个数i(1≤i≤n)和b(1≤b≤109)。
对于每个询问,输出把a[i]替换成b后,a的最长严格递增子序列(LIS)的长度。
注意每个询问之间彼此独立,比如第一个询问把a[1]替换成6,那么对于第二个询问,a[1] 还是原来的值。
输入样例1
4 4
1 2 3 4
1 1
1 4
4 3
4 5
输出样例1
4
3
3
4
输入样例2
4 2
1 3 2 6
3 5
2 4
输出样例2
4
3
算法
前后缀分解+动态规划
先做一个定义,对所有位置的元素进行一个分类:
- 如果a[i]不在任何LIS中,属于第1类。
- 如果a[i]在至少一个LIS中,但不在所有LIS中,属于第2类。
- 如果a[i]在所有LIS中,属于第3类。
状态定义
f[i]表示以a[i]结尾的最长上升子序列长度,g[i]表示以a[i]开头的最长上升子序列长度。这样一来,如果i属于某个最长上升子序列,一定满足f[i]+g[i]=maxlen+1,其中maxlen就是a的最长上升子序列长度,加1是因为a[i]算了两次。
状态转移
f和g的状态转移是类似的,由于n≤2×105,所以需要用O(nlog2n)的做法来求。利用树状数组或线段树来优化这个DP
,对a[i]的数值开权值线段树seg,seg[a[i]]表示以数值a[i]结尾的最长上升子序列的最大长度,a[i]的值域比较小,不需要对数值进行离散化。
然后开始枚举i∈[1,n],看每个a[i]属于哪种类型?可以发现,如果满足f[i]+g[i]=maxlen+1,那么i要么是2类型要么是3类型。否则就应该是1类型,它不属于任何一个LIS。重点是如何区分2和3两种类型,可以发现如果存在i≠j满足f[i]=f[j]和g[i]=g[j],那么i和j就都是2类型,因为无论是选i还是j都不影响最长上升子序列。所以用一个哈希表存储二元组(f[i],g[i])的频数,在f[i]+g[i]=maxlen+1的情况下,如果(f[i],g[i])出现次数超过1,i就是2类型,否则是3类型。
目前为止还只是DP
的预处理,还需要做一个前后缀的预处理,便于离线处理每条查询。
前后缀分解+双指针
将所有的询问queries按照修改的索引位置x排序,当遍历到x位置时(假设此时的询问是(x,b,index)),将所有满足j<x的a[j]对应的f[j]放到一棵最大值的线段树seg中,填入seg[a[j]]=max(f[j],seg[a[j]])。这样一来,遍历到x时,所有x左边的a[j]都已经加入到线段树中了,在线段树中查询<b的最大DP
值,这样就可以知道x左边能被b接上的最长递增子序列长度len,将其赋值给pre[index]=len。
同理,也用双指针算法配合线段树预处理出suf数组。有了pre和suf两个数组,如果a[x]=b这个元素要选进上升子序列,那么包含它的最长上升子序列长度就应该是pre[x]+suf[x]+1。分为以下三种情况:
- 如果x是第1类,修改a[x]后可能会延长包含x位置元素的最长上升子序列,但由于a[x]不在整个数组的LIS中,所以延长也不可能会超过maxlen,所以此时答案就是maxlen。
- 如果x是第2类,修改a[x]后也可能会延长包含x位置元素的最长上升子序列,由于a[x]不会存在于每个LIS中,所以整个数组仍然还会存在一个长度为maxlen的LIS。此时答案要么是pre[x]+suf[x]+1,要么是maxlen,取其中的较大值。
- 如果x是第3类,修改a[x]后仍然可能会延长包含x位置元素的最长上升子序列,且由于a[x]是数组LIS的必要元素,所以它被改变后LIS的长度可能会减少1(如果没有减少1,则pre[x]+suf[x]+1=maxlen,不影响答案),即maxlen−1。所以此时的答案要么是pre[x]+suf[x]+1,要么是maxlen−1,取其中的较大值。
最后需要注意的是,a[i]和每次询问的b范围较大,需要先进行离散化,再开权值线段树。
复杂度分析
时间复杂度
再最差情况下,需要离散化n+m个元素,在离散化的过程中,瓶颈在于排序,时间复杂度为O(n+m)log2(n+m)。DP
预处理出f和g,每次都要遍历i∈[1,n],对于每个i会存在线段树的查询操作,在最差情况下线段树的大小为n+m,因此状态转移的复杂度为O(log2(n+m)),总的时间复杂度为O(nlog2(n+m))。双指针预处理出pre和suf数组需要遍历[1,n]和[1,m],每次也需要在线段树上做查询,因此时间复杂度为O((n+m)log2(n+m))。最后遍历所有询问,利用所有预处理出来的辅助数组O(1)算出答案即可,时间复杂度为O(m)。
综上,整个算法的时间复杂度为O((n+m)log2(n+m))。
空间复杂度
f、g、pre、suf数组都是线性的,空间复杂度为O(n)。为了节省空间,可以复用一棵线段树(否则有可能超空间),线段树的空间复杂度为O(4n)。所以整个算法的额外空间复杂度为O(n),可以看成是一个常数较大的线性消耗。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
const int N = 400001;
int n, m, a[N], f[N], g[N], pre[N], suf[N];
struct custom_hash {
static uint64_t splitmix64(uint64_t x) {
x += 0x9e3779b97f4a7c15;
x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9;
x = (x ^ (x >> 27)) * 0x94d049bb133111eb;
return x ^ (x >> 31);
}
size_t operator()(uint64_t x) const {
static const uint64_t FIXED_RANDOM = chrono::steady_clock::now().time_since_epoch().count();
return splitmix64(x + FIXED_RANDOM);
}
};
class SegmentTree {
public:
struct Info {
int l, r, maximum;
Info() {}
Info(int left, int right, int val): l(left), r(right), maximum(val) {}
};
vector<Info> seg;
explicit SegmentTree() {}
void build(int u, int l, int r) {
seg.resize((r - l + 5)*4);
if(l == r) {
seg[u] = Info(l, r, 0);
}else {
int mid = l + r >> 1;
build(u<<1, l, mid);
build(u<<1|1, mid + 1, r);
pushup(u);
}
}
void modify(int pos, int val) {
modify(1, pos, val);
}
Info query(int l, int r) {
if(l > r) return Info(0, 0, 0);
return query(1, l, r);
}
private:
void modify(int u, int pos, int val) {
if(seg[u].l == pos && seg[u].r == pos) {
seg[u] = Info(pos, pos, val);
}else {
int mid = seg[u].l + seg[u].r >> 1;
if(pos <= mid) modify(u<<1, pos, val);
else modify(u<<1|1, pos, val);
pushup(u);
}
}
Info query(int u, int l, int r) {
if(l <= seg[u].l && seg[u].r <= r) return seg[u];
int mid = seg[u].l + seg[u].r >> 1;
if(r <= mid) {
return query(u<<1, l, r);
}else if(mid < l) {
return query(u<<1|1, l, r);
}else {
return merge(query(u<<1, l, r), query(u<<1|1, l, r));
}
}
void pushup(int u) {
seg[u] = merge(seg[u<<1], seg[u<<1|1]);
}
Info merge(const Info& lchild, const Info& rchild) {
Info info;
info.l = lchild.l;
info.r = rchild.r;
info.maximum = max(lchild.maximum, rchild.maximum);
return info;
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
vector<int> vals;
for(int i = 1; i <= n; i++) {
cin >> a[i];
vals.push_back(a[i]);
f[i] = g[i] = 1;
}
vector<array<int, 3>> queries;
for(int i = 1; i <= m; i++) {
int x, b;
cin >> x >> b;
vals.push_back(b);
queries.push_back({x, b, i});
}
sort(queries.begin(), queries.end());
sort(vals.begin(), vals.end());
vals.erase(unique(vals.begin(), vals.end()), vals.end());
unordered_map<int, int, custom_hash> val2pos;
int sz = vals.size();
for(int i = 0; i < sz; i++) {
val2pos[vals[i]] = i + 1;
}
int maxlen = 1;
SegmentTree seg;
seg.build(1, 1, sz);
for(int i = 1; i <= n; i++) {
int index = val2pos[a[i]];
int prelen = seg.query(1, index - 1).maximum;
if(seg.query(index, index).maximum < prelen + 1) {
seg.modify(index, prelen + 1);
}
f[i] = prelen + 1;
maxlen = f[i] > maxlen? f[i]: maxlen;
}
seg.build(1, 1, sz);
g[n + 1] = 0;
for(int i = n; i >= 1; i--) {
int index = val2pos[a[i]];
int suflen = seg.query(index + 1, sz).maximum;
if(seg.query(index, index).maximum < suflen + 1) {
seg.modify(index, suflen + 1);
}
g[i] = suflen + 1;
}
unordered_map<int, unordered_map<int, int, custom_hash>, custom_hash> mp;
for(int i = 1; i <= n; i++) {
mp[f[i]][g[i]]++;
}
seg.build(1, 1, sz);
for(int i = 0, j = 1; i < m; i++) {
int x = queries[i][0], b = queries[i][1], index = queries[i][2];
while(j < x) {
int val_idx = val2pos[a[j]];
if(seg.query(val_idx, val_idx).maximum < f[j]) {
seg.modify(val_idx, f[j]);
}
j++;
}
pre[index] = seg.query(1, val2pos[b] - 1).maximum;
}
seg.build(1, 1, sz);
for(int i = m - 1, j = n; i >= 0; i--) {
int x = queries[i][0], b = queries[i][1], index = queries[i][2];
while(j > x) {
int val_idx = val2pos[a[j]];
if(seg.query(val_idx, val_idx).maximum < g[j]) {
seg.modify(val_idx, g[j]);
}
j--;
}
suf[index] = seg.query(val2pos[b] + 1, sz).maximum;
}
for(int i = 0; i < m; i++) {
int x = queries[i][0], index = queries[i][2];
if(f[x] + g[x] == maxlen + 1) {
if(mp[f[x]][g[x]] > 1) {
a[index] = max(pre[index] + suf[index] + 1, maxlen);
}else {
a[index] = max(pre[index] + suf[index] + 1, maxlen - 1);
}
}else {
a[index] = max(pre[index] + suf[index] + 1, maxlen);
}
}
for(int i = 1; i <= m; i++) {
cout << a[i] << '\n';
}
return 0;
}