1 条题解

  • 0
    @ 2025-8-24 22:21:53

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar yizhiming
    ​ | 最后在线时间: 2025/8/21 13:30

    搬运于2025-08-24 22:21:53,当前版本为作者最后更新于2023-09-11 19:11:18,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    呃呃了,在以为其他题解做麻烦的前提下写了写发现假了,结果优化成了和其他人一样的做法。

    题目大意

    给定 nn 个点的两棵树 A,BA,B,求有多少个点集满足将点集内的点按照树上的边连边后,在 AA 树上形成一个联通块,在 BB 树上形成一条链。

    TT 组数据。

    T=3,1n105T = 3,1\leq n\leq10^5

    题目分析

    先考虑一个性质,对于一个树上的点集 TT,若其内部边的数量为 xx,那么这个点集的联通块数是 Tx|T|-x,证明考虑一开始每个点都单独一个联通块,每次连一条边就是把两个联通块合并成一个。

    有了这个性质如何做呢?这启发了我们维护联通块数。

    首先考虑特殊性质,对于 BB 树是链,等价于要求点集是个区间,所以考虑扫描线,设 sis_i 表示在当前扫描线右端点在 rr,左端点在 ii 时,这个区间点集在 AA 树上有几个联通块,答案显然是区间内 11 的个数,由于区间最小值一定最小为 11,所以可以直接维护区间最小值个数。

    如何转移,考虑由 rr 推到 r+1r+1,此时对于 [1,r+1][1,r+1] 来说都新加入了一个点,所以区间加 11,然后对于 AA 树上的一条边 (u,r+1)(u,r+1) 满足 u<r+1u<r+1 来说,[1,u][1,u]sis_i 对应的点集内一定有这条边,所以区间减 11 即可,答案就是所有版本的 11 的个数和。

    考虑扩展到树上,如何将区间转换成链,不难想到令每个点作为根,求出每个点到根路径形成的点集在 AA 树上的联通块个数,不妨设 fif_i 表示这个,答案会算多,原因是对于一条合法的链 (u,v)(u,v)u,vu,v 为根时都会计算一遍,所以要去掉,注意 (u,u)(u,u) 不会算重。

    接下来的内容默认会换根意义下的区间加减,若不会请去遥远的国度

    假设当前根为 uu,要换到他的儿子 vv,如何转移 sis_i,令 W(x,y)W(x,y) 表示以 xx 为根时,yy 的子树表达的点集。

    首先由于 vv 提到了根的位置,所以除了 W(u,v)W(u,v) 以外的所有点,所对应的点集都插入了一个点,区间加,同理 W(u,v)W(u,v) 整体少了一个点。

    现在考虑新的边的贡献,vv 对于 W(u,v)W(u,v) 的贡献在 uu 为根的时候已经统计过了,所以对于 AA 树边 (v,x)(v,x),若 xW(u,v)x \notin W(u,v) 那么就对 W(v,x)W(v,x) 进行一次子树减,因为这部分都会被这条边影响。同理我们也要删除 uu 的在 W(u,v)W(u,v) 内的 AA 树上邻居的贡献,但是发现每次换根都枚举一圈 AA 树的邻居,总的枚举个数就成了两树度数的平方。

    注意到对于 uu 需要删掉的贡献只有在 W(u,v)W(u,v) 内的,容易发现对于 uu 每个儿子,其子树区间不相交,所以我们可以将 AA 边按照 BB 树的 dfs 序排序,这样的话每个贡献只会增减各一次。

    Code

    注意要做到从 vv 版本回溯到 uu,所以记录下来操作反着做一遍即可。

    对于最开始的 11 号版本,可以暴力预处理出来初始情况。

    #include <iostream>
    #include <algorithm>
    #include <cstdio>
    #include <cstring>
    #include <cmath>
    #include <queue>
    #define int long long
    using namespace std;
    int read(){
    	int x=0,f=1;char ch=getchar();
    	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    	while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    	return x*f;
    }
    const int N = 1e5+5;
    int n;
    vector<int>in[N],ed[N];
    
    int minx,cnt,rt;
    struct seg{
    	struct aa{
    		int lc,rc,mi,sum,tag;
    		void clear(){
    			lc = rc = mi = sum = tag = 0;
    		}
    	}node[N*2];
    	void pushup(int u){
    		aa x = node[node[u].lc],y = node[node[u].rc];
    		node[u].mi = min(x.mi,y.mi);
    		node[u].sum = (x.mi==node[u].mi?x.sum:0)+(y.mi==node[u].mi?y.sum:0);
    	}
    	int tot;
    	int newnode(){
    		int u = ++tot;
    		node[u].clear();
    		return u;
    	}
    	void build(int &u,int l,int r){
    		u = newnode();
    		node[u].sum = (r-l+1);
    		if(l==r){
    			return;
    		}
    		int mid = (l+r)/2;
    	
    		build(node[u].lc,l,mid);
    		build(node[u].rc,mid+1,r);
    	}
    	void lazy_tag(int u,int x){
    		node[u].mi+=x;
    		node[u].tag+=x;
    	}
    	void pushdown(int u){
    		if(!node[u].tag){
    			return;
    		}
    		lazy_tag(node[u].lc,node[u].tag);
    		lazy_tag(node[u].rc,node[u].tag);
    		node[u].tag = 0;
    	}
    	void upd(int u,int l,int r,int ll,int rr,int x){
    		if(l==ll&&r==rr){
    			lazy_tag(u,x);
    			return;
    		}
    		pushdown(u);
    		int mid = (l+r)/2;
    		if(rr<=mid){
    			upd(node[u].lc,l,mid,ll,rr,x);
    		}else if(ll>mid){
    			upd(node[u].rc,mid+1,r,ll,rr,x); 
    		}else{
    			upd(node[u].lc,l,mid,ll,mid,x);
    			upd(node[u].rc,mid+1,r,mid+1,rr,x);
    		}
    		pushup(u);
    	}
    	void ask(int u,int l,int r,int ll,int rr){
    		if(l==ll&&r==rr){
    			if(node[u].mi<minx){
    				minx = node[u].mi;
    				cnt = node[u].sum;
    			}else if(node[u].mi==minx){
    				cnt+=node[u].sum;
    			}
    			return;
    		}
    		pushdown(u);
    		int mid = (l+r)/2;
    		if(rr<=mid){
    			ask(node[u].lc,l,mid,ll,rr);
    		}else if(ll>mid){
    			ask(node[u].rc,mid+1,r,ll,rr);
    		}else{
    			ask(node[u].lc,l,mid,ll,mid);
    			ask(node[u].rc,mid+1,r,mid+1,rr);
    		}
    	}
    }T;
    int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],tt;
    bool cmp(int a,int b){
    	return dfn[a]<dfn[b];
    }
    void dfs1(int u,int f){
    	siz[u] = 1;
    	son[u] = 0;
    	for(auto x:in[u]){
    		if(x==f){
    			continue;
    		}
    		fa[x] = u;
    		dep[x] = dep[u]+1;
    		dfs1(x,u);
    		siz[u]+=siz[x];
    		if(siz[x]>siz[son[u]]){
    			son[u] = x;
    		}
    	}
    }
    void dfs2(int u,int t){
    	top[u] = t;
    	dfn[u] = ++tt;
    	if(!son[u]){
    		return;
    	}
    	dfs2(son[u],t);
    	for(auto x:in[u]){
    		if(x==fa[u]||x==son[u]){
    			continue;
    		}
    		dfs2(x,x);
    	}
    }
    int Lca(int u,int v){
    	while(top[u]!=top[v]){
    		if(dep[top[u]]<dep[top[v]]){
    			swap(u,v);
    		}
    		u = fa[top[u]];
    	}
    	if(dep[u]<dep[v]){
    		swap(u,v);
    	}
    	return v;
    }
    int query(){
    	minx = 1e9;
    	cnt = 0;
    	T.ask(rt,1,n,1,n);
    	if(minx==1){
    		return cnt;
    	}else{
    		return 0;
    	}
    }
    struct bb{
    	int l,r,x;
    };
    vector<bb>op[N];
    void add(int u,int l,int r,int x){
    	op[u].push_back((bb){l,r,x});
    	T.upd(rt,1,n,l,r,x);
    }
    int RT,res;
    int get(int u,int x){
    	while(top[u]!=top[x]){
    		if(fa[top[u]]==x){
    			return top[u];
    		}
    		u = fa[top[u]];
    	}
    	return son[x];
    }
    void dfs(int u){
    	if(u!=1){
    		add(u,1,n,1);
    		add(u,dfn[u],dfn[u]+siz[u]-1,-2);
    		for(auto x:ed[u]){
    			if(dfn[u]<=dfn[x]&&dfn[x]<=dfn[u]+siz[u]-1){
    				continue;
    			}
    			if(dfn[x]<=dfn[u]&&dfn[u]<=dfn[x]+siz[x]-1){
    				int v = get(u,x);
    				add(u,1,n,-1);
    				add(u,dfn[v],dfn[v]+siz[v]-1,1);
    			}else{
    				add(u,dfn[x],dfn[x]+siz[x]-1,-1);
    			}
    		}
    	}
    	res+=query();
    	int r = 0;
    	int sz = ed[u].size(); 
    	while(r<sz){
    		int y = ed[u][r];
    		if(dfn[y]<dfn[u]||dfn[y]>dfn[u]+siz[u]-1){
    			r++;
    		}else{
    			break;
    		}
    	}
    	for(auto x:in[u]){
    		if(x==fa[u]){
    			continue;
    		}
    		int R = r;
    		while(r<sz){
    			int y = ed[u][r];
    			if(dfn[x]<=dfn[y]&&dfn[y]<=dfn[x]+siz[x]-1){
    				T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,1);
    				r++;
    			}else{
    				break;
    			}
    		}
    		dfs(x);
    		for(int i=R;i<r;i++){
    			int y = ed[u][i];
    			T.upd(rt,1,n,dfn[y],dfn[y]+siz[y]-1,-1);
    		}
    	}
    	
    	for(auto x:op[u]){
    		T.upd(rt,1,n,x.l,x.r,-x.x);
    	}
    }
    int U[N],V[N];
    void init(){
    	n = read();
    	T.tot = 0;rt = 0;res = 0;tt = 0;
    	for(int i=1;i<=n;i++){
    		in[i].clear();ed[i].clear();op[i].clear();
    	}
    	for(int i=1;i<n;i++){
    		int u,v;
    		u = read();v = read();
    		U[i] = u;V[i] = v;
    	}
    	for(int i=1;i<n;i++){
    		int u,v;
    		u = read();v = read();
    		in[u].push_back(v);
    		in[v].push_back(u);
    	}
    	dfs1(1,1);
    	dfs2(1,1);
    	T.build(rt,1,n);
    	for(int i=1;i<=n;i++){
    		T.upd(rt,1,n,dfn[i],dfn[i]+siz[i]-1,1);
    	}
    	for(int i=1;i<n;i++){
    		int u,v;
    		u = U[i];v = V[i];
    		ed[u].push_back(v);
    		ed[v].push_back(u);
    		if(dep[u]>dep[v]){
    			swap(u,v);
    		}
    		int L = Lca(u,v);
    		if(u==L){
    			T.upd(rt,1,n,dfn[v],dfn[v]+siz[v]-1,-1);
    		}
    	}
    	for(int i=1;i<=n;i++){
    		sort(ed[i].begin(),ed[i].end(),cmp);
    		sort(in[i].begin(),in[i].end(),cmp);
    	}
    	RT = 1;
    	dfs(1);
    	cout<<(res-n)/2+n<<"\n";
    }
    signed main(){
    	int T = read();
    	while(T--){
    		init();
    	}
    	return 0;
    }
    
    • 1

    信息

    ID
    4899
    时间
    5000ms
    内存
    500MiB
    难度
    6
    标签
    递交数
    0
    已通过
    0
    上传者