后缀数组 SA 学习笔记

· · 算法·理论

我是一个先学 SAM,再学 SA 的神秘的人。

我比较菜,本文只介绍倍增法,将来会了 SA-IS 再更。

SA 核心是两个数组,\text{rk}\text{sa},首先有一个字符串,叫 S,我们记 S_{[i,n]}\text{Suf}(i),那么就有 n 个字符串。

其中:

怎么求呢?我们考虑找出一个方法给所有的后缀排序。

比较好的方法是倍增加拼接。

怎么个倍增法?我们考虑排序过程中,一开始只看第一个字符,第二次看前两个字符,第三次看前四个字符......

那假设这一次看前 2l 个字符,则目前的排名是以前 l 个字符排名的,那么原来的排名看做是第一关键字,新加的 l 个字符看做是第二关键字,当然比较实际用的是已经求出的排名。

于是你拿长度翻倍之前的两项拼出了翻倍之后的一项。 我们来看看,模拟这个求解的过程如下: ![](https://oi-wiki.org/string/images/sa2.png) 以下是一份 $\mathcal O(n\log^2 n)$ 的实现,排序使用的是 STL Sort,可以通过模板题 [P3809 【模板】后缀排序](https://www.luogu.com.cn/problem/P3809)。 ```cpp #include<bits/stdc++.h> #define LL long long const LL N=2e6+5; const LL M=75; using namespace std; LL n,rk[N],sa[N],b[M+5],tmp[N]; char s[N]; // sa rk-> pos // rk pos-> rk LL l; int main() { scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n;i++)sa[i]=i,rk[i]=s[i]-'0'+1;//初始化排名用自己的值,用于比较 for(l=1;l<n;l*=2) { sort(sa+1,sa+n+1,[](LL x,LL y){ if(rk[x]==rk[y])return rk[x+l]<rk[y+l]; return rk[x]<rk[y]; }); for(int i=1;i<=n;i++)tmp[i]=rk[i];//考虑rk在使用同时被更新,所以用个临时数组。 LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t;//相同的两项排名相同。 else rk[sa[i]]=++t; } } for(int i=1;i<=n;i++) { printf("%lld ",sa[i]); } } ``` 但是这份代码太慢了。 排序这里 $\mathcal O(n\log n)$ 多少带点难受,因为排名的值域显然会很小,由于有两个关键字,可以考虑写一发计数排序。 学习这个算法可以看看 OI Wiki 或者知乎上 Pecco 大佬的文章,还是很简单的。 ```cpp #include<bits/stdc++.h> #define LL long long const LL N=2e6+5; const LL M=1e6; using namespace std; LL n,rk[N],sa[N],b[M+5],tmp[N]; char s[N]; // sa rk-> pos // rk pos-> rk LL l; int main() { scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n;i++)++b[rk[i]=s[i]]; for(int i=1;i<=M;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } for(l=1;l<n;l*=2) { memset(b,0,sizeof(b)); for(int i=1;i<=n;i++)tmp[i]=sa[i]; for(int i=1;i<=n;i++)b[rk[tmp[i]+l]]++; for(int i=1;i<=M;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--) { sa[b[rk[tmp[i]+l]]--]=tmp[i]; } memset(b,0,sizeof(b)); for(int i=1;i<=n;i++)tmp[i]=sa[i]; for(int i=1;i<=n;i++)b[rk[tmp[i]]]++; for(int i=1;i<=M;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--) { sa[b[rk[tmp[i]]]--]=tmp[i]; } for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } } for(int i=1;i<=n;i++) { printf("%lld ",sa[i]); } } ``` 这个代码的理论时间复杂度是 $\mathcal O(n\log n)$,但是实际提交的时候,这份代码时超了,虽然有一发过了,但是也跑了整整五秒,很极端。 所以考虑加些优化。 首先值域是不断变化的,所以考虑实时更新值域,减少枚举量。 然后我们发现没必要以第二关键字排序,第二关键字本质上是一段单调不降的数列加上一些 $0$,所以把 $0$ 提到前面就行了。 ```cpp #include<bits/stdc++.h> #define LL long long const LL N=2e6+5; const LL M=1e6+5; using namespace std; LL n,rk[N],sa[N],b[M],tmp[N],m=128; char s[N]; // sa rk-> pos // rk pos-> rk LL l; int main() { scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n;i++)++b[rk[i]=s[i]]; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(l=1;l<n;l*=2) { LL t=0; for(int i=n-l+1;i<=n;i++)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)b[rk[tmp[i]]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--) { sa[b[rk[tmp[i]]]--]=tmp[i]; } for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; } for(int i=1;i<=n;i++) { printf("%lld ",sa[i]); } } ``` 于是代码只需要三秒,似乎有犇犇写到了一秒以内。 逆天,我把 `long long` 换成 `int` 快成 $1.8s$ 了。 其实还有一些优化,可以更快!更快! 比如,如果排名已经互不相同,那就可以不用排序了。 ```cpp #include<bits/stdc++.h> #define LL int const LL N=2e6+5; const LL M=1e6+5; using namespace std; LL n,rk[N],sa[N],b[M],tmp[N],m=128; char s[N]; // sa rk-> pos // rk pos-> rk LL l; int main() { scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n;i++)++b[rk[i]=s[i]]; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(l=1;l<n;l*=2) { LL t=0; for(int i=n;i>=n-l+1;i--)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)b[rk[tmp[i]]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--) { sa[b[rk[tmp[i]]]--]=tmp[i]; } for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; if(n==m)break; } for(int i=1;i<=n;i++) { printf("%d ",sa[i]); } } ``` 飞快,$621ms$。 ## 例题 P4051 [JSOI2007] 字符加密 把字符串复制一份,每一种循环移位就是一个长度为 $n$ 的滑动窗口,那么就是对于 $n$ 个滑动窗口排一个序。 我们发现问题可以直接转换成后缀排序,然后直接摁造就没了。 ```cpp #include<bits/stdc++.h> #define LL long long const LL N=2e6+5; const LL M=1e6+5; using namespace std; LL n,rk[N],sa[N],b[M],tmp[N],m=M-5; char s[N]; // sa rk-> pos // rk pos-> rk LL l; int main() { scanf("%s",s+1); n=strlen(s+1); for(int i=1;i<=n-1;i++)s[i+n]=s[i]; n=2*n-1; for(int i=1;i<=n;i++)++b[rk[i]=s[i]]; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(l=1;l<n;l*=2) { LL t=0; for(int i=n-l+1;i<=n;i++)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)b[rk[tmp[i]]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--) { sa[b[rk[tmp[i]]]--]=tmp[i]; } for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; } for(int i=1;i<=n;i++) { if(sa[i]>(n+1)/2)continue; putchar(s[sa[i]+(n+1)/2-1]); } } ``` 类似的有 P1368 【模板】最小表示法。 本质上真不是和上面那题一样吗 要我说,感觉不如 SAM。 # 引入神秘的 $height$ 数组,获得相对成功的人生 让我们来引进 $height$ 数组,这个数组是唯一真神,就这么说吧,$\text{sa}$ 和 $\text{rk}$ 就是 $\text{Mg}$ 和 $\text{O}_2$,两者燃烧形成 $\text{MgO}$ 发出耀眼白光!天啊!那是神圣的! 这个数组用处多多,我们先来认识认识,定义是: $$ height(i)=\begin{cases}\text{LCP}(\text{Suf}(\text{sa}_i),\text{Suf}(\text{sa}_{i-1})) &i\not=1\\ 0&i=1 \end{cases} $$ 在这当中$\text{LCP}$ 是最长公共前缀的意思 咋求呢?我们发现一个重要性质: $$ height(\text{rk}_i)\geq height(\text{rk}_{i-1})-1 $$ 推导如下: $height(rk_{i-1})=\text{LCP}(\text{Suf}(i-1),prev(\text{Suf}(i-1)))$,其中 $prev$ 表示排序后的上一项的后缀。 对于 $height(\text{rk}_i)$ 而言,就是 $\text{LCP}({\text{Suf}(i),prev(i)})$。 考虑 $\text{Suf}(i)$ 实质上就是去掉 $\text{Suf}(i-1)$ 的第一位,那么显然 $prev(\text{Suf}(i-1))$ 也能匹配上一段,长度减小了$1$。 但这不一定是最终匹配方案,我们只能确定: $$ height(\text{rk}_i)\geq height(\text{rk}_{i-1})-1 $$ 至此,结论得证。 这是一个经典的套路,注意到 $height$ 是一个类递增的函数,可以记录之前的答案,减 $1$ 得到 $height(\text{rk}_i)$ 的最小值,然后尝试扩大即可,时间复杂度显然是 $\mathcal O(n)$ 的。 ## 子串最长公共前缀 再来看一个性质: $$ \text{LCP}(\text{Suf}(\text{sa}_i),\text{Suf}(\text{sa}_j))=\min_{k=i+1}^{j}height(k) $$ 理解应该比较简单,证明比较抽象,我就不证明了。 ## 不同子串数量 正难则反,统计有多少子串是相同的。 子串,本质就是一个后缀的前缀,也就是我们统计后缀之间的相同前缀的数量。 那么显然,统计和自己排名相邻的一项可以做到不重不漏。 所以得答案为: $$ \frac 1 2n(n+1)-\sum_{i=1}^nheight(i) $$ ```cpp #include<bits/stdc++.h> #define LL long long #define LF long double #define pLL pair<LL,LL> #define pb push_back #define fir first #define sec second using namespace std; //const LL inf; const LL N=1e5+5; //const LL M; //const LL mod; //const LF eps; //const LL P; LL T,n,rk[N],b[N],sa[N],tmp[N],m,h[N],tmp2[N]; char c[N]; int main() { scanf("%lld",&T); while(T--) { scanf("%s",c+1); n=strlen(c+1); m=200; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)b[rk[i]=c[i]]++,h[i]=0; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; for(int i=1;i<=n;i++)tmp[i]=rk[i]; LL t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(LL l=1;l<n;l<<=1) { LL t=0; for(int i=n-l+1;i<=n;i++)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)++b[tmp2[i]=rk[tmp[i]]]; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[tmp2[i]]--]=tmp[i]; for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; if(m==n)break; } t=0; LL ans=n*(n+1)/2; for(int i=1;i<=n;i++) { if(rk[i]==1)continue; if(t)t--; while(c[i+t]==c[sa[rk[i]-1]+t])t++; h[rk[i]]=t; } for(int i=1;i<=n;i++) { ans-=h[rk[i]]; // cout<<rk[i]<<' '<<h[rk[i]]<<' '<<rk[i]-1<<' '<<sa[rk[i]-1]<<endl; } printf("%lld\n",ans); } return 0; } //RP++ ``` # 最长公共子串 考虑把两个字符串放在一起处理出后缀数组,中间放个特殊符号分割。 然后你对于所有的 $height$,如果对应的两个后缀分别在两个字符串,那么就可以更新答案。 容易证明这样是不重不漏的。 需要注意,加了特殊符号之后值域不是 $[\text{a},\text{z}]$,所以值域设大一点。 ```cpp #include <bits/stdc++.h> #define LL long long using namespace std; const LL N = 1e6 + 5; LL n, n2, sa[N], rk[N], h[N], b[N], tmp[N], m = 200; char a[N], s[N]; int main() { scanf("%s%s", a + 1, s + 1); n = strlen(a + 1), n2 = strlen(s + 1); a[n + 1] = '#'; for (int i = 1; i <= n2; i++) { a[n + i + 1] = s[i]; } n = n + n2 + 1; for (int i = 1; i <= n; i++) b[rk[i] = a[i]]++; for (int i = 1; i <= m; i++) b[i] += b[i - 1]; for (int i = n; i >= 1; i--) sa[b[rk[i]]--] = i; LL t = 0; for (int i = 1; i <= n; i++) tmp[i] = rk[i]; for (int i = 1; i <= n; i++) { if (tmp[sa[i]] == tmp[sa[i - 1]]) rk[sa[i]] = t; else rk[sa[i]] = ++t; } m = t; for (LL l = 1; l < n; l *= 2) { LL t = 0; for (int i = n - l + 1; i <= n; i++) tmp[++t] = i; for (int i = 1; i <= n; i++) if (sa[i] > l) tmp[++t] = sa[i] - l; for (int i = 1; i <= m; i++) b[i] = 0; for (int i = 1; i <= n; i++) b[rk[tmp[i]]]++; for (int i = 1; i <= m; i++) b[i] += b[i - 1]; for (int i = n; i >= 1; i--) sa[b[rk[tmp[i]]]--] = tmp[i]; for (int i = 1; i <= n; i++) tmp[i] = rk[i]; t = 0; for (int i = 1; i <= n; i++) { if (tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + l] == tmp[sa[i - 1] + l]) rk[sa[i]] = t; else rk[sa[i]] = ++t; } m = t; } t = 0; for (int i = 1; i <= n; i++) { if (rk[i] == 1) continue; if (t) t--; while (a[i + t] == a[sa[rk[i] - 1] + t]) t++; h[rk[i]] = t; } LL ans = 0; for (int i = 1; i <= n; i++) { LL x = sa[i - 1], y = sa[i]; if (x > y) swap(x, y); if (x <= n - n2 - 1 && n - n2 + 1 <= y) ans = max(ans, h[i]); } printf("%lld", ans); } ``` # 多字符串最长公共子串 其实二分完全是可以的,卡卡常跑得飞快。 首先拼接成一个字符串,中间放特殊符号是显然的。 套路是二分,然后变成判断是否存在满足条件的长度为 $k$ 的子串。 由于需要满足长度为 $k$,我们就可以以 $height$ 的值将后缀分组,如果存在 $height(i)<k$ 就断开,$i$ 往后作为一个新的组。 这样可以保证组内任意两个后缀的最长公共前缀长度至少为 $k$,所以我们考虑是否存在组满足条件即可,对于这题,我们只需要保证每个字符串都有对应的后缀在组内。 时间复杂度为 $\mathcal O(n\log n)$。 ```cpp #include<bits/stdc++.h> #define LL int using namespace std; const LL N=2e6+5; LL T,n,L[N],R[N],sa[N],rk[N],h[N],b[N],tmp[N],tmp2[N],m=200,hav[12]; char s[N],a[N]; bool pd(LL x) { for(int i=1;i<=n;i++) { if(h[i]<x) { LL flg=1; for(int j=1;j<=T;j++) { if(!hav[j]) { flg=0; break; } }for(int i=1;i<=T;i++)hav[i]=0; if(flg)return 1; } for(int j=1;j<=T;j++) { if(L[j]<=sa[i]&&sa[i]<=R[j]) { hav[j]=1; break; } } } LL flg=1; for(int j=1;j<=T;j++) { if(!hav[j]) { flg=0; break; } }for(int i=1;i<=T;i++)hav[i]=0; if(flg)return 1; return 0; } int main() { while(~scanf("%s",s+1)) { L[++T]=n+1,R[T]=n+strlen(s+1); for(int j=L[T];j<=R[T];j++)a[j]=s[j-L[T]+1]; n=R[T]+1; a[n]='#'+T-1; } for(int i=1;i<=n;i++)b[rk[i]=a[i]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; LL t=0; for(int i=1;i<=n;i++)tmp[i]=rk[i]; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(LL l=1;l<n;l<<=1) { LL t=0; for(int i=n-l+1;i<=n;i++)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)++b[tmp2[i]=rk[tmp[i]]]; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[tmp2[i]]--]=tmp[i]; for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; if(m==n)break; } t=0; for(int i=1;i<=n;i++) { if(rk[i]==1)continue; if(t)t--; while(a[i+t]==a[sa[rk[i]-1]+t])t++; h[rk[i]]=t; } LL l=1,r=n,ans=0; while(l<=r) { LL mid=(l+r)/2; if(pd(mid))ans=mid,l=mid+1; else r=mid-1; } printf("%d",ans); } ``` # 长度不小于 $k$ 的公共子串个数 转换成后缀数组的经典形式,这题就是求长度不小于 $k$ 的公共前缀的数量,那就是你找到最长公共前缀,然后减去 $k-1$。 根据求 $\text{LCP}$ 的式子,显然就是维护多个区间最小值,最小值要大于 $k-1$。 容易想到不断更新右端点,用单调队列维护左端点来自第一个字符串的位置的所有符合条件的最小值。 然后遇到第二个字符串的位置就统计一下,正着做一次,倒着做一次,统计总答案就行了。 ```cpp #include<bits/stdc++.h> #define LL long long #define pLL pair<LL,LL> #define fir first #define sec second using namespace std; const LL N=1e6+5; LL k,n,n2,sa[N],rk[N],h[N],b[N],tmp[N],m; char a[N],s[N]; int main() { while(1) { scanf("%lld",&k); if(k==0)return 0; memset(sa,0,sizeof(sa)); memset(h,0,sizeof(h)); memset(b,0,sizeof(b)); memset(tmp,0,sizeof(tmp)); memset(rk,0,sizeof(rk)); memset(a,0,sizeof(a)); memset(s,0,sizeof(s)); scanf("%s%s",a+1,s+1); n=strlen(a+1),n2=strlen(s+1); a[n+1]='#'; for(int i=1;i<=n2;i++) { a[n+i+1]=s[i]; } n=n+n2+1; a[n+1]='$'; m=200; for(int i=1;i<=n;i++)b[rk[i]=a[i]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[i]]--]=i; LL t=0; for(int i=1;i<=n;i++)tmp[i]=rk[i]; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; for(LL l=1;l<n;l*=2) { LL t=0; for(int i=n-l+1;i<=n;i++)tmp[++t]=i; for(int i=1;i<=n;i++)if(sa[i]>l)tmp[++t]=sa[i]-l; for(int i=1;i<=m;i++)b[i]=0; for(int i=1;i<=n;i++)b[rk[tmp[i]]]++; for(int i=1;i<=m;i++)b[i]+=b[i-1]; for(int i=n;i>=1;i--)sa[b[rk[tmp[i]]]--]=tmp[i]; for(int i=1;i<=n;i++)tmp[i]=rk[i]; t=0; for(int i=1;i<=n;i++) { if(tmp[sa[i]]==tmp[sa[i-1]]&&tmp[sa[i]+l]==tmp[sa[i-1]+l])rk[sa[i]]=t; else rk[sa[i]]=++t; } m=t; } t=0; for(int i=1;i<=n;i++) { if(rk[i]==1)continue; if(t)t--; while(a[i+t]==a[sa[rk[i]-1]+t])t++; h[rk[i]]=t; } LL ans=0,sum=0; deque<pLL>q; for(int i=1;i<=n;i++) { LL num=0; if(n-n2+1<=sa[i])ans+=sum; else num++; while(!q.empty()&&q.back().fir>=h[i+1]-k+1)num+=q.back().sec,sum-=q.back().fir*q.back().sec,q.pop_back(); if(num&&h[i+1]>=k) { q.push_back({h[i+1]-k+1,num}); sum+=q.back().fir*q.back().sec; } } sum=0; q.clear(); for(int i=n;i>=1;i--) { LL num=0; if(n-n2+1<=sa[i])ans+=sum; else num++; while(!q.empty()&&q.back().fir>=h[i]-k+1)num+=q.back().sec,sum-=q.back().fir*q.back().sec,q.pop_back(); if(num&&h[i]>=k) { q.push_back({h[i]-k+1,num}); sum+=q.back().fir*q.back().sec; } } printf("%lld\n",ans); } } ```