티스토리 뷰
[CF 1092F] Tree with Maximum Cost - 2100
문제 설명 :
각 노드별로 가중치가 정해져있는 트리가 주어진다. 이 트리에서의 최대 cost를 구해야 하는데 cost는
(기준노드~다른노드까지의 거리) * (다른 노드의 가중치) 의 총합으로 정의된다. 여기서 거리는 기준노드~다른노드까지 거쳐야 하는 간선의 개수이다.
풀이 :
어떤 임의의 노드 i를 기준으로 cost를 구한다고 하자. 그리고 i에서 j로 기준노드를 변경할 때 어떠한 연관성이 있는지 살펴볼 것이다.
$cost_i= 1*(a1+a2+...a_p) + 2*(a_(p+1)+...a_q) + 3*(...) + 1*(b1+b2+..b_s)+2*(...)$이다. 이 식에서 a1,a2...a_p는 노드 i에서 거리가 1만큼 떨어진 다른 노드들의 집합이다. 비슷하게 a_(p+1)...a_q는 거리가 2만큼 떨어진 노드들의 집합이다. 그리고 b1,b2...b_s 역시 거리가 1만큼 떨어진 노드들의 집합인데 a와 차이점은 j로 향하는 방향에 있는 노드들이다. 즉 a노드를 i노드를 기준으로 특정 방향, b노드는 특정 방향과 반대 방향 노드들의 집합이라고 정의된다.
이번에는 cost_j를 정의할건데 i와 j간의 변화를 관찰해보자. i->j로 이동하면서 기존의 a집합에 속했던 노드들의 거리는 모두 1씩 증가할 것이고, b집합에 속했던 노드들은 1씩 감소할 것이다. 그리고 새롭게 a_i라는 가중치가 생길 것이다.
따라서 $cost_j = cost_i + ( 증가한 부분 ) - ( 감소한 부분)$ 이다. 증가한 부분은 거리의 a집합에서 가중치가 모두 1씩 증가했으므로 $a1+a2+...+alastA$가 될 것이다. 반대로 감소한 부분은 b집합에서 가중치가 모두 1씩 감소했으므로 $b1+b2+...+b_lastB$가 될 것이다. 문제는 이 식을 어떻게 빨리 계산하느냐이다. 만약 모든 노드에 대해서 다른 노드들을 모두 한번씩 순회해서 계산해본다면 $O(V^2)$이므로 불가능하다. 나는 어떻게든 각 노드마다 계산을 $O(logV)$나 $O(1)$만에 해내야 한다. 결론을 얘기하자면 $O(1)$만에 가능하다.
트리를 구성한 뒤에 나는 임의의 루트 노드를 설정해서 트리의 위계가 존재한다고 가정할 것이다. 그래서 나는 트리의 임의의 노드 i의 subtree를 정의할 수 있고, subtree의 가중치들의 합도 구할 수 있다. 이렇게 구한 노드 i의 subtree의 가중치 합을 나는 dp[i]로 정의하겠다. 그러면 위에서 설명했던 (증가한 부분) = (모든 노드의 가중치) - ( dp[j] ) 이다. 증가한 부분은 노드 j의 subtree를 제외한 나머지 노드들의 가중치의 합이므로 전체에서 j 가중치의 합을 빼준다는 개념이다. 그리고 (감소한 부분) = dp[j] 이다. 위와 비슷한 이유이다.
따라서 최종 식은 $cost_j = cost_i + ( total - dp[j] ) - dp[j]$ 이고 정리하면 $cost_i + total - 2*dp[j]$가 된다.
주의할 점 :
.
배울 점 :
전형적인 tree dp문제이다. 트리의 임의의 노드를 루트로 설정해서 위계를 만들면 트리의 '방향성'을 만든다는 것과 같다. 그래서 특정 노드 i의 위쪽 가중치의 합은 dp[]처럼 미리 구해놓으면 나머지 아래쪽 가중치의 합은 전체 합에서 dp[]를 빼주면 된다는 개념이다.
코드 :
const int SIZE= 200009;
vector<int> tree[SIZE];
int c[SIZE];
ll dp[SIZE];
ll sum=0,total=0,ans=0;
void dfs1(int node,int p,int dist){
total+=1LL*dist*c[node];
dp[node]=c[node];
for(int &next:tree[node]){
if(p==next) continue;
dfs1(next,node,dist+1);
dp[node]+=dp[next];
}
}
void dfs2(int node,int p,ll cur){
ans=max(ans,cur);
for(int &next:tree[node]){
if(p==next) continue;
dfs2(next,node,cur+sum-2*dp[next]);
}
}
int main(){
int n; scanf("%d",&n);
for(int i=1 ; i<=n ; i++){
scanf("%d",&c[i]);
sum+=c[i];
}
for(int i=0 ; i<n-1 ; i++){
int a,b; scanf("%d %d",&a,&b);
tree[a].push_back(b);
tree[b].push_back(a);
}
dfs1(1,0,0);
dfs2(1,0,total);
printf("%lld\n",ans);
return 0;
}
'Problem Solving > Dynamic Programming' 카테고리의 다른 글
[CF 1082E] Increasing Frequency (0) | 2019.11.14 |
---|---|
[CF 1155D] Beautiful Array (0) | 2019.08.20 |
[CF 1114D] Flood Fill (0) | 2019.07.25 |
[Codeforces 1036C] Classy Numbers (0) | 2019.05.13 |
[Codefocres 1061C] Multiplicity (0) | 2019.04.29 |