티스토리 뷰
가중치가 주어진 트리에서의 최단거리를 구하는 알고리즘이다. 물론 트리도 그래프의 일종이므로 최단거리 알고리즘인 다익스트라나 SPFA알고리즘을 사용할 수 있겠으나, 이 문제는 쿼리가 많이 들어오기 때문에 트리 전용 최단거리 알고리즘이 필요하다. 그 알고리즘이 바로 LCA알고리즘을 조금만 응용시키면 두 정점사이의 최단거리를 O(logN)만에 구할 수 있다. 참고로 이 알고리즘은 오직 트리에서만 가능한데, 특별히 이 알고리즘이 더 빠른 이유는 트리의 특징중에 하나인 노드~노드로 갈 수 있는 경로가 unique하기 때문이다.
LCA알고리즘을 어떻게 응용하느냐가 관건이다. 뭐 복잡하지도 않고 단순히 생각해보면 정점 A~B의 최단거리를 구하려면 dist(A~LCA(A,B)+dist(B~LCA(A,B)'가 될 것이다. 설령 LCA(A,B)는 O(logN)만에 빠르게 구했다 할지라도, 그 간선의 가중치의 합을 구하는데 O(N)이 걸리면 말짱 도루묵이다. 따라서 간선의 가중치의 정보(dist)를 어떻게 빨리 더하고, 저장하느냐도 중요한데, 그 정답은 당연하게도 par배열을 저장했을 때처럼 2의 멱수 단위로 나눠서 저장하면 된다. 그래서 par[][]배열처럼 'dist[node][i]=현재 node에서 2^i번째 조상까지의 거리의 합' 이라고 정의를 하자. 그리고 똑같이 점화식을 구해야 하는데 이 역시도 par배열을 구했을 때처럼 하면 된다. 그 점화식은
dist[node][i] = dist[node][i-1](X) + dist[par[node][i-1]][i-1](Y) 이다.
X부분은 '현재 노드~2^(i-1)조상까지의 거리'이고, Y부분은 '2^(i-1)조상 노드~그 조상의 2^(i-1)조상까지의 거리'이다. 이렇게 dist배열을 만들어주고 난 이후는 LCA를 구하는 과정과 동일하다. 다만, 내가 관심있는 것은 최단 거리이니 LCA를 저장해놓을 필요는 없고 그냥 LCA를 만날 때까지의 거리만 계속 더하면 된다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 |
#include<stdio.h>
#include<vector>
#include<algorithm>
#define SIZE 40009
using namespace std;
int n, q, max_level=20;
int par[SIZE][21], depth[SIZE], dist[SIZE][21];
vector<pair<int, int>> tree[SIZE];
void dfs(int node, int parent, int d) {
depth[node] = depth[parent] + 1;
par[node][0] = parent;
dist[node][0] = d;
for (int i = 1; i <= max_level; i++) {
par[node][i] = par[par[node][i - 1]][i - 1];
dist[node][i] = dist[node][i - 1] + dist[par[node][i - 1]][i - 1];
}
for (auto nx : tree[node]) {
int next = nx.first, nextCost = nx.second;
if (next != parent) dfs(next, node, nextCost);
}
return;
}
int LCA_Dist(int a, int b) {
int result = 0;
if (a != b) {
if (depth[a] > depth[b]) swap(a, b);
for (int i = max_level; i >= 0; i--) {
if (depth[a] <= depth[par[b][i]]) {
result += dist[b][i];
b = par[b][i];
}
}
}
if (a != b) {
for (int i = max_level; i >= 0; i--) {
if (par[a][i] != par[b][i]) {
result += (dist[a][i] + dist[b][i]);
a = par[a][i]; b = par[b][i];
}
}
result += (dist[a][0] + dist[b][0]);
//a와 b노드는 LCA 노드의 바로 밑임. 따라서 최종적으로 한번 더 더해줘야함
}
return result;
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n - 1; i++) {
int a, b, c; scanf("%d %d %d", &a, &b, &c);
tree[a].push_back(make_pair(b, c));
tree[b].push_back(make_pair(a, c));
}
dfs(1, 0, 0);
scanf("%d", &q);
while (q--) {
int a, b; scanf("%d %d", &a, &b);
printf("%d\n", LCA_Dist(a, b));
}
return 0;
}
|
cs |
'Problem Solving > LCA' 카테고리의 다른 글
LCA(Lowest Common Ancestor)알고리즘 (0) | 2018.12.02 |
---|