티스토리 뷰

[Codeforces 1065C] Make It Equal


참 오랫동안 붙잡고 있었던 문제이다. 처음에는 바이너리 서치로 풀려고 해봤는데 시간이 터졌다. 그렇게 어떻게 풀어야 할지 일주일 넘게 고민하다가 결국 솔루션을 봤다. 

솔루션은 '지금까지 소모된 cost'(=sum) + '현재 높이에서 새롭게 추가된 막대의 수'(=cnt[i])의 총합을 계산해서 k이하일 때는 계속 sum에 누적해나가고, k를 초과했다면 slice수를 추가하고, sum=cnt[i-1]로 두면 된다. 그리고 다음 높이(i-1)에서 소모되는 cost는 cnt[i-1]+=cnt[i]로 구할 수 있다.

문제의 첫 번째 예시 입력처럼 1 2 2 3 4 가 있다고 가정해보자. 가장 높이가 높은 4부터 시작해서 그 다음 높이인 3에 도달했다고 하자. 그러면 sum은 1, cnt[3]는 1+1=2가 되어 총합은 3이다. 그리고 다음 높이인 cnt[3-1]+=cnt[3]이 되어서 4가 된다. 2가 되었을 때는 현재 sum=3, cnt[3]는 4가 되어서 총합이 7이 된다. 이는 k=5다 큰 값이므로 slice를 한번 증가시켜야 함을 의미한다. 따라서 slice 1에는 높이 4,3이 포함되고, slice 2에는 높이 2가 해당되는 것이다. 

이 풀이가 성립하는 이유는 바로 input 데이터 값의 범위를 이용하기 때문이다. 입력으로 주어진 막대의 높이(h)값을 모두 cnt[]배열에 카운트트한다. 그 다음에 가장 높은 높이부터 가장 작은 높이까지 1씩 줄여나가면서 막대의 높이를 살펴본다. 실제로 시간복잡도를 계산해봐도 O(max(n,h))라고 할 수 있다. 이 문제도 일종의 라인스위핑 유형의 문제라고 볼 수도 있을 것 같다.



1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include<stdio.h>
#include<math.h>
#include<string.h>
#include<iostream>
#include<functional>
#include<string>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<stdlib.h>
#include<stack>
using namespace std;
typedef pair<int,int> pii;
typedef pair<int,int> Cord;
typedef long long ll;
/****************************Default Template******************************/
#define SIZE 200009
// #define IMP
// #define INF
// const int MOD=
struct S{
    int a,b;
 
    S(){}
    S(int _a,int _b){
        a=_a;
        b=_b;
    }
 
    const bool operator<(const S &o) const{
        return a<o.a;
    }
};
priority_queue<int> pq;
priority_queue<int,vector<int>,greater<int>> mpq;
map<int,int> mp;
stack<int> st;
set<int> set_i;
/****************************Default Template******************************/
 
 
int cnt[SIZE], ans=0;
int main(){
    int n,k;    scanf("%d %d",&n,&k);
 
    int min_h=1<<30,max_h=0;
    for(int i=1 ; i<=n ; i++){
        int inp;    scanf("%d",&inp);
        cnt[inp]++;
        min_h=min(min_h,inp);
        max_h=max(max_h,inp);
    }
 
    int sum=0;
    for(int i=max_h ; i>min_h ;i--){
        sum+=cnt[i];
        cnt[i-1]+=cnt[i];
 
        if(sum>k){
            sum=cnt[i];
            ans++;
        }
    }
    if(sum) ans++;
 
    printf("%d\n",ans);
    return 0;
}
 
cs



댓글