Tree DP and Rerooting
The Intuition
Some tree questions are about a single subtree. The size of the subtree at each node, the deepest leaf below a node, the maximum sum path inside a subtree. For these, the natural plan is to compute children first, then bring their answers up to the parent. This is the bottom-up pattern. It runs in O(n) because each node is visited once.
Other tree questions are about the whole tree but answered "from the point of view of every node". For example, the sum of the distances from a node to every other node. If you ask that question once with one fixed root, a normal bottom-up walk works. But if you need the answer for every node as the root, doing one fresh walk per node is slow.
Rerooting gives you all the answers in one extra pass. The first pass picks any node as the root and computes the answer for that single root using bottom-up tree DP. The second pass walks the tree again, this time from the root outward. When the walk steps from a parent to a child, the second pass updates the answer to reflect "what if this child were the root instead?". The update is small because moving the root by one edge only changes the contribution of the small subtree on one side and the large subtree on the other.
Concretely for sum of distances. When the root moves from a node u to its child v, every node inside v's subtree is one step closer (so the sum drops by the size of v's subtree). Every node outside v's subtree is one step farther (so the sum grows by the count of nodes outside). The new answer is answer(u) - size(v) + (n - size(v)). That is the rerooting formula. It runs in O(1) per edge, so the second pass is O(n). The total is O(n).
The plain version of tree DP (just the first pass) is enough for any question whose answer can be expressed as "combine my children's answers". The rerooting addition is needed only when you must answer the same question from every possible viewpoint.
Subtree sizes, depths, counts, sums, max-gain, House Robber on trees, and anytime the question is "compute X for every node as if it were the root". Reroot whenever the question wants results for every node and a single root walk is too slow.
- The graph has cycles (this is for trees only)
- The answer is global and does not change with the root (just compute it once)
- The state at each node depends on values that cannot be combined cheaply when the root moves
Variations:
- Bottom-up only: When you need a single answer for one fixed root.
- Two pass with reroot: When you need the answer for every node as the root.
- Carry extra state: Sometimes you also need the depth, the deepest leaf, or a count along with the main value. Carry them as a small tuple from each child.
- a tree with a single node (the answer is 0 for distance-sum problems)
- a long thin tree where recursion depth equals the number of nodes (use an iterative DFS for very large inputs)
- disconnected input (this pattern assumes a connected tree)
Key Points
- •First pass walks from leaves up. Each parent reads answers from its children and computes its own.
- •This works when the question is about a subtree (size, depth, max gain, count).
- •Some questions ask "what is the answer if X were the root?" for every node X.
- •Doing that with one fresh walk per root is slow (O(n*n)).
- •Rerooting does it in O(n). A second walk goes from the original root outward and adjusts the answer when the root moves to a neighbor.
- •The adjustment subtracts the contribution of the new root and adds the contribution of the old root.
Code Template
1 class TreeNode:
2 def __init__(self, val=0, left=None, right=None):
3 self.val = val
4 self.left = left
5 self.right = right
6
7 def subtree_size(root):
8 """Size of the subtree at every node. Classic bottom-up tree DP."""
9 sizes = {}
10
11 def dfs(node):
12 if not node:
13 return 0
14 sizes[node] = 1 + dfs(node.left) + dfs(node.right)
15 return sizes[node]
16
17 dfs(root)
18 return sizes
19
20 # Rerooting example on an undirected tree given as adjacency list.
21 # Goal: for every node v, return the sum of distances from v to all other nodes.
22 def sum_of_distances(n, edges):
23 graph = [[] for _ in range(n)]
24 for a, b in edges:
25 graph[a].append(b)
26 graph[b].append(a)
27
28 count = [1] * n # size of subtree rooted at each node when 0 is the root
29 answer = [0] * n # final result; first computed for root 0, then propagated
30
31 # Pass 1: bottom-up from root 0. Compute subtree sizes and the answer for node 0.
32 def post_order(node, parent):
33 for nxt in graph[node]:
34 if nxt == parent:
35 continue
36 post_order(nxt, node)
37 count[node] += count[nxt]
38 answer[node] += answer[nxt] + count[nxt]
39
40 # Pass 2: top-down. Move the root from a node to one of its children.
41 # When we move the root from u to v, distances to nodes inside v's subtree go down by 1
42 # for each of count[v] nodes, and distances to nodes outside go up by 1 for each of (n - count[v]) nodes.
43 def pre_order(node, parent):
44 for nxt in graph[node]:
45 if nxt == parent:
46 continue
47 answer[nxt] = answer[node] - count[nxt] + (n - count[nxt])
48 pre_order(nxt, node)
49
50 post_order(0, -1)
51 pre_order(0, -1)
52 return answerCommon Mistakes
- Forgetting that a child's answer must be computed before its parent's
- Mixing up "answer when this is the root" with "answer for the subtree under this node"
- Forgetting to subtract the moving-from side before adding the moving-to side during rerooting
- Recursion depth on a long thin tree (use an explicit stack if needed)