1 条题解

  • 0
    @ 2025-8-24 22:02:48

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar Jacob233
    **

    搬运于2025-08-24 22:02:48,当前版本为作者最后更新于2018-12-27 21:26:16,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    前置技能:NTT,多项式求逆,多项式对数函数,多项式求导。

    首先我们按题意写出答案的式子:$$ans_k=\frac{\sum_{i = 1}^n\sum_{j = 1}^m(a_i+b_j)^k}{nm}$$

    先不管分母,把分子用二项式定理展开:

    $$ans_k=\sum_{r = 0}^k\sum_{i = 1}^n\sum_{j = 1}^m\binom r ka_i^rb_j^{k-r} $$$$ans_k=k!\times\sum_{r = 0}^k(\sum_{i = 1}^n\frac{a_i^r}{r!})(\sum_j^m\frac{b_j^{k-r}}{(k-r)!}) $$

    这样子就写成了卷积的形式,那么我们只要能快速求出k=1nak\sum_{k=1}^na^k,就可以用NTT计算卷积了。

    我们写出这个东西的生成函数:$$1+a^1x+a^2x^2+...+a^\infty x^\infty$$

    xkx^k的系数表示kk次幂和,那么我们用等比数列求和公式解出来可以得到:$$\frac{1}{1-ax}$$

    f(x)f(x)为这些生成函数的和,那么:

    f(x)=i=1n11aixf(x) = \sum_{i = 1}^n\frac{1}{1-a_ix}

    这个东西不好算,我们发现$$\ln'(1-a_ix)=\displaystyle\frac{1}{1-a_ix}$$

    我们从对数函数角度考虑,又可以发现:

    (ln(1aix))=ai1aix(\ln(1-a_ix))'=\frac{-a_i}{1-a_ix}

    设$g(x)=\displaystyle\sum_{i = 1}^n\displaystyle\frac{-a_i}{1-a_ix}$,那么f(x)=x×g(x)+nf(x)=-x\times g(x)+n

    g(x)g(x)也很好算,化一下式子就变成了:

    g(x)=i=1n(ln(1aix))g(x)=\sum_{i = 1}^n(\ln(1-a_ix))' =(ln(i=1n(1aix)))=(\ln(\prod_{i = 1}^n(1-a_ix)))'

    这样子gg就可以用分治+NTT算,算出gg后再推出ffffxix_i的系数就是j=1naji\displaystyle\sum_{j = 1}^na_j^i,那么我们代回原式再做一遍卷积就好了,总的复杂度为O(nlog2n)O(n\log^2n)

    贴一下代码:

    #include <bits/stdc++.h>
    
    using namespace std;
    
    const int N = 1 << 19 | 1; 
    const int M = log2(N) + 3; 
    const int mod = 998244353;
    
    int n, m, t, cnt = -1, tp[M << 1][N];
    int a[N], b[N], ans[N], Sa[N], Sb[N];
    int rev[N], A[N], B[N], fac[N], ifac[N];
    
    inline int inv(int x) { return 1ll * ifac[x] * fac[x - 1] % mod; }
    inline int add(int x, int y) { return (x += y) < mod ? x : x - mod; }
    
    inline int qpow(int _, int __) {
    	int ___ = 1; 
    	for (; __; _ = 1ll * _ * _ % mod, __ >>= 1) 
    		if (__ & 1) ___ = 1ll * ___ * _ % mod;
    	return ___;
    }
    
    inline void Math_Init(int n) {
    	fac[0] = ifac[0] = 1; 
    	for (int i = 1; i <= n; ++ i) 
    		fac[i] = 1ll * fac[i - 1] * i % mod;
    	ifac[n] = qpow(fac[n], mod - 2);
    	for (int i = n; i; -- i) 
    		ifac[i - 1] = 1ll * ifac[i] * i % mod;
    }
    
    inline void NTT(int *a, int n, int fh) {
    	for (int i = 0; i < n; ++ i) 
    		if (i < rev[i]) swap(a[i], a[rev[i]]);;
    	for (int Wn, limit = 2; limit <= n; limit <<= 1) {
    		Wn = qpow(fh ^ 1 ? qpow(3, mod - 2) : 3, (mod - 1) / limit);
    		for (int W = 1, j = 0; j < n; j += limit, W = 1) 
    			for (int i = j; i < j + (limit >> 1); ++ i, W = 1ll * W * Wn % mod) {
    				int a1 = a[i], a2 = 1ll * W * a[i + (limit >> 1)] % mod;
    				a[i] = add(a1, a2), a[i + (limit >> 1)] = add(a1, mod - a2);
    			}
    	}
    	if (fh ^ 1) for (int i = 0; i < n; ++ i) 
    		a[i] = 1ll * a[i] * inv(n) % mod;
    }
    
    inline void Invpoly(int *a, int *b, int len) {
    	int limit = 1, k = 0;
    	if (len ^ 1) {
    		Invpoly(a, b, len >> 1);
    		while (limit < len * 2) limit <<= 1, ++ k;
    		for (int i = 0; i < limit; ++ i) {
    			A[i] = B[i] = 0;
    			rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
    		}
    		for (int i = 0; i < len; ++ i) 
    			A[i] = a[i], B[i] = b[i];
    		NTT(A, limit, 1), NTT(B, limit, 1);
    		for (int i = 0; i < limit; ++ i) 
    			A[i] = 1ll * A[i] * B[i] % mod * B[i] % mod;
    		NTT(A, limit, -1);
    		for (int i = 0; i < len; ++ i) 
    			b[i] = add(b[i], add(b[i], mod - A[i]));
    	}	
    	else b[0] = qpow(a[0], mod - 2);
    }
    
    inline void Derpoly(int *a, int len) {
    	for (int i = 0; i < len - 1; ++ i) 
    		a[i] = 1ll * (i + 1) * a[i + 1] % mod;
    	a[len - 1] = 0;
    }
    
    inline void Solve(int l, int r, int *a, int *b) {
    	if (l == r) return (void) (a[0] = 1, a[1] = mod - b[l]); 
    	int mid = (l + r) >> 1, *a1 = tp[++ cnt], *a2 = tp[++ cnt], limit = 1, k = 0; 
    	Solve(l, mid, a1, b), Solve(mid + 1, r, a2, b);
    	while (limit <= r - l + 1) limit <<= 1, ++ k;
    	for (int i = 0; i < limit; ++ i) 
    		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
    	NTT(a1, limit, 1), NTT(a2, limit, 1);
    	for (int i = 0; i < limit; ++ i) 
    		a[i] = 1ll * a1[i] * a2[i] % mod, a1[i] = a2[i] = 0;
    	NTT(a, limit, -1), cnt -= 2;
    }
    
    inline void Get_S(int *a, int *f, int n) {
    	int invf[N] = {0}, limit = 1, k = 0;
    	while (limit < max(n, t) * 2) limit <<= 1, ++ k;
    	Solve(1, n, f, a), Invpoly(f, invf, limit >> 1), Derpoly(f, limit);
    	for (int i = 0; i < limit; ++ i) 
    		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
    	NTT(f, limit, 1), NTT(invf, limit, 1);
    	for (int i = 0; i < limit; ++ i) 
    		f[i] = 1ll * f[i] * invf[i] % mod;
    	NTT(f, limit, -1);
    	for (int i = limit - 2; ~i; -- i) 
    		f[i + 1] = 1ll * f[i] * (mod - 1) % mod;
    	f[0] = n;
    	for (int i = 0; i < limit; ++ i) {
    		if (i > t) f[i] = 0;
    		f[i] = 1ll * f[i] * ifac[i] % mod;
    	}
    }
    
    int main() {
    	int limit = 1, k = 0;
    
    	Math_Init(N - 5), scanf("%d%d", &n, &m);
    	for (int i = 1; i <= n; ++ i) scanf("%d", &a[i]);
    	for (int i = 1; i <= m; ++ i) scanf("%d", &b[i]);
    	scanf("%d", &t), Get_S(a, Sa, n), Get_S(b, Sb, m);
    
    	while (limit <= t * 2) limit <<= 1, ++ k;
    	for (int i = 0; i < limit; ++ i) 
    		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1));
    	NTT(Sa, limit, 1), NTT(Sb, limit, 1);
    	for (int i = 0; i < limit; ++ i) 
    		Sa[i] = 1ll * Sa[i] * Sb[i] % mod;
    	NTT(Sa, limit, -1);
    
    	for (int i = 1; i <= t; ++ i) 
    		printf("%lld\n", 1ll * Sa[i] * fac[i] % mod * inv(n) % mod * inv(m) % mod);
    
    	return 0;
    }
    
    
    • 1

    信息

    ID
    3580
    时间
    3000ms
    内存
    500MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者