티스토리 뷰

 임의의 n개의 수 중에서 k번째로 큰 수(혹은 작은 수)를 찾으려면 어떻게 해야 할까? 가장 단순한 방법으로는 n개를 모두 정렬한 뒤 앞이나 뒤에서부터 k번째 있는 수를 찾으면 된다. 아니면 크기가 k짜리의 최소 힙(최대 힙)을 이용할 수도 있다. 그런데 사실 이것들은 조금 꼼수같다는 기분이 든다. 꼼수부리지 않는 정석적인 알고리즘이 바로 Quick Selection 알고리즘이다.

  이 알고리즘은 Quick Sort 알고리즘을 약간 변형시킨 알고리즘이다. Quick Sort 알고리즘은 pivot을 하나 정해서 자신의 위치를 찾아서 고정시키고 그 앞뒤로 다시 재귀호출을 통해서 수를 정렬시키는데, Quick Selection은 이때 앞뒤로 2번의 재귀 호출이 아닌 본인에게 필요한 딱 1번의 재귀호출만 하는 함수이다. 그 원리는 간단하다. 방금 말했다시피 pivot이 자신의 자리를 찾아서 고정하는데 자리가 고정되었다는 의미는 전체 수 중에서 pivot이 몇 번째로 큰지(작은지)확정되었다는 뜻이다. 예를 들어 만약 전체 10개의 수가 있는데 이번에 pivot이 index가 3에 위치하게 되었다면 그 pivot은 4번째로 작은 수라는 것이 확정이라는 것이다. 그런데 이때 사용자는 7번째로 작은 수를 찾는 것이 목표라고 가정하자. 그렇다면 여기서 나는 index 3보다 앞을 봐야할까, 뒤를 봐야할까? 정답은 뒤를 봐야한다. 앞을 본다 한들 1~3번째 작은 수만 찾을 수 있기 때문이다. 따라서 뒤를 봐야지만 5~10번째 작은 수를 찾을 수 있는 것이다. 이렇게 내가 찾고자 하는 k번째 수와 현재 pivot의 고정된 인덱스의 대소 비교를 통해서 앞을 볼지, 뒤를 볼지 분기를 나눠줘야 한다. 이 알고리즘의 시간 복잡도는 최적의 경우 O(N)이지만, pivot을 설정해야 하므로 최악의 경우 O(N^2)이 걸릴 수도 있어서 다소 조심스럽다. N이 만약 10만이 넘어간다면 차라리 O(NlogN)이 보장된 정렬이 더 안전할 수도 있다. 

 참고로 k번째 작은 수를 찾는 알고리즘은 이미 stl에 구현되어 있는데 'nth_element'이고, 'algorithm'헤더에 있다. 직접 구현한 Quick Selection의 경우 11004-K번째 수 문제에서 시간이 초과되었는데, 이 'nth_element'는 통과했다. 아마 좋은 최적화를 거친 모양이다. 따라서 간단한 예시 코드만 첨부하겠다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include<stdio.h>
#include<algorithm>
using namespace std;
int arr[100];
int Size, k;
void qSelection(int l, int r) {
    if (l >= r) return;
 
    int i = l + 1, j = r, pivot = arr[l];
 
    while (i <= j) {
        while (arr[i] < pivot && i <= j) i++;
        while (arr[j] >= pivot && i <= j) j--;
 
        if (i < j) swap(arr[i], arr[j]);
    }
    swap(arr[l], arr[j]);
 
    if (Size - j == k) {
        printf("%d-th greatest element : %d\n", k, arr[j]);
        return;
    }
 
    printf("%d는 %d번째로 큰 수입니다\n", arr[j], Size - j);
    if (Size - j > k) qSelection(j + 1, r);
    else qSelection(l, j - 1);
 
    return;
}
int main() {
    printf("Size: "); scanf("%d"&Size);
    printf("k-th element : ");    scanf("%d"&k);
 
    for (int i = 0; i < Size; i++)
        scanf("%d"&arr[i]);
 
    qSelection(0, Size - 1);
 
    return 0;
}
 
cs


그리고 다음은 nth_element의 코드이다. 파라미터는

nth_element(시작지점 주소, k번째(0 base), 끝지점 주소)

이고, 함수 실행 결과, k번째 원소를 기준으로 그곳의 왼쪽은 작은 값, 오른 쪽은 큰 값으로 partition된다. 따라서 arr[k-1]번째는 k번째로 작은 수임이 보장된다.


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#include<stdio.h>
#include<algorithm>
using namespace std;
#define SIZE 5000009
int arr[SIZE];
int main() {
    int n,k;    scanf("%d %d"&n,&k);
 
    for (int i = 0; i < n; i++)
        scanf("%d"&arr[i]);
 
    nth_element(arr, arr + k-1, arr + n);
 
    printf("%d\n", arr[k - 1]);
 
    return 0;
}
cs


댓글