思路
题目给出一个长度为$n$的序列,求长度大于1的严格上升子序列的方案数,且每个方案序列不能重复
1. 我们设$dp[i]$状态表示以$i$结尾的严格上升子序列的方案数,易得转移方程$dp[i] = \sum_{j=0}^{i-1}dp[j]$,我们知道$O(n^{2})$的复杂度是会TLE的,所以我们考虑用线段树维护$[0,i-1]$区间的$dp$数组的和,因为序列中的元素值$<=|10^{9}|$,所以我们需要对序列离散化,并去重排序
2. 关于线段树的更新操作,因为我们只需查询$[0,i-1]$区间的和加到下标$i$上,所以我们只需单点修改,不用维护懒标记和使用pushdown函数
3. 怎么保证方案不重复呢?例如序列1 2 3 5 4 5, 显然第二次$dp[5]$的方案数是包含第一次的,我们每次做修改操作之前记录上一次该位置的方案数,新的方案数-上一次的方案数即为增加的方案数
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1e5 + 5;
int a[N];
LL ans = 0, mod = 1e9 + 7;
vector<int> v;
struct Node{
int l, r;
LL sum;
bool vis;
}tr[4 * N];
void build(int u, int l, int r) {
tr[u] = {l, r};
if(l == r) {
tr[u] = {l, l, 0, false};
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void pushup(int u) {
tr[u].sum = (tr[u << 1].sum % mod + tr[u << 1 | 1].sum % mod) % mod;
}
void update(int u, int pos, LL val) {
if(tr[u].l == pos && tr[u].r == pos) {
if(!tr[u].vis) ++tr[u].sum; //如果该节点没有被访问过,则加上只有自身的方案{pos}
tr[u] = {pos, pos, tr[u].sum + val, true};
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) update(u << 1, pos, val);
else update(u << 1 | 1, pos, val);
pushup(u);
}
LL query(int u, int l, int r) {
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if(l <= mid) sum = (sum % mod + query(u << 1, l, r) % mod) % mod;
if(r > mid) sum = (sum % mod + query(u << 1 | 1, l, r) % mod) % mod;
return sum;
}
int main()
{
int n;
scanf("%d", &n);
v.resize(n);
for(int i = 0;i < n;i++) {
scanf("%d", &a[i]);
v[i] = a[i];
}
v.erase(unique(v.begin(), v.end()), v.end());
sort(v.begin(), v.end());
build(1, 0, v.size() - 1);
for(int i = 0;i < n;i++) {
int pos = lower_bound(v.begin(), v.end(), a[i]) - v.begin();
if(pos == 0) update(1, 0, 0);
else {
LL pre = query(1, pos, pos);
LL res = query(1, 0, pos - 1);
if(res == 0) {
update(1, pos, 0);
continue;
}
if(pre > 0) --pre;
//pre为当前节点上一次修改后的方案数,因为pre中包含了长度为1{pos}的方案,所以我们-1
//res为当前节点本次修改后的方案数,res - pre即为方案增量
res -= pre;
update(1, pos, res);
ans = (ans % mod + res % mod) % mod;
}
}
printf("%lld", ans);
return 0;
}