更好的阅读体验:https://www.cnblogs.com/Tenshi/p/14619631.html
原理
cdq分治用来解决什么样的问题呢?一般来说可以:
1. 统计具有三维属性 ($a,b,c$) 的、满足一定的比较关系 $data$ 有多少对。
2. 优化一些数据结构,算法。
流程:
下面说一下cdq分治解决上面模板题的流程。
排序解决第一维情况
首先我们将整个 $data$ 集排序。(以 $a$ 为第一关键字,以 $b$ 为第二关键字,以 $c$ 为第三关键字)
这样做便可以处理掉第一维的情况。
接下来,我们采取分治的思想来统计满足题意的对数。
对于目前需要处理的区间 $[l,r]$ :
记 $mid=\lfloor \frac{l+r}{2} \rfloor$ ,左区间 $[l,mid]$ ,右区间 $[mid+1,r]$ 。
分情况讨论:
- 假如满足条件的点对 $(p,q)$ 都在左区间或者右区间,那么我们递归处理下去就好。
- 否则,$p$ 必然在左区间,$q$ 必然在右区间,我们对这样的点对进行统计。
故重点在于第二种情况:
此时第一维情况($a_p\leq a_q$)必然是满足的。
采取双指针解决第二维情况
假设现在的左右两个区间是按照 $b$ 为关键字排好序了,我们记左区间的指针为 $j$ ,右区间的指针为 $i$ 。
因为我们的递归是自下而上的(基于归并排序的思想),所以我们可以在处理当前区间时进行以 $b$ 为关键字的排序,这样就可以保证处理到每一个区间时都是以 $b$ 为关键字的排序的了。
对于每一个 $i$ ,我们让 $j$ 右移,直到找到第一个 $j$ ,满足 $b_j>b_i$ ,那么对于 $x\in[l,j-1]$ 均有 $b_i\geq b_x$ 。
采用树状数组解决第三维情况
最后,只要解决第三维情况,就可以统计出每一个 $data_i$ 相应的贡献(对数)了:对于元素 $data_x$ 其中 $x\in[l,j-1]$ ,我们将这样的 $data$ 的属性 $c$ 值放入树状数组中,只需查询一下 $query(c_i)$ 就可以找到满足 $c_i\geq c_x$ 的个数了,注意到与此同时,前面两维情况也同时被满足,所以这样的个数就是所求的贡献。
细节:
我们记三维数据同时相等的 $data$ 为同一类,发现分治的时候同一类之间并不好处理,所以我们将同一类进行合并,用 $cnt$ 来记录每一类的个数,那么对不同类之间进行分治就可以了。
代码
#pragma GCC optimize("O3")
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=1e5+5, M=2e5+5;
int n,m;
struct data{
int a,b,c,cnt,res;
bool operator<(const data &o)const{
if(a!=o.a) return a<o.a;
if(b!=o.b) return b<o.b;
return c<o.c;
}
bool operator==(const data &o)const{
return a==o.a && b==o.b && c==o.c;
}
}e[N], tmp[N];
int tot;
int tr[M];
int lowbit(int x){return x&-x;}
void add(int p,int k){
for(;p<M;p+=lowbit(p)) tr[p]+=k;
}
int query(int p){
int res=0;
for(;p;p-=lowbit(p)) res+=tr[p];
return res;
}
void cdq(int l,int r){
if(l>=r) return;
int mid=l+r>>1;
cdq(l,mid), cdq(mid+1,r);
// 这里定义左区间对应的指针为 j, 右区间对应的指针为 i.
for(int j=l, i=mid+1, k=l;k<=r;k++)
if(i>r || j<=mid && e[j].b<=e[i].b) add(e[j].c,e[j].cnt), tmp[k]=e[j++]; // 如果说右区间的指针已经走到边界,或者左区间的b值比较小。
else e[i].res+=query(e[i].c), tmp[k]=e[i++];
for(int j=l;j<=mid;j++) add(e[j].c,-e[j].cnt); // 恢复桶
for(int k=l;k<=r;k++) e[k]=tmp[k]; // 完成排序,复制
}
int ans[M];
int main(){
cin>>n>>m;
for(int i=1;i<=n;i++){
int a,b,c; cin>>a>>b>>c;
e[i]={a,b,c,1};
}
sort(e+1,e+1+n);
for(int i=1;i<=n;i++)
if(e[i]==e[tot]) e[tot].cnt++;
else e[++tot]=e[i];
cdq(1,tot);
for(int i=1;i<=tot;i++) ans[e[i].res+e[i].cnt-1]+=e[i].cnt;
for(int i=0;i<n;i++) cout<<ans[i]<<'\n';
return 0;
}