可持久化线段树求 k 短路

· · 题解

让我们直接快进到这里:

一条 s\to t 的路径可以用一个边序列 e_1,\ldots,e_k 来描述,满足:

这条路径的形式就是依次经过每个 e_i,在相邻 e_i 之间只走最短路树上的边。

于是我们只需要对满足上述条件的边序列求第 k 小权值和即可,这可以用一个带扩展贪心来解决:

带扩展贪心的一般形式就是从一个最初状态开始依次扩展出每个可能的状态,并且只能由小的状态扩展到大的状态,每个状态的扩展方式唯一,那么再利用一个优先队列,每次取出队首进行扩展,连续 k 次就可以求出答案了。

在该问题中,扩展方式如下:

这样显然每个状态的扩展方式唯一,即不会有一个方案算两次(因为每个状态的前驱唯一)。

可以用可持久化线段树实现上述算法,我们对于每个 a 求出所有起点是 a 的祖先的非树边权值构成的线段树,由于要支持的操作仅仅是求某棵线段树上某值的后继和求某棵线段树上的最小值,都可以 O(\log n) 完成。

外面再套一个 priority_queue 维护当前所有状态,时间复杂度为 O(n+m\log n+k(\log k+\log n)),空间复杂度为 O(m\log n+k)

对比一下其他人说的堆的做法:

考虑将上述算法中的线段树换成堆,那么第二种扩展就不用是求后继了,而是取出堆中的一个结点后加入它的两个子结点(这样也满足大的状态总是被更小的状态扩展出),相当于对每个堆以一种拓扑序遍历。

这里需要使用一种可持久化的堆(不一定是可并堆,和线段树一样只需要单点加入即可),以最容易写的左偏树为例,由于堆的空间线性,而且求出两个子结点的过程 O(1),因此复杂度比上面更好一些,时间复杂度为 O(n+m\log n+k\log k),空间复杂度为 O(n\log n+m+k)

可见堆做法是更优越的,所以上面的线段树算法实际上没什么用,但这里提供一个思路。

#include <bits/stdc++.h>
using namespace std;
const int MAXN=5010,MAXM=200010,MAXD=4000010;
const double eps=1e-8;
double res,z,dis[MAXN];
int n,m,x,y,tot,cnt,ans,rt[MAXN],ch[MAXD][2],vis[MAXN];
pair <double,pair<int,int> > eg[MAXM];
vector <int> add[MAXN];
vector < pair<pair<int,int>,double> > v[MAXN];
priority_queue < pair<double,pair<int,int> > > q;
void dfs (int x) {
    vis[x]=1;
    int len=v[x].size();
    for (int i=0;i<len;i++) {
        int y=v[x][i].first.first;
        double z=v[x][i].second;
        if (vis[y]) {continue;}
        if (fabs(dis[y]-dis[x]-z)<=eps) {
            v[x][i].first.second=1;
            dfs(y);
        }
    }
}
int modify (int rt,int l,int r,int v) {
    int p=++tot;
    if (l==r) {return p;}
    int mid=(l+r)>>1;
    if (v<=mid) {
        ch[p][0]=modify(ch[rt][0],l,mid,v),ch[p][1]=ch[rt][1];
    } else {
        ch[p][1]=modify(ch[rt][1],mid+1,r,v),ch[p][0]=ch[rt][0];
    }
    return p;
}
int queryfir (int rt,int l,int r) {
    if (l==r) {return (rt?l:0);}
    int mid=(l+r)>>1;
    if (ch[rt][0]) {return queryfir(ch[rt][0],l,mid);}
    else {return queryfir(ch[rt][1],mid+1,r);}
}
int querynx (int rt,int l,int r,int v) {
    if (l==r) {return (rt?(l>v?l:0):0);}
    int mid=(l+r)>>1;
    if (mid<=v) {return querynx(ch[rt][1],mid+1,r,v);}
    else {
        int tmp=querynx(ch[rt][0],l,mid,v);
        if (tmp) {return tmp;}
        return querynx(ch[rt][1],mid+1,r,v);
    }
}
void dfs2 (int x) {
    int len=add[x].size();
    for (int i=0;i<len;i++) {rt[x]=modify(rt[x],1,cnt,add[x][i]);}
    len=v[x].size();
    for (int i=0;i<len;i++) {
        int y=v[x][i].first.first;
        if (v[x][i].first.second) {
            rt[y]=rt[x];
            dfs2(y);
        }
    }
}
void print (int rt,int l,int r) {
    printf("%d  %d  %d\n",rt,l,r);
    if (l==r) {return;}
    int mid=(l+r)>>1;
    print(ch[rt][0],l,mid);
    print(ch[rt][1],mid+1,r);
    return;
}
int main () {
    scanf("%d%d%lf",&n,&m,&res);
    for (int i=1;i<=m;i++) {
        scanf("%d%d%lf",&x,&y,&z);
        if (x==n) {continue;}
        v[y].push_back(make_pair(make_pair(x,0),z));
    }
    for (int i=1;i<=n-1;i++) {dis[i]=1e12;}
    q.push(make_pair(0,make_pair(n,0)));
    while (q.size()) {
        pair<double,pair<int,int> > a=q.top();
        q.pop();
        int nw=a.second.first;
        if (vis[nw]) {continue;}
        vis[nw]=1;
        int len=v[nw].size();
        for (int i=0;i<len;i++) {
            int y=v[nw][i].first.first;
            double z=v[nw][i].second;
            if (dis[y]>dis[nw]+z) {
                dis[y]=dis[nw]+z;
                q.push(make_pair(-dis[y],make_pair(y,0)));
            }
        }
    }
    memset(vis,0,sizeof(vis));
    dfs(n);
    for (int i=1;i<=n;i++) {
        int len=v[i].size();
        for (int j=0;j<len;j++) {
            int y=v[i][j].first.first;
            double z=v[i][j].second;
            if (!v[i][j].first.second) {eg[++cnt]=make_pair(dis[i]+z-dis[y],make_pair(y,i));}
        }
    }
    sort(eg+1,eg+cnt+1);
    for (int i=1;i<=cnt;i++) {
        //printf("%d   %.6f   %d  %d\n",i,eg[i].first,eg[i].second.first,eg[i].second.second);
        add[eg[i].second.first].push_back(i);
    }
    dfs2(n);
    //print(rt[1],1,cnt);
    res-=dis[1],ans=1;
    int tmp=queryfir(rt[1],1,cnt);
    if (tmp) {q.push(make_pair(-eg[tmp].first,make_pair(1,tmp)));}
    while (q.size()) {
        pair<double,pair<int,int> > a=q.top();
        q.pop();
        double val=dis[1]-a.first;
        if (res-val<=-eps) {break;}
        res-=val,ans++;
        int c=a.second.first,d=a.second.second;
        int nx=querynx(rt[c],1,cnt,d),tmp=eg[d].second.second;
        if (nx) {q.push(make_pair(a.first+eg[d].first-eg[nx].first,make_pair(c,nx)));}
        nx=queryfir(rt[tmp],1,cnt);
        if (nx) {q.push(make_pair(a.first-eg[nx].first,make_pair(tmp,nx)));}
    }
    printf("%d\n",ans);
    return 0;
}