1 条题解
-
0
自动搬运
来自洛谷,原作者为

NaCly_Fish
北海虽赊,扶摇可接。搬运于
2025-08-24 22:28:57,当前版本为作者最后更新于2021-02-13 20:28:52,作者可能在搬运后再次修改,您可在原文处查看最新版自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多
以下是正文
教程:如何正确地使用 OEIS首先当然是考虑每个数被算进答案多少次,设 是长为 的序列中,第 个数被算的次数。
随便打个暴力,在 OEIS 上可以找到这个数列,然后眼尖的你一眼就看到了递推式:
$$T_{n,k}=2(T_{n-1,k-1}+T_{n-1,k})-(2-[n=2])T_{n-2,k-1}\ (n\ge 2,k\ge1) $$上面没给每行的生成函数,只好我们自己算了,,
这个 的条件有点麻烦,如果假定 就简单很多。设 为 的生成函数,就有:
按二阶递推数列解出来就是
$$T_n(x)=\frac{(1+x+\sqrt{x^2+1})^{n+1}-(1+x-\sqrt{x^2+1})^{n+1}}{4\sqrt{1+x^2}} $$根据数列 的构造方法,我们知道答案是
而这实际上就是 与序列 的点积。要计算的话,就是维护一个 形式的多项式;而最后一步除 实际上就是提取其中 的系数。
顺便一提,容易证明 和 的 部分是互为相反数的,故可以只算一遍快速幂,常数减半。
代码奉上:
#include<cstdio> #include<iostream> #include<algorithm> #include<cstring> #include<cmath> #define N 262147 #define ll long long #define reg register #define p 998244353 using namespace std; inline int read(){ int x = 0; char c = getchar(); while(c<'0'||c>'9') c = getchar(); while(c>='0'&&c<='9'){ x = (x<<3)+(x<<1)+(c^48); c = getchar(); } return x; } inline int add(const int& x,const int& y){ return x+y>=p?x+y-p:x+y; } inline int dec(const int& x,const int& y){ return x<y?x-y+p:x-y; } inline int rec(const int& x){ return x>=p?x-p:x; } inline int power(int a,int t){ int res = 1; while(t){ if(t&1) res = (ll)res*a%p; a = (ll)a*a%p; t >>= 1; } return res; } int siz; int rev[N],rt[N],inv[N],fac[N],ifac[N]; void init(int n){ int lim = 1; while(lim<=n) lim <<= 1,++siz; for(reg int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1)); int w = power(3,(p-1)>>siz); fac[0] = fac[1] = ifac[0] = ifac[1] = inv[1] = rt[lim>>1] = 1; for(reg int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p; for(reg int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1]; for(reg int i=2;i<=n;++i) fac[i] = (ll)fac[i-1]*i%p; ifac[n] = power(fac[n],p-2); for(reg int i=n-1;i;--i) ifac[i] = (ll)ifac[i+1]*(i+1)%p; for(reg int i=2;i<=n;++i) inv[i] = (ll)fac[i-1]*ifac[i]%p; } inline void dft(int *f,int n){ static unsigned long long a[N]; reg int x,shift = siz-__builtin_ctz(n); for(reg int i=0;i!=n;++i) a[rev[i]>>shift] = f[i]; for(reg int mid=1;mid!=n;mid<<=1) for(reg int j=0;j!=n;j+=(mid<<1)) for(reg int k=0;k!=mid;++k){ x = a[j|k|mid]*rt[mid|k]%p; a[j|k|mid] = a[j|k]+p-x; a[j|k] += x; } for(reg int i=0;i!=n;++i) f[i] = a[i]%p; } inline void idft(int *f,int n){ reverse(f+1,f+n); dft(f,n); int x = p-((p-1)>>__builtin_ctz(n)); for(reg int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p; } inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); } struct poly{ int a[N],b[N]; int t; inline poly(int _t=0):t(_t){ memset(a,0,sizeof(a)); memset(b,0,sizeof(b)); } }; inline void getv(int lim,int *v){ int rt = power(3,(p-1)/lim),x = 1; for(reg int i=0;i!=lim;++i){ v[i] = ((ll)x*x+1)%p; x = (ll)x*rt%p; } } inline poly multiply(poly f,poly g){ static int A[N],B[N],C[N],D[N],R[N],S[N],v[N]; int n = f.t,lim = getlen(f.t<<1); getv(lim,v); // 这里是线性计算 (1+x^2) 的 dft,以减少此后的一次 idft poly h = poly(n); memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2); memcpy(C,g.a,n<<2),memcpy(D,g.b,n<<2); memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2); memset(C+n,0,(lim-n+2)<<2),memset(D+n,0,(lim-n+2)<<2); dft(A,lim),dft(B,lim),dft(C,lim),dft(D,lim); for(reg int i=0;i!=lim;++i){ R[i] = ((ll)A[i]*C[i]+(ll)B[i]*D[i]%p*v[i])%p; S[i] = ((ll)A[i]*D[i]+(ll)B[i]*C[i])%p; } idft(R,lim),idft(S,lim); for(reg int i=0;i!=n;++i){ h.a[i] = add(R[i],R[i+n]); h.b[i] = add(S[i],S[i+n]); } h.a[0] = add(h.a[0],R[n<<1]); h.a[1] = add(h.a[1],R[n<<1|1]); return h; } inline poly square(poly f){ // 算平方的时候可以减少 dft 次数 static int A[N],B[N],R[N],S[N],v[N]; int n = f.t,lim = getlen(f.t<<1); getv(lim,v); poly h = poly(n); memcpy(A,f.a,n<<2),memcpy(B,f.b,n<<2); memset(A+n,0,(lim-n+2)<<2),memset(B+n,0,(lim-n+2)<<2); dft(A,lim),dft(B,lim); for(reg int i=0;i!=lim;++i){ R[i] = ((ll)A[i]*A[i]+(ll)B[i]*B[i]%p*v[i])%p; S[i] = ((ll)A[i]*B[i]<<1)%p; } idft(R,lim),idft(S,lim); for(reg int i=0;i!=n;++i){ h.a[i] = add(R[i],R[i+n]); h.b[i] = add(S[i],S[i+n]); } h.a[0] = add(h.a[0],R[n<<1]); h.a[1] = add(h.a[1],R[n<<1|1]); return h; } inline poly power(poly f,int n,int t){ poly g = poly(n); g.a[0] = 1; while(t){ if(t&1) g = multiply(f,g); t >>= 1; if(t==0) break; f = square(f); } return g; } int n,k,ans; poly f; int a[N]; int main(){ n = read(),k = read(); init(k<<1); f.t = k; f.a[0] = f.a[1] = f.b[0] = 1; f = power(f,k,n); for(reg int i=0;i!=k;++i) ans = (ans+(ll)f.b[i]*read())%p; printf("%lld",(ll)ans*power(2,p-2)%p); return 0; }
- 1
信息
- ID
- 6012
- 时间
- 2000ms
- 内存
- 500MiB
- 难度
- 7
- 标签
- 递交数
- 0
- 已通过
- 0
- 上传者