1 条题解

  • 0
    @ 2025-8-24 22:44:18

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar Demeanor_Roy
    小时候我们总想去改变别人,后来发现,比起改变,筛选是性价比更高的事。

    搬运于2025-08-24 22:44:18,当前版本为作者最后更新于2023-01-07 21:47:19,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    • 出题人题解。

    不妨思考极远点对的本质是什么。发现若 (u,v)(u,v) 为极远点对,可以看作从任意点 uu 找到距离其最远点 vv,再从点 vv 找到距离其最远点且此点恰好为 uu,不难察觉这符合dfs两遍求树的直径的过程,故 u,vu,v 为同一直径两端点。

    所以题目意转化为对于每个点,求出其在多少条直径上。

    将树以 11 为根定型,然后求解答案。

    不妨先树形 DP 求出总直径数 sumsum,以 xx 为端点直径数的子树和 lxl_x,挂在 xx 上的直径数 pxp_xpxp_x 的子树和 sxs_x。其中除了 lxl_x 的求解略显繁琐外,其他的求解都是朴素的(注意细节,一定要想清楚再写,数组的含义千万不能弄混)。

    考虑经过 xx 的直径数,不难得到:vx=px+(lxsx×2)v_x =p_x + (l_x - s_x \times 2)。前者是挂在 xx 上的直径,后者是简单容斥得到的横穿 xx 子树内外的直径。同样,这样的计算是直观的。

    至此整道题解决完毕,时间复杂度线性。

    下面阐述一下部分分缘由:

    1. n300n \leq 300,暴力找极远点对,暴力标记即可。O(n3)O(n^3)

    2. n2000n \leq 2000,暴力找极远点对,树上差分标记即可。O(n2logn)O(n^2 \log n)

    3. n105n \leq 10^5,给数据结构学傻的人。

    4. k=1k=1,不难发现贡献可以合起来算,令 lenlen 为直径长度,答案即为 len×sumlen \times sum

    下附代码:

    #include<bits/stdc++.h>
    using namespace std;
    #define LL long long
    const int N=5e6+10,mod=998244353;
    int n,k,len,ans,s[N],l[N];
    int h[N],e[N<<1],ne[N<<1],idx;
    struct node
    {
    	int val,cnt;
    }p[N],fi[N],se[N];
    inline int read()
    {
    	int x=0;char ch=getchar();
    	while(ch<'0'||ch>'9') ch=getchar();
    	while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
    	return x;
    }
    inline void add(int a,int b)
    {		
    	e[idx]=b,ne[idx]=h[a],h[a]=idx++;
    }
    inline void dfs(int u,int fa)
    {
    	fi[u].val=fi[u].cnt=1;
    	for(int i=h[u];~i;i=ne[i])
    	{
    		if(e[i]==fa)	continue;
    		dfs(e[i],u);
    		if(fi[u].val+fi[e[i]].val>p[u].val) p[u]={fi[u].val+fi[e[i]].val,(int)((LL)fi[u].cnt*fi[e[i]].cnt%mod)};	
    		else if(fi[u].val+fi[e[i]].val==p[u].val) p[u].cnt=(p[u].cnt+(LL)fi[u].cnt*fi[e[i]].cnt%mod)%mod;
    		if(fi[e[i]].val+1>fi[u].val) se[u]=fi[u],fi[u]={fi[e[i]].val+1,fi[e[i]].cnt};
    		else if(fi[e[i]].val+1==fi[u].val) fi[u].cnt+=fi[e[i]].cnt;
    		else if(fi[e[i]].val+1>se[u].val) se[u]={fi[e[i]].val+1,fi[e[i]].cnt};
    		else if(fi[e[i]].val+1==se[u].val) se[u].cnt+=fi[e[i]].cnt;
    	} 
    	len=max(len,p[u].val);
    }
    inline void DFS(int u,int fa,int up,int num)
    {
    	s[u]=p[u].cnt=(p[u].val==len?p[u].cnt:0);
    	for(int i=h[u];~i;i=ne[i])
    	{
    		if(e[i]==fa)	continue;
    		int cur=max(up+1,(fi[e[i]].val+1==fi[u].val&&fi[e[i]].cnt==fi[u].cnt)?se[u].val:fi[u].val);
    		if(fi[e[i]].val+1==fi[u].val&&fi[e[i]].cnt==fi[u].cnt) DFS(e[i],u,cur,num*((up+1)==cur)+se[u].cnt*((se[u].val)==cur));
    		else if(fi[e[i]].val+1==fi[u].val) DFS(e[i],u,cur,num*((up+1)==cur)+(fi[u].cnt-fi[e[i]].cnt)*((fi[u].val)==cur));
    		else DFS(e[i],u,cur,num*((up+1)==cur)+fi[u].cnt*((fi[u].val)==cur));
    		l[u]=(l[u]+l[e[i]])%mod;s[u]=(s[u]+s[e[i]])%mod;
    	}
    	if(up+1==len) l[u]=(l[u]+num)%mod;
    	if(fi[u].val==len) l[u]=(l[u]+fi[u].cnt)%mod;
    }
    inline int pwr(int x,int y){return y==1?x:(LL)x*x%mod;}
    int main()
    {
    	memset(h,-1,sizeof h);
    	n=read(),k=read();
    	for(int i=1;i<n;i++)	
    	{
    		int u,v;
    		u=read(),v=read();
    		add(u,v),add(v,u);
    	}
    	dfs(1,-1);DFS(1,-1,0,0);
    	for(int i=1;i<=n;i++) ans=(ans+pwr(p[i].cnt+(l[i]-s[i]*2),k))%mod;
    	printf("%d",ans);
    	return 0;
    }
    

    似乎验题人有更简单的做法。。。

    • 1

    信息

    ID
    8275
    时间
    1000~3000ms
    内存
    512MiB
    难度
    5
    标签
    递交数
    0
    已通过
    0
    上传者