题目:http://www.lydsy.com/JudgeOnline/problem.php?id=3238
这是我ORZ了网上的题解才知道的555:
首先做一次后缀数组,求出sa[],height[],然后对于height[2..len(s)]建立Cartesian Tree,那么每个节点对于题目中项lcp()的贡献为(size[left[v]]+1)*(size[right[v]]+1)(很好证明。。。略),然后得出答案。
代码:
速度依旧被虐成渣。。。
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
#define MAXN 500010
#define ll long long
#define F(t) father[t]
#define L(t) left[t]
#define R(t) right[t]
#define K(t) key[t]
#define S(t) size[t]
#define clear(t) memset(t,0,sizeof(t))
int getstring(char *s) {
int ch,len=0;
for (ch=getchar();!(ch>='a'&&ch<='z');ch=getchar()) ;
s[++len]=ch;
for (ch=getchar();ch>='a'&&ch<='z';ch=getchar()) s[++len]=ch;
return len;
}
char s[MAXN];
int n,height[MAXN],sa[MAXN],rank[MAXN],w[MAXN],r[MAXN],x[MAXN],y[MAXN];
void make_sa() {
int N,b=1,M=n;
s[sa[0]=0]='$';
for (int i=0;i++<n;) M=max(M,rank[i]=s[i]);
do {
for (int i=0;i++<n;) x[i]=rank[i],y[i]=i+b<=n?rank[i+b]:0;
b<<=1;
for (int i=0;i<=M;i++) w[i]=0;
for (int i=0;i++<n;) w[y[i]]++;
for (int i=0;i++<M;) w[i]+=w[i-1];
for (int i=0;i++<n;) r[w[y[i]]--]=i;
for (int i=0;i<=M;i++) w[i]=0;
for (int i=0;i++<n;) w[x[r[i]]]++;
for (int i=0;i++<M;) w[i]+=w[i-1];
for (int i=n;i;i--) sa[w[x[r[i]]]--]=r[i];
N=0;
for (int i=0;i++<n;) {
if (i==1||x[sa[i]]!=x[sa[i-1]]||y[sa[i]]!=y[sa[i-1]]) N++;
rank[sa[i]]=N;
}
} while (N<n);
int k=0;
for (int i=0;i++<n;) {
height[rank[i]]=k;
for (int j=k;i+j<=n&&sa[rank[i]-1]+j<=n&&s[i+j]==s[sa[rank[i]-1]+j];j++) height[rank[i]]++;
k=max(height[rank[i]]-1,0);
}
}
int father[MAXN],left[MAXN],right[MAXN],roof=0;
ll size[MAXN],key[MAXN];
void dfs(int t) {
S(t)=1;
if (L(t)) dfs(L(t)),S(t)+=S(L(t));
if (R(t)) dfs(R(t)),S(t)+=S(R(t));
}
ll Solve() {
ll rec=0;
for (int i=0;i++<n;) {
ll ret=n-i+1;
rec+=ret;
}
rec*=(n-1);
F(0)=L(0)=R(0)=K(0)=S(0)=0;
int t=0;
for (int i=2;i<=n;i++) {
K(i)=height[i];
if (!roof) { roof=t=i; continue; }
for (;t&&K(t)>K(i);t=F(t)) ;
if (t) {
L(i)=R(t),R(t)=i,F(i)=t;
} else if (K(roof)>K(i)) L(i)=roof,roof=i; else L(i)=R(roof),R(roof)=i,F(i)=roof;
t=i;
}
dfs(roof);
for (int i=2;i<=n;i++) {
ll ret=(S(L(i))+1)*(S(R(i))+1)*K(i)*2;
rec-=ret;
}
return rec;
}
int main() {
n=getstring(s);
make_sa();
printf("%lld\n",Solve());
return 0;
}