为了更好的阅读体检,可以查看我的算法学习网
在线评测链接:P1445
题目内容
小美拿到了一棵树,每个节点有一个权值。初始每个节点都是白色。
小美有若干次操作,每次操作可以选择两个相邻的节点,如果它们都是白色且权值的乘积是完全平方数,小美就可以把这两个节点同时染红。
小美想知道,自己最多可以染红多少个节点?
输入描述
第一行输入一个正整数 n n n,代表节点的数量。
第二行输入 n n n个正整数 a i a_i ai,代表每个节点的权值。
接下来的 n − 1 n-1 n−1行,每行输入两个正整教 u , v u,v u,v,代表节点 u u u和节点 v v v有一条边连接
1 ≤ n ≤ 1 0 5 1 \leq n \leq 10^5 1≤n≤105
1 ≤ a i ≤ 1 0 9 1 \leq a_i \leq 10^9 1≤ai≤109
1 ≤ u , v ≤ n 1 \leq u,v \leq n 1≤u,v≤n
输出描述
输出一个整数表示最多可以染红的节点数量。
样例
输入输出示例仅供调试,后台判题数据一般不包含示例
输入
3
3 3 12
1 2
2 3
输出
2
说明
可以染红第二个和第三个节点。
请注意,此时不能再染红第一个和第二个节点,因为第二个节点已经被染红。
因此,最多染红 2 2 2个节点
思路:树形DP
树形DP 问题。
考虑DP的状态定义:
- d p [ i ] [ 0 ] dp[i][0] dp[i][0] 表示以 i 为子树,不选择 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量
- d p [ i ] [ 1 ] dp[i][1] dp[i][1] 表示以 i 为子树,对 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量
状态转移方程为:
-
$dp[i][0] = \sum\limits_{j\in son(i)} \max(dp[j][0], dp[j][1]) $
即对于 i i i 的所有儿子节点 j j j ,取 d p [ j ] [ 0 ] dp[j][0] dp[j][0] 和 d p [ j ] [ 1 ] dp[j][1] dp[j][1] 的较大值。
-
d p [ i ] [ 1 ] = max j ∈ s o n ( i ) ( d p [ i ] [ 0 ] − max ( d p [ j ] [ 0 ] , d p [ j ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 ) dp[i][1] = \max\limits_{j\in son(i)}(dp[i][0]-\max(dp[j][0], dp[j][1])+dp[j][0]+2) dp[i][1]=j∈son(i)max(dp[i][0]−max(dp[j][0],dp[j][1])+dp[j][0]+2)
这里需要满足 a [ i ] × a [ j ] a[i] \times a[j] a[i]×a[j] 是一个完全平方数
首先,由于我们只可以对一个节点染色一次,所以我们选择一个 a [ i ] × a [ j ] a[i] \times a[j] a[i]×a[j] 为完全平方数的 j j j ,将这个 i i i 和 j j j 同时染为红色。与 d p [ i ] [ 0 ] dp[i][0] dp[i][0] 不同的是,其他的儿子都是取 m a x ( d p [ x x x ] [ 0 ] , d p [ x x x ] [ 1 ] ) max(dp[xxx][0], dp[xxx][1]) max(dp[xxx][0],dp[xxx][1]) ,而 j j j 是取 d p [ j ] [ 0 ] + 2 dp[j][0]+2 dp[j][0]+2 。
转移到 d p [ i ] [ 1 ] dp[i][1] dp[i][1] 就是 d p [ i ] [ 1 ] = ∑ x x x ∈ s o n ( i ) , x x x ≠ j max ( d p [ x x x ] [ 0 ] , d p [ x x x ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 dp[i][1] = \sum\limits_{xxx \in son(i),xxx\neq j}\max(dp[xxx][0],dp[xxx][1])+dp[j][0]+2 dp[i][1]=xxx∈son(i),xxx=j∑max(dp[xxx][0],dp[xxx][1])+dp[j][0]+2
即 d p [ i ] [ 1 ] = d p [ i ] [ 0 ] − max ( d p [ j ] [ 0 ] , d p [ j ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 dp[i][1] = dp[i][0]-\max{(dp[j][0],dp[j][1])}+dp[j][0]+2 dp[i][1]=dp[i][0]−max(dp[j][0],dp[j][1])+dp[j][0]+2
时间复杂度: O ( n ) O(n) O(n)
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in);
int n = scanner.nextInt();
List<Integer> a = new ArrayList<>();
for (int i = 0; i < n; ++i) {
a.add(scanner.nextInt());
}
List<List<Integer>> g = new ArrayList<>(n);
for (int i = 0; i < n; ++i) {
g.add(new ArrayList<>());
}
for (int i = 1; i < n; ++i) {
int u = scanner.nextInt() - 1;
int v = scanner.nextInt() - 1;
g.get(u).add(v);
g.get(v).add(u);
}
int[][] dp = new int[n][2];
dfs(0, -1, g, dp, a);
System.out.println(Math.max(dp[0][0], dp[0][1]));
}
static void dfs(int u, int fa, List<List<Integer>> g, int[][] dp, List<Integer> a) {
for (int v : g.get(u)) {
if (v == fa) continue;
dfs(v, u, g, dp, a);
dp[u][0] += Math.max(dp[v][0], dp[v][1]);
}
for (int v : g.get(u)) {
if (v == fa) continue;
long val = (long) a.get(v) * a.get(u);
long sq = (long) Math.sqrt(val + 0.5);
if (sq * sq != val) continue;
dp[u][1] = Math.max(dp[u][1], (dp[u][0] - Math.max(dp[v][0], dp[v][1])) + dp[v][0] + 2);
}
}
}