题解:P3806 【模板】点分治 1 / 点分治学习笔记
前置知识:
- 树的重心
- 题目描述
给定一棵有
一眼:这题我会,可以
数据规模与约定:- 对于
时间限制:200ms。
提示:本题不卡常。
这就要用到点分治了。
点分治适合处理大规模的树上路径信息问题。 ——OI-wiki
- Part 1:求树的重心
由树的重心性质可知,只需要求
证明:每次的子树找重心换根后,其每个新的子树的节点数一定不会大于
以下是实现以上过程的代码,sz
是当前子树的大小,maxp
是最大子树的大小,新的根就是
void getroot(int u,int fa){
sz[u]=1;
maxp[u]=0;
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]+=sz[v];
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],sum-sz[u]);
if(maxp[u]<maxp[root]) root=u;
}
- Part 2:点分治
点分治是一个离线算法,你不需要知道它为什么离线。
离线记录询问并在分治过程中处理,记为 que
数组。
设当前根为 rem
数组中,存储方式为 rem
数组。令 jud
数组表示在 pd
数组表示树上距离为
证明:由于树上的每个最短路径都是从某个节点到某个祖宗节点再到某个后代节点,所以此操作可以不重不漏的统计最短路径。
此部分的代码如下:
for(int j=1;j<=dfn;j++){
for(int k=1;k<=m;k++){
if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
}
}
for(int i=1;i<=m;++i){
if(pd[i]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
最后别忘了清空,注意:用 memset 会 TLE。
for(int i=1;i<=p;i++) jud[q[i]]=0;
其中 q
就是用来记录哪个节点该被清空。
以下是实现以上过程的代码:
void getdis(int u,int fa){
rem[++dfn]=dis[u];
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+G[u][i].fi;
getdis(v,u);
}
}
void calc(int u){
int p=0;
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(vis[v]) continue;
dfn=0;
dis[v]=G[u][i].fi;
getdis(v,u);
for(int j=1;j<=dfn;j++){
for(int k=1;k<=m;k++){
if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
}
}
for(int j=1;j<=dfn;j++){
q[++p]=rem[j];
jud[rem[j]]=1;
}
}
for(int i=1;i<=p;i++) jud[q[i]]=0;
}
最后就是朴实无华的每次找根并解决询问的过程。
void solve(int u){
vis[u]=1;
jud[0]=1;
calc(u);
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(vis[v]) continue;
sum=sz[v];
root=0;
maxp[0]=1e7+1;
getroot(v,0);
solve(root);
}
}
完整代码如下,时间复杂度
#include<bits/stdc++.h>
#define PII pair<int,int>
#define fi first
#define se second
using namespace std;
int n,m,u,v,w,maxp[100001],sz[100001],dis[100001],rem[100001],q[100001],que[100001],sum,dfn,root,ans;
bool vis[100001],pd[10000001],jud[10000001];
vector<PII>G[100001];
void getroot(int u,int fa){
sz[u]=1;
maxp[u]=0;
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(v==fa||vis[v]) continue;
getroot(v,u);
sz[u]+=sz[v];
maxp[u]=max(maxp[u],sz[v]);
}
maxp[u]=max(maxp[u],sum-sz[u]);
if(maxp[u]<maxp[root]) root=u;
}
void getdis(int u,int fa){
rem[++dfn]=dis[u];
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+G[u][i].fi;
getdis(v,u);
}
}
void calc(int u){
int p=0;
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(vis[v]) continue;
dfn=0;
dis[v]=G[u][i].fi;
getdis(v,u);
for(int j=1;j<=dfn;j++){
for(int k=1;k<=m;k++){
if(que[k]>=rem[j]) pd[k]|=jud[que[k]-rem[j]];
}
}
for(int j=1;j<=dfn;j++){
q[++p]=rem[j];
jud[rem[j]]=1;
}
}
for(int i=1;i<=p;i++) jud[q[i]]=0;
}
void solve(int u){
vis[u]=1;
jud[0]=1;
calc(u);
for(int i=0;i<G[u].size();i++){
int v=G[u][i].se;
if(vis[v]) continue;
sum=sz[v];
root=0;
maxp[0]=n;
getroot(v,0);
solve(root);
}
}
int main(){
cin>>n>>m;
for(int i=1;i<n;i++){
cin>>u>>v>>w;
G[u].push_back({w,v});
G[v].push_back({w,u});
}
for(int i=1;i<=m;i++) cin>>que[i];
maxp[0]=n;
sum=n;
getroot(1,0);
solve(root);
for(int i=1;i<=m;i++){
if(pd[i]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
}