题解:P3806 【模板】点分治 1 / 点分治学习笔记

· · 算法·理论

前置知识:

给定一棵有 n 个点的树,询问树上距离为 k 的点对是否存在。

一眼:这题我会,可以 O(n^2\log n) 用倍增 LCA 求得。

数据规模与约定:- 对于 100\% 的数据,保证 1 \leq n\leq 10^4

时间限制:200ms。

提示:本题不卡常

这就要用到点分治了。

点分治适合处理大规模的树上路径信息问题。 ——OI-wiki

由树的重心性质可知,只需要求 \log n 次重心,每个点就都确定了,树的重心一定是最大子树最小的点,每次找到子树的重心把根节点换成它,而点分治的“分治”就得名于此,而这种做法相比于暴力换根的 O(n^2) 变成了 O(n\log n),不会被卡掉。

证明:每次的子树找重心换根后,其每个新的子树的节点数一定不会大于 \dfrac n 2n 为当前子树的节点数),从而保证了换根的复杂度是 O(n\log n)

以下是实现以上过程的代码,sz 是当前子树的大小,maxp 是最大子树的大小,新的根就是 \min_{v\in u}\text maxp_vv,其中 u 是当前子树内的所有节点。

void getroot(int u,int fa){
    sz[u]=1;
    maxp[u]=0;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],sum-sz[u]);
    if(maxp[u]<maxp[root]) root=u;
}

点分治是一个离线算法,你不需要知道它为什么离线

离线记录询问并在分治过程中处理,记为 que 数组。

设当前根为 root,他的所有子树为 v_1,v_2\sim v_n,假设当前处理的子树为 v_i,我们先给每个节点求出 dfn 序(然而并没什么用,只是为了方便统计),求出该子树中每个节点与根节点之间的距离 dis,并存储于 rem 数组中,存储方式为 rem_{dfn}=dis,遍历完每个子树后要清空 rem 数组。令 jud 数组表示在 v_1\sim v_{i-1} 中是否有与根节点长度为 dis 的点,若有则 jud_{dis}=1,否则 jud_{dis}=0pd 数组表示树上距离为 k 的点对是否存在,pd_k=1 则表示存在,pd_k=0 则表示不存在,判断标准就是如果 jud_{que_k-rem_j}=1,则 pd_k=1,如果 jud_{que_k-rem_j}=0,则 pd_k=0,通俗的解释就是把 v_i 中的节点与每个询问的 k(即 que_i)挨个配对,如果能配上则 pd_k=1

证明:由于树上的每个最短路径都是从某个节点到某个祖宗节点再到某个后代节点,所以此操作可以不重不漏的统计最短路径。

此部分的代码如下:

        for(int j=1;j<=dfn;j++){
            for(int k=1;k<=m;k++){
                if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
            }
        }
    for(int i=1;i<=m;++i){
        if(pd[i]) cout<<"AYE"<<endl;
        else cout<<"NAY"<<endl;
    }

最后别忘了清空,注意:用 memset 会 TLE。

    for(int i=1;i<=p;i++) jud[q[i]]=0;

其中 q 就是用来记录哪个节点该被清空。

以下是实现以上过程的代码:

void getdis(int u,int fa){
    rem[++dfn]=dis[u];
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(v==fa||vis[v]) continue;
        dis[v]=dis[u]+G[u][i].fi;
        getdis(v,u);
    }
}
void calc(int u){
    int p=0;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(vis[v]) continue;
        dfn=0;
        dis[v]=G[u][i].fi;
        getdis(v,u);
        for(int j=1;j<=dfn;j++){
            for(int k=1;k<=m;k++){
                if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
            }
        }
        for(int j=1;j<=dfn;j++){
            q[++p]=rem[j];
            jud[rem[j]]=1;
        }
    }
    for(int i=1;i<=p;i++) jud[q[i]]=0;
}

最后就是朴实无华的每次找根并解决询问的过程。

void solve(int u){
    vis[u]=1;
    jud[0]=1;
    calc(u);
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(vis[v]) continue;
        sum=sz[v];
        root=0;
        maxp[0]=1e7+1;
        getroot(v,0);
        solve(root);
    }
}

完整代码如下,时间复杂度 O(nm)(也许):

#include<bits/stdc++.h>
#define PII pair<int,int>
#define fi first
#define se second
using namespace std;
int n,m,u,v,w,maxp[100001],sz[100001],dis[100001],rem[100001],q[100001],que[100001],sum,dfn,root,ans;
bool vis[100001],pd[10000001],jud[10000001];
vector<PII>G[100001];
void getroot(int u,int fa){
    sz[u]=1;
    maxp[u]=0;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(v==fa||vis[v]) continue;
        getroot(v,u);
        sz[u]+=sz[v];
        maxp[u]=max(maxp[u],sz[v]);
    }
    maxp[u]=max(maxp[u],sum-sz[u]);
    if(maxp[u]<maxp[root]) root=u;
}
void getdis(int u,int fa){
    rem[++dfn]=dis[u];
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(v==fa||vis[v]) continue;
        dis[v]=dis[u]+G[u][i].fi;
        getdis(v,u);
    }
}
void calc(int u){
    int p=0;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(vis[v]) continue;
        dfn=0;
        dis[v]=G[u][i].fi;
        getdis(v,u);
        for(int j=1;j<=dfn;j++){
            for(int k=1;k<=m;k++){
                if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
            }
        }
        for(int j=1;j<=dfn;j++){
            q[++p]=rem[j];
            jud[rem[j]]=1;
        }
    }
    for(int i=1;i<=p;i++) jud[q[i]]=0;
}
void solve(int u){
    vis[u]=1;
    jud[0]=1;
    calc(u);
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i].se;
        if(vis[v]) continue;
        sum=sz[v];
        root=0;
        maxp[0]=n;
        getroot(v,0);
        solve(root);
    }
}
int main(){
    cin>>n>>m;
    for(int i=1;i<n;i++){
        cin>>u>>v>>w;
        G[u].push_back({w,v});
        G[v].push_back({w,u});
    }
    for(int i=1;i<=m;i++) cin>>que[i];
    maxp[0]=n;
    sum=n;
    getroot(1,0);
    solve(root);
    for(int i=1;i<=m;i++){
        if(pd[i]) cout<<"AYE"<<endl;
        else cout<<"NAY"<<endl;
    }
}