1 条题解

  • 0
    @ 2025-8-24 23:16:45

    自动搬运

    查看原文

    来自洛谷,原作者为

    avatar diandian2020
    Q

    搬运于2025-08-24 23:16:45,当前版本为作者最后更新于2025-05-29 13:31:36,作者可能在搬运后再次修改,您可在原文处查看最新版

    自动搬运只会搬运当前题目点赞数最高的题解,您可前往洛谷题解查看更多

    以下是正文


    m62m\le 62 为字符集大小。

    如 B 站视频建出 DFA,注意手模样例可以发现本题没有“空”这个节点,sns_n 的后面应当紧接着 s1s_1

    fi,cf_{i,c} 表示已经匹配了 ii 个字符,当前在DFA的 cc 这个节点的答案,O((nm)3)\mathcal{O}((nm)^3)

    注意到当 i1i\ge 1 时,只有 c=sic=s_i 的位置有意义,O((n+m)3)\mathcal{O}((n+m)^3),具体地:


    $f_c=g_{c,s_1}f_1+\sum_{c'\neq s_1} g_{c,c'}f_{c'}+1(I)$

    对于 1i<n1\le i<n,$f_i=g_{s_i,s_{i+1}}f_{i+1}+\sum_{c\ne s_{i+1}\land kmp_{i,c}\neq 0}g_{s_i,c}f_{kmp_{i,c}}+\sum_{c\neq s_{i+1}\land kmp_{i,c}=0}g_{s_i,c}f_c+1(II)$

    其中 kmpi,ckmp_{i,c} 表示当前匹配了 ii 位,又匹配了一个字符 cc,当前最多匹配多少。可以 O(nm)\mathcal{O}(nm) 递推预处理。

    fn=0f_n=0


    以下简记记 f0,cf_{0,c}fcf_cfi,sif_{i,s_i}fif_i

    观察 II 类方程,发现 fcf_c 只可能依赖 fc,f1f_{c'},f_1 的值,这意味着如果我们知道所有 fcf_c,则我们可以直接推出 f1f_1 的值。

    观察 IIII 类方程,发现 fi(i1)f_i(i\ge 1) 只可能依赖 fi+1,f<i,fcf_{i+1},f_{<i},f_{c} 的值,这意味着如果我们知道所有 fc,fif_c,f_{\le i} 的值,则我们可以直接推出 fi+1f_{i+1} 的值。

    这意味着,对于每个字符集里的元素 cc,我们可以先用所有 fcf_{c'} 表示出 f1f_1,接着表示出 f2,f3,,fnf_2,f_3,\dots,f_n,而 fn=0f_n=0,这意味着每个 cc 都可以造出一个关于所有 fcf_{c'} 的方程,且这样的方程可以造出 mm 个,解方程部分复杂度降到 O(m3)\mathcal{O}(m^3)

    然而对一个 cc,造方程需要 O(nm2)\mathcal{O}(nm^2),也即我们造方程组需要 O(nm3)\mathcal{O}(nm^3),可以获得 98pts

    #include<cstdio>
    #include<string>
    #include<vector>
    #include<cassert>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #define fi first
    #define se second
    using namespace std;
    typedef long long LL;
    typedef pair<int,int> PII;
    const int N=3e5+9,P=998244353;
    int qmi(int a,int b){
    	int res=1;
    	while(b){
    		if(b&1) res=(LL)res*a%P;
    		a=(LL)a*a%P;
    		b>>=1;
    	}
    	return res;
    }
    string str;
    char s[N]; int n,m,id[128],g[128][128],deg[128],a[128][128];
    int kmp[N][128],nxt[N],tmp[N][128],f[N];
    void add(int a,int b){
    	g[a][b]++;
    	deg[a]++;
    }
    void gauss(){
    	for(int i=1;i<=m;i++){
    		int id=i;
    		for(int j=i;j<=m;j++) if(a[j][i]) id=j;
    		if(id^i) swap(a[id],a[i]);
    		if(!a[i][i]) assert(0);
    		int inv=qmi(a[i][i],P-2);
    		for(int j=i;j<=m+1;j++) a[i][j]=(LL)a[i][j]*inv%P;
    		for(int k=1;k<=m;k++) if(k!=i&&a[k][i])
    			for(int j=m+1;j>=i;j--) a[k][j]=(a[k][j]-(LL)a[k][i]*a[i][j]%P+P)%P;
    	}
    }
    int main(){
    	getline(cin,str);
    	for(int i=0,len=str.size();i<len;i++) s[++n]=str[i];
    	memset(id,-1,sizeof(id));
    	for(int i=1;i<=n;i++) if(!~id[s[i]]) id[s[i]]=m++; m--;
    	if(!m) return printf("%d\n",n),0;
    	for(int i=1;i<n;i++) add(id[s[i]],id[s[i+1]]); add(id[s[n]],id[s[1]]);
    	for(int i=0;i<=m;i++){
    		int inv=qmi(deg[i],P-2);
    		for(int j=0;j<=m;j++) g[i][j]=(LL)inv*g[i][j]%P;
    	}
    	for(int i=2,j=0;i<=n;i++){
    		while(j&&s[j+1]!=s[i]) j=nxt[j];
    		if(s[j+1]==s[i]) j++;
    		nxt[i]=j;
    	}
    	for(int i=0;i<n;i++) for(int c=0;c<=m;c++){
    		if(id[s[i+1]]==c) kmp[i][c]=i+1;
    		else kmp[i][c]=kmp[nxt[i]][c];
    //		printf("kmp %d %d %d\n",i,c,kmp[i][c]);
    	}
    	for(int i=1;i<=m;i++){
    		if(!g[i][0]){
    			for(int j=1;j<=m;j++) if(i^j) a[i][j]=(P-g[i][j])%P;
    			a[i][i]=(P-g[i][i]+1)%P;
    			a[i][m+1]=1;
    		}
    		else{
    			int inv=qmi(g[i][0],P-2);
    			for(int j=1;j<=m;j++) if(i^j) tmp[1][j]=(LL)(P-g[i][j])*inv%P;
    			tmp[1][i]=(LL)(P-g[i][i]+1)*inv%P;
    			tmp[1][m+1]=(LL)(P-1)*inv%P;
    //			for(int p=1;p<=m+1;p++) printf("%d ",tmp[1][p]); puts("");
    			for(int j=2;j<=n;j++){
    				int inv=qmi(g[id[s[j-1]]][id[s[j]]],P-2);
    				for(int c=1;c<=m+1;c++) tmp[j][c]=(LL)tmp[j-1][c]*inv%P;
    //				for(int p=1;p<=m+1;p++) printf("%d ",tmp[j][p]); puts("");
    				for(int c=0;c<=m;c++) if(c!=id[s[j]]){
    					int coef=(LL)(P-g[id[s[j-1]]][c])*inv%P;
    //					printf("kmp %d %d %d %d\n",j-1,c,kmp[j-1][c],coef);
    					if(!kmp[j-1][c]) tmp[j][c]=(tmp[j][c]+coef)%P;
    					else{
    						int k=kmp[j-1][c];
    						for(int c=1;c<=m+1;c++) tmp[j][c]=(tmp[j][c]+(LL)coef*tmp[k][c])%P;
    					}
    				}
    				tmp[j][m+1]=(tmp[j][m+1]+(LL)(P-1)*inv)%P;
    //				for(int p=1;p<=m+1;p++) printf("%d ",tmp[j][p]); puts("");
    			}
    			for(int j=1;j<=m;j++) a[i][j]=tmp[n][j];
    			a[i][m+1]=(P-tmp[n][m+1])%P;
    		}
    	}
    //	for(int i=1;i<=m;i++,puts("")) for(int j=1;j<=m+1;j++) printf("%d ",a[i][j]);
    	gauss();
    	int c=1;
    	while(!g[c][0]) c++;
    	int f1=(a[c][m+1]-1+P)%P;
    	for(int j=1;j<=m;j++) f1=(f1+(LL)(P-g[c][j])*a[j][m+1])%P;
    	f1=(LL)f1*qmi(g[c][0],P-2)%P;
    	printf("%d\n",(f1+1)%P);
        return 0;
    }
    

    进一步考察上述过程,我们发现我们是将每个 fif_i 用一个向量 tmpi,1,tmpi,2,,tmpi,m,tmpi,m+1tmp_{i,1},tmp_{i,2},\dots,tmp_{i,m},tmp_{i,m+1} 表示,表示 fi=c[1,m]tmpi,cfc+tmpi,m+1f_i=\sum_{c'\in[1,m]}tmp_{i,c'}f_{c'}+tmp_{i,m+1}。求 fif_i 的过程是 O(m)\mathcal{O}(m)fj(j<i)f_{j}(j<i) 的线性组合再改 O(m)O(m) 项,而无论是线性组合的系数,还是改的常数,都是与 cc 无关的。

    这意味着无论 cc 取何值,tmpi,ctmp_{i,c'} 必然可以表示为 k1,ctmp1,c+b1,ck_{1,c'}tmp_{1,c'}+b_{1,c'} 其中 k,bk,b 为常数组,与 cc 无关。

    所以我们其实只要 O(nm2)\mathcal{O}(nm^2) 递推一次求出 k,bk,b,对于每个 cc 就可以单次 O(m)\mathcal{O}(m) 的造出方程。

    总时间复杂度 O(nm2)\mathcal{O}(nm^2),空间复杂度 O(nm)\mathcal{O}(nm),可以通过:

    #include<cstdio>
    #include<string>
    #include<vector>
    #include<cassert>
    #include<cstring>
    #include<iostream>
    #include<algorithm>
    #define fi first
    #define se second
    using namespace std;
    typedef long long LL;
    typedef pair<int,int> PII;
    const int N=3e5+9,P=998244353;
    int qmi(int a,int b){
    	int res=1;
    	while(b){
    		if(b&1) res=(LL)res*a%P;
    		a=(LL)a*a%P;
    		b>>=1;
    	}
    	return res;
    }
    string str;
    char s[N]; int n,m,id[128],g[128][128],deg[128],a[128][128];
    int kmp[N][128],nxt[N],f[N];
    PII tmp[N][128];
    void add(int a,int b){
    	g[a][b]++;
    	deg[a]++;
    }
    void gauss(){
    	for(int i=1;i<=m;i++){
    		int id=i;
    		for(int j=i;j<=m;j++) if(a[j][i]) id=j;
    		if(id^i) swap(a[id],a[i]);
    		if(!a[i][i]) assert(0);
    		int inv=qmi(a[i][i],P-2);
    		for(int j=i;j<=m+1;j++) a[i][j]=(LL)a[i][j]*inv%P;
    		for(int k=1;k<=m;k++) if(k!=i&&a[k][i])
    			for(int j=m+1;j>=i;j--) a[k][j]=(a[k][j]-(LL)a[k][i]*a[i][j]%P+P)%P;
    	}
    }
    int main(){
    	getline(cin,str);
    	for(int i=0,len=str.size();i<len;i++) s[++n]=str[i];
    	memset(id,-1,sizeof(id));
    	for(int i=1;i<=n;i++) if(!~id[s[i]]) id[s[i]]=m++; m--;
    	if(!m) return printf("%d\n",n),0;
    	for(int i=1;i<n;i++) add(id[s[i]],id[s[i+1]]); add(id[s[n]],id[s[1]]);
    	for(int i=0;i<=m;i++){
    		int inv=qmi(deg[i],P-2);
    		for(int j=0;j<=m;j++) g[i][j]=(LL)inv*g[i][j]%P;
    	}
    	for(int i=2,j=0;i<=n;i++){
    		while(j&&s[j+1]!=s[i]) j=nxt[j];
    		if(s[j+1]==s[i]) j++;
    		nxt[i]=j;
    	}
    	for(int i=0;i<n;i++) for(int c=0;c<=m;c++){
    		if(id[s[i+1]]==c) kmp[i][c]=i+1;
    		else kmp[i][c]=kmp[nxt[i]][c];
    //		printf("kmp %d %d %d\n",i,c,kmp[i][c]);
    	}
    	for(int i=1;i<=m+1;i++) tmp[1][i]={1,0};
    	for(int j=2;j<=n;j++){
    		int inv=qmi(g[id[s[j-1]]][id[s[j]]],P-2);
    		for(int c=1;c<=m+1;c++) tmp[j][c].fi=(LL)tmp[j-1][c].fi*inv%P,tmp[j][c].se=(LL)tmp[j-1][c].se*inv%P;
    		for(int c=0;c<=m;c++) if(c!=id[s[j]]){
    			int coef=(LL)(P-g[id[s[j-1]]][c])*inv%P;
    			if(!kmp[j-1][c]) tmp[j][c].se=(tmp[j][c].se+coef)%P;
    			else{
    				int k=kmp[j-1][c];
    				for(int c=1;c<=m+1;c++) tmp[j][c].fi=(tmp[j][c].fi+(LL)tmp[k][c].fi*coef)%P,tmp[j][c].se=(tmp[j][c].se+(LL)tmp[k][c].se*coef)%P;
    			}
    		}
    		tmp[j][m+1].se=(tmp[j][m+1].se+(LL)(P-1)*inv)%P;
    	}
    	for(int i=1;i<=m;i++){
    		if(!g[i][0]){
    			for(int j=1;j<=m;j++) if(i^j) a[i][j]=(P-g[i][j])%P;
    			a[i][i]=(P-g[i][i]+1)%P;
    			a[i][m+1]=1;
    		}
    		else{
    			int inv=qmi(g[i][0],P-2);
    			for(int j=1;j<=m;j++) if(i^j) a[i][j]=((LL)(P-g[i][j])*inv%P*tmp[n][j].fi+tmp[n][j].se)%P;
    			a[i][i]=((LL)(P-g[i][i]+1)*inv%P*tmp[n][i].fi+tmp[n][i].se)%P;
    			a[i][m+1]=(P-((LL)(P-1)*inv%P*tmp[n][m+1].fi+tmp[n][m+1].se)%P)%P;
    		}
    	}
    //	for(int i=1;i<=m;i++,puts("")) for(int j=1;j<=m+1;j++) printf("%d ",a[i][j]);
    	gauss();
    	int c=1;
    	while(!g[c][0]) c++;
    	int f1=(a[c][m+1]-1+P)%P;
    	for(int j=1;j<=m;j++) f1=(f1+(LL)(P-g[c][j])*a[j][m+1])%P;
    	f1=(LL)f1*qmi(g[c][0],P-2)%P;
    	printf("%d\n",(f1+1)%P);
        return 0;
    }
    
    • 1

    信息

    ID
    12359
    时间
    1000ms
    内存
    512MiB
    难度
    6
    标签
    (无)
    递交数
    0
    已通过
    0
    上传者