티스토리 뷰

[CF 1092F] Tree with Maximum Cost - 2100

 

문제 설명 : 

각 노드별로 가중치가 정해져있는 트리가 주어진다. 이 트리에서의 최대 cost를 구해야 하는데 cost는 

(기준노드~다른노드까지의 거리) * (다른 노드의 가중치) 의 총합으로 정의된다. 여기서 거리는 기준노드~다른노드까지 거쳐야 하는 간선의 개수이다.

 

풀이 : 

어떤 임의의 노드 i를 기준으로 cost를 구한다고 하자. 그리고 i에서 j로 기준노드를 변경할 때 어떠한 연관성이 있는지 살펴볼 것이다.

$cost_i= 1*(a1+a2+...a_p) + 2*(a_(p+1)+...a_q) + 3*(...) + 1*(b1+b2+..b_s)+2*(...)$이다. 이 식에서 a1,a2...a_p는 노드 i에서 거리가 1만큼 떨어진 다른 노드들의 집합이다. 비슷하게 a_(p+1)...a_q는 거리가 2만큼 떨어진 노드들의 집합이다. 그리고 b1,b2...b_s 역시 거리가 1만큼 떨어진 노드들의 집합인데 a와 차이점은 j로 향하는 방향에 있는 노드들이다. 즉 a노드를 i노드를 기준으로 특정 방향, b노드는 특정 방향과 반대 방향 노드들의 집합이라고 정의된다.

이번에는 cost_j를 정의할건데 i와 j간의 변화를 관찰해보자. i->j로 이동하면서 기존의 a집합에 속했던 노드들의 거리는 모두 1씩 증가할 것이고, b집합에 속했던 노드들은 1씩 감소할 것이다. 그리고 새롭게 a_i라는 가중치가 생길 것이다. 

따라서 $cost_j = cost_i + ( 증가한 부분 ) - ( 감소한 부분)$ 이다. 증가한 부분은 거리의 a집합에서 가중치가 모두 1씩 증가했으므로 $a1+a2+...+alastA$가 될 것이다. 반대로 감소한 부분은 b집합에서 가중치가 모두 1씩 감소했으므로 $b1+b2+...+b_lastB$가 될 것이다. 문제는 이 식을 어떻게 빨리 계산하느냐이다. 만약 모든 노드에 대해서 다른 노드들을 모두 한번씩 순회해서 계산해본다면 $O(V^2)$이므로 불가능하다. 나는 어떻게든 각 노드마다 계산을 $O(logV)$나 $O(1)$만에 해내야 한다. 결론을 얘기하자면 $O(1)$만에 가능하다. 

트리를 구성한 뒤에 나는 임의의 루트 노드를 설정해서 트리의 위계가 존재한다고 가정할 것이다. 그래서 나는 트리의 임의의 노드 i의 subtree를 정의할 수 있고, subtree의 가중치들의 합도 구할 수 있다. 이렇게 구한 노드 i의 subtree의 가중치 합을 나는 dp[i]로 정의하겠다. 그러면 위에서 설명했던 (증가한 부분) = (모든 노드의 가중치) - ( dp[j] ) 이다. 증가한 부분은 노드 j의 subtree를 제외한 나머지 노드들의 가중치의 합이므로 전체에서 j 가중치의 합을 빼준다는 개념이다. 그리고 (감소한 부분) = dp[j] 이다. 위와 비슷한 이유이다.

따라서 최종 식은 $cost_j = cost_i + ( total - dp[j] ) - dp[j]$ 이고 정리하면 $cost_i + total - 2*dp[j]$가 된다.

 

 

주의할 점 : 

.

 

배울 점 : 

전형적인 tree dp문제이다. 트리의 임의의 노드를 루트로 설정해서 위계를 만들면 트리의 '방향성'을 만든다는 것과 같다. 그래서 특정 노드 i의 위쪽 가중치의 합은 dp[]처럼 미리 구해놓으면 나머지 아래쪽 가중치의 합은 전체 합에서 dp[]를 빼주면 된다는 개념이다.

 

코드 : 

const int SIZE= 200009;

vector<int> tree[SIZE];
int c[SIZE];
ll dp[SIZE];

ll sum=0,total=0,ans=0;

void dfs1(int node,int p,int dist){
    total+=1LL*dist*c[node];
    dp[node]=c[node];

    for(int &next:tree[node]){
        if(p==next) continue;

        dfs1(next,node,dist+1);
        dp[node]+=dp[next];
    }
}

void dfs2(int node,int p,ll cur){
    ans=max(ans,cur);

    for(int &next:tree[node]){
        if(p==next) continue;

        dfs2(next,node,cur+sum-2*dp[next]);
    }
}

int main(){
    int n;  scanf("%d",&n);

    for(int i=1 ; i<=n ; i++){
        scanf("%d",&c[i]);
        sum+=c[i];
    }

    for(int i=0 ; i<n-1 ; i++){
        int a,b;    scanf("%d %d",&a,&b);
        tree[a].push_back(b);
        tree[b].push_back(a);
    }

    dfs1(1,0,0);
    dfs2(1,0,total);

    printf("%lld\n",ans);
    return 0;
}

 

'Problem Solving > Dynamic Programming' 카테고리의 다른 글

[CF 1082E] Increasing Frequency  (0) 2019.11.14
[CF 1155D] Beautiful Array  (0) 2019.08.20
[CF 1114D] Flood Fill  (0) 2019.07.25
[Codeforces 1036C] Classy Numbers  (0) 2019.05.13
[Codefocres 1061C] Multiplicity  (0) 2019.04.29
댓글