题解 P2487 【[SDOI2011]拦截导弹】
cdq分治优化dp
导弹拦截拓展版
与大家都做过的入门级习题不同的是,这题不仅有高度还有速度
也就是是一个三维问题(分别是高度,速度,时间即输入顺序,下面我会分别用x,y,z表示这三维)
我们考虑借鉴只有二维时的思路
这一道题的第一问其实就是让我们求多了一维限制的最长不上升子序列
而第二问是求每枚导弹被拦截掉的概率,即每枚导弹在所有最长不上升子序列方案里的出现次数/总最长不上升子序列方案数
我们可以比较套路的定义一下dp
定义:
dp1[i]表示以第i枚导弹为结尾的方案中的最长不上升子序列长度
dp2[i]表示以第i枚导弹为开头的方案中的最长不上升子序列长度
f1[i]表示以第i枚导弹为结尾的最长不上升子序列方案总数
f2[i]表示以第i枚导弹为开头的最长不上升子序列方案总数
那么第一问的答案即max{dp1[i]}或者max{dp2[i]}
第二问第i枚导弹的答案即f1[i]*f2[i](乘法原理)且dp1[i]+dp2[i]=ans-1,ans为第一问答案,减1是i导弹被计算了两次
故我们只要把dp解决这个问题就结束了
显然O(n^2)的dp类比导弹拦截能很容易的得出
//dp1的代码,dp2对称一下容易得到
//这里默认z按从小到大排好了
for(int i=1;i<=n;i++){
for(int j=1;j<i;j++){
if(a[i].x<=a[j].x&&a[i].y<=a[j].y){
if(dp1[i]==dp1[j]+1)
f1[i]+=f1[j];
else if(dp1[i]<dp1[j]+1){
dp1[i]=dp1[j]+1;
f1[i]=f1[j];
}
}
}
}
这是一个三维偏序的相关问题
我们考虑用cdq分治来优化dp
这里由于要用到树状数组,我懒得离散化x和y所以我用z搞树状数组,用第一维排序
我们把按x排好序后发现只能左边的导弹的dp值贡献给右边
这里我们应该先把左边的dp值都算完了再去更新右边的
故这里的递归顺序应是左中右
这个地方特别注意
bool cmp1(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y>j.y):i.x>j.x;
}
bool cmp2(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y<j.y):i.x<j.x;
}
bool cmp3(node i,node j){
return i.y>j.y;
}
bool cmp4(node i,node j){
return i.y<j.y;
}
//顺序什么的可以自己搞定
void cdq(int l,int r,int opt){//opt==1是计算dp1,2是计算dp2
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid,opt);//先把左边算好
if(opt==1){
sort(a+l,a+mid+1,cmp3);
sort(a+mid+1,a+r+1,cmp3);
}
else{
sort(a+l,a+mid+1,cmp4);
sort(a+mid+1,a+r+1,cmp4);
}
merge_sort(l,r,opt);//把左边对右边的贡献加到右边上
//这样写而不是先去递归右边主要是怕右边的左半部分还每加上左边的贡献就去算对右边右半部分的贡献了
if(opt==1)sort(a+mid+1,a+r+1,cmp1);
else sort(a+mid+1,a+r+1,cmp2);
//记得排回来
cdq(mid+1,r,opt);
return ;
}
然后注意树状数组维护时算f会比较麻烦
因为f需要所有方案的和
放代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=50005;
struct node{
int x,y,z;
int dp;
double f;
}a[maxn];
int n;
int tree1[maxn];
double tree2[maxn];
int dp1[maxn],dp2[maxn];
double f1[maxn],f2[maxn];
int read(){
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*y;
}
bool cmp1(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y>j.y):i.x>j.x;
}
bool cmp2(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y<j.y):i.x<j.x;
}
bool cmp3(node i,node j){
return i.y>j.y;
}
bool cmp4(node i,node j){
return i.y<j.y;
}
int lowbit(int x){
return x&-x;
}
void update(int x,int dp,double f){
while(x<=n){
if(tree1[x]<dp){
tree1[x]=dp;
tree2[x]=f;
}
else if(tree1[x]==dp)
tree2[x]+=f;
x+=lowbit(x);
}
return ;
}
int query1(int x){
int ans=0;
while(x){
ans=max(ans,tree1[x]);
x-=lowbit(x);
}
return ans;
}
double query2(int x,int dp){
double ans=0;
while(x){
if(tree1[x]==dp)ans+=tree2[x];
x-=lowbit(x);
}
return ans;
}
void del(int x){
while(x<=n){
tree1[x]=tree2[x]=0;
x+=lowbit(x);
}
return ;
}
bool check(int p,int q,int opt){
if(opt==1)return a[p].y>=a[q].y;
return a[p].y<=a[q].y;
}
void merge_sort(int l,int r,int opt){
int mid=l+((r-l)>>1);
int p=l,q=mid+1;
while(p<=mid&&q<=r){
if(check(p,q,opt)){
update(a[p].z,a[p].dp,a[p].f);
p++;
}
else{
int val=query1(a[q].z)+1;
if(a[q].dp<val){
a[q].dp=val;
a[q].f=query2(a[q].z,val-1);
}
else if(a[q].dp==val)
a[q].f+=query2(a[q].z,val-1);
q++;
}
}
while(q<=r){
int val=query1(a[q].z)+1;
if(a[q].dp<val){
a[q].dp=val;
a[q].f=query2(a[q].z,val-1);
}
else if(a[q].dp==val)
a[q].f+=query2(a[q].z,val-1);
q++;
}
for(int i=l;i<p;i++)del(a[i].z);
return ;
}
void cdq(int l,int r,int opt){
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid,opt);
if(opt==1){
sort(a+l,a+mid+1,cmp3);
sort(a+mid+1,a+r+1,cmp3);
}
else{
sort(a+l,a+mid+1,cmp4);
sort(a+mid+1,a+r+1,cmp4);
}
merge_sort(l,r,opt);
if(opt==1)sort(a+mid+1,a+r+1,cmp1);
else sort(a+mid+1,a+r+1,cmp2);
cdq(mid+1,r,opt);
return ;
}
int main(){
n=read();
for(int i=1;i<=n;i++){
a[i].x=read();a[i].y=read();a[i].z=i;
a[i].dp=1;a[i].f=1;
}
sort(a+1,a+n+1,cmp1);
cdq(1,n,1);
int ans=0;
for(int i=1;i<=n;i++){
dp1[a[i].z]=a[i].dp;
f1[a[i].z]=a[i].f;
ans=max(ans,dp1[a[i].z]);
}
printf("%d\n",ans);
for(int i=1;i<=n;i++){
a[i].z=n-a[i].z+1;
a[i].dp=1;a[i].f=1;
}
sort(a+1,a+n+1,cmp2);
cdq(1,n,2);
double k=0;
for(int i=1;i<=n;i++){
dp2[n-a[i].z+1]=a[i].dp;
f2[n-a[i].z+1]=a[i].f;
if(dp2[n-a[i].z+1]==ans)k+=f2[n-a[i].z+1];
}
for(int i=1;i<=n;i++){
if(dp1[i]+dp2[i]==ans+1)
printf("%.5lf ",f1[i]*f2[i]/k);
else
printf("0.00000 ");
}
return 0;
}