자료구조와 알고리즘 - 세그먼트 트리(Segment Tree)

알고리즘.png


 

1. 구간합과 누적합 문제

1) 구간 합(Range Sum)

 구간 합 문제는 주어진 배열에서 연속된 특정 구간의 합을 계산하는 문제입니다. 예를 들어, arr = [2, 4, 6, 8, 10]라고 할 때

 

- 인덱스 1~3까지의 합 = arr[1] + arr[2] + arr[3] = 4 + 6 + 8 = 18

- 인덱스 0~3까지의 합 = arr[0] + arr[1] + arr[2] + arr[3] = 2 + 4 + 6 + 8 = 20

 

 위와 같은 방식으로 구한 구간합 방식의 시간복잡도는 얼마나 될까요. 배열의 길이를 n이라고 하고, 구하고자 하는 구간합의 범위를 a에서 b까지(a, b ≤ n)라고 할 때, 구간합을 계산하기 위해서는 a, a+1, a+2, ..., b까지의 모든 원소를 더해야 합니다. 이러한 계산은 최악의 경우 n개의 원소를 접근하게 되므로, 시간 복잡도는 O(n)으로 표현할 수 있습니다. 

2) 누적합(Prefix Sum) 배열

누적 합 배열은 배열의 각 원소까지의 합을 미리 계산해 놓는 방법입니다. 이를 통해 구간 합문제를 더 빠르게 계산할 수 있습니다. 예를 들어,

 

arr = [2, 4, 6, 8, 10]

 

- 0~0까지의 합: prefix = [2]

- 0~1까지의 합: prefix = [2, 6]

- 0~2까지의 합: prefix = [2, 6, 12]

- 0~3까지의 합: prefix = [2, 6 ,12, 20]

- 0~4까지의 합: prefix = [2, 6, 12, 20, 30]

- 누적합 배열 = [2, 6, 12, 20, 30]

 

이를 토대로 1 부터 3 까지의 구간 합을 구하면

 

구하고자 하는 값 = arr[1] + arr[2] + arr[3]

prefix[3] = arr[0] + arr[1] + arr[2] + arr[3]

prefix[0] = arr[0]

=> prefix[3] - prefix[0] = 18로 구간 합을 매번 새로 구하는 것보다 빠른 시간 안에 구간합을 계산할 수 있습니다.

 

 처음 누적합 배열을 만드는 과정은 O(n)이지만 한번 구해놓으면 이후부터는 O(1) 시간안에 구간합 문제를 해결 할 수 있습니다.

- 누적합 방식의 한계

 누적합 방식은 효율적으로 특정 구간의 합을 계산할 수 있지만, 데이터 업데이트 시에는 한계가 있습니다. 배열의 길이가 n일 경우, 원래 배열의 어떠한 원소가 변경되면 누적합 배열을 매번 새로 계산해야 하므로 업데이트에 O(n) 시간이 소요됩니다.

 

 예를 들어, 배열이 [2, 4, 6, 8, 10]일 때 누적합 배열은 [2, 6, 12, 20, 30]으로 구성됩니다. 만약 배열의 인덱스 1의 값을 5로 변경하면, 새로운 누적합 배열은 [2, 7, 13, 21, 31]로 다시 계산해야 합니다. 데이터 업데이트가 반복되면 매번 누적합 배열을 갱신해야 하므로, 이 접근 방식은 업데이트 효율성 측면에서 부족합니다.

 

 이런 한계를 해결하기 위한 방법 중 하나가 세그먼트 트리입니다. 세그먼트 트리를 사용하면 데이터 업데이트 시 보다 효율적으로 구간합을 관리할 수 있습니다.

 

2. 세그먼트 트리(Segment Tree)

 세그먼트 트리는 구간 합을 구하는 과정과 업데이트를 O(log n) 시간에 효율적으로 처리할 수 있는 자료구조입니다. 이는 세그먼트 트리가 완전 이진 트리 구조를 기반으로 하여, 각 노드가 특정 구간의 합을 저장하기 때문입니다.

 

1) 세그먼트 트리 개념

세그먼트 트리를 구현하기 전에 구간합과 업데이트 과정이 왜 O(log n)인지 보겠습니다.

 

- 사전 준비

1. 배열 [11, 4, 3, 5, 7, 6, 2, 9] 에서 2개 씩 짝을 지은 새로운 배열을 만듭니다.

세그1.png

 

2. 이 과정을 다음의 형태가 될 때까지 반복합니다.

세그2.png

 

- 구간 합 구하기

위 데이터 구조를 이용하여 배열 [11, 4, 3, 5, 7, 6, 2, 9]의 구간 (2~6)까지의 구간 합을 구하는 과정을 보겠습니다.

 

1. 0~7(전체)까지의 범위인 맨 아래부터 탐색을 합니다.

죄회1.png

 

 

2. 탐색을 위로 올립니다. 구간이 두 개로 나뉩니다. (0~3), (4~7)

- (2~3)범위는 (0~3)구간에 속하고

- (4~6)범위는 (4~7)구간에 속합니다.

죄회2.png

 

 

3. 탐색을 위로 올립니다. 구간이 4개로 나뉩니다.

- (2~3)범위는 (2~3)구간에 속합니다. => 탐색을 종료합니다.

- (4~5)범위는 (4~5)구간에 속합니다. => 탐색을 종료합니다.

- (6)범위는 (6~7)구간에 속합니다.

죄회3.png

 

 

4. 탐색을 위로 올립니다. (6)범위는 (6)구간에 속합니다. 최종적으로 (2~6)범위는 빨간색 영역의 합입니다.

죄회4.png

 

 위 그림을 뒤집어서 보면 완전 이진 트리 구조로 구성되어 있고 트리의 깊이는 (log n)입니다.(여기서는 n = 8이므로 깊이 = 3) 

 

 구간합을 구하기 위해서 자기가 속한 범위의 노드들만 탐색합니다. 이를 통해 원하는 구간의 합을 계산하기 위해서는 자기의 상위 노드들만 방문하면 되므로, 전체 트리를 탐색할 필요가 없습니다. 따라서, 특정 구간의 합을 구하는 과정은 O(log n)의 시간 복잡도입니다.

 

 

- 업데이트 과정

데이터를 업데이트할 때 자기가 속한 범위안에서만 수정하면 됩니다.

업데이트.png

 

 구간 합을 계산하는 것과 마찬가지로, 세그먼트 트리에서 업데이트를 진행할 때도 해당 노드가 속한 범위 안에서 탐색을 수행합니다. 이 과정은 트리의 깊이인 O(log n) 만큼의 시간 복잡도가 소요됩니다. 따라서, 업데이트 과정 역시 효율적으로 처리되어 O(log n)의 시간 복잡도를 가집니다. 하지만 그 만큼 용량도 추가되고 공간복잡도는 O(n)입니다. 

 

2) 세그먼트 트리 구현 방법

 

완전이진트리를 이용하여 세그먼트트리를 Top-Down방식으로 구현하는 방법입니다.

 

- 초기 설정

class SegmentTree:
    def __init__(self, data):
        self.n = len(data)
        self.tree = [0] * (4 * self.n)  # 세그먼트 트리를 저장할 리스트

 

1. self.n = len(data)

- 입력으로 주어진 배열 `data`의 길이를 `self.n`에 저장합니다. 이 값은 세그먼트 트리를 구성하는 데 사용됩니다.

 

2. self.tree = [0] * (4 * self.n)

-  트리의 용량을 결정합니다. 데이터의 개수가 50이면 2의 제곱수인 64지만 최대 4배 큰 배열로 할당됩니다. 이는 트리가 완전 이진 트리 형태를 갖기 때문에, 안전하게 모든 노드를 저장하기 위함입니다. .

 

 

- 초기 빌드과정

def build(self, data, node, start, end):
        if start == end:
            # leaf 노드인 경우
            self.tree[node] = data[start]
        else:
            mid = (start + end) // 2
            # 왼쪽 자식 노드와 오른쪽 자식 노드 구성
            self.build(data, 2 * node + 1, start, mid)
            self.build(data, 2 * node + 2, mid + 1, end)
            # 부모 노드의 값을 자식 노드 값으로 갱신
            self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

 

1. 매개변수 설명

- data: 원래 배열. 이 배열의 값들을 세그먼트 트리에서 사용합니다.

- node: 현재 노드를 나타내는 인덱스입니다.

- start, end: 현재 노드가 담당하는 구간의 시작 인덱스와 끝 인덱스입니다. 

 

2. 기저 조건 (Leaf 노드 확인)

if start == end:
	self.tree[node] = data[start]

- 우선, 현재 노드가 리프 노드인지 확인합니다. 리프 노드는 `start`와 `end`가 같을 때 발생합니다.

- 이 경우, 해당 노드가 가리키는 값은 입력 배열 `data`의 `start` 위치에 있는 값이므로 이 값을 `self.tree[node]`에 저장합니다.

 

3. 재귀적 노드 분할

mid = (start + end) // 2
self.build(data, 2 * node + 1, start, mid) # 왼쪽자식노드
self.build(data, 2 * node + 2, mid + 1, end) # 오른쪽 자식노드

- 현재 노드가 리프 노드가 아니면, 배열의 구간을 반으로 나누어 왼쪽과 오른쪽 서브트리를 구성합니다.

- `mid` 변수를 이용하여 구간을 반으로 나누고, 왼쪽 서브트리를 재귀호출하여 구합니다.

- 오른쪽 서브트리도 재귀호출하여 구합니다.

 

4. 부모 노드 업데이트

self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

- 왼쪽과 오른쪽 자식 노드의 값(리프 노드 값들 또는 재귀호출로 구해진 값을) 합쳐서 현재 부모 노드의 값을 업데이트합니다.

- 이렇게 함으로써 각 노드는 그 아래 노드들이 담당하는 구간의 합을 저장하게 됩니다.

 

 

5. 전체 흐름

- 기본적으로 먼저 리프 노드를 설정하고, 그 후 리프 노드들을 기반으로 부모 노드의 값을 계산하여 트리를 완성합니다.

- 최종적으로 세그먼트 트리의 각 노드는 특정 구간의 합을 저장합니다.

- 이 방식으로 세그먼트 트리를 구축하면, 배열의 모든 원소를 다루는 O(n) 시간 복잡도로 초기화가 가능합니다.

 

 

- 업데이트 과정

def update(self, idx, value, node, start, end):
        if start == end:
            # 리프 노드에서 값 수정
            self.tree[node] = value
        else:
            mid = (start + end) // 2
            if start <= idx <= mid:
                # 왼쪽 서브트리 업데이트
                self.update(idx, value, 2 * node + 1, start, mid)
            else:
                # 오른쪽 서브트리 업데이트
                self.update(idx, value, 2 * node + 2, mid + 1, end)
            # 부모 노드의 값을 자식 노드 값으로 갱신
            self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

 update 메소드는 세그먼트 트리에서 특정 인덱스의 값을 업데이트하는 역할을 하는 재귀 함수입니다. 이 메소드는 주어진 인덱스의 값을 새로운 값으로 수정하고, 그에 따라 영향을 받는 부모 노드의 값도 갱신합니다. 각 부분을 자세히 설명하겠습니다.

 

1. 매개변수 설명

- idx: 업데이트할 배열의 인덱스입니다. 이 값을 새로운 값으로 변경합니다.

- value: `idx` 위치에 업데이트할 새로운 값입니다.

- node: 현재 노드를 나타내는 인덱스입니다.

- start, end: 현재 노드가 담당하는 구간의 시작과 끝 인덱스입니다.

 

2. 기저 조건 (리프 노드 확인)

if start == end:
	self.tree[node] = value

- 여기서 현재 노드가 리프 노드인지 확인합니다. `start`와 `end`가 같다면, 이 노드는 리프 노드입니다.

- 이 경우, 해당 인덱스의 값을 새로운 값으로 업데이트합니다. 즉, 세그먼트 트리의 해당 노드에 `value`가 저장됩니다.

 

2. 재귀적 노드 분할

mid = (start + end) // 2
if start <= idx <= mid:
	self.update(idx, value, 2 * node + 1, start, mid)
else:
	self.update(idx, value, 2 * node + 2, mid + 1, end)

- 만약 현재 노드가 리프가 아니라면, `mid`를 계산하여 구간을 반으로 나눕니다.

- `idx`가 현재 노드의 왼쪽 자식이 담당하는 구간에 속하면 `self.update`를 호출하여 왼쪽 서브트리(부모 노드의 왼쪽 자식)를 업데이트합니다.

- 그렇지 않으면 (즉, `idx`가 오른쪽 자식에 해당) 오른쪽 서브트리를 업데이트합니다.

 

3. 부모 노드 업데이트

self.tree[node] = self.tree[2 * node + 1] + self.tree[2 * node + 2]

- 왼쪽과 오른쪽 서브트리에서 업데이트가 끝나면, 현재 부모 노드의 값을 갱신합니다.

- 이때 부모 노드는 자식 노드의 합으로 업데이트되어야 하므로, 왼쪽 자식의 값과 오른쪽 자식의 값을 더하여 저장합니다.

 

4. 전체 흐름

- 가장 상위 노드에서부터 시작해 변경될 인덱스까지 내려가고, 실행 결과를 상위 노드에 전파합니다.

- 이 과정은 O(log n) 시간 복잡도로 수행됩니다.

 

-쿼리 과정

def query(self, L, R, node, start, end):
        if R < start or end < L:
            # 구간이 겹치지 않음
            return 0
        if L <= start and end <= R:
            # 구간이 완전히 포함됨
            return self.tree[node]
        # 겹치는 구간에서 왼쪽과 오른쪽 서브트리로 나누어 쿼리 수행
        mid = (start + end) // 2
        left_sum = self.query(L, R, 2 * node + 1, start, mid)
        right_sum = self.query(L, R, 2 * node + 2, mid + 1, end)
        return left_sum + right_sum

 

`query` 메소드는 세그먼트 트리에서 특정 구간 `[L, R]`의 합을 계산하는 함수입니다. 이 메소드는 세그먼트 트리를 통해 원하는 구간의 합을 구하는 데 사용됩니다. 각 부분을 자세히 설명하겠습니다.

 

1. 매개변수 설명

- L: 쿼리할 구간의 왼쪽 경계 인덱스입니다.

- R: 쿼리할 구간의 오른쪽 경계 인덱스입니다.

- node: 현재 노드를 나타내는 인덱스입니다.

- start, end: 현재 노드가 담당하는 구간의 시작과 끝 인덱스입니다.

 

2. 구간이 겹치지 않는 경우

if R < start or end < L:
	return 0

- 만약 주어진 쿼리 구간 [L, R]가 현재 노드가 담당하는 구간 [start, end]와 겹치지 않는 경우입니다.

- 이 조건은 두 인덱스의 비교를 통해 판별합니다.

- 이때는 관련된 값을 찾을 수 없기 때문에 0을 반환합니다. (합이 없으므로 0으로 처리됩니다.)

 

3. 구간이 완전히 포함되는 경우

if L <= start and end <= R:
	return self.tree[node]

- 만약 현재 노드가 담당하는 구간이 쿼리 구간 [L, R]에 완전히 포함되는 경우입니다.

- 이 경우, 해당 노드의 값(구간의 합)을 그대로 반환합니다. 이미 계산된 합이 저장되어 있으므로, 추가적인 계산이 필요 없습니다.

 

4. 구간이 일부 겹치는 경우

mid = (start + end) // 2
left_sum = self.query(L, R, 2 * node + 1, start, mid)
right_sum = self.query(L, R, 2 * node + 2, mid + 1, end)
return left_sum + right_sum

- 현재 노드의 구간이 쿼리 구간과 일부 겹치는 경우에는 재귀적으로 왼쪽과 오른쪽 서브트리로 나누어 쿼리를 수행합니다.

- 먼저 mid를 계산하여 현재 노드를 두 개의 자식 노드로 나눕니다.

- `self.query(L, R, 2 * node + 1, start, mid)`를 통해 왼쪽 자식의 구간에서 합을 구하고, 이어서 `self.query(L, R, 2 * node + 2, mid + 1, end)`를 통해 오른쪽 자식의 구간에서 합을 구합니다.

- 최종적으로 두 결과인 `left_sum`과 `right_sum`을 더하여 반환합니다.

 

5. 전체 흐름

- 주어진 구간 [L, R]의 합을 재귀적으로 탐색하여 구합니다.

- 케이스를 '구간 밖인 경우', '구간에 완전 포함하는 경우', '구간에 일부 포함하는 경우'를 분기하여 처리합니다.

- 최악의 경우에도 O(log n) 시간 복잡도로 쿼리를 처리할 수 있습니다.

 

-관련문제

https://www.acmicpc.net/problem/2042