1 条题解

  • 0
    @ 2025-8-24 22:57:21

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar NaCly_Fish
    北海虽赊,扶摇可接。

    搬运于2025-08-24 22:57:21,当前版本为作者最后更新于2025-04-25 15:37:49,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    先来写个时间复杂度 Θ(nlog2n)\Theta(n \log^2 n) 的做法,进一步的优化待补充。

    首先我们有一个朴素的 DP 就是

    $$\begin{cases} f_{i,j}=g_{i-1,j-1}(m-j+1)+f_{i-1,j}(m-j) \\ g_{i,j}=(f_{i-1,j}+g_{i-1,j})j \end{cases} $$

    有初始值 f1,1=mf_{1,1}=m,最后的答案为 ign,i\sum_i g_{n,i}


    要考虑优化的话,比较直接的想法就是按 行/列 建立生成函数。如果要按行做的话,哈哈,那你就掉沟里了。

    这题比较好的做法是按列来做,设 Fj(x)F_j(x){fi,j}i0\{ f_{i,j} \}_{i \geq 0} 的生成函数,Gj(x)G_j(x) 同理,就能得到

    Fj(x)=(mj+1)xGj1(x)+(mj)xFj(x)F_j(x)=(m-j+1)x G_{j-1}(x)+(m-j)x F_j(x) Gj(x)=jx(Fj(x)+Gj(x))G_j(x)=jx (F_j(x)+G_j(x))

    最终答案只和 gg 有关,所以也只用关注 Gj(x)G_j(x) 的递推:

    $$G_j(x)=\frac{j(m-j+1)x^2}{(1-jx)(1-(m-j)x)}G_{j-1}(x) $$

    其中 G0(x)=1G_0(x)=1,而答案就是

    $$[x^n]\sum_{i \geq 1}G_i(x)=[x^n]\sum_{i = 1}^{\lfloor n/2 \rfloor} \prod_{j=1}^i \frac{j(m-j+1)}{(1-jx)(1-(m-j)x)}x^2 $$

    这个东西显然可以分治来计算。设 Pj=Gj(x)/Gj1(x)P_j=G_j(x)/G_{j-1}(x),简单来说我们维护

    S(l,r)=i=lrj=liPjS(l,r)=\sum_{i=l}^r \prod_{j=l}^i P_j

    然后就能根据下式来分治计算:

    $$S(l,r)=S(l,\text{mid})+S(\text{mid}+1,r)\prod_{i=l}^{\text{mid}}P_i $$

    注意将幂级数表示为分式的形式,这样分子和分母的度数都是 Θ(rl)\Theta(r-l) 的,时间复杂度也就是

    T(n)=2T(n/2)+Θ(nlogn)=Θ(nlog2n)T(n)=2T(n/2)+\Theta(n \log n)=\Theta(n \log^2 n)

    给个答案对 998244353\color{red}998244353 取模的代码,不想写任意模了,仅供参考。

    #include<cstdio>
    #include<iostream>
    #include<cstring>
    #include<vector>
    #include<algorithm>
    #define N 524292
    #define p 998244353
    #define ll long long
    using namespace std;
    
    inline int power(int a,int t){
        int res = 1;
        while(t){
            if(t&1) res = (ll)res*a%p;
            a = (ll)a*a%p;
            t >>= 1;
        }
        return res;
    }
    
    int siz;
    int rev[N],rt[N];
    
    void init(int n){
        int lim = 1;
        while(lim<=n) lim <<= 1,++siz;
        for(int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
        int w = power(3,(p-1)>>siz);
        rt[lim>>1] = 1;
        for(int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
        for(int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
    }
    
    inline void dft(int *f,int n){
        static unsigned long long a[N];
        int x,shift = siz-__builtin_ctz(n);
        for(int i=0;i!=n;++i) a[rev[i]>>shift] = f[i];
        for(int mid=1;mid!=n;mid<<=1)
        for(int j=0;j!=n;j+=(mid<<1))
        for(int k=0;k!=mid;++k){
            x = a[j|k|mid]*rt[mid|k]%p;
            a[j|k|mid] = a[j|k]+p-x;
            a[j|k] += x;
        }
        for(int i=0;i!=n;++i) f[i] = a[i]%p;
    }
    
    inline void idft(int *f,int n){
        reverse(f+1,f+n);
        dft(f,n);
        int x = p-(p-1)/n;
        for(int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
    }
    
    inline int getlen(int n){
        return 1<<(32-__builtin_clz(n));
    }
    
    inline void _inv(const int *f,int n,int *r){
        static int g[N],h[N],st[30];
        memset(g,0,getlen(n<<1)<<2);
        int lim = 1,top = 0;
        while(n){
            st[++top] = n;
            n >>= 1;
        }
        g[0] = power(f[0],p-2);
        while(top--){
            n = st[top+1];
            while(lim<=(n<<1)) lim <<= 1;
            memcpy(h,f,(n+1)<<2);
            memset(h+n+1,0,(lim-n)<<2);
            dft(g,lim),dft(h,lim);
            for(int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p;
            idft(g,lim);
            memset(g+n+1,0,(lim-n)<<2);
        }
        memcpy(r,g,(n+1)<<2);
    }
    
    struct poly{
        vector<int> a;
        inline int operator [] (const int& x) const{ return x<a.size()?a[x]:0; }
        inline int& operator [] (const int& x){ return a[x]; }
        inline int deg() const{ return a.size()-1; }
    	inline void resize(int n){ a.resize(n+1); }
    
    	inline poly inverse(){
            static int f[N];
            int n = a.size()-1;
            for(int i=0;i<=n;++i) f[i] = a[i];
            _inv(f,n,f);
            poly res;
            res.resize(n);
            memcpy(res.a.begin().base(),f,(n+1)<<2);
            return res;
        }
    };
    inline bool operator < (const poly& f,const poly& g){ return f.deg() > g.deg(); }
    
    inline poly operator * (const poly& f,const poly& g){
        static int A[N],B[N];
        int n = f.deg(),m = g.deg();
    	poly res;
    	res.resize(n+m);
    	if(n<=4||m<=4){
    		for(int i=0;i<=n;++i)
    		for(int j=0;j<=m;++j)
    			res[i+j] = (res[i+j] + (ll)f[i]*g[j])%p;
    	}else{
    		memcpy(A,f.a.begin().base(),(n+1)<<2),memcpy(B,g.a.begin().base(),(m+1)<<2);
    		int lim = 1<<(32-__builtin_clz(n+m));
    		memset(A+n+1,0,(lim-n)<<2),memset(B+m+1,0,(lim-m)<<2);
    		dft(A,lim),dft(B,lim);
    		for(int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p;
    		idft(A,lim);
    		memcpy(res.a.begin().base(),A,(n+m+1)<<2);
    	}
        return res;
    }
    
    inline poly operator + (const poly& f,const poly& g){
    	int n = max(f.deg(),g.deg());
    	poly res;
    	res.resize(n);
    	for(int i=0;i<=n;++i) res[i] = (f[i]+g[i])%p;
    	return res;
    }
    
    int pd[N];
    int n,m;
    
    void prod(int l,int r,int u){
    	if(l==r){
    		pd[u] = (ll)l*(m-l+1+p)%p;
    		return;
    	}
    	int mid = (l+r)/2;
    	prod(l,mid,u<<1);
    	prod(mid+1,r,u<<1|1);
    	pd[u] = (ll)pd[u<<1]*pd[u<<1|1]%p;
    }
    
    pair<poly,poly> solve(int l,int r,int u){
    	if(l==r){
    		poly P,Q;
    		P.resize(2), Q.resize(2);
    		P[0] = P[1] = 0, P[2] = (ll)l*(m-l+1+p)%p;
    		Q[0] = 1, Q[1] = p-m, Q[2] = (ll)l*(m-l+p)%p;
    		return make_pair(P,Q);
    	}
    	int mid = (l+r)/2;
    	pair<poly,poly> L = solve(l,mid,u<<1);
    	pair<poly,poly> R = solve(mid+1,r,u<<1|1);
    	L.first = L.first * R.second;
    	L.second = L.second * R.second;
    	for(int i=0;i<=R.first.deg();++i) R.first[i] = (ll)R.first[i]*pd[u<<1]%p;
    	int k = (mid-l+1)*2;
    	R.first.resize(R.first.deg() + k);
    	for(int i=R.first.deg();i>=k;--i) R.first[i] = R.first[i-k];
    	for(int i=k-1;i>=0;--i) R.first[i] = 0;
    	return make_pair(L.first + R.first, L.second);
    }
    
    int main(){
    	scanf("%d%d",&n,&m);
        m %= p;
    	init(n*2);
    	prod(1,n/2,1);
    	pair<poly,poly> res = solve(1,n/2,1);
    	poly f = res.first, g = res.second;
    	f = f * g.inverse();
    	printf("%d",f[n]);
    	return 0;
    }
    
    • 1

    [AHOI2024 初中组 / 科大国创杯初中组 2024] 计数

    信息

    ID
    10057
    时间
    1000ms
    内存
    512MiB
    难度
    6
    标签
    递交数
    0
    已通过
    0
    上传者