P4897 【模板】最小割树(Gomory-Hu Tree) 题解

· · 题解

分析

最小割树用于处理任意割断代价总和最小的边使得任意两个点不连通的问题。

很显然,每次得到询问之后处理是不优秀的,时间复杂度为 O(Qn^2m),那么我们就可以用 Gomory-Hu 算法优化到 O(n^3m+Q)

具体来说,这个算法把这些点建立成了一棵树,然后使得其中任意两个点 s,t 的最小割就是这两个点树上简单路径上的边权最小值。

为了建立出来这么一棵树,我们需要分治。

  1. 在当前的点集中任意选取两个点 s,t,在原图上跑最小割,得到结果。
  2. 把当前最小割划分出来的两个集合分别向下递归,得到两棵子树的根。
  3. 连接这两个根,边权为当前最小割的值。
  4. 如果 s 这个集合包含了点集外的数,返回根为 s;如果 t 这个集合包含了点集外的数,返回根为 t;否则当前点集一定是全集,特判一下即可。

这就是构建最小割树的过程,之后对于每个询问,我们可以预处理回答,也可以处理倍增后在线回答询问,一个是 O(Q) 的,一个是 (Q \log n) 的。

还有一种构造过程:

  1. 随便选取两个点 s,ts 连向 t 一条双向边,边权为源点为 s,汇点为 t 的最小割。
  2. 得到两个集合,这两个集合往下递归即可。

第二种做法更加简单,推荐选择第二种做法。

这个做法的正确性在于,我们选取了两个点,假设最小割割成了两个集合 S,T,那么在分治的过程中,我们至少会有一次任取两个点一个取到了 S 集合,一个取到了 T 集合,这个显然可得。

于是任意询问两点最小割的问题就解决了。

代码

代码中的网络流使用 Dinic 实现。

#include<bits/stdc++.h>
#define ll long long
#define N 50005
using namespace std;
vector<ll> op[N],w[N];
ll n,m,i,x,y,z;
ll tot=1,cur[N],ne[N],la[N],to[N],val[N],vis[N],we[N],q[N],he,ta,dis[N],s,t,inf,tofa[N],id[N],idd[N],st[N][21],st2[N][21],dep[N];
inline void merge(ll x,ll y,ll z){
    tot++,ne[tot] = la[x],la[x] = tot,to[tot] = y,val[tot] = z,we[tot] = z;
    tot++,ne[tot] = la[y],la[y] = tot,to[tot] = x,val[tot] = 0,we[tot] = 0;
}
bool bfs(){
    for(ll i=1;i<=n;i++) dis[i]=-1;
    q[he=ta=1]=s,dis[s]=0;
    while(he<=ta){
        ll tmp = q[he++];
        for(ll i=la[tmp];i;i=ne[i]){
            if(val[i]>0&&dis[to[i]]==-1){
                dis[to[i]]=dis[tmp]+1;
                q[++ta]=to[i];
            }
        }
    }
    return dis[t]!=-1;
}
ll dfs(ll x,ll step){
    if(x==t||!step) return step;
    ll used = 0;
    for(ll i=cur[x];i;i=ne[i]){
        cur[x] = i;
        if(dis[to[i]]==dis[x]+1&&val[i]>0){
            ll temp = dfs(to[i],min(step-used,val[i]));
            used += temp,val[i] -= temp,val[i^1] += temp;
            if(used==step) return used;
        }
    }
    return used;
}
ll get_max(ll x,ll y){
    ll ans = 0;
    for(ll i=0;i<=tot;i++) val[i]=we[i];
    s=x,t=y;
    while(bfs()){
        for(ll i=1;i<=n;i++) cur[i]=la[i];
        ans += dfs(s,1e18);
    }
    return ans;
}
ll solve(ll l,ll r){
    if(l==r) return id[l];
    ll temp = get_max(id[l],id[l+1]);
    ll tot = l,pos = tot-1,root = 0;
    for(ll i=l;i<=r;i++) vis[id[i]]=1;
    for(ll i=l;i<=r;i++) if(dis[id[i]]!=-1) idd[tot++]=id[i];
    pos=tot-1;
    for(ll i=l;i<=r;i++) if(dis[id[i]]==-1) idd[tot++]=id[i];
    for(ll i=l;i<=r;i++) id[i]=idd[i];
    for(ll i=1;i<=n;i++){
        if(!vis[i]&&dis[i]==-1) root=-2;
        if(!vis[i]&&dis[i]!=-1) root=-1;
    }
    for(ll i=l;i<=r;i++) vis[id[i]]=0;
    ll ls = solve(l,pos),rs = solve(pos+1,r);
    op[ls].push_back(rs),w[ls].push_back(temp),op[rs].push_back(ls),w[rs].push_back(temp);
    if(root==-1) return ls;
    if(root==-2) return rs;
    return root;
}
void dfss(ll x,ll fa){
    for(ll i=1;i<=20;i++){
        st[x][i] = st[st[x][i-1]][i-1];
        st2[x][i] = min(st2[x][i-1],st2[st[x][i-1]][i-1]);
    }
    for(ll i=0;i<op[x].size();i++){
        if(op[x][i]==fa) continue;
        st[op[x][i]][0] = x,st2[op[x][i]][0] = w[x][i],dep[op[x][i]] = dep[x]+1;
        dfss(op[x][i],x);
    }
}
ll found(ll x,ll y){
    ll ans = LLONG_MAX;
    if(dep[x]>dep[y]) swap(x,y);
    for(ll i=20;i>=0;i--) if(dep[st[y][i]]>=dep[x]) ans=min(ans,st2[y][i]),y=st[y][i];
    if(x==y) return ans;
    for(ll i=20;i>=0;i--){
        if(st[x][i]!=st[y][i]){
            ans=min(ans,min(st2[x][i],st2[y][i]));
            x=st[x][i],y=st[y][i];
        }
    }
    return min(ans,min(st2[x][0],st2[y][0]));
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n>>m,n++;
    while(m--){
        cin>>x>>y>>z,x++,y++;
        merge(x,y,z),merge(y,x,z);
    }
    for(i=1;i<=n;i++) id[i]=i;
    solve(1,n);
    dep[1]=1;
    dfss(1,-1);
    cin>>m;
    while(m--){
        cin>>x>>y;
        x++,y++;
        cout<<found(x,y)<<endl;
    }
    return 0;
}
/*
Input:
4 5
1 2 2
2 3 2
4 2 3
4 3 1
1 3 1
3
1 4
2 4
2 3

Output:
3
4
4
*/

第二种写法的代码:

#include<bits/stdc++.h>
#define ll long long
#define N 50005
using namespace std;
vector<ll> op[N],w[N];
ll n,m,i,x,y,z;
ll tot=1,cur[N],ne[N],la[N],to[N],val[N],vis[N],we[N],q[N],he,ta,dis[N],s,t,inf,tofa[N],id[N],idd[N],st[N][21],st2[N][21],dep[N];
inline void merge(ll x,ll y,ll z){
    tot++,ne[tot] = la[x],la[x] = tot,to[tot] = y,val[tot] = z,we[tot] = z;
    tot++,ne[tot] = la[y],la[y] = tot,to[tot] = x,val[tot] = 0,we[tot] = 0;
}
bool bfs(){
    for(ll i=1;i<=n;i++) dis[i]=-1;
    q[he=ta=1]=s,dis[s]=0;
    while(he<=ta){
        ll tmp = q[he++];
        for(ll i=la[tmp];i;i=ne[i]){
            if(val[i]>0&&dis[to[i]]==-1){
                dis[to[i]]=dis[tmp]+1;
                q[++ta]=to[i];
            }
        }
    }
    return dis[t]!=-1;
}
ll dfs(ll x,ll step){
    if(x==t||!step) return step;
    ll used = 0;
    for(ll i=cur[x];i;i=ne[i]){
        cur[x] = i;
        if(dis[to[i]]==dis[x]+1&&val[i]>0){
            ll temp = dfs(to[i],min(step-used,val[i]));
            used += temp,val[i] -= temp,val[i^1] += temp;
            if(used==step) return used;
        }
    }
    return used;
}
ll get_max(ll x,ll y){
    ll ans = 0;
    for(ll i=0;i<=tot;i++) val[i]=we[i];
    s=x,t=y;
    while(bfs()){
        for(ll i=1;i<=n;i++) cur[i]=la[i];
        ans += dfs(s,1e18);
    }
    return ans;
}
void solve(ll l,ll r){
    if(l==r) return ;
    ll temp = get_max(id[l],id[l+1]);
    ll tt1 = id[l],tt2 = id[l+1];
    ll tot = l,pos = tot-1;
    for(ll i=l;i<=r;i++) vis[id[i]]=1;
    for(ll i=l;i<=r;i++) if(dis[id[i]]!=-1) idd[tot++]=id[i];
    pos=tot-1;
    for(ll i=l;i<=r;i++) if(dis[id[i]]==-1) idd[tot++]=id[i];
    for(ll i=l;i<=r;i++) id[i]=idd[i];
    for(ll i=l;i<=r;i++) vis[id[i]]=0;
    solve(l,pos),solve(pos+1,r);
    op[tt1].push_back(tt2),w[tt1].push_back(temp);
    op[tt2].push_back(tt1),w[tt2].push_back(temp);
    return ;
}
void dfss(ll x,ll fa){
    for(ll i=1;i<=20;i++){
        st[x][i] = st[st[x][i-1]][i-1];
        st2[x][i] = min(st2[x][i-1],st2[st[x][i-1]][i-1]);
    }
    for(ll i=0;i<op[x].size();i++){
        if(op[x][i]==fa) continue;
        st[op[x][i]][0] = x,st2[op[x][i]][0] = w[x][i],dep[op[x][i]] = dep[x]+1;
        dfss(op[x][i],x);
    }
}
ll found(ll x,ll y){
    ll ans = LLONG_MAX;
    if(dep[x]>dep[y]) swap(x,y);
    for(ll i=20;i>=0;i--) if(dep[st[y][i]]>=dep[x]) ans=min(ans,st2[y][i]),y=st[y][i];
    if(x==y) return ans;
    for(ll i=20;i>=0;i--){
        if(st[x][i]!=st[y][i]){
            ans=min(ans,min(st2[x][i],st2[y][i]));
            x=st[x][i],y=st[y][i];
        }
    }
    return min(ans,min(st2[x][0],st2[y][0]));
}
int main(){
    ios::sync_with_stdio(false);
    cin>>n>>m,n++;
    while(m--){
        cin>>x>>y>>z,x++,y++;
        merge(x,y,z),merge(y,x,z);
    }
    for(i=1;i<=n;i++) id[i]=i;
    solve(1,n);
    dep[1]=1;
    dfss(1,-1);
    cin>>m;
    while(m--){
        cin>>x>>y;
        x++,y++;
        cout<<found(x,y)<<endl;
    }
    return 0;
}
/*
Input:
4 5
1 2 2
2 3 2
4 2 3
4 3 1
1 3 1
3
1 4
2 4
2 3

Output:
3
4
4
*/