[알고리즘] 세그먼트 트리 (Segment Tree)

개요

취직해서 지금 당장 좋은 점 중 하나는 취직 전까지는 내가 하고 싶은 공부를 할 수 있다는 것이기도 하다.

사실 옛날부터 세그트리에 대한 궁금증이 있었는데, 코테의 당락을 결정하지는 않을 알고리즘이라 미뤄두었던 기억이 나서 이번에 알아두려고 한다.

 

 

세그먼트 트리는 '여러 개의 데이터가 연속적으로 존재할 때, 특정 범위의 데이터의 합/곱/최대값/최소값 등을 구하는 방법'중 하나이다. 수열을 저장하는 배열 A가 존재한다고 가정했을 때, 다음과 같은 케이스를 생각해보자.

 

(A) 구간 l, r (l<=r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구하는 경우를 M회 반복

 

 이 경우 시간복잡도는 O(N)이다. 누적합, 구간합을 구해놓은 후 특정 구간의 합은 O(1)에 도출할 수 있기 때문이다.

 

 

(B) 구간 l, r (l<=r)이 주어졌을 때, A[l] + A[l+1] + ... + A[r-1] + A[r]을 구하고 나서, i번째 수를 v로 바꾸는 경우를 M회 반복

 

 이전 케이스와 다르게, 한번 구해놓은 누적합을 재사용할 수 없다. 따라서 구간합을 사용하게 되면 O(NM)의 시간이 걸린다. 

S[0] = A[0];
for(int i=1; i<N; i++) {
    S[i] = S[i-1] + A[i];
}

 

누적합을 구하기 위해 위와 같은 연산을 해 놓게 되는데, 이 때마다 합을 저장한 배열인 S의 값이 변경되기 때문이다. 세그먼트 트리를 사용하면 A[l]부터 A[r]을 구하는 연산과, i번째 수를 v로 바꾸는 경우 각각을 O(lgN)만에 수행할 수 있다.

 

 

 

세그먼트 트리

세그먼트 트리를 사용하면 트리 구조의 특성을 이용하여 O(logN)의 시간복잡도 안에 특정 구간의 데이터의 합, 최대값 등을 빠르게 구할 수 있다.

 

위 트리는 특정 구간의 데이터 합 정보를 갖는 세그먼트 트리에 저장되는 데이터의 정보를 표시한 그림이다. 리프 노드는 배열의 수 자체를 말하고, 인터널 노드는 배열의 왼쪽 자식과 오른쪽 자식의 합을 저장하고 있다. 

 

따라서 세그먼트 트리의 크기(높이)는 N, 즉 배열의 크기에 따라서 결정된다. 따라서 N이 2의 a배승 형태인 배열에서는 Full Binary Tree(모든 노드가 자식을 0개 또는 2개 가지는 이진 트리)가 생성되고, 리프노드가 N개인 Full Binary Tree는 lg(N)의 높이와 2N-1개의 노드 개수를 갖는다. 만약 N이 2의 배승꼴이 아니라면, lg(N)의 높이를 가지며, 세그먼트 트리를 만드는데 2^(⌈lg(N)⌉+1)-1 크기의 배열이 필요하다.

 

Tree 클래스

static class SegTree {
    long[] tree;
    int size;
    
    public SegTree(int arrSize) {
        // 트리 높이 구하기
        int h = (int)Math.ceil(Math.log(arrSize)/Math.log(2));
        // 높이를 이용한 배열 사이즈 구하기
        this.size = (int) Math.pow(2, h+1);
        tree = new long[size];
    }
}

 

위와 같이 필드로 배열을 갖는 클래스를 사용할 수 있다. 트리의 높이( lg(N)⌉)를 구하고, 높이가 h일때의 노드 크기(2^(h+1)-1)만큼 트리 배열 크기 size를 지정한다. (여기서는 인덱스를 1부터 시작하도록 했다)

 

참고) 트리를 1차원 배열로 표현하는 방법
- 루트 인덱스 = 1
- 왼쪽 자식노드의 인덱스 = 부모 노드의 인덱스 * 2
- 오른쪽 자식노드의 인덱스 = 부모 노드의 인덱스 * 2 + 1

 

 

초기화

public long init(long[] arr, int node, int start, int end) {
    // 배열과 시작과 끝이 같다면 leaf노드이므로 값 저장
    if(start==end) {
        return tree[node] = arr[start];
    }

    // 재귀 - leaf노드가 아니라면 자식노드(왼쪽+오른쪽)의 합 담기
    return tree[node] = init(arr, node*2, start, (start+end)/2) + init(arr, node*2+1, (start+end)/2+1, end);
}

 

이렇게 구간 합 트리를 초기화하기 위해서는 반복과 재귀를 사용할 수 있는데, 재귀를 사용하는 것이 직관적이고 편리하다.

 

- 파라미터로 받는 arr배열은 데이터를 담은 배열 A를, node는 초기화를 진행할 노드를 말한다. 현재 노드(node)를 기준으로 왼쪽 자식노드(2*node), 오른쪽 자식노드(2*node+1)를 탐색해나가며 재귀적으로 탐색한다.

- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타낸다. 원 배열 A를 절반씩 나누어 탐색해나가는 병합 정렬과 유사하다고 볼 수 있다. start와 end가 동일하다면 리프 노드에 도달했다고 볼 수 있으므로 원 배열의 값을 트리에 저장해둔다.

 

 

구간 합 구하기

public long sum(int node, int start, int end, int left, int right) {
    if(left>end || right<start) {
        return 0;
    }

    if(left<=start && end<=right) {
        return tree[node];
    }

    return sum(node*2, start, (start+end)/2, left, right) + sum(node*2+1, (start+end)/2+1, end, left, right);
}

 

이제 초기화한 트리를 이용하여 구간 합을 구할 수 있다. 

 

https://cano721.tistory.com/38

 

- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타내고, left와 right는 원하는 누적합의 시작과 끝 인덱스를 말한다.

- 모든 내부 노드는 자식으로 갖는 리프 노드의 합을 갖고 있으므로, 자식노드로 내려가며 확인해가다가 현재 배열이 찾고자 하는 범위에서 벗어나면 0을 반환하고, 범위 안에 포함되면 해당 값을 반환한다.

- 위 예시에서 3~5까지의 구간합을 구하기 위해서는 arr[3]~arr[4]의 합을 가지고 있는 5번 노드(tree[5])와 arr[5]의 값을 가지고 있는 12번 노드(tree[5])의 값을 더해서 해결할 수 있다.

 

 

 

값 업데이트하기

public void update(int node, int start, int end, int idx, long diff) {
    if(idx<start || end<idx) return;

    tree[node] += diff;

    if(start!=end) {
        update(node*2, start, (start+end)/2, idx, diff);
        update(node*2+1, (start+end)/2+1, end, idx, diff);
    }
}

 

https://cano721.tistory.com/38

 

값을 업데이트한다는 것은 리프 노드를 업데이트한다는 것이고, 이는 즉 값을 바꿀 리프 노드의 부모 노드들의 값을 모두 바꾸어주면 된다.

 

- start와 end는 원배열에서 구간 합 범위의 시작과 끝 인덱스를 나타낸다.

- 파라미터로 들어오는 idx는 값을 변경할 원 배열의 인덱스, diff는 원 배열값의 원래 값과 새로운 값의 차이를 말한다.

- 만약 현재 범위 내에 idx가 포함되지 않는다면, 아무것도 할 필요가 없다.

- 그것이 아니라면, 현재 tree의 노드에 diff만큼을 더해주고, 재귀적으로 반복한다.

 

 

static class SegTree {
    long[] tree;
    int size;

    public SegTree(int arrSize) {
        // 트리 높이 구하기
        int h = (int)Math.ceil(Math.log(arrSize)/Math.log(2));
        // 높이를 이용한 배열 사이즈 구하기
        this.size = (int) Math.pow(2, h+1);
        tree = new long[size];
    }

    public long init(long[] arr, int node, int start, int end) {
        // 배열과 시작과 끝이 같다면 leaf노드이므로 값 저장
        if(start==end) {
            return tree[node] = arr[start];
        }

        // 재귀 - leaf노드가 아니라면 자식노드(왼쪽+오른쪽)의 합 담기
        return tree[node] = init(arr, node*2, start, (start+end)/2) + init(arr, node*2+1, (start+end)/2+1, end);
    }

    public void update(int node, int start, int end, int idx, long diff) {
        if(idx<start || end<idx) return;

        tree[node] += diff;

        if(start!=end) {
            update(node*2, start, (start+end)/2, idx, diff);
            update(node*2+1, (start+end)/2+1, end, idx, diff);
        }
    }

    public long sum(int node, int start, int end, int left, int right) {
        if(left>end || right<start) {
            return 0;
        }

        if(left<=start && end<=right) {
            return tree[node];
        }

        return sum(node*2, start, (start+end)/2, left, right) + sum(node*2+1, (start+end)/2+1, end, left, right);
    }
}

 

 

정리

이렇게 세그먼트 트리를 이용하여 특정 구간의 데이터의 합이나 곱, 최대값 등을 O(logN)에 빠르게 구할 수 있는 방법을 알아보았다. 지금까지는 특정 구간의 데이터 합을 구하는 트리를 설명했는데, '부모 노드가 자식 노드의 정보를 모두 포함하도록' 트리를 만들어주면 합 말고 다양한 정보들 또한 빠르게 구할 수 있다.