1 条题解

  • 0
    @ 2025-8-24 22:30:26

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar littleKtian
    MoVe yoUR BoDy

    搬运于2025-08-24 22:30:26,当前版本为作者最后更新于2021-02-05 14:25:18,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    显然我们并不需要具体是哪个磁极,只需要只要磁极是否相同。

    yy 为树的根,设 fi,0/1f_{i,0/1} 表示从 ii 走到 yy,磁极相同/不同时的期望步数。

    有 $\begin{cases}f_{i,0}=pf_{fa_i,1}+(1-p)(f_{i,0}+1)+1,f_{i,1}=(1-p)f_{fa_i,1}+p(f_{i,0}+1)+1&i\text{是叶子节点}\\f_{i,0}=pf_{fa_i,1}+(1-p)\sum\dfrac{f_{j,0}}{so_i}+1,f_{i,1}=(1-p)f_{fa_i,1}+p\sum\dfrac{f_{j,0}}{so_i}+1&i\text{不是叶子节点}\end{cases}$,其中 jjii 的儿子,faifa_iii 的父亲,soiso_iii 的儿子节点个数。特别的,fy,0=fy,1=0f_{y,0}=f_{y,1}=0

    Subtask 1

    直接写式子然后高斯消元。

    Subtask 2&3

    (我感觉这两部分做法是一样的)

    考虑人脑消元。

    对叶子节点的式子变形,有 $f_{i,0}=f_{fa_i,1}+\dfrac{2-p}{p},f_{i,1}=f_{fa_i,1}+3$。

    发现递推式中只涉及到 ffai,1f_{fa_i,1} 以及常数项,考虑归纳证明每个节点的递推式只与其父亲有关。

    对于非叶子节点 ii,假设其后代节点均满足此规律。设 fj,0soi=k1fi,1+k2\sum\dfrac{f_{j,0}}{so_i}=k_1f_{i,1}+k_2

    有 $f_{i,0}=pf_{fa_i,1}+(1-p)(k_1f_{i,1}+k_2)+1,f_{i,1}=(1-p)f_{fa_i,1}+p(k_1f_{i,1}+k_2)+1$。

    变形,得 $f_{i,0}=\dfrac{p+k_1-2pk_1}{1-pk_1}f_{fa_i,1}+\dfrac{k_1+k_2-2pk_1-pk_2+1}{1-pk_1},f_{i,1}=\dfrac{1-p}{1-pk_1}f_{fa_i,1}+\dfrac{1+pk_2}{1-pk_1}$,得证。

    因此我们只需要先进行一次树形dp求出递推式的系数和常数,然后再递推得出结果。

    Subtask 2 直接以每个点为根的情况各进行一次dp,Subtask 3 直接以 11 为根进行dp即可。

    Subtask 4

    容易发现每个点的递推式实际上只和根的方向有关。

    所以可以用换根dp求出所有的递推式再利用一些奇怪的操作来回答每个询问。

    这样做还是很麻烦,所以我们考虑再简化递推式。

    实际上把每个点的递推式的系数打表打出来会发现都为 11

    同样考虑归纳证明容易证得,于是得到 k1=1k_1=1

    于是有递推式 $f_{i,0}=f_{fa_i,1}+k_2+2,f_{i,1}=f_{fa_i,1}+\dfrac{1+pk_2}{1-p}$。

    于是这道题就成了一道换根dp+树上链求和的板题,预处理逆元可以做到 O(n+qlogn)O(n+q\log n)

    当然具体实现还有很多细节(特别是要分清 fi,0f_{i,0}fi,1f_{i,1},以及注意链跳的方向)。

    #include<bits/stdc++.h>
    using namespace std;
    #define mod 998244353
    int lw[1000005],bi[2000005][2],bs;
    int s[1000005],tf[1000005],f[1000005][2],ff[1000005][2],sf[1000005][2];
    int fa[1000005],so[1000005],si[1000005],de[1000005],to[1000005];
    int n,q,p,np,nfp,ny[1000005];
    int dr()
    {
    	int xx=0;char ch=getchar();
    	while(ch<'0'||ch>'9')ch=getchar();
    	while(ch>='0'&&ch<='9')xx=(xx<<1)+(xx<<3)+ch-'0',ch=getchar();
    	return xx;
    }
    int P(int x,int y=mod-2)
    {
    	int z=1;
    	for(;y;x=1ll*x*x%mod,y>>=1)if(y&1)z=1ll*z*x%mod;
    	return z;
    }
    void tj(int u,int v){++bs,bi[bs][0]=lw[u],bi[bs][1]=v,lw[u]=bs;}
    void dfs1(int w,int fath,int d)
    {
    	fa[w]=fath,si[w]=1,de[w]=d;
    	int x;
    	for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    	{
    		int v=bi[o_o][1];
    		if(v!=fath)
    		{
    			++s[w],dfs1(v,w,d+1),tf[w]=(tf[w]+f[v][0])%mod,si[w]+=si[v];
    			if(si[v]>si[so[w]])so[w]=v;
    		}
    	}
    	if(s[w])x=1ll*tf[w]*ny[s[w]]%mod,f[w][0]=(2+x)%mod,f[w][1]=1ll*(1+1ll*p*x)%mod*nfp%mod;
    	else f[w][0]=1ll*(2-p+mod)*np%mod,f[w][1]=3;
    }
    void dfs2(int w)
    {
    	if(s[w]==0)return;
    	int ns=ny[s[w]-(fa[w]?0:1)],tt=(tf[w]+ff[w][0])%mod,x;
    	for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    	{
    		int v=bi[o_o][1];
    		if(v!=fa[w])
    		{
    			if(fa[w]||s[w]>1)x=1ll*(tt-f[v][0]+mod)*ns%mod,ff[v][0]=(2+x)%mod,ff[v][1]=1ll*(1+1ll*p*x)%mod*nfp%mod;
    			else ff[v][0]=1ll*(2-p+mod)*np%mod,ff[v][1]=3;
    			dfs2(v);
    		}
    	}
    }
    void dfs(int w,int t)
    {
    	to[w]=t,sf[w][0]=(sf[fa[w]][0]+f[w][1])%mod,sf[w][1]=(sf[fa[w]][1]+ff[w][1])%mod;
    	if(so[w])dfs(so[w],t);
    	for(int o_o=lw[w];o_o;o_o=bi[o_o][0])
    	{
    		int v=bi[o_o][1];
    		if(v!=fa[w]&&v!=so[w])dfs(v,v);
    	}
    }
    int lca(int x,int y)
    {
    	int fx=to[x],fy=to[y];
    	while(fx!=fy)if(de[fx]>de[fy])x=fa[fx],fx=to[x];else y=fa[fy],fy=to[y];
    	return de[x]<de[y]?x:y;
    }
    int gs(int x,int y)
    {
    	int fx=to[x],fy=to[y];
    	while(fx!=fy)
    	{
    		if(fa[fx]==y)return fx;
    		x=fa[fx],fx=to[x];
    	}
    	return so[y];
    }
    int main()
    {
    	n=dr(),q=dr(),p=dr(),np=P(p),nfp=P(mod+1-p);
    	ny[1]=1;
    	for(int i=2;i<=n;i++)ny[i]=1ll*(mod-mod/i)*ny[mod%i]%mod;
    	for(int i=1;i<n;i++)
    	{
    		int u=dr(),v=dr();
    		tj(u,v),tj(v,u);
    	}
    	dfs1(1,0,1),dfs2(1),dfs(1,1);
    	while(q--)
    	{
    		int x=dr(),y;char c1=getchar(),c2;
    		while(c1!='N'&&c1!='S')c1=getchar();
    		y=dr(),c2=getchar();
    		while(c2!='N'&&c2!='S')c2=getchar();
    		int a=lca(x,y),ans=((sf[x][0]-sf[a][0]+sf[y][1]-sf[a][1])%mod+mod)%mod;
    		if(c1==c2)
    		{
    			if(x==a)
    			{
    				int b=gs(y,a);
    				ans=((ans-ff[b][1]+ff[b][0])%mod+mod)%mod;
    			}
    			else ans=((ans-f[x][1]+f[x][0])%mod+mod)%mod;
    		}
    		printf("%d\n",ans);
    	}
    }
    

    关于为什么 p2p\geq 2

    其实就是直接帮你们排除了 p=0p=0 以及 p=1p=1 的情况,因为上面所有的式子变形都基于 p0p\neq 0p1p\neq 1

    • 1

    信息

    ID
    6392
    时间
    2000ms
    内存
    512MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者