자료구조와 알고리즘 - 유니온 파인드(Union-Find)


 

 

다음과 같이 A와 B가 연결된 상태에서

 

 

C가 추가되어서 A랑 연결되었다고 하면

 

결론적으로 C,B도 같이 연결된 상태가 된다.

 

추가된 C가 B랑 연결되었다고 판단하는 방법은 그래프 탐색을 하면 된다.

 

하지만 매번 추가되는 노드마다 그래프탐색을 수행하는 방법은 비효율적이다.

더 효율적인 방법은 각 노드마다 그룹을 정하고 이를 비교하여 같은 그룹에 속한다면 두 노드는 연결된 상태이고

그렇지 않다면 연결되지 않다고 하면 된다.

 

 

위 예시에서는 노드의 루트노드를 그룹으로 정하면 된다.

예를 들어 B의 그룹(루트노드)은 A이고 C의 그룹(루트노드)도 A이므로 C랑 B는 연결되어 있다.

 

 

이때 그룹을 정하고, 두 원소가 같은 그룹인지 판별할 때 쓰이는 알고리즘이 유니온파인드(Union-Find)이다.

 

 

 


유니온 파인드 알고리즘의 기본적인 동작은 두 원소가 속한 집합을 합치는 연산(Union)과 두 원소가 같은 집합에 속하는지 여부를 판단하는 연산(Find)으로 구성된다.

 

Union 연산은 두 원소가 속한 집합을 합쳐서 하나의 집합으로 만드는 것이다. 이 때, 각각의 집합은 트리 형태로 구성되며, 합쳐지는 집합 중 하나의 루트 노드를 다른 집합의 루트 노드의 부모 노드로 설정한다.

 

Find 연산은 두 원소가 같은 집합에 속하는지 여부를 판단하는 것이다. 이 때, 각각의 원소가 속한 집합의 루트 노드를 찾아서 비교하면 된다. 만약 루트 노드가 같으면 두 원소는 같은 집합에 속하는 것이고, 다르면 서로 다른 집합에 속하는 것이다.

 


def initialize(n):
    parent = [i for i in range(n)]
    return parent

def find(x, parent):
    if x != parent[x]:
        find(parent[x], parent)
    return x

def union(x, y, parent):
    xroot = find(x, parent)
    yroot = find(y, parent)
    if xroot != yroot:
    	parent[xroot] = yroot

 

위 코드에서 parent 리스트는 각 원소의 부모 노드를 저장한다. 이때 제일 위에 위치한 루트노드는 자기 자신의 값을 저장한다.

find함수에서 루트노드까지 재귀를 시행하여 x의 루트노드를 찾는다. x == parent[x]가 종료 조건이다.

(루트노드에서 x == parent[x]이므로)

 

union함수에서 x의 루트노드와 y의 루트노드가 같다면 서로 연결해도 차이가 없기 때문에 다를 때만 연결한다.

 

 


 

 

유니온-파인드 알고리즘 최적화 방법

 

1. find 최적화

find연산을 할때 B의 루트노드를 찾기 위해서 최악의 경우 다음과 같이 높이만큼의 시간복잡도가 생긴다.

 

 

 

이 문제를 해결하기 위해서 B의 parent를 부모노드로 설정하지 않고 루트노드로 설정한다.

 

 

-수정코드

def find(x, parent):
    if x != parent[x]:
        parent[x] = find(parent[x], parent)
    return parent[x]

find 함수가 호출될 때마다 parent 배열을 갱신하여 경로 압축을 수행한다. 이로 인해, find 함수를 다시 호출할 때 루트 노드까지의 경로를 따라가지 않아도 된다. 따라서, find 함수의 실행 시간이 크게 개선된다.

 

 

2.union 최적화

 

다음과 같이 A,B를 연결하려고 할 때 두 가지 경우가 있다.

 

두 가지의 상황을 비교해보면

A를 루트노드로 둔 트리의 깊이 = 3

B를 루트노드로 둔 트리의 깊이 = 4

 

find연산을 할 때 높이가 낮아야 연산량이 줄어들기 때문에 깊이가 낮은 연결 방식이 더 효율적이다.

 

따라서 연결하려는 두 개의 노드의 랭크(깊이)를 비교하여

랭크가 더 높은 노드가 부모노드가 된다.(A랭크>B랭크 ) 

만약에 두 개의 랭크가 같다면 아무데나 붙히고 rank += 1을 하면 된다.

 

-수정코드

def initialize(n):
    parent = [i for i in range(n)]
    rank = [0] * n
    return parent, rank
    
    
 def union(x, y, parent, rank):
    xroot = find(x, parent)
    yroot = find(y, parent)

    if xroot == yroot:
        return

    if rank[xroot] < rank[yroot]:
        parent[xroot] = yroot
    elif rank[xroot] > rank[yroot]:
        parent[yroot] = xroot
    else:
        parent[yroot] = xroot
        rank[xroot] += 1

 

노드의 랭크를 저장하는 rank 배열을 추가한다. union 함수에서 두 루트 노드의 랭크를 비교하여 작은 쪽의 랭크를 가진 부모를 큰 쪽으로 연결한다. 이때, 랭크가 같은 경우에는 한쪽의 부모를 변경한 후, 해당 루트 노드의 랭크를 1 증가시켜준다.

 

 

 

 

- union-find 최종 코드

def initialize(n):
    parent = [i for i in range(n)]
    rank = [0] * n
    return parent, rank

def find(x, parent):
    if x != parent[x]:
        parent[x] = find(parent[x], parent)
    return parent[x]

def union(x, y, parent, rank):
    xroot = find(x, parent)
    yroot = find(y, parent)

    if xroot == yroot:
        return

    if rank[xroot] < rank[yroot]:
        parent[xroot] = yroot
    elif rank[xroot] > rank[yroot]:
        parent[yroot] = xroot
    else:
        parent[yroot] = xroot
        rank[xroot] += 1

 

 

- 유니온파인드 기본문제

https://www.acmicpc.net/problem/1717

 

1717번: 집합의 표현

초기에 $n+1$개의 집합 $\{0\}, \{1\}, \{2\}, \dots , \{n\}$이 있다. 여기에 합집합 연산과, 두 원소가 같은 집합에 포함되어 있는지를 확인하는 연산을 수행하려고 한다. 집합을 표현하는 프로그램을 작

www.acmicpc.net

import sys
input = sys.stdin.readline
n,m = [int(i) for i in input().split()]

parent = [i for i in range(n+1)]
rank = [0]*(n+1)

def find(parent, x):
    if parent[x] != x:
        parent[x] = find(parent,parent[x])
    return parent[x]

def union(x,y,parent,rank):
    xroot = find(parent, x)
    yroot = find(parent,y)

    if xroot == yroot:
        return

    if rank[xroot] < rank[yroot]:
        parent[xroot] = yroot
    elif rank[xroot] > rank[yroot]:
        parent[yroot] = xroot
    else:
        parent[yroot] = xroot
        rank[xroot] += 1

for _ in range(m):
    c,a,b = [int(i) for i in input().split()]
    if c == 0:
        union(a,b,parent,rank)
    else:
        aroot = find(parent, a)
        broot = find(parent,b)
        if aroot == broot:
            print("YES")
        else:
            print("NO")