티스토리 뷰

ACM 예선에서 트리에서의 두 정점간의 최단거리를 구하는 방법을 몰라서 충격받고 바로 그 방법인 LCA를 공부했다. 공부하고보니 그렇게 어려운 알고리즘도 아니어서 더욱 아까웠다.

LCA 알고리즘은 말 그대로 트리에서 두 정점간의 가장 깊은(루트에서 먼) 부모 노드를 찾아내는 알고리즘이다. 그 부모 노드를 찾으면 거리 또한 자연스럽게 알아낼 수 있다. 일단 naive하게 두 노드가 서로 만날 때까지 위로 한칸 한칸씩만 올라가면, 트리가 일직선으로 쭉 뻗어있는 모형에서 시간이 O(N)이 소요된다(트리의 노드의 개수가 N개). 따라서 11438-LCA2 같이 노드가 10만개, query가 10만개가 들어올 경우 시간이 터지게 된다. 따라서 O(N)보다 더 빠르게 찾아내야 하는데 실제로 O(logN)만에 가능하다. 이제부터 그 방법을 설명하겠다. O(logN)이 가능한 이유는 크게 두 가지 전제가 있기 때문이다.


1. LCA(a,b)=m이면, m을 기준으로 그 밑은 절대로 LCA가 존재하지 않고, 그 위는 반드시 LCA가 존재한다.

2. 모든 정수꼴은 이진수의 형태로 표현이 가능하다.


1번 전제를 통해 우리는 이분 탐색의 아이디어를 차용할 수 있다. 이분 탐색이 가능하기 위해서는 기준점의 한 쪽은 반드시 정답, 다른 한 쪽은 반드시 오답이 되기 때문인데 이 말이 곧 1번 전제와 같은 뜻이다. 그리고 2번 전제는 밑에서 설명하겠다.

먼저 LCA알고리즘을 수행하기 위해서는 전처리가 필요한데, 채워야할 정보는 1. 노드의 깊이를 저장하는 'depth[]배열, 2.par[a][i]=현재 노드 a에서 2^i만큼 위의 조상' 이다. par[][]배열을 구체적으로 설명하자면 par[a][0]=a 노드의 2^0(=1)만큼의 조상노드(=바로 위의 부모노드)이다. 따라서 i의 값이 증가할수록 기하급수적으로 2^i가 커져서 얕은 깊이의 노드라면 다 못 채워질수도 있을텐데, 그러면 그냥 0으로 놔두면 된다. 이 두 배열을 채우기 위해서 dfs를 순회할 것이다. 그 소스코드는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
void dfs(int node, int parent) {
    depth[node] = depth[parent] + 1;
    par[node][0= parent;
 
    for (int i = 1; i <= max_level; i++) {
        par[node][i] = par[par[node][i - 1]][i - 1];
    }
    for (int next : tree[node])
        if(next!=parent) dfs(next, node);
    return;
}
cs

위에서 par[][]배열의 중요한 점화식이 등장했다. par[node][i]=par[par[node][i-1]][i-1]인데 이게 무슨 뜻이냐면, i가 커지면서 node의 2^i번째 조상을 찾으러 수고롭게 위로 올라갈 필요가 없이, 바로 직전인 2^(i-1)번째 조상의 2^(i-1)번째 조상이 그것과 같다는 뜻이다. 일종의 dp의 메모이제이션 기법이라고 할 수 있다. 그리고 for문에서 등장하는 변수 max_level은 쉽게 얘기해서 log2(MaxOfN)의 upper bound로 설정하면 된다. 이 문제에서는 N=10만이므로 log2(10만)을 넉넉잡아서 20이면 충분할 것이다.


여기까지가 전처리 과정이었고, 이제 본격적으로 query를 받아서 LCA를 찾아내는 과정을 볼 것이다. 그 과정도 두 개의 과정으로 나뉘는데 첫 번째는 query로 들어온 두 노드(a,b라고 하자)의 깊이를 일치시켜 주는 것이다. 즉, 높이가 더 깊은 노드를 얕은 노드의 깊이까지 끌어 올리는 것이다. 여기서 '트리'라는 그래프는 노드~노드로 가는 경로가 unique하기 때문에 노드를 끌어 올린다고해서, LCA에 영향을 끼치지 않는다. 편의상 깊이가 더 깊은 쪽을 b노드 쪽이라고 하고, 만약 a가 더 깊다면 b와 swap해주도록 하자. 그리고 두 번째로는 깊이가 일치된 두 노드가 서로 LCA가 일치할 때까지 위로 쭉쭉 올리는 것이다. 다만 여기서 중요한 점은 2의 멱수만큼 위로 올리는 것이다. 위에서 언급했다시피 1개씩만 올라가면 시간이 터질 수도 있기 때문이다. 따라서 큼직하게 2의 멱수만큼 위로 올라가는 것인데, 이렇게 했을 때의 문제점은 빈틈이 생길 수도 있다는 점이다. 1->2->4->8...이렇게 올라가는 과정에서 2->4에서는 3이 비고, 4->8에서는 5,6,7이 비는 것처럼 말이다. 이럴 때는 다시 멱수의 scale을 줄여서 꼭 맞을 때까지 줄여나가면 된다(예를 들어 11이라는 수를 8+2+1로 표현 가능한 것처럼 말이다). 이것이 위에서 얘기한 전제 2번. '모든 정수꼴은 이진수의 형태로 표현이 가능하다.'라는 말과 일맥상통하다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
int LCA(int u, int v) {
    if (depth[u] != depth[v]) {
        if (depth[u] > depth[v]) swap(u, v);
 
        for (int i = max_level; i >= 0; i--) {
            if (depth[u] <= depth[par[v][i]])
                v = par[v][i];
        }
    }
    int lca = u;
    if (u != v) {
        for (int i = max_level; i >= 0; i--) {
            if (par[u][i] != par[v][i]) {
                u = par[u][i];    v = par[v][i];
            }
        }
        lca = par[u][0];
    }
    return lca;
}
cs

LCA의 코드와 논리를 살펴보니 2의 멱수만큼 구간들을 큼직하게 저장해놓고 정수꼴을 표현하는게 마치 'Fenwick Tree' 자료구조와도 흡사하다.


11438-LCA2 전체코드는 다음과 같다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<stdio.h>
#include<vector>
#include<algorithm>
using namespace std;
#define SIZE 100009
vector<int> tree[SIZE];
int n, q, max_level = 20;
int depth[SIZE], par[SIZE][21];
void dfs(int node, int parent) {
    depth[node] = depth[parent] + 1;
    par[node][0= parent;
 
    for (int i = 1; i <= max_level; i++) {
        par[node][i] = par[par[node][i - 1]][i - 1];
    }
    for (int next : tree[node])
        if(next!=parent) dfs(next, node);
    return;
}
int LCA(int u, int v) {
    if (depth[u] != depth[v]) {
        if (depth[u] > depth[v]) swap(u, v);
 
        for (int i = max_level; i >= 0; i--) {
            if (depth[u] <= depth[par[v][i]])
                v = par[v][i];
        }
    }
    int lca = u;
    if (u != v) {
        for (int i = max_level; i >= 0; i--) {
            if (par[u][i] != par[v][i]) {
                u = par[u][i];    v = par[v][i];
            }
        }
        lca = par[u][0];
    }
    return lca;
}
int main() {
    scanf("%d"&n);
 
    for (int i = 0; i < n - 1; i++) {
        int a, b;    scanf("%d %d"&a, &b);
        tree[a].push_back(b);    tree[b].push_back(a);
    }
 
    scanf("%d"&q);
    dfs(10);
    while (q--) {
        int a, b;    scanf("%d %d"&a, &b);
        printf("%d\n", LCA(a,b));
    }
 
    return 0;
}
 
cs


LCA알고리즘을 배우면서 2의 멱수(Binary적인 테크닉?)은 시간을 logN으로 줄일 수 있는 효과적인 방법이란 것을 깨달았다.

'Problem Solving > LCA' 카테고리의 다른 글

1761-정점들의 거리(LCA)  (0) 2018.12.02
댓글