题目描述
难度分:1900
输入n(1≤n≤3×105)和长为n的数组a(−10≤a[i]≤10),下标从1开始。
- 定义s(l,r)=a[l]+a[l+1]+…+a[r]。
- 定义长为 n×(n+1)2的数组b=[s(1,1),s(1,2),s(1,3),…,s(1,n),s(2,2),s(2,3),s(2,4),…,s(2,n),s(3,3),s(3,4),s(3,5),…,s(3,n),…s(n,n)]
然后输入q(1≤q≤3×105)和q个询问。每个询问输入两个数l和r(1≤l≤r≤n×(n+1)2)。
输出b[l]+b[l+1]+…b[r]。
输入样例
4
1 2 5 10
15
1 1
1 2
1 3
1 4
1 5
1 10
5 10
6 10
2 8
3 4
3 10
3 8
5 6
5 5
1 8
输出样例
1
4
12
30
32
86
56
54
60
26
82
57
9
2
61
算法
前缀和+二分
先预处理出a的前缀和数组pres、二阶前缀和数组press,以及一个数组s,s[i]=s(i,i)+s(i,i+1)+…+s(i,n−1)+s(i,n)(s[i]可以由press和pres两个数组求出来),再对s求前缀和。
对于某个询问(l,r),我们首先要知道l和r这两个位置对应的是b数组中的哪个s(i,j)?这可以通过二分查找确定,i=1时能够贡献b的n个元素,i=2时能够贡献b的n−1个元素,……,因此这样就可以二分确定l和r对应s(i,j)的(i,j)。
假设lb是l通过二分确定的(i,j),ub是r通过二分确定的(i,j)。
如果i=lb[0]=ub[0],说明答案是
s(i,lb[1])+s(i,lb[1]+1+…+s(i,ub[1])
=Σlb[1]j=lb[1]a[j]+Σlb[1]+1j=lb[1]a[j]+…+Σub[1]j=lb[1]a[j]
=Σlb[1]j=1a[j]+Σlb[1]+1j=1a[j]+…+Σub[1]j=1a[j]−(ub[1]−lb[1]+1)×Σlb[1]−1j=1a[j]
=press[ub[1]]−press[lb[1]−1]−(ub[1]−lb[1]+1)×pres[lb[1]−1]
如果lb[0]<ub[0],答案就可以由3段构成:
-
第一段i=lb[0]对应j∈[lb[1],n],可以用情况1中的方式求得贡献。
-
第三段i=ub[0]对应j∈[ub[0],ub[1]],也可用情况1中的方式求得贡献。
-
中间部分的段都是完整的,贡献为s[lb[0]+1]+s[lb[0]+2]+…+s[ub[0]−1],可以通过s的前缀和数组求得。
因此对于每个询问,O(log2n)二分查找之后计算答案就只需要O(1)的时间复杂度。
复杂度分析
时间复杂度
先O(n)预处理出a数组的前缀和数组pres、二阶前缀和数组press,以及s数组。之后对q个询问的每一个,都要进行O(log2n)的二分查找,因此处理所有询问的时间复杂度为O(qlog2n)。因此,整个算法的时间复杂度为O(n+qlog2n)。
空间复杂度
空间消耗就是三个线性空间的数组s、pres、press,因此额外空间复杂度为O(n)。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 300010;
int n, q, a[N];
LL pres[N], press[N], s[N];
array<LL, 2> get(LL x) {
// [n,n-1,...,2,1]
LL l = 1, r = n;
while(l < r) {
LL mid = l + r >> 1;
LL m = n - mid + 1;
if(mid*(n + m)/2 >= x) {
r = mid;
}else {
l = mid + 1;
}
}
LL c = r - 1;
return {r, x - c*(n + n - c + 1)/2 + c};
}
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
pres[i] = pres[i - 1] + a[i];
press[i] = press[i - 1] + pres[i];
}
for(int i = 1; i <= n; i++) {
LL cnt = n - i + 1;
s[i] = s[i - 1] + press[n] - press[i - 1] - cnt * pres[i - 1];
}
scanf("%d", &q);
for(int i = 1; i <= q; i++) {
LL l, r;
scanf("%lld%lld", &l, &r);
auto lb = get(l);
auto ub = get(r);
if(lb[0] == ub[0]) {
LL cnt = ub[1] - lb[1] + 1;
printf("%lld\n", press[ub[1]] - press[lb[1] - 1] - cnt * pres[lb[0] - 1]);
}else {
LL cnt = n - lb[1] + 1;
LL s1 = press[n] - press[lb[1] - 1] - cnt * pres[lb[0] - 1];
LL s2 = lb[0] + 1 <= ub[0] - 1? s[ub[0] - 1] - s[lb[0]]: 0;
cnt = ub[1] - ub[0] + 1;
LL s3 = press[ub[1]] - press[ub[0] - 1] - cnt * pres[ub[0] - 1];
printf("%lld\n", s1 + s2 + s3);
}
}
return 0;
}