개요
옛날부터 세그트리에 대한 궁금증이 있었는데, 코테의 당락을 결정하지는 않을 알고리즘이라 취업과는 상관없다는 생각에 미뤄두었던 기억이 나서 이번에 알아두려고 한다.
세그먼트 트리는 '여러 개의 데이터가 연속적으로 존재할 때, 특정 범위의 데이터의 합/곱/최대값/최소값 등을 구하는 방법'중 하나이다. 수열을 저장하는 배열 A가 존재한다고 가정했을 때, 다음과 같은 케이스를 생각해보자.
(A) 구간 l, r (l<=r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구하는 경우를 M회 반복
이 경우 시간복잡도는 O(N)이다. 누적합, 구간합을 구해놓은 후 특정 구간의 합은 O(1)에 도출할 수 있기 때문이다.
(B) 구간 l, r (l<=r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구하고 나서, i번째 수를 v로 바꾸는 경우를 M회 반복
이전 케이스와 다르게, 한번 구해놓은 누적합을 재사용할 수 없다. 따라서 구간합을 사용하게 되면 O(NM)의 시간이 걸린다.
S[0] = A[0];
for(int i=1; i<N; i++) {
S[i] = S[i-1] + A[i];
}
누적합을 구하기 위해 위와 같은 연산을 해 놓게 되는데, 이 때마다 합을 저장한 배열인 S의 값이 변경되기 때문이다. 세그먼트 트리를 사용하면 A[l]부터 A[r]을 구하는 연산과, i번째 수를 v로 바꾸는 경우 각각을 O(lgN)만에 수행할 수 있다.
세그먼트 트리
세그먼트 트리를 사용하면 트리 구조의 특성을 이용하여 O(logN)의 시간복잡도 안에 특정 구간의 데이터의 합, 최대값 등을 빠르게 구할 수 있다.

위 트리는 특정 구간의 데이터 합 정보를 갖는 세그먼트 트리에 저장되는 데이터의 정보를 표시한 그림이다. 리프 노드는 배열의 수 자체를 말하고, 인터널 노드는 배열의 왼쪽 자식과 오른쪽 자식의 합을 저장하고 있다.
따라서 세그먼트 트리의 크기(높이)는 N, 즉 배열의 크기에 따라서 결정된다. 따라서 N이 2의 a배승 형태인 배열에서는 Full Binary Tree(모든 노드가 자식을 0개 또는 2개 가지는 이진 트리)가 생성되고, 리프노드가 N개인 Full Binary Tree는 lg(N)의 높이와 2N-1개의 노드 개수를 갖는다. 만약 N이 2의 배승꼴이 아니라면, ⌈lg(N)⌉의 높이를 가지며, 세그먼트 트리를 만드는데 2^(⌈lg(N)⌉+1)-1 크기의 배열이 필요하다.
Tree 클래스
static class SegTree {
long[] tree;
int size;
public SegTree(int arrSize) {
// 트리 높이 구하기
int h = (int)Math.ceil(Math.log(arrSize)/Math.log(2));
// 높이를 이용한 배열 사이즈 구하기
this.size = (int) Math.pow(2, h+1);
tree = new long[size];
}
}
위와 같이 필드로 배열을 갖는 클래스를 사용할 수 있다. 트리의 높이( ⌈lg(N)⌉)를 구하고, 높이가 h일때의 노드 크기(2^(h+1)-1)만큼 트리 배열 크기 size를 지정한다. (여기서는 인덱스를 1부터 시작하도록 했다)
참고) 트리를 1차원 배열로 표현하는 방법
- 루트 인덱스 = 1
- 왼쪽 자식노드의 인덱스 = 부모 노드의 인덱스 * 2
- 오른쪽 자식노드의 인덱스 = 부모 노드의 인덱스 * 2 + 1
초기화
public long init(long[] arr, int node, int start, int end) {
// 배열과 시작과 끝이 같다면 leaf노드이므로 값 저장
if(start==end) {
return tree[node] = arr[start];
}
// 재귀 - leaf노드가 아니라면 자식노드(왼쪽+오른쪽)의 합 담기
return tree[node] = init(arr, node*2, start, (start+end)/2) + init(arr, node*2+1, (start+end)/2+1, end);
}
이렇게 구간 합 트리를 초기화하기 위해서는 반복과 재귀를 사용할 수 있는데, 재귀를 사용하는 것이 직관적이고 편리하다.
- 파라미터로 받는 arr배열은 데이터를 담은 배열 A를, node는 초기화를 진행할 노드를 말한다. 현재 노드(node)를 기준으로 왼쪽 자식노드(2*node), 오른쪽 자식노드(2*node+1)를 탐색해나가며 재귀적으로 탐색한다.
- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타낸다. 원 배열 A를 절반씩 나누어 탐색해나가는 병합 정렬과 유사하다고 볼 수 있다. start와 end가 동일하다면 리프 노드에 도달했다고 볼 수 있으므로 원 배열의 값을 트리에 저장해둔다.
구간 합 구하기
public long sum(int node, int start, int end, int left, int right) {
if(left>end || right<start) {
return 0;
}
if(left<=start && end<=right) {
return tree[node];
}
return sum(node*2, start, (start+end)/2, left, right) + sum(node*2+1, (start+end)/2+1, end, left, right);
}
이제 초기화한 트리를 이용하여 구간 합을 구할 수 있다.

- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타내고, left와 right는 원하는 누적합의 시작과 끝 인덱스를 말한다.
- 모든 내부 노드는 자식으로 갖는 리프 노드의 합을 갖고 있으므로, 자식노드로 내려가며 확인해가다가 현재 배열이 찾고자 하는 범위에서 벗어나면 0을 반환하고, 범위 안에 포함되면 해당 값을 반환한다.
- 위 예시에서 3~5까지의 구간합을 구하기 위해서는 arr[3]~arr[4]의 합을 가지고 있는 5번 노드(tree[5])와 arr[5]의 값을 가지고 있는 12번 노드(tree[5])의 값을 더해서 해결할 수 있다.
값 업데이트하기
public void update(int node, int start, int end, int idx, long diff) {
if(idx<start || end<idx) return;
tree[node] += diff;
if(start!=end) {
update(node*2, start, (start+end)/2, idx, diff);
update(node*2+1, (start+end)/2+1, end, idx, diff);
}
}

값을 업데이트한다는 것은 리프 노드를 업데이트한다는 것이고, 이는 즉 값을 바꿀 리프 노드의 부모 노드들의 값을 모두 바꾸어주면 된다.
- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타낸다.
- 파라미터로 들어오는 idx는 값을 변경할 원 배열의 인덱스, diff는 원 배열값의 원래 값과 새로운 값의 차이를 말한다.
- 만약 현재 범위 내에 idx가 포함되지 않는다면, 아무것도 할 필요가 없다.
- 그것이 아니라면, 현재 tree의 노드에 diff만큼을 더해주고, 재귀적으로 반복한다.
static class SegTree {
long[] tree;
int size;
public SegTree(int arrSize) {
// 트리 높이 구하기
int h = (int)Math.ceil(Math.log(arrSize)/Math.log(2));
// 높이를 이용한 배열 사이즈 구하기
this.size = (int) Math.pow(2, h+1);
tree = new long[size];
}
public long init(long[] arr, int node, int start, int end) {
// 배열과 시작과 끝이 같다면 leaf노드이므로 값 저장
if(start==end) {
return tree[node] = arr[start];
}
// 재귀 - leaf노드가 아니라면 자식노드(왼쪽+오른쪽)의 합 담기
return tree[node] = init(arr, node*2, start, (start+end)/2) + init(arr, node*2+1, (start+end)/2+1, end);
}
public void update(int node, int start, int end, int idx, long diff) {
if(idx<start || end<idx) return;
tree[node] += diff;
if(start!=end) {
update(node*2, start, (start+end)/2, idx, diff);
update(node*2+1, (start+end)/2+1, end, idx, diff);
}
}
public long sum(int node, int start, int end, int left, int right) {
if(left>end || right<start) {
return 0;
}
if(left<=start && end<=right) {
return tree[node];
}
return sum(node*2, start, (start+end)/2, left, right) + sum(node*2+1, (start+end)/2+1, end, left, right);
}
}
정리
이렇게 세그먼트 트리를 이용하여 특정 구간의 데이터의 합이나 곱, 최대값 등을 O(logN)에 빠르게 구할 수 있는 방법을 알아보았다. 지금까지는 특정 구간의 데이터 합을 구하는 트리를 설명했는데, '부모 노드가 자식 노드의 정보를 모두 포함하도록' 트리를 만들어주면 합 말고 다양한 정보들 또한 빠르게 구할 수 있다.
'[ CS기초 ] > 알고리즘' 카테고리의 다른 글
| [알고리즘] Traveling Salesman Problem (0) | 2024.04.08 |
|---|---|
| [알고리즘] LIS로 알아보는 역추적 기법 (0) | 2024.03.23 |
| [알고리즘] union-find 알고리즘 (0) | 2023.04.04 |
| [알고리즘] 누적합(부분 배열 합) (0) | 2023.02.18 |
| [알고리즘] 플로이드 알고리즘 (0) | 2023.01.09 |