1 条题解

  • 0
    @ 2025-8-24 22:38:08

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar EuphoricStar
    Remember.

    搬运于2025-08-24 22:38:08,当前版本为作者最后更新于2025-03-12 20:19:41,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    F(x,y)=D(x,y)+D(y,x)F(x, y) = D(x, y) + D(y, x),那么当 xxyy 祖先时,F(x,y)=szxszyF(x, y) = sz_x - sz_y;当 yyxx 祖先时,F(x,y)=szyszxF(x, y) = sz_y - sz_x;当 x,yx, y 不互为祖孙关系时,F(x,y)=szx+szyF(x, y) = sz_x + sz_y。图上 i,ji, j 的边权即为 F(xi,xj)+F(yi,yj)F(x_i, x_j) + F(y_i, y_j)

    完全图 MST 容易想到 Boruvka,问题转化为求一端为 ii 且另一端与 ii 不在同一个连通块的边权最小值。然后是 Boruvka 的经典套路,考虑直接求出一个信息:一端为 ii 的边权最小值和次小值,钦定这两条边的另一个端点不在同一个连通块。这个信息是可以合并的,所以可以当成没有不在同一个连通块的限制然后做。

    99 种情况讨论:

    • 没有任何限制,直接取 min\min
    • xjx_jxix_i 的祖先,yy 无限制:求出每个点到根的链 min\min 即可。
    • yjy_jyiy_i 的祖先,xx 无限制:同上。
    • xix_ixjx_j 的祖先,yy 无限制:求出每个点的子树 min\min 即可。
    • yiy_iyjy_j 的祖先,xx 无限制:同上。
    • xix_ixjx_j 的祖先,yiy_iyjy_j 的祖先:直接线段树合并,xx 的限制已经满足,用线段树满足 yy 的限制即可。可以线段树维护 yy 的 dfn 序。
    • xix_ixjx_j 的祖先,yjy_jyiy_i 的祖先:考虑所有 xj=ux_j = u,对 yjy_j 单点修改,然后再求所有 xi=ux_i = u 的答案,就是查询 yiy_i 子树的 min\minuu 出栈时把所有 xj=ux_j = u 更新的信息撤回。需要一个支持单点 checkmin、区间查询、撤销修改操作的线段树,可以写 zkw 线段树。
    • xjx_jxix_i 的祖先,yiy_iyjy_j 的祖先:同上。
    • xjx_jxix_i 的祖先,yjy_jyiy_i 的祖先:每 dfs 到一个点 uu,考虑所有 xj=ux_j = u,更新 yjy_j 子树内所有点的信息,然后再求所有 xi=ux_i = u 的答案,就是直接对 yiy_i 单点查询,uu 出栈时把所有 xj=ux_j = u 更新的信息撤回。需要一个支持区间 checkmin、单点查询、撤销修改操作的线段树。

    时间复杂度 O(nlogm+mlognlogm)O(n \log m + m \log n \log m)。实现时可以预处理出 dfs 序,就不用每次都 dfs 了。

    代码看起来很长,但是很多内容都是重复的。

    // Problem: P8336 [Ynoi2004] 2stmst
    // Contest: Luogu
    // URL: https://www.luogu.com.cn/problem/P8336
    // Memory Limit: 512 MB
    // Time Limit: 6000 ms
    // 
    // Powered by CP Editor (https://cpeditor.org)
    
    #include <bits/stdc++.h>
    #define pb emplace_back
    #define fst first
    #define scd second
    #define mkp make_pair
    #define mems(a, x) memset((a), (x), sizeof(a))
    
    using namespace std;
    typedef long long ll;
    typedef double db;
    typedef unsigned long long ull;
    typedef long double ldb;
    typedef pair<int, int> pii;
    
    namespace IO {
    	const int maxn = 1 << 20;
    	
        char ibuf[maxn], *iS, *iT, obuf[maxn], *oS = obuf;
    
    	inline char gc() {
    		return (iS == iT ? iT = (iS = ibuf) + fread(ibuf, 1, maxn, stdin), (iS == iT ? EOF : *iS++) : *iS++);
    	}
    
    	template<typename T = int>
    	inline T read() {
    		char c = gc();
    		T x = 0;
    		bool f = 0;
    		while (c < '0' || c > '9') {
    			f |= (c == '-');
    			c = gc();
    		}
    		while (c >= '0' && c <= '9') {
    			x = (x << 1) + (x << 3) + (c ^ 48);
    			c = gc();
    		}
    		return f ? ~(x - 1) : x;
    	}
    
    	inline void flush() {
    		fwrite(obuf, 1, oS - obuf, stdout);
    		oS = obuf;
    	}
    	
    	struct Flusher {
    		~Flusher() {
    			flush();
    		}
    	} AutoFlush;
    
    	inline void pc(char ch) {
    		if (oS == obuf + maxn) {
    			flush();
    		}
    		*oS++ = ch;
    	}
    
    	template<typename T>
    	inline void write(T x) {
    		static char stk[64], *tp = stk;
    		if (x < 0) {
    			x = ~(x - 1);
    			pc('-');
    		}
    		do {
    			*tp++ = x % 10;
    			x /= 10;
    		} while (x);
    		while (tp != stk) {
    			pc((*--tp) | 48);
    		}
    	}
    	
    	template<typename T>
    	inline void writesp(T x) {
    		write(x);
    		pc(' ');
    	}
    	
    	template<typename T>
    	inline void writeln(T x) {
    		write(x);
    		pc('\n');
    	}
    }
    
    using IO::read;
    using IO::write;
    using IO::pc;
    using IO::writesp;
    using IO::writeln;
    
    const int maxn = 1000100;
    const int inf = 0x3f3f3f3f;
    
    int n, m, fa[maxn], pa[maxn];
    
    struct que {
    	int x, y;
    } a[maxn];
    
    struct graph {
    	int hd[maxn], to[maxn], nxt[maxn], len;
    	
    	inline void add_edge(int u, int v) {
    		to[++len] = v;
    		nxt[len] = hd[u];
    		hd[u] = len;
    	}
    } G;
    
    int find(int x) {
    	return fa[x] == x ? x : fa[x] = find(fa[x]);
    }
    
    inline bool merge(int x, int y) {
    	x = find(x);
    	y = find(y);
    	if (x != y) {
    		fa[x] = y;
    		return 1;
    	} else {
    		return 0;
    	}
    }
    
    int st[maxn], ed[maxn], tim, rnk[maxn], sz[maxn];
    int tot, in[maxn], out[maxn], ord[maxn << 1];
    
    void dfs(int u) {
    	st[u] = ++tim;
    	in[u] = ++tot;
    	ord[tot] = u;
    	sz[u] = 1;
    	rnk[tim] = u;
    	for (int i = G.hd[u]; i; i = G.nxt[i]) {
    		int v = G.to[i];
    		dfs(v);
    		sz[u] += sz[v];
    	}
    	ed[u] = tim;
    	out[u] = ++tot;
    	ord[tot] = u;
    }
    
    struct node {
    	int x1, f1, x2, f2;
    	node(int a = 0, int b = 0, int c = 0, int d = 0) : x1(a), f1(b), x2(c), f2(d) {}
    } c[maxn];
    
    pii b[maxn];
    
    inline node operator + (node a, node b) {
    	if (a.x1 > b.x1) {
    		swap(a, b);
    	}
    	node res = a;
    	if (b.x1 < res.x2 && b.f1 != a.f1) {
    		res.x2 = b.x1;
    		res.f2 = b.f1;
    	} else if (b.x2 < res.x2 && b.f2 != a.f1) {
    		res.x2 = b.x2;
    		res.f2 = b.f2;
    	}
    	return res;
    }
    
    struct List {
    	int hd[maxn], to[maxn], nxt[maxn], len;
    	
    	inline void add(int x, int y) {
    		to[++len] = y;
    		nxt[len] = hd[x];
    		hd[x] = len;
    	}
    } L1, L2;
    
    int rt[maxn];
    
    struct SGT1 {
    	int nt, ls[maxn * 3], rs[maxn * 3];
    	node a[maxn * 3];
    	
    	inline void init() {
    		for (int i = 0; i <= nt; ++i) {
    			ls[i] = rs[i] = 0;
    			a[i] = node();
    		}
    		a[0] = node(inf, 0, inf, 0);
    		nt = 0;
    	}
    	
    	void update(int &rt, int l, int r, int x, const node &y) {
    		if (!rt) {
    			rt = ++nt;
    			a[rt] = node(inf, 0, inf, 0);
    		}
    		a[rt] = a[rt] + y;
    		if (l == r) {
    			return;
    		}
    		int mid = (l + r) >> 1;
    		(x <= mid) ? update(ls[rt], l, mid, x, y) : update(rs[rt], mid + 1, r, x, y);
    	}
    	
    	void query(int rt, int l, int r, int ql, int qr, node &res) {
    		if (!rt) {
    			return;
    		}
    		if (ql <= l && r <= qr) {
    			res = res + a[rt];
    			return;
    		}
    		int mid = (l + r) >> 1;
    		if (ql <= mid) {
    			query(ls[rt], l, mid, ql, qr, res);
    		}
    		if (qr > mid) {
    			query(rs[rt], mid + 1, r, ql, qr, res);
    		}
    	}
    	
    	int merge(int u, int v, int l, int r) {
    		if (!u || !v) {
    			return u | v;
    		}
    		if (l == r) {
    			a[u] = a[u] + a[v];
    			return u;
    		}
    		int mid = (l + r) >> 1;
    		ls[u] = merge(ls[u], ls[v], l, mid);
    		rs[u] = merge(rs[u], rs[v], mid + 1, r);
    		a[u] = a[ls[u]] + a[rs[u]];
    		return u;
    	}
    } T1;
    
    pair<node*, node> stk[maxn * 3];
    int top, tp[maxn];
    
    struct SGT2 {
    	node a[maxn * 3];
    	int N;
    	
    	inline void init() {
    		N = 1;
    		while (N < n + 2) {
    			N <<= 1;
    		}
    		for (int i = 1; i <= N + n; ++i) {
    			a[i] = node(inf, 0, inf, 0);
    		}
    	}
    	
    	inline void update(int x, node y) {
    		x += N;
    		while (x) {
    			stk[++top] = mkp(a + x, a[x]);
    			a[x] = a[x] + y;
    			x >>= 1;
    		}
    	}
    	
    	inline node query(int l, int r) {
    		node res(inf, 0, inf, 0);
    		for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1) {
    			if (!(l & 1)) {
    				res = res + a[l ^ 1];
    			}
    			if (r & 1) {
    				res = res + a[r ^ 1];
    			}
    		}
    		return res;
    	}
    } T2;
    
    struct SGT3 {
    	node a[maxn << 2];
    	
    	void build(int rt, int l, int r) {
    		a[rt] = node(inf, 0, inf, 0);
    		if (l == r) {
    			return;
    		}
    		int mid = (l + r) >> 1;
    		build(rt << 1, l, mid);
    		build(rt << 1 | 1, mid + 1, r);
    	}
    	
    	void update(int rt, int l, int r, int ql, int qr, const node &x) {
    		if (ql <= l && r <= qr) {
    			stk[++top] = mkp(a + rt, a[rt]);
    			a[rt] = a[rt] + x;
    			return;
    		}
    		int mid = (l + r) >> 1;
    		if (ql <= mid) {
    			update(rt << 1, l, mid, ql, qr, x);
    		}
    		if (qr > mid) {
    			update(rt << 1 | 1, mid + 1, r, ql, qr, x);
    		}
    	}
    	
    	void query(int rt, int l, int r, int x, node &res) {
    		res = res + a[rt];
    		if (l == r) {
    			return;
    		}
    		int mid = (l + r) >> 1;
    		(x <= mid) ? query(rt << 1, l, mid, x, res) : query(rt << 1 | 1, mid + 1, r, x, res);
    	}
    } T3;
    
    void solve() {
    	n = read();
    	m = read();
    	for (int i = 2; i <= n; ++i) {
    		pa[i] = read();
    		G.add_edge(pa[i], i);
    	}
    	for (int i = 1; i <= m; ++i) {
    		a[i].x = read();
    		a[i].y = read();
    		fa[i] = i;
    		L1.add(a[i].x, i);
    		L2.add(a[i].y, i);
    	}
    	dfs(1);
    	ll ans = 0;
    	while (1) {
    		bool fl = 1;
    		for (int i = 1; i <= m; ++i) {
    			fl &= (find(i) == find(1));
    			b[i] = mkp(inf, 0);
    		}
    		if (fl) {
    			break;
    		}
    		node p(inf, 0, inf, 0);
    		for (int i = 1; i <= m; ++i) {
    			p = p + node(sz[a[i].x] + sz[a[i].y], fa[i], inf, 0);
    		}
    		for (int i = 1; i <= m; ++i) {
    			if (p.f1 != fa[i]) {
    				b[fa[i]] = min(b[fa[i]], mkp(p.x1 + sz[a[i].x] + sz[a[i].y], p.f1));
    			} else {
    				b[fa[i]] = min(b[fa[i]], mkp(p.x2 + sz[a[i].x] + sz[a[i].y], p.f2));
    			}
    		}
    		for (int i = 1; i <= n; ++i) {
    			int u = rnk[i];
    			c[u] = node(inf, 0, inf, 0);
    			if (u > 1) {
    				c[u] = c[pa[u]];
    			}
    			for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    				int j = L1.to[_];
    				c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
    			}
    			for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    				int j = L1.to[_];
    				if (c[u].f1 != fa[j]) {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].y] - sz[u], c[u].f1));
    				} else {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].y] - sz[u], c[u].f2));
    				}
    			}
    		}
    		for (int i = 1; i <= n; ++i) {
    			int u = rnk[i];
    			c[u] = node(inf, 0, inf, 0);
    			if (u > 1) {
    				c[u] = c[pa[u]];
    			}
    			for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    				int j = L2.to[_];
    				c[u] = c[u] + node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0);
    			}
    			for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    				int j = L2.to[_];
    				if (c[u].f1 != fa[j]) {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] - sz[u], c[u].f1));
    				} else {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] - sz[u], c[u].f2));
    				}
    			}
    		}
    		for (int i = n; i; --i) {
    			int u = rnk[i];
    			c[u] = node(inf, 0, inf, 0);
    			for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    				int j = L1.to[_];
    				c[u] = c[u] + node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0);
    			}
    			for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
    				int v = G.to[_];
    				c[u] = c[u] + c[v];
    			}
    			for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    				int j = L1.to[_];
    				if (c[u].f1 != fa[j]) {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
    				} else {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
    				}
    			}
    		}
    		T1.init();
    		for (int i = n; i; --i) {
    			int u = rnk[i];
    			c[u] = node(inf, 0, inf, 0);
    			rt[u] = 0;
    			for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    				int j = L2.to[_];
    				c[u] = c[u] + node(sz[a[j].x] - sz[u], fa[j], inf, 0);
    				T1.update(rt[u], 1, n, st[a[j].x], node(-sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
    			}
    			for (int _ = G.hd[u]; _; _ = G.nxt[_]) {
    				int v = G.to[_];
    				c[u] = c[u] + c[v];
    				rt[u] = T1.merge(rt[u], rt[v], 1, n);
    			}
    			for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    				int j = L2.to[_];
    				if (c[u].f1 != fa[j]) {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x1 + sz[a[j].x] + sz[a[j].y], c[u].f1));
    				} else {
    					b[fa[j]] = min(b[fa[j]], mkp(c[u].x2 + sz[a[j].x] + sz[a[j].y], c[u].f2));
    				}
    				node res(inf, 0, inf, 0);
    				T1.query(rt[u], 1, n, st[a[j].x], ed[a[j].x], res);
    				if (res.f1 != fa[j]) {
    					b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] + sz[a[j].y], res.f1));
    				} else {
    					b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] + sz[a[j].y], res.f2));
    				}
    			}
    		}
    		T2.init();
    		top = 0;
    		for (int i = 1; i <= tot; ++i) {
    			int u = ord[i];
    			if (in[u] == i) {
    				tp[u] = top;
    				for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    					int j = L1.to[_];
    					T2.update(st[a[j].y], node(sz[a[j].x] - sz[a[j].y], fa[j], inf, 0));
    				}
    				for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    					int j = L1.to[_];
    					node res = T2.query(st[a[j].y], ed[a[j].y]);
    					if (res.f1 != fa[j]) {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] + sz[a[j].y], res.f1));
    					} else {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] + sz[a[j].y], res.f2));
    					}
    				}
    			} else {
    				while (top > tp[u]) {
    					*stk[top].fst = stk[top].scd;
    					--top;
    				}
    			}
    		}
    		T2.init();
    		top = 0;
    		for (int i = 1; i <= tot; ++i) {
    			int u = ord[i];
    			if (in[u] == i) {
    				tp[u] = top;
    				for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    					int j = L2.to[_];
    					T2.update(st[a[j].x], node(sz[a[j].y] - sz[a[j].x], fa[j], inf, 0));
    				}
    				for (int _ = L2.hd[u]; _; _ = L2.nxt[_]) {
    					int j = L2.to[_];
    					node res = T2.query(st[a[j].x], ed[a[j].x]);
    					if (res.f1 != fa[j]) {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x1 + sz[a[j].x] - sz[a[j].y], res.f1));
    					} else {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x2 + sz[a[j].x] - sz[a[j].y], res.f2));
    					}
    				}
    			} else {
    				while (top > tp[u]) {
    					*stk[top].fst = stk[top].scd;
    					--top;
    				}
    			}
    		}
    		T3.build(1, 1, n);
    		top = 0;
    		for (int i = 1; i <= tot; ++i) {
    			int u = ord[i];
    			if (in[u] == i) {
    				tp[u] = top;
    				for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    					int j = L1.to[_];
    					T3.update(1, 1, n, st[a[j].y], ed[a[j].y], node(sz[a[j].x] + sz[a[j].y], fa[j], inf, 0));
    				}
    				for (int _ = L1.hd[u]; _; _ = L1.nxt[_]) {
    					int j = L1.to[_];
    					node res(inf, 0, inf, 0);
    					T3.query(1, 1, n, st[a[j].y], res);
    					if (res.f1 != fa[j]) {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x1 - sz[a[j].x] - sz[a[j].y], res.f1));
    					} else {
    						b[fa[j]] = min(b[fa[j]], mkp(res.x2 - sz[a[j].x] - sz[a[j].y], res.f2));
    					}
    				}
    			} else {
    				while (top > tp[u]) {
    					*stk[top].fst = stk[top].scd;
    					--top;
    				}
    			}
    		}
    		for (int i = 1; i <= m; ++i) {
    			if (fa[i] == i && merge(i, b[i].scd)) {
    				ans += b[i].fst;
    			}
    		}
    	}
    	writeln(ans);
    }
    
    int main() {
    	int T = 1;
    	// scanf("%d", &T);
    	while (T--) {
    		solve();
    	}
    	return 0;
    }
    
    
    • 1

    信息

    ID
    7657
    时间
    6000ms
    内存
    512MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者