1 条题解

  • 0
    @ 2025-8-24 23:07:08

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar Alphas
    **

    搬运于2025-08-24 23:07:08,当前版本为作者最后更新于2024-12-18 23:30:10,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    如何达成以最优解倒数第一通过此题?官解竟然认为该做法只能获得 48 分?

    考虑维护一个 2×22\times 2 的矩阵:

    [xy00]\begin{bmatrix} x & y\\ 0 & 0 \end{bmatrix}

    其中 xix_i 表示当前最后一个的类型是 ii 的期望收益,yiy_i 表示当前最后一个的类型是 ii 的期望分数。那么我们上一个类型是 ii,转移到下一个视频类型为 jj 的矩阵的贡献是:

    $$\begin{bmatrix} (x_i+b_j\cdot d_{i,j}\cdot y_i)\frac 1n & y_i\cdot d_{i,j}\cdot \frac 1n \\ 0 & 0 \end{bmatrix}$$

    那么转移矩阵是:

    $$\begin{bmatrix} \frac 1n & 0 \\ \frac{b_{i,j}\cdot d_j}{n} & \frac{d_{i,j}}{n} \end{bmatrix}$$

    于是我们很自然的将每个 iji\rightarrow j2×22\times 2 的小转移矩阵放到新加入一个位置时的 t×tt\times t 的大转移矩阵的 iijj 列的位置上,转移直接做矩阵乘法就行,需要手写 2×22\times 2 矩阵的加、乘。注意特判 n=1n=1

    时间复杂度 Θ(t3logn)\Theta(t^3\log n)2×22\times 2 矩阵自带 8 倍常数。

    于是官解认为这就过不了了,实测也只有 48 分。记录

    但是真的是这样的吗?我们认为计算过程的极大开销在于取模,考虑能否减少一些这样的操作,并发现可以轻松做到这一点,相信不少习惯好的人本来就是这么写的。

    我们首先将 2×22\times 2 矩阵的乘法中的三层 for 删了,因为只有 4 个位置,所以直接把式子手写上去就行,这样,我们不但省去了 i++ 这样自增操作的运算次数,还将取模次数减半了,因为原来为了不溢出我们是 res.d[i][k] = (res.d[i][k] + d[i][j] * rhs.d[j][k]) % mod 对于每个 (i,k)(i, k) 执行了两遍,现在 res.d[i][k] = (d[i][0] * rhs.d[0][k] + d[i][1] * rhs.d[1][k]) % mod 对于每个 (i,k)(i, k) 执行了一遍。

    乘法操作常数减半,成功获得 65 分。记录

    然后就简单优化一下加法就行了,我们把 res.d[i][j] = (d[i][j] + rhs.d[i][j]) % mod 变成直接相加后 if (res.d[i][j] >= mod) res.d[i][j] -= mod,于是又扔掉了很多的取模操作。

    于是获得了 100 分。记得加 inline,虽然我也不知道为什么加了以后快了 0.2s。

    #include <bits/stdc++.h>
    #define ll long long
    using namespace std;
    
    const int MAXN = 201;
    const int mod = 998244353;
    
    ll ans, invn, inv100, b[MAXN];
    int m, n, K, id, sid;
    
    struct matrix2 {
        ll d00, d01, d10, d11;
        matrix2() {
            d00 = d01 = d10 = d11 = 0;
        }
        matrix2 operator * (const matrix2 &rhs) const {
            matrix2 res;
            res.d00 = (d00 * rhs.d00 + d01 * rhs.d10) % mod;
            res.d01 = (d00 * rhs.d01 + d01 * rhs.d11) % mod;
            res.d10 = (d10 * rhs.d00 + d11 * rhs.d10) % mod;
            res.d11 = (d10 * rhs.d01 + d11 * rhs.d11) % mod;
            return res;
        }
        matrix2 operator + (const matrix2 &rhs) const {
            matrix2 res;
            res.d00 = d00 + rhs.d00;
            if (res.d00 >= mod) 
                res.d00 -= mod;
            res.d01 = d01 + rhs.d01;
            if (res.d01 >= mod) 
                res.d01 -= mod;
            res.d10 = d10 + rhs.d10;
            if (res.d10 >= mod) 
                res.d10 -= mod;
            res.d11 = d11 + rhs.d11;
            if (res.d11 >= mod) 
                res.d11 -= mod;
            return res;
        }
    } one;
    
    struct matrixn {
        matrix2 d[MAXN][MAXN];
        matrixn operator * (const matrixn &rhs) const {
            matrixn res;
            for (int i = 0; i < n; i++)
                for (int j = 0; j < n; j++)
                    for (int k = 0; k < n; k++) 
                        res.d[i][k] = res.d[i][k] + d[i][j] * rhs.d[j][k];
            return res;
        }
    } One, zy, st;
    
    inline ll fpow(ll x, ll y) {
        int t = 1;
        while (y) {
            if (y & 1) 
                t = t * x % mod;
            x = x * x % mod;
            y /= 2;
        }
        return t;
    }
    
    inline matrixn fpowm(ll y) {
        matrixn t = One;
        while (y) {
            if (y & 1) 
                t = t * zy;
            zy = zy * zy;
            y /= 2;
        }
        return t;
    }
    
    int main() {
        ios::sync_with_stdio(0), cin.tie(0);
        cin >> m >> n >> K >> id >> sid, id --;
        one.d00 = one.d11 = 1;
        for (int i = 0; i < n; i++) 
            One.d[i][i] = one;
        invn = fpow(n, mod - 2), inv100 = fpow(100, mod - 2);
        for (int i = 0; i < n; i++) 
            cin >> b[i];
        for (int i = 0; i < n; i++)
            for (int j = 0; j < n; j++) {
                ll x;
                cin >> x, x = x * inv100 % mod;
                matrix2 tmp;
                tmp.d00 = invn;
                tmp.d10 = x * b[j] % mod * invn % mod;
                tmp.d11 = x * invn % mod;
                zy.d[i][j] = tmp;
            }
        matrix2 tmp;
        tmp.d00 = K * b[id] % mod, tmp.d01 = K;
        st.d[0][id] = tmp;
        if (m == 1) {
            cout << "inverse " << tmp.d00 << '\n';
            return 0;
        }
        st = st * fpowm(m - 1);
        for (int i = 0; i < n; i++) 
            ans = (ans + st.d[0][i].d00) % mod;
        cout << "inverse " << ans << '\n';
        return 0;
    }
    
    • 1

    信息

    ID
    10816
    时间
    2000ms
    内存
    512MiB
    难度
    5
    标签
    递交数
    0
    已通过
    0
    上传者