1 条题解

  • 0
    @ 2025-8-24 22:06:51

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar NaCly_Fish
    北海虽赊,扶摇可接。

    搬运于2025-08-24 22:06:51,当前版本为作者最后更新于2020-03-25 22:09:47,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    怎么这题比较正常的题解都在下面,,
    其实就是一道基础生成函数题。


    首先,设 fnf_n 为一个人分 nn 块糖时的快乐值,这个很容易求出;再设 F(x)F(x) 为它的生成函数。

    考虑其组合意义,不难发现 (F(x)1)k(F(x)-1)^k 就是有 kk 个人都有糖时,「所有情况的快乐值乘积之和」的生成函数。

    那么枚举有糖的人数,答案就直接出来了:

    k=0A(F(x)1)k\sum_{k=0}^A (F(x)-1)^k =1(F(x)1)A+12F(x)= \frac{1-(F(x)-1)^{A+1}}{2-F(x)}

    由此也可以得到一个常数优化:当 AmA \geq m 时,那个多项式的幂在模 xm+1x^{m+1} 下为 00,可以直接忽略掉。这时的时间复杂度是 Θ(mlogm)\Theta(m \log m)

    否则就只能倍增算那个多项式的幂(因为模数为合数,如果有更优的做法请告诉我)。

    另外就是多项式的长度,由于 F(x)1F(x)-1 的常数项为 00,快速幂后前 AA 次都是 00,非 00 项只有 mAm-A 个,于是快速幂可以在模 xmAx^{m-A} 下计算——时间复杂度可以做到 Θ((mA)logAlog(mA)) (A<m)\Theta((m-A) \log A \log (m-A)) \space(A < m)

    加上这些优化,成功跑到最优解第一:

    #pragma GCC optimize ("unroll-loops")
    #pragma GCC optimize ("Ofast")
    #include<cstdio>
    #include<iostream>
    #include<cstring>
    #include<cmath>
    #include<algorithm>
    #define N 32774
    #define ll long long
    #define reg register
    #define add(x,y) (x+y>=p?x+y-p:x+y)
    #define dec(x,y) (x<y?x-y+p:x-y)
    using namespace std;
    
    int p;
    
    int rev[N],rt[N];
    int siz;
    
    #define md 998244353
    
    inline int power(int a,int t){
        int res = 1;
        while(t){
            if(t&1) res = (ll)res*a%md;
            a = (ll)a*a%md;
            t >>= 1; 
        }
        return res;
    }
    
    void init(int n){
        int w,lim = 1;
        while(lim<=n) lim <<= 1,++siz;
        for(reg int i=1;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
        rt[lim>>1] = 1;
        w = power(3,(md-1)>>siz);
        for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%md;
        for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    }
    
    inline void dft(int *f,int lim){
        static unsigned long long a[N];
        reg int x,shift = siz-__builtin_ctz(lim);
        for(reg int i=0;i!=lim;++i) a[rev[i]>>shift] = f[i];
        for(reg int mid=1;mid!=lim;mid<<=1)
        for(reg int j=0;j!=lim;j+=(mid<<1))
        for(reg int k=0;k!=mid;++k){
            x = a[j|k|mid]*rt[mid|k]%md;
            a[j|k|mid] = a[j|k]+md-x;
            a[j|k] += x;
        }
        for(reg int i=0;i!=lim;++i) f[i] = a[i]%md;
    }
    
    inline void idft(int *f,int lim){
        reverse(f+1,f+lim);
        dft(f,lim);
        int x = md-((md-1)>>__builtin_ctz(lim));
        for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*x%md;
    }
    
    inline int getlen(int n){
        return 1<<(32-__builtin_clz(n));
    }
    
    void multiply(const int *A,const int *B,int n,int m,int *R,int len){
        static int f[N],g[N];
        memcpy(f,A,(n+1)<<2),memcpy(g,B,(m+1)<<2);
        int lim = getlen(n+m);
        memset(f+n+1,0,(lim-n)<<2);
        memset(g+m+1,0,(lim-m)<<2);
        dft(f,lim),dft(g,lim);
        for(reg int i=0;i!=lim;++i) f[i] = (ll)f[i]*g[i]%md;
        idft(f,lim);
        for(reg int i=0;i<=len;++i) R[i] = f[i]%p;
    }
    
    inline void inverse(const int *f,int n,int *R){
        static int g[N],h[N];
        memset(g,0,getlen(n<<1)<<2);
        int lim = 1,top = 0;
        int s[30];
        while(n){
            s[++top] = n;
            n >>= 1;
        }
        g[0] = 1;
        while(top--){
            n = s[top+1];
            while(lim<=(n<<1)) lim <<= 1;
            memcpy(h,f,(n+1)<<2);
            memset(h+n+1,0,(lim-n)<<2);
            multiply(h,g,n,n,h,n);
            multiply(h,g,n,n,h,n);
            for(reg int i=0;i<=n;++i) g[i] = dec(add(g[i],g[i]),h[i]);
        }
        memcpy(R,g,(n+1)<<2);
    }
    
    inline void power(int *f,int n,int k,int *R){
        int g[N];
        g[0] = 1;
        while(1){
            if(k&1) multiply(g,f,n,n,g,n);
            k >>= 1;
            if(k==0) break;
            multiply(f,f,n,n,f,n);
        }
        memcpy(R,g,(n+1)<<2);
    }
    
    int m,A,o,s,u;
    int f[N],g[N],h[N];
    
    int main(){
        scanf("%d%d%d%d%d%d",&m,&p,&A,&o,&s,&u);
        init(m<<1);
        for(reg int i=1;i<=m;++i) f[i] = (u+i*(s+o*i))%p;
        if(A<m) memcpy(g,f+1,(m-A)<<2);
        f[0] = 1;
        for(reg int i=1;i<=m;++i) f[i] = f[i]==0?0:p-f[i];
        inverse(f,m,f);
        if(A>=m){
            printf("%d",f[m]);
            return 0;
        }
        power(g,m-A-1,A+1,h);
        for(reg int i=m;i>A;--i) h[i] = h[i-A-1];
        for(reg int i=A+1;i<=m;++i) h[i] = h[i]==0?0:p-h[i];
        int ans = f[m];
        for(reg int i=A+1;i<=m;++i)
            if(h[i]!=0) ans = (ans+h[i]*f[m-i])%p;
        printf("%d",ans);    
        return 0;   
    }
    
    • 1

    信息

    ID
    4106
    时间
    1000ms
    内存
    250MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者