算法思路
本题类似于 AcWing 1169. 糖果 : 差分约束思路建图, 使用$spfa$算法求最长路并判断图中
是否存在正环. 由于$spfa$最坏时间复杂度为$O(nm)$, 需要用栈结构替换队列以优化时间.
考虑用$tarjan$得到的强连通分量:
- 强连通分量内部一定存在环, 本题权重为$0$或$1$, 若强连通分量内部存在边权值大于$0$, 则图中存在正环;
否则分量内边权均为$0$ — 根据定义$dist(u_1) = dist(u_2) = … $, 即分量内所有顶点取值相同,
可以视为同一顶点.
对于缩点后得到的$DAG$, 可以按拓扑序线性时间求解最长路.
具体实现
$tarjan$强连通分量常见操作:
-
$tarjan$求强连通分量
-
缩点得到$DAG$
-
按拓扑序(强联通分量编号逆序)递推求解
对于差分约束问题我们还需要一个明确下界以及满足条件的源点. 由于数值下界为$1$:
- 建立虚拟源点$s$, $dist(s) = 0$, $u\ge s + 1$, 即从$s$向所有顶点连一条权重为$1$的有向边.
代码实现 $O(V + E)$
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10, M = 3e5 * 2 + 10;
int n, m;
int h[N], hs[N], e[M], w[M], ne[M], idx;
int dfn[N], low[N], timestamp;
int stk[N], top; bool in_stk[N];
int id[N], cize[N], scc_cnt;
int dist[N];
void add(int h[], int u, int v, int c)
{
e[idx] = v, w[idx] = c, ne[idx] = h[u], h[u] = idx ++;
}
void tarjan(int u)
{
dfn[u] = low[u] = ++ timestamp;
stk[++ top] = u, in_stk[u] = true;
for( int i = h[u]; ~i; i = ne[i] )
{
int v = e[i];
if( !dfn[v] )
{
tarjan(v);
low[u] = min(low[u], low[v]);
}
else if( in_stk[v] )
{
low[u] = min(low[u], low[v]);
}
}
if( dfn[u] == low[u] )
{
++ scc_cnt;
int v;
do {
v = stk[top --];
in_stk[v] = false;
id[v] = scc_cnt;
cize[scc_cnt] ++;
} while( u != v );
}
}
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
while( m -- )
{
int t, a, b;
scanf("%d%d%d", &t, &a, &b);
switch( t )
{
case 1: add(h, a, b, 0), add(h, b, a, 0); break;
case 2: add(h, a, b, 1); break;
case 3: add(h, b, a, 0); break;
case 4: add(h, b, a, 1); break;
case 5: add(h, a, b, 0); break;
}
}
for( int u = 1; u <= n; u ++ ) add(h, 0, u, 1); //u >= s + 1, s = 0
//1. tarjan
tarjan(0); //0可以连接所有顶点
//2. 缩点建DAG
memset(hs, -1, sizeof hs);
bool success = true;
for( int u = 0; u <= n; u ++ )
{
for( int i = h[u]; ~i; i = ne[i] )
{
int v = e[i];
if( id[u] != id[v] ) add(hs,id[u], id[v], w[i]);
else if( w[i] > 0 )
{//在同一联通分量内 存在正权 -> 存在正环
success = false;
break;
}
}
}
//3. 拓扑序递推
if( !success ) puts("-1");
else
{
//dist[0] = 0
for( int u = scc_cnt; u; u -- )
{
for( int i = hs[u]; ~i; i = ne[i] )
{
int v = e[i];
dist[v] = max(dist[v], dist[u] + w[i]);
}
}
ll res = 0;
for( int u = 1; u <= scc_cnt; u ++ )
res += (ll)dist[u] * cize[u];
printf("%lld\n", res);
}
return 0;
}