Bootstrap

【备战秋招】每日一题:2023.08.12-美团机试-第五题-树上染色

为了更好的阅读体检,可以查看我的算法学习网
在线评测链接:P1445

题目内容

小美拿到了一棵树,每个节点有一个权值。初始每个节点都是白色。

小美有若干次操作,每次操作可以选择两个相邻的节点,如果它们都是白色且权值的乘积是完全平方数,小美就可以把这两个节点同时染红。

小美想知道,自己最多可以染红多少个节点?

输入描述

第一行输入一个正整数 n n n,代表节点的数量。

第二行输入 n n n个正整数 a i a_i ai,代表每个节点的权值。

接下来的 n − 1 n-1 n1行,每行输入两个正整教 u , v u,v u,v,代表节点 u u u和节点 v v v有一条边连接

1 ≤ n ≤ 1 0 5 1 \leq n \leq 10^5 1n105

1 ≤ a i ≤ 1 0 9 1 \leq a_i \leq 10^9 1ai109

1 ≤ u , v ≤ n 1 \leq u,v \leq n 1u,vn

输出描述

输出一个整数表示最多可以染红的节点数量。

样例

输入输出示例仅供调试,后台判题数据一般不包含示例

输入

3
3 3 12
1 2
2 3

输出

2

说明

可以染红第二个和第三个节点。

请注意,此时不能再染红第一个和第二个节点,因为第二个节点已经被染红。

因此,最多染红 2 2 2个节点

思路:树形DP

树形DP 问题。
考虑DP的状态定义:

  • d p [ i ] [ 0 ] dp[i][0] dp[i][0] 表示以 i 为子树,不选择 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量
  • d p [ i ] [ 1 ] dp[i][1] dp[i][1] 表示以 i 为子树,对 i 这个节点进行染色,i 这棵子树可以染色的结点最大数量

状态转移方程为:

  • $dp[i][0] = \sum\limits_{j\in son(i)} \max(dp[j][0], dp[j][1]) $

    即对于 i i i 的所有儿子节点 j j j ,取 d p [ j ] [ 0 ] dp[j][0] dp[j][0] d p [ j ] [ 1 ] dp[j][1] dp[j][1] 的较大值。

  • d p [ i ] [ 1 ] = max ⁡ j ∈ s o n ( i ) ( d p [ i ] [ 0 ] − max ⁡ ( d p [ j ] [ 0 ] , d p [ j ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 ) dp[i][1] = \max\limits_{j\in son(i)}(dp[i][0]-\max(dp[j][0], dp[j][1])+dp[j][0]+2) dp[i][1]=json(i)max(dp[i][0]max(dp[j][0],dp[j][1])+dp[j][0]+2)

    这里需要满足 a [ i ] × a [ j ] a[i] \times a[j] a[i]×a[j] 是一个完全平方数
    首先,由于我们只可以对一个节点染色一次,所以我们选择一个 a [ i ] × a [ j ] a[i] \times a[j] a[i]×a[j] 为完全平方数的 j j j ,将这个 i i i j j j 同时染为红色。

    d p [ i ] [ 0 ] dp[i][0] dp[i][0] 不同的是,其他的儿子都是取 m a x ( d p [ x x x ] [ 0 ] , d p [ x x x ] [ 1 ] ) max(dp[xxx][0], dp[xxx][1]) max(dp[xxx][0],dp[xxx][1]) ,而 j j j 是取 d p [ j ] [ 0 ] + 2 dp[j][0]+2 dp[j][0]+2

    转移到 d p [ i ] [ 1 ] dp[i][1] dp[i][1] 就是 d p [ i ] [ 1 ] = ∑ x x x ∈ s o n ( i ) , x x x ≠ j max ⁡ ( d p [ x x x ] [ 0 ] , d p [ x x x ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 dp[i][1] = \sum\limits_{xxx \in son(i),xxx\neq j}\max(dp[xxx][0],dp[xxx][1])+dp[j][0]+2 dp[i][1]=xxxson(i),xxx=jmax(dp[xxx][0],dp[xxx][1])+dp[j][0]+2

    d p [ i ] [ 1 ] = d p [ i ] [ 0 ] − max ⁡ ( d p [ j ] [ 0 ] , d p [ j ] [ 1 ] ) + d p [ j ] [ 0 ] + 2 dp[i][1] = dp[i][0]-\max{(dp[j][0],dp[j][1])}+dp[j][0]+2 dp[i][1]=dp[i][0]max(dp[j][0],dp[j][1])+dp[j][0]+2

时间复杂度: O ( n ) O(n) O(n)

import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int n = scanner.nextInt();

        List<Integer> a = new ArrayList<>();
        for (int i = 0; i < n; ++i) {
            a.add(scanner.nextInt());
        }

        List<List<Integer>> g = new ArrayList<>(n);
        for (int i = 0; i < n; ++i) {
            g.add(new ArrayList<>());
        }

        for (int i = 1; i < n; ++i) {
            int u = scanner.nextInt() - 1;
            int v = scanner.nextInt() - 1;
            g.get(u).add(v);
            g.get(v).add(u);
        }

        int[][] dp = new int[n][2];
        dfs(0, -1, g, dp, a);
        System.out.println(Math.max(dp[0][0], dp[0][1]));
    }

    static void dfs(int u, int fa, List<List<Integer>> g, int[][] dp, List<Integer> a) {
        for (int v : g.get(u)) {
            if (v == fa) continue;
            dfs(v, u, g, dp, a);
            dp[u][0] += Math.max(dp[v][0], dp[v][1]);
        }

        for (int v : g.get(u)) {
            if (v == fa) continue;
            long val = (long) a.get(v) * a.get(u);
            long sq = (long) Math.sqrt(val + 0.5);
            if (sq * sq != val) continue;
            dp[u][1] = Math.max(dp[u][1], (dp[u][0] - Math.max(dp[v][0], dp[v][1])) + dp[v][0] + 2);
        }
    }
}
;