题解 P5960 【【模板】差分约束算法】

· · 题解

UPD(2020/12/23):修正了入队次数为 n 即判断有负环的 bug。

UPD(2024/08/22):修正了 dis 数组没有开 long long 导致溢出的问题。

一个差分约束系统是这样的:

给出 n 个变量和 m 个约束条件,形如 x_i - x_j \leq c_k,你需要求出一组解,使得所有约束条件均被满足。

怎样解这个差分约束系统呢?我们将上面的不等式变形一下:

x_i \leq x_j + c_k

容易发现这个形式和最短路中的三角形不等式 dis_v \leq dis_u + w 非常相似。

因此我们就将这个问题转化为一个求最短路的问题:比如对于上面这个不等式,我们从 ji 连一条权值为 c_k 的边。

接下来,我们再新建一个 0 号点,从 0 号点向其他所有点连一条权值为 0 的边。

这个操作相当于新增了一个变量 x_0n 个约束条件:x_i \leq x_0,从而将所有变量都和 x_0 这一个变量联系起来。

然后以 0 号点为起点,用 Bellman-Ford 跑最短路。如果有负权环,差分约束系统无解。否则设从 0 号点到 i 号点的最短路为 dis_i,则 x_i = dis_i 即为差分约束系统的一组可行解。

负环的判断

建议先做 P3385,不过这里还是先大概讲一下吧。

首先,如果点 u 到点 v 的路径上存在一个负环,则 u \to v 的最短路不存在(可以一直绕着这个负环走,从而最短路可以取到 -\infty)。

原图中有 n+1 个点(注意添加了一个超级源点),如果不存在负环的话,则最短路最多经过 n+1 个点,n 条边(没有负环时,一个点经过两次显然不优)。

用 Bellman-Ford 的话,最多会执行 n 轮松弛操作(啥是松弛?简单来说就是根据三角形不等式更新最短路的过程),如果 n 轮松弛结束后仍然存在能松弛的边,则一定存在负环。

啥你用的是 SPFA?

(偷偷告诉你 SPFA 码量比 Bellman-Ford 长而且最坏复杂度和 Bellman-Ford 一样)

那就统计一个点入队的次数,如果一个点进队了 n+1 次那就一定有负环。

扩展

很多时候差分约束的条件并不是简单的小于等于号,这时候我们需要稍微做点变形。

如果有 x_i - x_j \geq c_k,则可以两边同时乘 -1,将不等号反转过来。

如果有 x_i - x_j = c_k,则可以把这个等式拆分为 x_i - x_j \leq c_kx_i - x_j \geq c_k 两个约束条件。

Code

#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
using i64 = long long;
const int maxn = 5'000;
const int maxm = 10'000;
struct edge {
  int v, w, next;
} e[maxm + 5];
int head[maxn + 5], tot[maxn + 5], vis[maxn + 5], cnt, n, m;
i64 dis[maxn + 5];
void addedge(int u, int v, int w) {
  e[++cnt].v = v;
  e[cnt].w = w;
  e[cnt].next = head[u];
  head[u] = cnt;
}
bool spfa(int s) {
  queue<int> q;
  memset(dis, 63, sizeof(dis));
  dis[s] = 0, vis[s] = 1;
  q.push(s);
  while (!q.empty()) {
    int u = q.front();
    q.pop();
    vis[u] = 0;
    for (int i = head[u]; i; i = e[i].next) {
      int v = e[i].v;
      if (dis[v] > dis[u] + e[i].w) {
        dis[v] = dis[u] + e[i].w;
        if (!vis[v]) {
          vis[v] = 1, tot[v]++;
          if (tot[v] == n + 1) return false;  // 注意添加了一个超级源点
          q.push(v);
        }
      }
    }
  }
  return true;
}
int main() {
  cin >> n >> m;
  for (int i = 1; i <= n; i++) addedge(0, i, 0);
  for (int i = 1; i <= m; i++) {
    int v, u, w;
    cin >> v >> u >> w;
    addedge(u, v, w);
  }
  if (!spfa(0))
    cout << "NO" << endl;
  else
    for (int i = 1; i <= n; i++) cout << dis[i] << ' ';
  return 0;
}