最小点集覆盖
前言
宣传: 算法主页
问题描述
给定一张有向图 G=(V,E) 。
每一次操作可以用 W−u 的代价来消除所有 u 射出的边,
或用 W+u 的代价来消除所有射入 u 的边。
求将所有边清除所需的最小代价。
过程
设新图为 G′ ,原图为 G 。
在新图中建立源点 S 和汇点 T。
将每一个点 u 拆分成两部分 u 和 u′ 。
从 S 向所有 u 连权值为 W−u 的边。
从所有 u′ 向 T 连权值为 W+u 的边。
对于原图中所有边 (u,v) ,在新图中建边 (u,v′) ,权值为 inf。
在新图中求最小割,即为所求的最小代价。
分析
对于每一条边 (u,v) ,由于 G′ 中 (u,v′) 权值为 inf ,所以一定不会选这条边。
在最小割 [S,T] 中,要么 u,v∈S ,要么 u,v∈T ,所以代价为 W−u 或 W+v 。
实现
#include <bits/stdc++.h>
using namespace std;
const int N = 10010, M = 200010, inf = 0x3f3f3f3f;
int n, m, S, T;
int h[N], e[M], c[M], ne[M], idx;
int q[N], dep[N], cur[N];
bool st[N];
int read()
{
int x;
scanf("%d", &x);
return x;
}
void add(int u, int v, int w)
{
e[idx] = v, c[idx] = w, ne[idx] = h[u], h[u] = idx ++ ;
e[idx] = u, c[idx] = 0, ne[idx] = h[v], h[v] = idx ++ ;
}
bool bfs()
{
memset(dep, -1, sizeof dep);
int hh = 0, tt = 0;
q[0] = S, dep[S] = 0;
cur[S] = h[S];
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (dep[j] == -1 && c[i] > 0)
{
dep[j] = dep[t] + 1;
cur[j] = h[j];
if (j == T)return true;
q[++ tt] = j;
}
}
}
return false;
}
int find(int u, int limit)
{
if (u == T)return limit;
int flow = 0;
for (int i = cur[u]; ~i && flow < limit; i = ne[i])
{
cur[u] = i;
int j = e[i];
if (dep[j] == dep[u] + 1 && c[i] > 0)
{
int t = find(j, min(c[i], limit - flow));
if (!t)dep[j] = -1;
c[i] -= t, c[i ^ 1] += t, flow += t;
}
}
return flow;
}
int dinic()
{
int res = 0, flow;
while (bfs()) while (flow = find(S, inf))res += flow;
return res;
}
void dfs(int u)
{
st[u] = true;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (st[j] || c[i] == 0)continue;
dfs(j);
}
}
int main()
{
scanf("%d%d", &n, &m);
S = 0, T = n + n + 1;
memset(h, -1, sizeof h);
for (int i = n + 1; i <= n + n; i ++ )add(i, T, read());
for (int i = 1; i <= n; i ++ )add(S, i, read());
for (int i = 0; i < m; i ++ )
{
int u, v;
scanf("%d%d", &u, &v);
add(u, n + v, inf);
}
printf("%d\n", dinic());
dfs(S);
int cnt = 0;
for (int i = 0; i < idx; i += 2)
{
int a = e[i ^ 1], b = e[i];
if (st[a] && !st[b])cnt ++ ;
}
printf("%d\n", cnt);
for (int i = 0; i < idx; i += 2)
{
int a = e[i ^ 1], b = e[i];
if (st[a] && !st[b])
{
if (a == S)printf("%d -\n", b);
if (b == T)printf("%d +\n", a - n);
}
}
return 0;
}