poj 2774
求两个字符串的最长公共子串,可以二分长度,把A串中长度为mid的子串的hash值存入hash table里(set map也可),在B串中枚举子串判断是否存在hash table里。hash的常数较大,比后缀数组、后缀自动机的解法较慢,模板长度也不小(我的代码用双hash,第一个hash检索table中的下标,第二个hash判断冲突时是否相等,hash_table中的tim作为计数器,用来初始多组数据,避免多次memset超时)。
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn=1e5+100;
const int ha1=443,ha2=317,mod=1000099;
struct Point
{
int x,y;
Point(int a=0,int b=0) { x=a,y=b; }
friend Point operator + (const Point &a,const Point &b)
{
return Point((a.x+b.x) % mod,(a.y+b.y) % mod);
}
friend Point operator - (const Point &a,const Point &b)
{
return Point((a.x-b.x+mod) % mod,(a.y-b.y+mod) % mod);
}
friend Point operator * (const Point &a,const Point &b)
{
return Point((long long)a.x*b.x % mod, (long long)a.y*b.y % mod);
}
friend bool operator == (const Point &a,const Point &b)
{
return (a.x==b.x && a.y==b.y);
}
};
struct hash_table
{
int tim;
int vis[mod],key2[mod];
hash_table()
{
tim=0;
memset(vis,0,sizeof(vis));
memset(key2,0,sizeof(key2));
}
void init()
{
tim++;
}
int get_pos(const Point &a)
{
int pos=a.x,v=a.y;
for (; vis[pos]==tim && key2[pos]!=v; pos+=11,pos%=mod);
return pos;
}
void insert(const Point &a)
{
int pos=get_pos(a);
vis[pos]=tim;
key2[pos]=a.y;
}
bool find(const Point &a)
{
int pos=get_pos(a);
return (vis[pos]==tim);
}
};
Point po[maxn],sum1[maxn],sum2[maxn];
hash_table HT;
char A[maxn],B[maxn];
int N,M;
void Prepare_hash()
{
po[0]=Point(1,1);
po[1]=Point(ha1,ha2);
for (int i=2; i<maxn; i++)
po[i]=po[i-1]*po[1];
}
void get_hash(char *s,Point *sum,int len)
{
sum[0]=Point(0,0);
for (int i=1; i<=len; i++)
sum[i]=sum[i-1]+po[i]*Point(s[i],s[i]);
}
bool check(int len)
{
int U=max(N,M);
HT.init();
for (int i=1; i<=N-len+1; i++)
{
Point p=sum1[i+len-1]-sum1[i-1];
p=p*po[U-i];
HT.insert(p);
}
for (int i=1; i<=M-len+1; i++)
{
Point p=sum2[i+len-1]-sum2[i-1];
p=p*po[U-i];
if (HT.find(p)) return true;
}
return false;
}
int main()
{
Prepare_hash();
for (; scanf("%s%s",A+1,B+1)!=EOF; )
{
N=strlen(A+1);
M=strlen(B+1);
get_hash(A,sum1,N);
get_hash(B,sum2,M);
int le=0,ri=min(N,M)+1;
while(le+1<ri)
{
int mid=(le+ri)>>1;
check(mid) ? le=mid : ri=mid;
}
printf("%d\n",le);
}
return 0;
}