1 条题解

  • 0
    @ 2025-8-24 22:55:29

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar 云浅知处

    搬运于2025-08-24 22:55:29,当前版本为作者最后更新于2024-02-19 20:10:39,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    O(nq)O(n\sqrt{q}) 做法

    不难发现本题实际上是要算:

    • 只保留树上颜色为 xx 或者 yy 的点,求此时树上所有连通块大小的平方之和。

    cxc_x 表示颜色 xx 出现的次数。

    那么,对于一组询问 x,yx,y,如果我们能设计出 O(min(cx,cy))O(\min(c_x,c_y)) 的算法并进行记忆化,总的复杂度就不会超过 O(nq)O(n\sqrt{q})

    • min(cx,cy)nq\min(c_x,c_y)\le \frac{n}{\sqrt{q}},则这部分复杂度不超过 q×nq=nqq\times \frac{n}{\sqrt{q}}=n\sqrt{q}
    • min(cx,cy)>nq\min(c_x,c_y)>\frac{n}{\sqrt{q}},这样的 xx 至多 q\sqrt{q} 个,对于每个 xx,由于 yy 同样也只有 q\sqrt{q} 个,对所有 cy>nqc_y>\frac{n}{\sqrt{q}}yy 求和 min(cx,cy)\min(c_x,c_y) 的结果不会超过 ycx+cyn+cx×q\sum_y c_x+c_y\le n+c_x\times \sqrt{q},再把这一结果对 xx 求和就得到了这部分的复杂度同样是 O(nq)O(n\sqrt{q})

    对每组询问 (x,y)(x,y),若 cx<cyc_x<c_y 则交换 x,yx,y,我们将这组询问 (x,y)(x,y) 挂在颜色 xx 上。接下来对每种颜色 xx 分别处理:考虑维护树上的点集形成的连通块,我们将每个连通块的信息放在这个连通块的根的位置,则插入一个点时只需考虑它的若干儿子,以及它的父亲处可能存在的信息合并。

    考虑提前预处理出每个颜色为 xx 的连通块的根,接下来按照深度从大到小依次插入所有颜色为 yy 的点。这样插入一个点时,父亲处若存在信息合并,必然是完整的 xx 连通块,我们可以在 O(1)O(1) 的时间内将其若干儿子的连通块的信息提到总的连通块的根上面,就在 O(1)O(1) 时间内完成了插入。

    总的时间复杂度为 O(nq)O(n\sqrt{q})

    O(n+q)O(n+q) 做法

    可以发现,对于给出的颜色 xi,yix_i,y_i ,若树上不存在一条边 ii 满足 ii 两端点的颜色分别是 xi,yix_i,y_i ,那么 xi,yix_i,y_i 的点会形成若干独立的连通块,可以预处理后直接计算。

    其余的颜色对一共最多只有 n1n-1 种,就是每条边端点的颜色对集合。

    首先把两端点颜色相同的边缩点。

    枚举每一种可能的颜色对 xi,yix_i,y_i,把树上这样的边连上,计算形成的连通块的大小的平方和即可。

    使用哈希表存答案,然后建边后再树上跑 DFS,可以做到 O(n+m)O(n+m),如果使用 std::map 或者可撤销并查集也能通过。

    #include<bits/stdc++.h>
    using namespace std;
    const int N = 1e6+7;
    int n,m,q;
    int c[N];
    int idx=0;struct dsu
    {
    	int fa[N],siz[N];
    	int find(int x)
    	{
    		if(x==fa[x])return x;
    		return fa[x]=find(fa[x]);
    	}
    	void merge(int x,int y)
    	{
    		if(find(x)==find(y))return;
    		x=find(x);y=find(y);
    		fa[x]=y;
    		siz[y]+=siz[x];
    	}
    }A,B;
    struct edge 
    {
    	int a,b,next,id;
    }e[N];
    const int M = 1e6+7;
    int flink[M],t=0;
    int get(int a,int b)
    {
    	int h=(1ll*a*131%M+b)%M;
    	for(int i=flink[h];i;i=e[i].next)
    	if(e[i].a==a&&e[i].b==b)return e[i].id;
    	e[++t].a=a;
    	e[t].b=b;
    	e[t].id=++idx;
    	e[t].next=flink[h];
    	flink[h]=t;
    	return idx;
    }
    int qry(int a,int b)
    {
    	int h=(1ll*a*131%M+b)%M;
    	for(int i=flink[h];i;i=e[i].next)
    	if(e[i].a==a&&e[i].b==b)return e[i].id;
    	return 0;
    }
    #define PII pair<int,int>
    #define mk(x,y) make_pair(x,y)
    #define X(x) x.first
    #define Y(x) x.second
    typedef long long LL;
    inline int read() {
    	char ch = getchar(); int x = 0;
    	while (!isdigit(ch)) {ch = getchar();}
    	while (isdigit(ch)) {x = x * 10 + ch - 48; ch = getchar();}
    	return x;
    }
    void write(LL x) {
    	if (!x) return;
    	write(x / 10); putchar(x % 10 + '0');
    }
    inline void print(LL x, char ch = '\n') {
    	if (!x) putchar('0');
    	else write(x);
    	putchar(ch);
    }
    vector<int> E[N];
    LL ans[N];
    int U[N],V[N];
    int seq[2*N],tot=0;
    bool mark[N];
    LL ext[N];
    int vis[N],tag;
    int main()
    {
    	n = read(); q = read();
    	for(int i=1;i<=n;i++)
    	{
    		c[i] = read();
    		A.fa[i]=i;
    		A.siz[i]=1;
    	}
    	for(int i=2;i<=n;i++)
    	{
    		int x;
    		x = read();
    		if(c[i]==c[x]) A.merge(x,i);
    		else 
    		{
    			int cx=c[x],cy=c[i];
    			if(cx>cy)swap(cx,cy);
    			++m;
    			U[m]=x;
    			V[m]=i;
    			E[get(cx,cy)].push_back(m);
    		}
    	}
    	for(int i=1;i<=m;i++)
    	{
    		U[i]=A.find(U[i]);
    		V[i]=A.find(V[i]);
    		mark[U[i]]=1;
    		mark[V[i]]=1;
    	}
    	for(int i=1;i<=n;i++)
    	if(A.find(i)==i)
    	ext[c[i]]+=1ll*A.siz[i]*A.siz[i];
    	for(int r=1;r<=idx;r++)
    	{
    		tot=0;++tag; 
    		for(auto p:E[r])
    		{
    			int x=U[p],y=V[p];
    			if(vis[x]!=tag)vis[x]=tag,seq[++tot]=x;
    			if(vis[y]!=tag)vis[y]=tag,seq[++tot]=y;
    		}
    		LL res=0;
    		for(int i=1;i<=tot;i++)
    		{
    			B.fa[seq[i]]=seq[i];
    			B.siz[seq[i]]=A.siz[seq[i]];
    			res-=1ll*A.siz[seq[i]]*A.siz[seq[i]];
    		}
    		for(auto p:E[r])
    		{
    			int x=U[p],y=V[p];
    			B.merge(x,y);
    		}
    		for(int i=1;i<=tot;i++)
    		{
    			int x=seq[i];
    			if(B.find(x)==x)
    			res+=1ll*B.siz[x]*B.siz[x];
    		}
    		ans[r]=res;
    	}
    	while(q--)
    	{
    		int x,y;
    		x = read(); y = read();
    		assert(x != y);
    		if(x>y)swap(x,y);
    		LL res=ext[x]+ext[y];
    		if(qry(x,y))res+=ans[qry(x,y)];
    		print(res);
    	}
    	return 0;
    }
    
    • 1

    信息

    ID
    8981
    时间
    2000ms
    内存
    512MiB
    难度
    5
    标签
    递交数
    0
    已通过
    0
    上传者