1 条题解

  • 0
    @ 2025-8-24 22:55:56

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar 0x3F
    Wir müssen wissen, wir werden wissen.

    搬运于2025-08-24 22:55:56,当前版本为作者最后更新于2024-04-27 13:59:57,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    神秘题。

    首先如果把餐厅记作 +1+1,将甜品店记作 1-1,那么如果一条边的某一侧的数字和为 x(x>0)x(x>0),那么这条边至少需要经过 2x2x 次。

    x=0x=0,如果这条边的两侧都有关键点(11 号点或者餐厅或者甜品店),那么这条边仍然至少需要经过 22 次。

    否则,这条边可以不被经过。

    但是怎么证明这样做是对的呢?我们考虑一个神秘的构造方法:如果一个点的子树里 ++- 多,那么可以把该子树内的 ++- 排列成若干个形如 ++++-+\cdots+ 的段(称为 ++ 段),每段后留一个空位。

    类似的,如果一个点的子树里 -++ 多,那么可以把该子树内的 ++- 排列成若干个形如 +-+-\cdots- 的段(称为 - 段),每段前留一个空位。

    如果 ++- 一样多,那么可以排成单独的一个形如 ++++-+-\cdots+- 的段(称为 00 段)。

    对于子树信息的合并,我们只需要将 ++ 段和 - 段交错合并即可,而 00 段可以和任意一个 +,,0+,-,0 段合并,直至合并成若干 ++ 段或若干 - 段或一个 00 段。

    容易证明,使用这种方法可以使得每条边被经过的次数都取得最小值。

    实现的时候,可以用链表套链表,外层链表维护一个子树包含哪些段,内层链表维护段内的 ++ 点和 - 点的编号。

    时间复杂度为 O(n)\mathcal{O}(n)

    代码如下:

    #include <bits/stdc++.h>
    using namespace std;
    const int _ = 3e5 + 10;
    int n, m, arr[_], brr[_], pos[_], neg[_], cntpos[_], cntneg[_], pnex[_], nnex[_], e, hd[_], nx[600010], to[600010], lef[_], rig[_], ans[_], bns[_];
    int pcnt, ptop, pbin[_], pl[_], pr[_], pt[_];
    int ncnt, ntop, nbin[_], nl[_], nr[_], nt[_];
    long long len;
    inline int pget(void) {
        if (ptop) {
            return pbin[ptop--];
        } else {
            return ++pcnt;
        }
    }
    inline void pdel(int x) {
        pbin[++ptop] = x;
    }
    inline int nget(void) {
        if (ntop) {
            return nbin[ntop--];
        } else {
            return ++ncnt;
        }
    }
    inline void ndel(int x) {
        nbin[++ntop] = x;
    }
    inline void pmerge(int& L, int& R, int l, int r) {
        if (l == 0 && r == 0) {
        } else if (L == 0 && R == 0) {
            L = l;
            R = r;
        } else {
            pt[R] = l;
            R = r;
        }
    }
    inline void nmerge(int& L, int& R, int l, int r) {
        if (l == 0 && r == 0) {
        } else if (L == 0 && R == 0) {
            L = l;
            R = r;
        } else {
            nt[R] = l;
            R = r;
        }
    }
    inline void mmerge(int& L, int& R, int l, int r) {
        if (l == 0 && r == 0) {
        } else if (L == 0 && R == 0) {
            L = l;
            R = r;
        } else {
            nnex[R] = l;
            R = r;
        }
    }
    inline void three_way_merge(int& L, int& R, int lp, int rp, int ln, int rn, int lm, int rm) {
        if (lp == 0 && rp == 0 && ln == 0 && rn == 0) {
            L = lm;
            R = rm;
        } else if (lp == 0 && rp == 0) {
            if (!(lm == 0 && rm == 0)) {
                nnex[nr[rn]] = lm;
                nr[rn] = rm;
            }
            L = ln;
            R = rn;
        } else if (ln == 0 && rn == 0) {
            if (!(lm == 0 && rm == 0)) {
                nnex[rm] = pl[rp];
                pl[rp] = lm;
            }
            L = lp;
            R = rp;
        } else {
            while (lp != -1 && ln != -1) {
                pnex[pr[lp]] = nl[ln];
                if (lm == 0 && rm == 0) {
                    lm = pl[lp];
                    rm = nr[ln];
                } else {
                    nnex[rm] = pl[lp];
                    rm = nr[ln];
                }
                pdel(lp);
                ndel(ln);
                lp = pt[lp];
                ln = nt[ln];
            }
            if (lp == -1 && ln == -1) {
                L = lm;
                R = rm;
            } else if (lp == -1) {
                nnex[nr[rn]] = lm;
                nr[rn] = rm;
                L = ln;
                R = rn;
            } else {
                nnex[rm] = pl[rp];
                pl[rp] = lm;
                L = lp;
                R = rp;
            }
        }
    }
    inline void add(int u, int v) {
        e++;
        nx[e] = hd[u];
        to[e] = v;
        hd[u] = e;
    }
    void dfs(int x, int f) {
        if (pos[x]) cntpos[x]++;
        if (neg[x]) cntneg[x]++;
        for (int i = hd[x]; i; i = nx[i]) {
            int y = to[i];
            if (y == f) continue;
            dfs(y, x);
            cntpos[x] += cntpos[y];
            cntneg[x] += cntneg[y];
            if (cntpos[y] > cntneg[y]) {
                len += cntpos[y] - cntneg[y];
            } else if (cntneg[y] > cntpos[y]) {
                len += cntneg[y] - cntpos[y];
            } else if (cntpos[y]) {
                len++;
            }
        }
    }
    void solve(int x, int f) {
        int lp = 0, rp = 0;
        int ln = 0, rn = 0;
        int lm = 0, rm = 0;
        if (pos[x]) {
            int a = pget();
            lp = rp = a;
            pl[a] = pos[x];
            pr[a] = pos[x];
            pt[a] = -1;
        }
        if (neg[x]) {
            int a = nget();
            ln = rn = a;
            nl[a] = neg[x];
            nr[a] = neg[x];
            nt[a] = -1;
        }
        for (int i = hd[x]; i; i = nx[i]) {
            int y = to[i];
            if (y == f) continue;
            solve(y, x);
            if (cntpos[y] > cntneg[y]) {
                pmerge(lp, rp, lef[y], rig[y]);
            } else if (cntpos[y] < cntneg[y]) {
                nmerge(ln, rn, lef[y], rig[y]);
            } else {
                mmerge(lm, rm, lef[y], rig[y]);
            }
        }
        three_way_merge(lef[x], rig[x], lp, rp, ln, rn, lm, rm);
    }
    int main() {
        ios::sync_with_stdio(0);
        cin.tie(0);
        cin >> n >> m;
        for (int i = 1; i <= m; i++) {
            cin >> arr[i];
            pos[arr[i]] = i;
        }
        for (int i = 1; i <= m; i++) {
            cin >> brr[i];
            neg[brr[i]] = i;
        }
        for (int i = 1; i < n; i++) {
            int u, v;
            cin >> u >> v;
            add(u, v);
            add(v, u);
        }
        dfs(1, 0);
        (len <<= 1LL);
        solve(1, 0);
        ans[1] = lef[1];
        bns[1] = pnex[ans[1]];
        for (int i = 2; i <= m; i++) {
            ans[i] = nnex[bns[i-1]];
            bns[i] = pnex[ans[i]];
        }
        cout << len << endl;
        for (int i = 1; i <= m; i++) {
            cout << ans[i] << ' ' << bns[i];
            if (i != m) cout << ' ';
        }
        cout << endl;
        return 0;
    }
    
    • 1

    信息

    ID
    9873
    时间
    2000ms
    内存
    512MiB
    难度
    6
    标签
    递交数
    0
    已通过
    0
    上传者