1 条题解
-
0
自动搬运
来自洛谷,原作者为

NaCly_Fish
北海虽赊,扶摇可接。搬运于
2025-08-24 22:57:21,当前版本为作者最后更新于2025-04-25 15:37:49,作者可能在搬运后再次修改,您可在原文处查看最新版自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多
以下是正文
先来写个时间复杂度 的做法,进一步的优化待补充。
首先我们有一个朴素的 DP 就是
$$\begin{cases} f_{i,j}=g_{i-1,j-1}(m-j+1)+f_{i-1,j}(m-j) \\ g_{i,j}=(f_{i-1,j}+g_{i-1,j})j \end{cases} $$有初始值 ,最后的答案为 。
要考虑优化的话,比较直接的想法就是按 行/列 建立生成函数。如果要按行做的话,哈哈,那你就掉沟里了。
这题比较好的做法是按列来做,设 是 的生成函数, 同理,就能得到
最终答案只和 有关,所以也只用关注 的递推:
$$G_j(x)=\frac{j(m-j+1)x^2}{(1-jx)(1-(m-j)x)}G_{j-1}(x) $$其中 ,而答案就是
$$[x^n]\sum_{i \geq 1}G_i(x)=[x^n]\sum_{i = 1}^{\lfloor n/2 \rfloor} \prod_{j=1}^i \frac{j(m-j+1)}{(1-jx)(1-(m-j)x)}x^2 $$这个东西显然可以分治来计算。设 ,简单来说我们维护
然后就能根据下式来分治计算:
$$S(l,r)=S(l,\text{mid})+S(\text{mid}+1,r)\prod_{i=l}^{\text{mid}}P_i $$注意将幂级数表示为分式的形式,这样分子和分母的度数都是 的,时间复杂度也就是
给个答案对 取模的代码,不想写任意模了,仅供参考。
#include<cstdio> #include<iostream> #include<cstring> #include<vector> #include<algorithm> #define N 524292 #define p 998244353 #define ll long long using namespace std; inline int power(int a,int t){ int res = 1; while(t){ if(t&1) res = (ll)res*a%p; a = (ll)a*a%p; t >>= 1; } return res; } int siz; int rev[N],rt[N]; void init(int n){ int lim = 1; while(lim<=n) lim <<= 1,++siz; for(int i=0;i!=lim;++i) rev[i] = (rev[i>>1]>>1)|((i&1)<<(siz-1)); int w = power(3,(p-1)>>siz); rt[lim>>1] = 1; for(int i=(lim>>1)+1;i!=lim;++i) rt[i] = (ll)rt[i-1]*w%p; for(int i=(lim>>1)-1;i;--i) rt[i] = rt[i<<1]; } inline void dft(int *f,int n){ static unsigned long long a[N]; int x,shift = siz-__builtin_ctz(n); for(int i=0;i!=n;++i) a[rev[i]>>shift] = f[i]; for(int mid=1;mid!=n;mid<<=1) for(int j=0;j!=n;j+=(mid<<1)) for(int k=0;k!=mid;++k){ x = a[j|k|mid]*rt[mid|k]%p; a[j|k|mid] = a[j|k]+p-x; a[j|k] += x; } for(int i=0;i!=n;++i) f[i] = a[i]%p; } inline void idft(int *f,int n){ reverse(f+1,f+n); dft(f,n); int x = p-(p-1)/n; for(int i=0;i!=n;++i) f[i] = (ll)f[i]*x%p; } inline int getlen(int n){ return 1<<(32-__builtin_clz(n)); } inline void _inv(const int *f,int n,int *r){ static int g[N],h[N],st[30]; memset(g,0,getlen(n<<1)<<2); int lim = 1,top = 0; while(n){ st[++top] = n; n >>= 1; } g[0] = power(f[0],p-2); while(top--){ n = st[top+1]; while(lim<=(n<<1)) lim <<= 1; memcpy(h,f,(n+1)<<2); memset(h+n+1,0,(lim-n)<<2); dft(g,lim),dft(h,lim); for(int i=0;i!=lim;++i) g[i] = g[i]*(2-(ll)g[i]*h[i]%p+p)%p; idft(g,lim); memset(g+n+1,0,(lim-n)<<2); } memcpy(r,g,(n+1)<<2); } struct poly{ vector<int> a; inline int operator [] (const int& x) const{ return x<a.size()?a[x]:0; } inline int& operator [] (const int& x){ return a[x]; } inline int deg() const{ return a.size()-1; } inline void resize(int n){ a.resize(n+1); } inline poly inverse(){ static int f[N]; int n = a.size()-1; for(int i=0;i<=n;++i) f[i] = a[i]; _inv(f,n,f); poly res; res.resize(n); memcpy(res.a.begin().base(),f,(n+1)<<2); return res; } }; inline bool operator < (const poly& f,const poly& g){ return f.deg() > g.deg(); } inline poly operator * (const poly& f,const poly& g){ static int A[N],B[N]; int n = f.deg(),m = g.deg(); poly res; res.resize(n+m); if(n<=4||m<=4){ for(int i=0;i<=n;++i) for(int j=0;j<=m;++j) res[i+j] = (res[i+j] + (ll)f[i]*g[j])%p; }else{ memcpy(A,f.a.begin().base(),(n+1)<<2),memcpy(B,g.a.begin().base(),(m+1)<<2); int lim = 1<<(32-__builtin_clz(n+m)); memset(A+n+1,0,(lim-n)<<2),memset(B+m+1,0,(lim-m)<<2); dft(A,lim),dft(B,lim); for(int i=0;i!=lim;++i) A[i] = (ll)A[i]*B[i]%p; idft(A,lim); memcpy(res.a.begin().base(),A,(n+m+1)<<2); } return res; } inline poly operator + (const poly& f,const poly& g){ int n = max(f.deg(),g.deg()); poly res; res.resize(n); for(int i=0;i<=n;++i) res[i] = (f[i]+g[i])%p; return res; } int pd[N]; int n,m; void prod(int l,int r,int u){ if(l==r){ pd[u] = (ll)l*(m-l+1+p)%p; return; } int mid = (l+r)/2; prod(l,mid,u<<1); prod(mid+1,r,u<<1|1); pd[u] = (ll)pd[u<<1]*pd[u<<1|1]%p; } pair<poly,poly> solve(int l,int r,int u){ if(l==r){ poly P,Q; P.resize(2), Q.resize(2); P[0] = P[1] = 0, P[2] = (ll)l*(m-l+1+p)%p; Q[0] = 1, Q[1] = p-m, Q[2] = (ll)l*(m-l+p)%p; return make_pair(P,Q); } int mid = (l+r)/2; pair<poly,poly> L = solve(l,mid,u<<1); pair<poly,poly> R = solve(mid+1,r,u<<1|1); L.first = L.first * R.second; L.second = L.second * R.second; for(int i=0;i<=R.first.deg();++i) R.first[i] = (ll)R.first[i]*pd[u<<1]%p; int k = (mid-l+1)*2; R.first.resize(R.first.deg() + k); for(int i=R.first.deg();i>=k;--i) R.first[i] = R.first[i-k]; for(int i=k-1;i>=0;--i) R.first[i] = 0; return make_pair(L.first + R.first, L.second); } int main(){ scanf("%d%d",&n,&m); m %= p; init(n*2); prod(1,n/2,1); pair<poly,poly> res = solve(1,n/2,1); poly f = res.first, g = res.second; f = f * g.inverse(); printf("%d",f[n]); return 0; }
- 1
信息
- ID
- 10057
- 时间
- 1000ms
- 内存
- 512MiB
- 难度
- 6
- 标签
- 递交数
- 0
- 已通过
- 0
- 上传者