1 条题解

  • 0
    @ 2025-8-24 22:10:46

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar Elegia
    irony

    搬运于2025-08-24 22:10:46,当前版本为作者最后更新于2019-06-29 21:51:44,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    题目大意

    给两颗树 T1,T2T_1, T_2 以及映射 f:T1T2f: T_1 \rightarrow T_2,其中 T2T_2 有根,记 [u,v][u,v]T1T_1 上两点 u,vu, v 的简单路径上点集,求

    $$\sum_{x \in T_1} \sum_{y \in T_1} \sum_{u\in T_1} \sum_{v\in T_1} [x<y][u<v][[u,v] \subseteq [x, y]] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) $$

    其中求 LCA 和深度均为在 T2T_2 上。

    前置知识

    • 点分治
    • 虚树

    题解

    改变求和顺序,考虑将与值主要相关的 u,vu, v 放到前面:

    $$\sum_{u\in T_1} \sum_{v\in T_1} [u<v] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1)\sum_{x \in T_1} \sum_{y \in T_1} [x<y][[u,v] \subseteq [x, y]] $$

    注意 [u<v]+[v<u]=[uv]2[u<v]+[v<u] = \frac{[u\neq v]}2,不妨将问题转化为求不同点之间关系:

    $$\frac14 \sum_{u\in T_1} \sum_{v\in T_1} [u\neq v] (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1)\sum_{x \in T_1} \sum_{y \in T_1} [x\neq y][[u,v] \subseteq [x, y]] $$

    考虑进行点分治,对穿过重心 cc 的所有 [u,v][u,v] 对答案的贡献进行统计。记当前分治时 cc 为根,sub(u)\operatorname{sub}(u)uu 节点在整个 T1T_1 的子树大小。首先考虑 u,vu, v 均不为 cc 的情况,若它们在不同子树内,则贡献为

    $$(\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v) $$

    考虑按照此定义将整个分治的连通块的贡献进行计算,然后容斥掉每个子树内的贡献。即我们要解决对于 T1T_1 上一个点集 SS,求出

    $$\left(\sum_{u\in S} \sum_{v\in S} (\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v) \right) - \sum_{u\in S} (\operatorname{dep}(f(u)) - 1)\operatorname{sub}(u)^2 $$

    考虑在 T2T_2 上 $(\operatorname{dep}(\operatorname{lca}(f(u),f(v))) - 1) \operatorname{sub} (u) \operatorname{sub} (v)$ 的意义,这等价于将 f(u)f(u)T2T_2 的根上的每一个节点加上 sub(u)\operatorname{sub}(u) 的权值,询问 f(v)f(v) 的父亲到根节点的路径上权值和与 sub(v)\operatorname{sub}(v) 的乘积。如果通过数据结构进行计算会多出一个 log\log 的代价,考虑该问题是离线的,通过虚树优化复杂度。

    将 $\{f(u) | u\in S\} \cup \{ \operatorname{root}(T_2) \}$ 的虚树建出,考虑先将每个点映射到的点加上对应权值 w(f(u))w(f(u))+sub(u)w(f(u)) \leftarrow w(f(u)) + \operatorname{sub}(u),这维护了实际修改对应的差分,第二步将虚树从下至上进行累加,这时每个点与原本的点权概念一致,第三步从上至下累加,记 p(x)p(x) 为在虚树上的父节点,$s(f(u)) = s(p(f(u))) + w(p(f(u)))[\operatorname{dep} (f(u)) - \operatorname{dep}(p(f(u)))]$。

    这时 vSs(f(v))sub(v)\sum_{v\in S} s(f(v)) \operatorname{sub}(v) 即为所求。虚树计算的时间复杂度为 Θ(S)\Theta(|S|)。但注意建立虚树时通常有一个对 DFS 序排序的问题,考虑到本题所求全部为离线问题,所以离线进行多数组的基数排序即可做到严格的 $\Theta\left(m + \sum |S|\right) = \Theta(m + n\log n)$。

    对于 uu 为重心的情况,可以枚举 vv 进行计算。

    综上所述,点分治内部所消耗的时间复杂度为 Θ(S)\Theta(|S|),故点分治的总体复杂度为 Θ(nlogn)\Theta(n\log n)。为了辅助求出虚树,可能需要 Θ(1)\Theta(1) LCA 计算,预处理的理论复杂度可以做到 Θ(m)\Theta(m),使用朴素的 ST 表可以做到 Θ(mlogm)\Theta(m\log m)

    因此,本题的理论时间复杂度为 Θ(m+nlogn)\Theta(m + n\log n),需要 Θ(m+nlogn)\Theta(m + n\log n) 的空间。本题 std(by PinkRabbit) 的实现时间复杂度为 Θ(mlogm+nlogn)\Theta(m\log m + n\log n),空间复杂度为 Θ(mlogm+nlogn)\Theta(m\log m + n\log n)

    取决于求权值这一部分的不同实现,可能有例如以下一些复杂度的算法可能卡过时限:

    树链剖分:Θ(m+nlognlog2m)\Theta(m + n\log n\log^2 m)

    LCT / 全局平衡二叉树优化:Θ(mlogm+nlognlogm)\Theta(m\log m + n\log n \log m)

    使用 std::sort() 进行 DFS 序的排序:Θ(mlogm+nlog2n)\Theta(m\log m + n\log^2 n),空间 Θ(mlogm+n)\Theta(m\log m + n)

    代码

    #include <cstdio>
    #include <algorithm>
    // #include <ctime>
    #define il inline
    typedef long long LL;
    const int MN = 300005;
    const int MV = 12000005; // max virtual tree nodes < 2 N log N
    const int Mod = 998244353, Inv2 = (Mod + 1) / 2;
    il void ad(int &x, int y) { x -= (x += y) >= Mod ? Mod : 0; }
    
    int N, M, Ans;
    int id[MN], h[MN], g[MN], nxt[MN * 3], to[MN * 3], tot;
    il void ins(int *a, int x, int y) { nxt[++tot] = a[x], to[tot] = y, a[x] = tot; }
    il void init() {
    	int x;
    	scanf("%d%d", &N, &M);
    	for (int i = 1; i <= N; ++i) {
    		scanf("%d", &x);
    		if (x) ins(g, x, i), ins(g, i, x);
    	}
    	scanf("%*d");
    	for (int i = 2; i <= M; ++i) scanf("%d", &x), ins(h, x, i);
    	scanf("%*s");
    	for (int i = 1; i <= N; ++i) scanf("%d", &id[i]);
    } // read & edge-linking part
    
    int idf[MN], dfn[MN], dfc;
    int dep[MN], lg[MN * 2], st[MN * 2][20], ldf[MN], rdf[MN], ecn;
    void tdfs(int u) {
    	idf[++dfc] = u, dfn[u] = dfc;
    	st[++ecn][0] = u, ldf[u] = ecn;
    	for (int i = h[u]; i; i = nxt[i])
    		dep[to[i]] = dep[u] + 1, tdfs(to[i]), st[++ecn][0] = u;
    	rdf[u] = ecn;
    }
    il int chkdep(int i, int j) { return dep[i] < dep[j] ? i : j; }
    il void _st() {
    	lg[0] = -1;
    	for (int i = 1; i <= ecn; ++i) lg[i] = lg[i >> 1] + 1;
    	for (int j = 0; j < lg[ecn]; ++j)
    		for (int i = 2 << j; i <= ecn; ++i)
    			st[i][j + 1] = chkdep(st[i - (1 << j)][j], st[i][j]);
    }
    il int qurdep(int l, int r) { int b = lg[r - l + 1]; return chkdep(st[l + (1 << b) - 1][b], st[r][b]); }
    il int lca(int x, int y) {
    	if (rdf[x] < ldf[y]) return qurdep(rdf[x], ldf[y]);
    	else if (rdf[y] < ldf[x]) return qurdep(rdf[y], ldf[x]);
    	else return dep[x] < dep[y] ? x : y;
    } // dfn & O(M log M) - O(1) RMQ-LCA part
    
    int Q, ty[MN * 2], qh[MN * 2], buk[MN], nx[MV * 2], b[MV * 2], w[MV * 2], vcn;
    il void add(int *h, int x, int y, int z) { nx[++vcn] = h[x], b[vcn] = y, w[vcn] = z, h[x] = vcn; }
    // queries part
    
    int faz[MN], siz[MN], _siz[MN];
    void _dfs(int u, int fz) {
    	faz[u] = fz, siz[u] = 1;
    	for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz)
    		_dfs(to[i], u), siz[u] += siz[to[i]];
    	_siz[u] = siz[u];
    }
    int vis[MN], sz[MN], ts, rsz, rt;
    void getrt(int u, int fz) {
    	int mxs = 0; sz[u] = 1;
    	for (int i = g[u]; i; i = nxt[i]) {
    		if (to[i] == fz || vis[to[i]]) continue;
    		getrt(to[i], u), sz[u] += sz[to[i]];
    		mxs = std::max(mxs, sz[to[i]]);
    	}
    	mxs = std::max(mxs, ts - sz[u]);
    	if (mxs < rsz) rt = u, rsz = mxs;
    }
    void fors(int u, int fz) {
    	add(buk, dfn[id[u]], Q, siz[u]);
    	for (int i = g[u]; i; i = nxt[i]) if (to[i] != fz && !vis[to[i]]) fors(to[i], u);
    }
    void dfs(int u) {
    	if (ts == 1) return ; // no-op
    	siz[u] = N;
    	for (int x = u; !vis[faz[x]]; x = faz[x]) siz[faz[x]] = N - _siz[x];
    	ty[++Q] = 0, fors(u, 0);
    	for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]])
    		ty[++Q] = 1, add(buk, dfn[id[u]], Q, siz[to[i]]), fors(to[i], u);
    	for (int x = u; !vis[x]; x = faz[x]) siz[x] = _siz[x];
    	vis[u] = 1;
    	int nsz = ts;
    	for (int i = g[u]; i; i = nxt[i]) if (!vis[to[i]]) {
    		rsz = ts = sz[to[i]] < sz[u] ? sz[to[i]] : nsz - sz[u];
    		getrt(to[i], 0), dfs(rt);
    	}
    } // O(N log N) tree decomposition part
    
    int Sum, sum[MN], rch[MN], tp;
    il void C(int x, int y) { if (x) ad(Sum, (LL)y * y % Mod * (dep[x] - dep[rch[tp]]) % Mod); }
    
    int main() {
    //	freopen("eternal-8-3.in", "r", stdin);
    //	int tim1, tim2, tim3, tim4, tim5, tim6, tim7;
    //	tim1 = clock();
    	init();
    //	tim2 = clock();
    	dep[1] = 0, tdfs(1);
    //	tim3 = clock();
    	_st();
    //	tim4 = clock();
    	_dfs(1, 0), vis[0] = 1, rsz = ts = N, getrt(1, 0), dfs(rt);
    //	tim5 = clock();
    	for (int i = M; i >= 1; --i)
    		for (int j = buk[i]; j; j = nx[j])
    			add(qh, b[j], idf[i], w[j]);
    	// O(N log N) radix sort part
    //	tim6 = clock();
    	for (int q = 1; q <= Q; ++q) {
    		int s = 0, x, y; Sum = 0;
    		rch[tp = 1] = 1, sum[1] = 0;
    		for (int i = qh[q]; i; i = nx[i]) {
    			int u = b[i];
    			ad(Sum, Mod - (LL)w[i] * w[i] * dep[u] % Mod);
    			ad(s, w[i]);
    			if (nx[i] && b[nx[i]] == u) continue;
    			int lc = lca(rch[tp], u);
    			for (x = y = 0; dep[rch[tp]] > dep[lc]; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
    			if (dep[rch[tp]] < dep[lc]) rch[++tp] = lc, sum[tp] = 0;
    			ad(sum[tp], y);
    			C(x, y);
    			rch[++tp] = u, sum[tp] = s;
    			s = 0;
    		}
    		for (x = y = 0; tp; ad(y, sum[tp]), x = rch[tp--]) C(x, y);
    		ad(Ans, ty[q] ? Mod - Sum : Sum);
    	} // virtual tree part
    //	tim7 = clock();
    	Ans = (LL)Ans * Inv2 % Mod, printf("%d\n", Ans);
    /*	printf(" Init : %dms\n", tim2 - tim1);
    	printf(" Trie : %dms\n", tim3 - tim2);
    	printf(" ST   : %dms\n", tim4 - tim3);
    	printf(" DFZ  : %dms\n", tim5 - tim4);
    	printf(" sort : %dms\n", tim6 - tim5);
    	printf(" vt   : %dms\n", tim7 - tim6);*/
    	return 0;
    }
    
    • 1

    信息

    ID
    4420
    时间
    1000~10000ms
    内存
    500MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者