1 条题解

  • 0
    @ 2025-8-24 22:54:22

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar win114514
    过去的我正在消散,未来的我模糊不清,唯一能相信的只有现在的自己。

    搬运于2025-08-24 22:54:22,当前版本为作者最后更新于2024-02-15 15:54:18,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    非常好题目,使我代码长度起飞。

    思路

    发现 KK 只有三种取值。

    考虑分类讨论。

    k=1k=1

    容易发现只需要求一个端点是 11 的最长链。

    k=3k=3

    考虑这个时候我们将有一个遍历整个树的方案。

    考虑递归的处理整个问题。

    我们从该节点跳到它一个儿子的儿子。

    然后递归处理这个儿子的儿子。

    然后再跳到该节点的这个儿子的另一个儿子。

    然后递归处理。

    将所有儿子的儿子处理完以后,在跳回这个儿子。

    然后继续处理其他的儿子的儿子。

    这样就可以简单找到遍历整个树的方案。

    k=2k=2

    考虑 k=2k=2 怎么做。

    我们可以使用树形 dp。

    设:

    fx,0f_{x,0} 为从 xx 出发往下走,对终止节点无要求的最大贡献。

    fx,1f_{x,1} 为从 xx 出发往下走,对终止节点要求为 xx 的某个儿子或 xx 的最大贡献。

    fx,2f_{x,2} 为从 xx 的某个儿子出发往下走,对终止节点无要求的最大贡献。

    fx,3f_{x,3} 为从 xx 的某个儿子出发往下走,对终止节点要求为 xx 的最大贡献。

    考虑转移式。

    1. fx,0=ax+fy,2f_{x,0}=a_x+f_{y,2}

    表示先走到 xx,在直接从 yy 往下走。

    1. $$f_{x,0}=a_x+f_{y1,3}+\sum_{y\not=y1,y2} a_y+f_{y2,0} $$

    表示先走到 xx,在再把 y1y1 走一圈后回到 y1y1,然后走它的兄弟,最后在某个兄弟往下走。

    1. fx,1=ax+fy1,3+yy1ayf_{x,1}=a_x+f_{y1,3}+\sum_{y\not=y1}a_y

    表示先走到 xx,在再把 y1y1 走一圈后回到 y1y1,然后走它的兄弟。

    1. fx,2=fx,0f_{x,2}=f_{x,0}

    和情况一类似。

    1. $$f_{x,2}=\sum_{y\not=y1,y2} a_y+f_{y1,1}+a_x+f_{y2,2} $$

    表示先走到 xx 的一些儿子,然后走到 y1y1 这个儿子转一圈,然后回到 xx,然后从 y2y2 往下走。

    1. $$f_{x,2}=\sum_{y\not=y1,y2,y3} a_y+f_{y1,1}+a_x+f_{y2,3}+f_{y3,0} $$

    表示先走到 xx 的一些儿子,然后走到 y1y1 这个儿子转一圈,然后回到 xx,然后从 y2y2 往下走一圈,然后从 y3y3 往下走。

    1. fx,3=yy1ay+fy1,1+axf_{x,3}=\sum_{y\not=y1} a_y+f_{y1,1}+a_x

    表示先走到 xx 的一些儿子,然后走到 y1y1 这个儿子转一圈,然后回到 xx

    注意很重要的一点,在记录方案时,这些顺序时不能随便颠倒的,否则容易方案不合法。

    容易发现以上所有 dp 式都可以线性解决。

    时间复杂度:O(n)O(n)

    Code

    #include <bits/stdc++.h>
    using namespace std;
    
    #define x first
    #define y second
    #define int long long
    #define mp(x, y) make_pair(x, y)
    #define eb(...) emplace_back(__VA_ARGS__)
    #define fro(i, x, y) for(int i = (x); i <= (y); i++)
    #define pre(i, x, y) for(int i = (x); i >= (y); i--)
    inline void JYFILE19();
    
    typedef int64_t i64;
    typedef pair<int, int> PII;
    
    bool ST;
    const int N = 2e5 + 10;
    const int mod = 998244353;
    
    int n, m, a[N], dp[N], fa[N], pre[N];
    vector<int> to[N];
    
    namespace subtask1 {
    	inline void dfs(int now, int fa) {
    		dp[now] = a[now];
    		for(auto i : to[now]) {
    			if(i == fa) continue;
    			dfs(i, now);
    			if(dp[i] > dp[pre[now]])
    				pre[now] = i;
    		}
    		dp[now] += dp[pre[now]];
    	}
    	inline void Solve() {
    		dfs(1, 0);
    		vector<int> ans;
    		int now = 1;
    		while(now) ans.eb(now), now = pre[now];
    		cout << dp[1] << "\n";
    		cout << ans.size() << "\n";
    		for(auto i : ans) cout << i << " ";
    		cout << "\n";
    	}
    }
    namespace subtask2 {
    struct Node {
    	int x, op;
    	inline Node(int xx, int opx) {
    		x = xx, op = opx;
    	}
    };
    struct node {
    	int num, id;
    	inline bool operator<(const node &tmp) const {
    		return num < tmp.num;
    	}
    } pr[N], sf[N];
    int tp, stk[N], f[N][4], dp[N][2][2][2][4];
    vector<Node> g[N][4];
    inline void dfs(int now, int fa) {
    	int sum = 0;
    	vector<int> son;
    	for(auto i : to[now])
    		if(i != fa) son.eb(i);
    	for(auto i : son)
    		dfs(i, now), sum += a[i];
    	tp = 0;
    	for(auto i : son) stk[++tp] = i;
    	if(tp == 0) {
    		fro(i, 0, 3) {
    			f[now][i] = a[now];
    			g[now][i].eb(now, 4);
    		}
    		return;
    	}
    	{
    		int idl = 0;
    		for(auto i : son)
    			if(f[i][2] > f[idl][2])
    				idl = i;
    		pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
    		fro(i, 1, tp) {
    			pr[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
    			sf[i] = {f[stk[i]][3] - a[stk[i]], stk[i]};
    		}
    		fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
    		pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
    		int id = 0;
    		auto get = [&](int x) {
    			if(x == 0) return 0ll;
    			return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][0] - a[stk[x]];
    		};
    		fro(i, 1, tp) if(get(id) <= get(i)) id = i;
    		f[now][0] = a[now] + get(id) + sum;
    		if(f[now][0] < a[now] + f[idl][2]) {
    			f[now][0] = a[now] + f[idl][2];
    			g[now][0].eb(now, 4);
    			g[now][0].eb(idl, 2);
    		}
    		else {
    			int id1 = max(pr[id - 1], sf[id + 1]).id;
    			int id2 = stk[id];
    			g[now][0].eb(now, 4);
    			if(id1) g[now][0].eb(id1, 3);
    			for(auto i : son)
    				if(i != id1 && i != id2)
    					g[now][0].eb(i, 4);
    			if(id2) g[now][0].eb(id2, 0);
    		}
    	}
    	{
    		int id = 0;
    		for(auto i : son)
    			if(f[i][3] - a[i] > f[id][3] - a[id])
    				id = i;
    		f[now][1] = a[now] + f[id][3] - a[id] + sum;
    		g[now][1].eb(now, 4);
    		if(id) g[now][1].eb(id, 3);
    		for(auto i : son) if(i != id)
    			g[now][1].eb(i, 4);
    	}
    	{
    		int num1 = f[now][0];
    		pr[0] = sf[0] = pr[tp + 1] = sf[tp + 1] = {0, 0};
    		fro(i, 1, tp) {
    			pr[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
    			sf[i] = {f[stk[i]][1] - a[stk[i]], stk[i]};
    		}
    		fro(i, 1, tp) pr[i] = max(pr[i], pr[i - 1]);
    		pre(i, tp, 1) sf[i] = max(sf[i], sf[i + 1]);
    		int id = 0;
    		auto get = [&](int x) {
    			if(x == 0) return 0ll;
    			return max(pr[x - 1], sf[x + 1]).num + f[stk[x]][2] - a[stk[x]];
    		};
    		fro(i, 1, tp) if(get(id) <= get(i)) id = i;
    		int num2 = a[now] + get(id) + sum;
    		fro(i, 0, tp) {
    			fro(op1, 0, 1) {
    				fro(op2, 0, 1) {
    					fro(op3, 0, 1) {
    						dp[i][op1][op2][op3][0] = -1e18;
    						dp[i][op1][op2][op3][1] = 0;
    						dp[i][op1][op2][op3][2] = 0;
    						dp[i][op1][op2][op3][3] = 0;
    					}
    				}
    			}
    		}
    		dp[0][0][0][0][0] = 0;
    		fro(i, 1, tp) {
    			fro(op1, 0, 1) { fro(op2, 0, 1) { fro(op3, 0, 1) {
    				fro(k, 0, 3) dp[i][op1][op2][op3][k] = dp[i - 1][op1][op2][op3][k];
    			}}}
    			fro(op1, 0, 1) {
    				fro(op2, 0, 1) {
    					fro(op3, 0, 1) {
    						int num = dp[i - 1][op1][op2][op3][0];
    						int A = dp[i - 1][op1][op2][op3][1];
    						int B = dp[i - 1][op1][op2][op3][2];
    						int C = dp[i - 1][op1][op2][op3][3];
    						if(op1 == 0) {
    							if(dp[i][1][op2][op3][0] < num - a[stk[i]] + f[stk[i]][0]) {
    								dp[i][1][op2][op3][0] = num - a[stk[i]] + f[stk[i]][0];
    								dp[i][1][op2][op3][1] = stk[i];
    								dp[i][1][op2][op3][2] = B;
    								dp[i][1][op2][op3][3] = C;
    							}
    						}
    						if(op2 == 0) {
    							if(dp[i][op1][1][op3][0] < num - a[stk[i]] + f[stk[i]][1]) {
    								dp[i][op1][1][op3][0] = num - a[stk[i]] + f[stk[i]][1];
    								dp[i][op1][1][op3][1] = A;
    								dp[i][op1][1][op3][2] = stk[i];
    								dp[i][op1][1][op3][3] = C;
    							}
    						}
    						if(op3 == 0) {
    							if(dp[i][op1][op2][1][0] < num - a[stk[i]] + f[stk[i]][3]) {
    								dp[i][op1][op2][1][0] = num - a[stk[i]] + f[stk[i]][3];
    								dp[i][op1][op2][1][1] = A;
    								dp[i][op1][op2][1][2] = B;
    								dp[i][op1][op2][1][3] = stk[i];
    							}
    						}
    					}
    				}
    			}
    		}
    		int num3 = 0, f1 = 0, f2 = 0, f3 = 0;
    		fro(op1, 0, 1) {
    			fro(op2, 0, 1) {
    				fro(op3, 0, 1) {
    					if(num3 < dp[tp][op1][op2][op3][0]) {
    						num3 = dp[tp][op1][op2][op3][0];
    						f1 = op1, f2 = op2, f3 = op3;
    					}
    				}
    			}
    		}
    		num3 += sum + a[now];
    		f[now][2] = max({num1, num2, num3});
    		if(num1 >= num2 && num1 >= num3) {
    			g[now][2] = g[now][0];
    		}
    		else if(num2 >= num1 && num2 >= num3) {
    			int id1 = max(pr[id - 1], sf[id + 1]).id;
    			int id2 = stk[id];
    			for(auto i : son)
    				if(i != id1 && i != id2)
    					g[now][2].eb(i, 4);
    			if(id1) g[now][2].eb(id1, 1);
    			g[now][2].eb(now, 4);
    			if(id2) g[now][2].eb(id2, 2);
    		}
    		else {
    			int id1 = dp[tp][f1][f2][f3][1];
    			int id2 = dp[tp][f1][f2][f3][2];
    			int id3 = dp[tp][f1][f2][f3][3];
    			for(auto i : son)
    				if(i != id1 && i != id2 && i != id3)
    					g[now][2].eb(i, 4);
    			if(id2) g[now][2].eb(id2, 1);
    			g[now][2].eb(now, 4);
    			if(id3) g[now][2].eb(id3, 3);
    			if(id1) g[now][2].eb(id1, 0);
    		}
    	}
    	{
    		int id = 0;
    		for(auto i : son)
    			if(f[i][1] - a[i] > f[id][1] - a[id])
    				id = i;
    		f[now][3] = a[now] + f[id][1] - a[id] + sum;
    		for(auto i : son) if(i != id)
    			g[now][3].eb(i, 4);
    		if(id) g[now][3].eb(id, 1);
    		g[now][3].eb(now, 4);
    	}
    }
    vector<int> res;
    inline void print(int x, int op) {
    	for(auto i : g[x][op]) {
    		if(i.op == 4) res.eb(i.x);
    		else print(i.x, i.op);
    	}
    }
    inline void Solve() {
    	dfs(1, 0);
    	int num = max({f[1][0], f[1][1]});
    	fro(i, 0, 1) if(num == f[1][i]) { print(1, i); break; }
    	cout << num <<"\n";
    	cout << res.size() << "\n";
    	for(auto i : res) cout << i << " ";
    	cout << "\n";
    }
    }
    namespace subtask3 {
    	vector<int> ans;
    	inline void dfs(int now) {
    		for(auto i : to[now]) {
    			if(i == fa[now]) continue;
    			fa[i] = now, dfs(i);
    		}
    	}
    	inline void calc(int now) {
    		ans.eb(now);
    		for(auto i : to[now]) {
    			if(i == fa[now]) continue;
    			for(auto j : to[i]) {
    				if(j == fa[i]) continue;
    				calc(j);
    			}
    			ans.eb(i);
    		}
    	}
    	inline void Solve() {
    		dfs(1), calc(1);
    		int num = 0;
    		fro(i, 1, n) num += a[i];
    		cout << num << "\n";
    		cout << ans.size() << "\n";
    		for(auto i : ans) cout << i << " ";
    		cout << "\n";
    	}
    }
    
    signed main() {
    	JYFILE19();
    	cin >> n >> m;
    	fro(i, 1, n - 1) {
    		int x, y;
    		cin >> x >> y;
    		to[x].eb(y);
    		to[y].eb(x);
    	}
    	fro(i, 1, n) cin >> a[i];
    	if(m == 1) subtask1::Solve();
    	if(m == 2) subtask2::Solve();
    	if(m == 3) subtask3::Solve();
    	return 0;
    }
    
    bool ED;
    inline void JYFILE19() {
    	// freopen("", "r", stdin);
    	// freopen("", "w", stdout);
    	ios::sync_with_stdio(0), cin.tie(0);
    	double MIB = fabs((&ED-&ST)/1048576.), LIM = 1024;
    	cerr << "MEMORY: " << MIB << endl, assert(MIB<=LIM);
    }
    
    • 1

    信息

    ID
    9730
    时间
    2000ms
    内存
    1024MiB
    难度
    7
    标签
    递交数
    0
    已通过
    0
    上传者