1 条题解

  • 0
    @ 2025-8-24 22:08:19

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar 绝顶我为峰
    蓝的盆。

    搬运于2025-08-24 22:08:19,当前版本为作者最后更新于2022-03-03 15:34:53,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    注意到 min{a,a1}1000\min\{a,a^{-1}\}\leq1000,必有蹊跷。

    考虑将连续的一段不受取模影响的 f(i)f(i) 合并,这样每段的长度都应该是 O(na)O(\frac na),那么这样的段有 O(a)O(a) 个。

    这就可以做 a1000a\leq 1000 的部分了,直接算出每一段的首项和长度,然后枚举一对连续段计算逆序对。

    a11000a^{-1}\leq 1000 的怎么办?考虑 f(i)f(i) 的反函数,也就是 i=a1f(i)a1(b+1)i=a^{-1}f(i)-a^{-1}(b+1),这个东西也是只有 O(a)O(a) 个连续段,就也能做了。

    但是这里有个问题,这真的是个函数吗?换句话说, iif(i)f(i) 之间是否有一一对应关系,否则反函数将不存在,这个算法也就不可行了。

    幸运的是在 mm 是质数是这的确是个函数,证明就考虑不妨设 n=mn=m,然后这时 i=1,2,,ni=1,2,\cdots,n 是模 mm 意义下的一个完系,那么 ai=a,2a,,nai=a,2a,\cdots,n 也是一个完系,加上常量相当于整体旋转,所以 f(i)f(i) 是模 mm 意义下的一个完系,即模 mm 意义下互不相同。因此反函数存在。

    然后考虑 n<mn<m 相当于在互不相同的 f(i)f(i) 中取一个子集出来,显然互不相同。

    然后的工作比较简单,和上面一样的方式计算逆序对即可。值得注意的是并非所有位置的值都是合法的,需要把 i=0i=0i>ni>n 的位置去掉。

    求逆序对需要比较多的分类讨论,不是很好写。

    #include<iostream>
    #include<cstdio>
    using namespace std;
    #define int long long
    int n,mod,a,b,ans;
    pair<int,int> v[10001];
    inline int read()
    {
        int x=0;
        char c=getchar();
        while(c<'0'||c>'9')
            c=getchar();
        while(c>='0'&&c<='9')
        {
            x=(x<<1)+(x<<3)+(c^48);
            c=getchar();
        }
        return x;
    }
    inline int pw(int a,int b)
    {
        int res=1;
        while(b)
        {
            if(b&1)
                res=res*a%mod;
            b>>=1;
            a=a*a%mod;
        }
        return res;
    }
    signed main()
    {
        n=read(),mod=read(),a=read()%mod,b=read()%mod;
        if(!a)
        {
            puts("0");
            return 0;
        }
        if(a<=1000)
        {
            for(int i=1,j=1;j<=n;++i)
            {
                v[i]={(a*j%mod+b)%mod+1,(mod-(a*j%mod+b)%mod-1)/a+1};
                if(j+v[i].second>n)
                {
                    v[i].second=n-j+1;
                    j=n+1;
                }
                else
                    j+=v[i].second;
                for(int k=1;k<i;++k)
                    if(v[k].first<v[i].first)
                    {
                        int cnt=v[k].second-((v[i].first-v[k].first)/a+1);
                        ans+=(cnt+cnt-v[i].second+1)*v[i].second/2;
                    }
                    else if(v[k].first>v[i].first)
                    {
                        int cnt=(v[k].first-v[i].first+a-1)/a;
                        if(cnt>=v[i].second)
                        {
                            ans+=v[i].second*v[k].second;
                            continue;
                        }
                        ans+=cnt*v[k].second+(v[k].second-1+v[k].second-v[i].second+cnt)*(v[i].second-cnt)/2;
                    }
                    else
                        ans+=(v[k].second-min(v[k].second,v[i].second)+v[k].second-1)*min(v[i].second,v[k].second)/2;
            }
            cout<<ans<<'\n';
            return 0;
        }
        a=pw(a,mod-2);
        b=(mod-(b+1)%mod*a%mod)%mod;
        for(int i=1,j=1;j<=mod;++i)
        {
            v[i]={(a*j%mod+b-1)%mod+1,(mod-(a*j%mod+b-1)%mod-1)/a+1};
            if(j+v[i].second>mod)
            {
                v[i].second=mod-j+1;
                j=mod+1;
            }
            else
                j+=v[i].second;
            if(v[i].first>n||v[i].second<=0)
            {
                --i;
                continue;
            }
            if(v[i].first+(v[i].second-1)*a>n)
                v[i].second=(n-v[i].first)/a+1;
            if(!v[i].first)
            {
                --v[i].second;
                v[i].first+=a;
            }
            for(int k=1;k<i;++k)
                if(v[k].first<v[i].first)
                {
                    int cnt=v[k].second-((v[i].first-v[k].first)/a+1);
                    ans+=(cnt+cnt-v[i].second+1)*v[i].second/2;
                }
                else if(v[k].first>v[i].first)
                {
                    int cnt=(v[k].first-v[i].first+a-1)/a;
                    if(cnt>=v[i].second)
                    {
                        ans+=v[i].second*v[k].second;
                        continue;
                    }
                    ans+=cnt*v[k].second+(v[k].second-1+v[k].second-v[i].second+cnt)*(v[i].second-cnt)/2;
                }
                else
                    ans+=(v[k].second-min(v[k].second,v[i].second)+v[k].second-1)*min(v[i].second,v[k].second)/2;
        }
        cout<<ans<<'\n';
        return 0;
    }
    
    • 1

    信息

    ID
    4187
    时间
    1000ms
    内存
    128MiB
    难度
    6
    标签
    递交数
    0
    已通过
    0
    上传者