1 条题解

  • 0
    @ 2025-8-24 21:54:02

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar YoungLove
    **

    搬运于2025-08-24 21:54:02,当前版本为作者最后更新于2017-09-18 18:18:08,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    Youngsc

    又WA又T又MLE了一节课终于搞出来了这道题。

    这道题很显然是一道树形DP,而且用到了组合数学中的乘法原理。

    f[i][j]f[i][j]表示以ii为根的这棵子树在ii为颜色jj的时候的方案数,根据乘法原理可得f[i][j]=πf[k][c]f[i][j]=πf[k][c] 其中kkii的所有儿子,cc是所有与jj不同的颜色。

    因此我们可以很容易相处一个O(n3)O(n^3)的是算法:枚举所有点,枚举它的颜色,枚举每一个儿子,枚举儿子的颜色,因为一个点只有一个父亲,所以枚举儿子的那一层加上枚举每一个点一共是O(n)O(n)的。

    回过头来再看数据n<=5000n<=5000,m<=5000m<=5000,上述的算法似乎运行起来很吃力。那我们需要向一个办法去将复杂度变成O(n2)O(n^2),仔细思考后我们会发现,其实枚举儿子那一层是没有必要的,因为我只要颜色不相同的总数,与其O(m)O(m)地去求和,不如之前在枚举到它的时候就处理好所有的和,然后O(1)O(1)地去剪掉这个颜色,从而求出除去这个颜色外的方案数,达到O(n2)O(n^2)的复杂度。

    这道题还是很不错的,这种思想在很多时候都能用得上。

    代码在这里

    # include <algorithm>
    # include <iostream>
    # include <cstring>
    # include <cstdio>
    # include <vector>
    # include <cmath>
    # define R register
    # define mod 1000000007
    
    using namespace std;
    
    int e,n,m,f[5010][5010],h[5010],tot[5010],sum[5010],x,y,d;
    struct pp{int v,pre;}ed[10010];
    
    inline void in(R int &a){
        R char c = getchar();R int x=0, f=1;
        while(!isdigit(c)) {if(c == '-') f=-1; c = getchar();}
        while(isdigit(c)) x = (x<<1)+(x<<3)+c-'0',c = getchar();
        a=x*f;
    }
    
    inline void add(R int x,R int y){
        ed[++e] = (pp){y,h[x]};
        h[x] = e;
    }
    
    inline void dfs(R int fa,R int x){//树形DP一般用DFS来实现
        for(R int i=h[x]; i; i=ed[i].pre)
        {
            R int p = ed[i].v;
            if(p == fa) continue;
            dfs(x,p);
        }
        for(R int j=1; j<=m; ++j)
        {
            if(!f[x][j]) continue;//没有这种颜色
            for(R int i=h[x]; i; i=ed[i].pre)
            {
                R int p = ed[i].v;
                if(p == fa) continue;
                f[x][j] = 1LL*f[x][j]*(tot[p]-f[p][j])%mod;
            }
            while(f[x][j]<0) f[x][j] += mod;//上边(tot[p]-f[p][j])可能会变成负数,这里把它变回来。
            tot[x] = (1LL*tot[x]+1LL*f[x][j])%mod;
        }
    }
    
    inline int yg(){
        in(n),in(m);
        for(R int i=1; i<=n; ++i)
        {
            in(sum[i]);
            for(R int j=1; j<=sum[i]; ++j) in(d),f[i][d]++;
        }
        for(R int i=1; i<n; ++i)
        {
            in(x),in(y);
            add(x,y),add(y,x);
        }
        add(0,1);//为了好写,新建一个原点连向1,这样就不用额外求tot[1]了
        dfs(0,0);
        cout << tot[1];//tot[1]就是最终的答案
    }
    
    int youngsc = yg();
    int main(){;}
    
    • 1

    信息

    ID
    2871
    时间
    3000ms
    内存
    512MiB
    难度
    5
    标签
    递交数
    0
    已通过
    0
    上传者