1 条题解

  • 0
    @ 2025-8-24 22:24:06

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar littleKtian
    MoVe yoUR BoDy

    搬运于2025-08-24 22:24:06,当前版本为作者最后更新于2020-12-03 19:39:44,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    upd on 2021.3.16:修改了原做法和代码中的错误,现在能通过 hack 数据了。


    显然这是道树形 dp,为了方便处理信息,我们考虑将这 mm 条链放到 lca 上维护。

    fi,0/1f_{i,0/1} 表示以 ii 为根的子树中没有选取/有选取以 ii 为 lca 的链时,选取链的方案数,gig_i 为两者之和。

    易得转移方程 $f_{i,0}=\prod\limits_{j\text{是}i\text{的孩子}}g_j,f_{i,1}=\sum\limits_{\text{链}L\text{的lca是}i}\prod\limits_{j\text{不在}L\text{上,且}\atop j\text{的父亲在}L\text{上}}g_j$。

    fi,0f_{i,0} 很容易计算出,考虑如何计算 fi,1f_{i,1}

    发现连乘的形式和 fi,0f_{i,0} 很像,于是考虑往这个方向变形。

    得 $f_{i,1}=\sum\limits_{\text{链}L\text{的lca是}i}\dfrac{\prod\limits_{j\text{的父亲在}L\text{上}}g_j}{\prod\limits_{k\text{在}L\text{上且}k\neq i}g_k}=\sum f_{i,0}\prod\limits_{j\text{在}L\text{上且}j\neq i}\dfrac{f_{j,0}}{g_j}$。

    显然这是一个树上的区间求乘积,最简单的方法是 树剖+线段树,但并不能通过此题。

    我们可以对树剖的每条重链维护后缀积,因为整个 dp 的过程是从下往上的,而且我们只需要用到下面的信息点,所以对于每个点我们都可以做到 O(1)O(1) 维护后缀积+每条链 O(lognlogp)O(\log n\log p) 查询,预处理逆元后可以每个点 O(logp)O(\log p) 维护+每条链 O(logn)O(\log n) 查询,可以通过。

    以上做法已被 hack


    上面做法的最大问题是 fi,0f_{i,0}gig_i 在对 998244353998244353 取模意义下可能为 00,导致逆元无法使用。但我们也可以从中获得一些启示。

    我们把乘积式中的元素分为 jjii 的儿子以及不是 ii 的儿子两部分,发现对于 jjii 的儿子的部分,最多只会有两个元素缺失,所以我们可以给 ii 的每个儿子一个标号,用线段树处理出区间乘积,对于每条链最多只会有 33 次询问,所以对于每条链复杂度是 O(logn)O(\log n)

    对于不是 ii 的儿子的部分,可以发现对于所有链 LL 上的点 jj,都最多只会缺失一个元素,并且缺失的元素就是链 LL 上的点。所以我们可以另设 hih_i 表示 faifa_i 的儿子中除了 ii 以外所有 gg 的乘积,那么我们相当于需要查询一条链上所有元素的乘积,我们可以转化为每次对以某个点为根的子树内所有元素乘上一个数,然后进行单点查询。因此复杂度还是 O(logn)O(\log n) 的。

    总复杂度应该是 O((n+m)logn)O\left((n+m)\log n\right),细节相对较多。

    #include<bits/stdc++.h>
    using namespace std;
    #define p 998244353
    int lsw[500005],bi[1000005][2],bs;
    int fa[500005],so[500005],si[500005],de[500005],to[500005],xh[500005],dfn;
    int lw[500005],la[500005][3];
    int n,m,ff[2][500005],f[500005];
    int dr()
    {
        int xx=0;char ch=getchar();
        while(ch<'0'||ch>'9')ch=getchar();
        while(ch>='0'&&ch<='9')xx=xx*10+ch-'0',ch=getchar();
        return xx;
    }
    int P(int x,int y=p-2)
    {
        int z=1;
        for(;y;x=1ll*x*x%p,y>>=1)if(y&1)z=1ll*z*x%p;
        return z;
    }
    void tj(int u,int v){++bs,bi[bs][0]=lsw[u],bi[bs][1]=v,lsw[u]=bs;}
    void dfs_1(int w,int f)
    {
        fa[w]=f,si[w]=1,de[w]=de[f]+1;
        for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
        {
            int v=bi[o_o][1];
            if(v!=f)
            {
                dfs_1(v,w),si[w]+=si[v];
                if(si[v]>si[so[w]])so[w]=v;
            }
        }
    }
    void dfs_2(int w,int t)
    {
        to[w]=t,xh[w]=++dfn;
        if(so[w])dfs_2(so[w],t);
        for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
        {
            int v=bi[o_o][1];
            if(v!=fa[w]&&v!=so[w])dfs_2(v,v);
        }
    }
    int fi(int x,int y)
    {
        int fx=to[x],fy=to[y];
        while(fx!=fy)
        {
            if(de[fx]>=de[fy])x=fa[fx],fx=to[x];
            else y=fa[fy],fy=to[y];
        }
        return de[x]<de[y]? x:y;
    }
    int ju(int x,int y)
    {
    	int fx=to[x],fy=to[y];
    	while(fx!=fy)
    	{
    		if(fa[fx]==y)return fx;
    		x=fa[fx],fx=to[x];
    	}
    	return so[y];
    }
    int tree[2000005],ltree[2000005],a[500005],lxh[500005];
    #define ls(w) (w<<1)
    #define rs(w) (ls(w)^1)
    void d(int w){if(tree[w]!=1)tree[ls(w)]=1ll*tree[ls(w)]*tree[w]%p,tree[rs(w)]=1ll*tree[rs(w)]*tree[w]%p,tree[w]=1;}
    void csh1(int w,int l,int r)
    {
    	tree[w]=1;
    	if(l==r)return;
    	int mid=(l+r)>>1;
    	csh1(ls(w),l,mid),csh1(rs(w),mid+1,r);
    }
    void xg1(int w,int l,int r,int L,int R,int x)
    {
    	if(L<=l&&r<=R){tree[w]=1ll*tree[w]*x%p;return;}
    	d(w);
    	int mid=(l+r)>>1;
    	if(L<=mid)xg1(ls(w),l,mid,L,R,x);
    	if(mid<R)xg1(rs(w),mid+1,r,L,R,x);
    }
    int cx1(int w,int l,int r,int xh)
    {
    	if(l==r)return tree[w];
    	d(w);
    	int mid=(l+r)>>1;
    	if(xh<=mid)return cx1(ls(w),l,mid,xh);
    	else return cx1(rs(w),mid+1,r,xh);
    }
    void u(int w){ltree[w]=1ll*ltree[ls(w)]*ltree[rs(w)]%p;}
    void csh2(int w,int l,int r)
    {
    	if(l==r){ltree[w]=a[l];return;}
    	int mid=(l+r)>>1;
    	csh2(ls(w),l,mid),csh2(rs(w),mid+1,r);
    	u(w);
    }
    int cx2(int w,int l,int r,int L,int R)
    {
    	if(L>R)return 1;
    	if(L<=l&&r<=R)return ltree[w];
    	int mid=(l+r)>>1,x=1;
    	if(L<=mid)x=1ll*x*cx2(ls(w),l,mid,L,R)%p;
    	if(mid<R)x=1ll*x*cx2(rs(w),mid+1,r,L,R)%p;
    	return x;
    }
    void dp(int w)
    {
        ff[0][w]=1;int tot=0;
        for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
        {
            int v=bi[o_o][1];
            if(v!=fa[w])dp(v),lxh[v]=++tot,ff[0][w]=1ll*ff[0][w]*f[v]%p;
        }
        for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
        {
        	int v=bi[o_o][1];
        	if(v!=fa[w])a[lxh[v]]=f[v];
    	}
    	if(tot)csh2(1,1,tot);
        for(int o_o=lw[w];o_o;o_o=la[o_o][0])
    	{
    		int u=la[o_o][1],v=la[o_o][2];
    		if(u!=w&&v!=w)
    		{
    			int x=ju(u,w),y=ju(v,w);
    			if(lxh[x]>lxh[y])swap(x,y);
    			ff[1][w]=(ff[1][w]+(1ll*cx1(1,1,n,xh[u])*ff[0][u]%p)*(1ll*cx1(1,1,n,xh[v])*ff[0][v]%p)%p*(1ll*cx2(1,1,tot,1,lxh[x]-1)*cx2(1,1,tot,lxh[x]+1,lxh[y]-1)%p*cx2(1,1,tot,lxh[y]+1,tot)%p))%p;
    		}
    		else if(u==w&&v==w)
    		{
    			ff[1][w]=(ff[1][w]+ff[0][w])%p;
    		}
    		else
    		{
    			if(v==w)swap(u,v);
    			int x=ju(v,w);
    			ff[1][w]=(ff[1][w]+1ll*cx1(1,1,n,xh[v])*ff[0][v]%p*cx2(1,1,tot,1,lxh[x]-1)%p*cx2(1,1,tot,lxh[x]+1,tot))%p;
    		}
    	}
    	for(int o_o=lsw[w];o_o;o_o=bi[o_o][0])
    	{
    		int v=bi[o_o][1];
        	if(v!=fa[w])
        	{
        		int x=1ll*cx2(1,1,tot,1,lxh[v]-1)*cx2(1,1,tot,lxh[v]+1,tot)%p;
        		xg1(1,1,n,xh[v],xh[v]+si[v]-1,x);
    		}
    	}
        f[w]=(ff[0][w]+ff[1][w])%p;
    }
    int main()
    {
        n=dr(),m=dr();
        for(int i=1;i<n;i++)
        {
            int u=dr(),v=dr();
            tj(u,v),tj(v,u);
        }
        dfs_1(1,0),dfs_2(1,1);
        for(int i=1;i<=m;i++)
        {
            int u=dr(),v=dr(),g=fi(u,v);
            la[i][0]=lw[g],la[i][1]=u,la[i][2]=v,lw[g]=i;
        }
        csh1(1,1,n),dp(1);
        printf("%d",f[1]);
    }
    
    
    • 1

    信息

    ID
    5689
    时间
    3000ms
    内存
    256MiB
    难度
    6
    标签
    递交数
    0
    已通过
    0
    上传者