动物王国中有三类动物A,B,C,这三类动物的食物链构成了有趣的环形。
A吃B, B吃C,C吃A。
现有N个动物,以1-N编号。
每个动物都是A,B,C中的一种,但是我们并不知道它到底是哪一种。
有人用两种说法对这N个动物所构成的食物链关系进行描述:
第一种说法是”1 X Y”,表示X和Y是同类。
第二种说法是”2 X Y”,表示X吃Y。
此人对N个动物,用上述两种说法,一句接一句地说出K句话,这K句话有的是真的,有的是假的。
当一句话满足下列三条之一时,这句话就是假话,否则就是真话。
1) 当前的话与前面的某些真的话冲突,就是假话;
2) 当前的话中X或Y比N大,就是假话;
3) 当前的话表示X吃X,就是假话。
你的任务是根据给定的N和K句话,输出假话的总数。
输入格式
第一行是两个整数N和K,以一个空格分隔。
以下K行每行是三个正整数 D,X,Y,两数之间用一个空格隔开,其中D表示说法的种类。
若D=1,则表示X和Y是同类。
若D=2,则表示X吃Y。
输出格式
只有一个整数,表示假话的数目。
数据范围
$1≤N≤50000$,
$0≤K≤100000$
输入样例:
100 7
1 101 1
2 1 2
2 2 3
2 3 3
1 1 3
2 3 1
1 5 5
输出样例:
3
C++ 代码
#include<iostream>
using namespace std;
const int N = 50010;
int p[N],d[N];
int n,k;
int dd,x,y;
/*
%3 规则:
余 1 : 吃 根
余 2 : 被根吃
余 0 : 与根是同类
注: d[x]的含义 : x 到 px(x的根结点)的距离。
*/
int find(int x)
{
if(p[x]!=x)
{
int t = find(p[x]);
d[x]+=d[p[x]]; // x 经 p[x] 到祖宗结点的距离。
p[x] = t;
}
return p[x];
}
int main()
{
cin>>n>>k;
for(int i=0;i<n;i++) p[i]=i;
int res = 0;
while(k--)
{
scanf("%d%d%d",&dd,&x,&y);
if(x>n||y>n) res++; // 越界 为假话。
else
{
int px = find(x);
int py = find(y);
if(dd==1) // 同类
{
if(px==py&&(d[x]-d[y])%3) res++; // 同根生 ,辈分不同, 为假话。
else if(px!=py) // 不同根。
{
p[px]=py; //让py 成为 px 的父亲。
d[px]=d[y]-d[x]; // 因为 x,y 是同类 所以 (d[x]+d[px]-d[y])%3==0
}
}
else // x 吃 y
{ // x 吃 y : 说明 d[x] 比 d[y] 多 1. 所以 (d[x]-d[y]-1)%3==0
if (px == py && (d[x] - d[y] - 1) % 3) res ++ ;
else if (px != py)
{
p[px] = py;
d[px] = d[y] + 1 - d[x]; // (d[x]+d[px]-d[y]-1)%3==0
}
}
}
}
printf("%d\n", res);
return 0;
}
python3 代码
import sys
def find(x):
if p[x] != x:
t = find(p[x])
path[x] += path[p[x]]
p[x] = t
return p[x]
if __name__ == '__main__':
n, k = map(int, input().split())
count = 0
p = list(range(n+10))
path = [0] * (n + 10)
for _ in range(k):
op, a, b = map(int,sys.stdin.readline().strip().split())
if a > n or b > n:
count += 1
else:
pa = find(a)
pb = find(b)
diff = (path[a] - path[b]) % 3
if op == 1:
if pa == pb:
if diff : count += 1
else:
p[pa] = pb
path[pa] = path[b] - path[a]
if op == 2:
if pa == pb:
if diff != 1: count += 1
else:
p[pa] = pb
path[pa] = path[b] - path[a] + 1
print(count)
java 代码
import java.io.*;
public class Main{
private static final int N = 50010;
private static int[] p = new int[N];
private static int[] d = new int[N];
public static int find(int x){
if(p[x] != x){
int t = find(p[x]);
d[x] += d[p[x]];
p[x] = t;
}
return p[x];
}
public static void main(String[] args) throws IOException{
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String[] strs = reader.readLine().split(" ");
int n = Integer.parseInt(strs[0]);
int k = Integer.parseInt(strs[1]);
for(int i = 1; i <= n; i++){
p[i] = i;
}
int res = 0;
while(k-- > 0){
strs = reader.readLine().split(" ");
int t = Integer.parseInt(strs[0]);
int x = Integer.parseInt(strs[1]);
int y = Integer.parseInt(strs[2]);
if(x > n || y > n) res++;
else{
//px,py是x,y的祖宗节点
int px = find(x);
int py = find(y);
if(t == 1){
if(px == py && (d[x] - d[y]) % 3 != 0) res++;
else if(px != py){
p[px] = py;
//(d[x] + ? - d[y]) % 3 == 0 => ? = dy - dx
d[px] = d[y] - d[x];
}
}else{
if(px == py && (d[x] - d[y] - 1) % 3 != 0) res++;
else if(px != py){
//(dx + ? - dy - 1) % 3 == 0 => ? = dy - dx + 1
p[px] = py;
d[px] = d[y] - d[x] + 1;
}
}
}
}
System.out.println(res);
reader.close();
}
}