1 条题解

  • 0
    @ 2025-8-24 22:28:57

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar NaCly_Fish
    北海虽赊,扶摇可接。

    搬运于2025-08-24 22:28:57,当前版本为作者最后更新于2021-02-13 20:28:52,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    教程:如何正确地使用 OEIS

    首先当然是考虑每个数被算进答案多少次,设 Tn,k (k[0,n])T_{n,k} \ (k\in[0,n]) 是长为 n+1n+1 的序列中,第 k+1k+1 个数被算的次数。

    随便打个暴力,在 OEIS 上可以找到这个数列,然后眼尖的你一眼就看到了递推式:

    $$T_{n,k}=2(T_{n-1,k-1}+T_{n-1,k})-(2-[n=2])T_{n-2,k-1}\ (n\ge 2,k\ge1) $$

    上面没给每行的生成函数,只好我们自己算了,,

    这个 [n=2][n=2] 的条件有点麻烦,如果假定 T0,0=12T_{0,0} = \dfrac12 就简单很多。设 Tn(x)T_n(x){Tn,k}k=0n\left\{ T_{n,k}\right\}_{k=0}^n 的生成函数,就有:

    Tn(x)=2(1+x)Tn1(x)2xTn2(x)T_n(x)=2(1+x)T_{n-1}(x)-2xT_{n-2}(x)

    按二阶递推数列解出来就是

    $$T_n(x)=\frac{(1+x+\sqrt{x^2+1})^{n+1}-(1+x-\sqrt{x^2+1})^{n+1}}{4\sqrt{1+x^2}} $$

    根据数列 aa 的构造方法,我们知道答案是

    i=0n1bimodkTn1,i\sum_{i=0}^{n-1}b_{i \bmod k}T_{n-1,i}

    而这实际上就是 Tn1(x)mod(xk1)T_{n-1}(x) \bmod (x^k-1) 与序列 bb 的点积。要计算的话,就是维护一个 F(x)=f(x)+1+x2g(x)F(x)=f(x)+\sqrt{1+x^2}g(x) 形式的多项式;而最后一步除 1+x2\sqrt{1+x^2} 实际上就是提取其中 g(x)g(x) 的系数。

    顺便一提,容易证明 (1+x+x2+1)n(1+x+\sqrt{x^2+1})^n(1+xx2+1)n(1+x-\sqrt{x^2+1})^n1+x2\sqrt{1+x^2} 部分是互为相反数的,故可以只算一遍快速幂,常数减半。

    代码奉上:

    #include<cstdio>
    #include<iostream>
    #include<algorithm>
    #include<cstring>
    #include<cmath>
    #define N 262147
    #define ll long long
    #define reg register
    #define p 998244353
    using namespace std;
    
    inline int read(){
    	int x = 0;
    	char c = getchar();
    	while(c<'0'||c>'9') c = getchar();
    	while(c>='0'&&c<='9'){
    		x = (x<<3)+(x<<1)+(c^48);
    		c = getchar();
    	}
    	return x;
    }
    
    inline int add(const int& x,const int& y){ return x+y>=p?x+y-p:x+y; }
    inline int dec(const int& x,const int& y){ return x<y?x-y+p:x-y; }
    inline int rec(const int& x){ return x>=p?x-p:x; }
    
    inline int power(int a,int t){
        int res = 1;
        while(t){
            if(t&1) res = (ll)res*a%p;
            a = (ll)a*a%p;
            t >>= 1;
        }
        return res;
    }
    
    int siz;
    int rev[N],rt[N],inv[N],fac[N],ifac[N];
    
    void init(int n){
        int lim = 1;
        while(lim<=n) lim <<= 1,++siz;
        for(reg int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1));
        int w = power(3,(p-1)>>siz);
        fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1;
        for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p;
        for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1];
        for(reg int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p;
        ifac[n] = power(fac[n],p-2);
        for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p;
        for(reg int i=2;i<=n;++i) inv[i] = (ll)fac[i-1]*ifac[i]%p;
    }
    
    inline void dft(int *f,int n){
        static unsigned long long a[N];
        reg int x,shift = siz-__builtin_ctz(n);
        for(reg int i=0;i!=n;++i) a[rev[i]>>shift] = f[i];
        for(reg int mid=1;mid!=n;mid<<=1)
        for(reg int j=0;j!=n;j+=(mid<<1))
        for(reg int k=0;k!=mid;++k){
            x = a[j|k|mid]*rt[mid|k]%p;
            a[j|k|mid] = a[j|k]+p-x;
            a[j|k] += x;
        }
        for(reg int i=0;i!=n;++i) f[i] = a[i]%p;
    }
    
    inline void idft(int *f,int n){
        reverse(f+1,f+n);
        dft(f,n);
        int x = p-((p-1)>>__builtin_ctz(n));
        for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p;
    }
    
    inline int getlen(int n){
        return 1<<(32-__builtin_clz(n));
    }
    
    struct poly{
    	int a[N],b[N];
    	int t;
    	
    	inline poly(int _t=0):t(_t){
    		memset(a,0,sizeof(a));
    		memset(b,0,sizeof(b));
    	}
    };
    
    inline void getv(int lim,int *v){
    	int rt = power(3,(p-1)/lim),x = 1;
    	for(reg int i=0;i!=lim;++i){
    		v[i] = ((ll)x*x+1)%p;
    		x = (ll)x*rt%p;
    	}
    }
    
    inline poly multiply(poly f,poly g){
    	static int A[N],B[N],C[N],D[N],R[N],S[N],v[N];
    	int n = f.t,lim = getlen(f.t<<1);
    	getv(lim,v); // 这里是线性计算 (1+x^2) 的 dft,以减少此后的一次 idft
    	poly h = poly(n);
    	memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2);
    	memcpy(C,g.a,n<<2),memcpy(D,g.b,n<<2);
    	memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2);
    	memset(C+n,0,(lim-n+2)<<2),memset(D+n,0,(lim-n+2)<<2);
    	dft(A,lim),dft(B,lim),dft(C,lim),dft(D,lim);
    	for(reg int i=0;i!=lim;++i){
    		R[i] = ((ll)A[i]*C[i]+(ll)B[i]*D[i]%p*v[i])%p;
    		S[i] = ((ll)A[i]*D[i]+(ll)B[i]*C[i])%p;
    	}
    	idft(R,lim),idft(S,lim);
    	for(reg int i=0;i!=n;++i){
    		h.a[i] = add(R[i],R[i+n]);
    		h.b[i] = add(S[i],S[i+n]);
    	}
    	h.a[0] = add(h.a[0],R[n<<1]);
    	h.a[1] = add(h.a[1],R[n<<1|1]);
    	return h;
    }
    
    inline poly square(poly f){ // 算平方的时候可以减少 dft 次数
    	static int A[N],B[N],R[N],S[N],v[N];
    	int n = f.t,lim = getlen(f.t<<1);
    	getv(lim,v);
    	poly h = poly(n);
    	memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2);
    	memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2);
    	dft(A,lim),dft(B,lim);
    	for(reg int i=0;i!=lim;++i){
    		R[i] = ((ll)A[i]*A[i]+(ll)B[i]*B[i]%p*v[i])%p;
    		S[i] = ((ll)A[i]*B[i]<<1)%p;
    	}
    	idft(R,lim),idft(S,lim);
    	for(reg int i=0;i!=n;++i){
    		h.a[i] = add(R[i],R[i+n]);
    		h.b[i] = add(S[i],S[i+n]);
    	}
    	h.a[0] = add(h.a[0],R[n<<1]);
    	h.a[1] = add(h.a[1],R[n<<1|1]);
    	return h;
    }
    
    inline poly power(poly f,int n,int t){
    	poly g = poly(n);
    	g.a[0] = 1;
    	while(t){
    		if(t&1) g = multiply(f,g);
    		t >>= 1;
    		if(t==0) break;
    		f = square(f);
    	}
    	return g;
    }
    
    int n,k,ans;
    poly f;
    int a[N];
    
    int main(){
        n = read(),k = read();
        init(k<<1);
        f.t = k;
        f.a[0] = f.a[1] = f.b[0] = 1;
        f = power(f,k,n);
        for(reg int i=0;i!=k;++i) ans = (ans+(ll)f.b[i]*read())%p;
        printf("%lld",(ll)ans*power(2,p-2)%p);
        return 0;   
    }
    
    • 1

    信息

    ID
    6012
    时间
    2000ms
    内存
    500MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者