trie树+kmp。。fail指针其实就是相当于kmp那个未优化的next数组,考虑到fail是有方向的,方向可以理解成当前这个(到这个节点为止)的后缀是之前一个短串的后缀,是比root到当前这个短的,所以在统计的时候,才能保证fail往前走就可以。
统计时候,fail数组一直往前,一路相加(这里设置了last数组,就是fail数组中节点是模式串结尾的点)
插入:
void Insert(int v){
int u = 0,len = strlen(ss[v]);
for(int i = 0 ; i < len ; i++){
int tmp = ss[v][i]-'A';
if(!ch[u][tmp]){
memset(ch[node],0,sizeof(ch[node]));
val[node] = 0;
ch[u][tmp] = node++;
}
u = ch[u][tmp];
}
val[u] = v;
}
fail:
fail指针指向的问题和KMP算法中构造next数组的方式如出一辙。具体方法如下
1)将根结点的所有孩子结点的fail指向根结点,然后将根结点的所有孩子结点依次入列。
2)若队列不为空:
2.1)出列,我们将出列的结点记为curr, failTo表示curr的fail指向的结点,即failTo = curr.fail
2.2) a.判断curr.child[i] == failTo.child[i]是否成立,
成立:curr.child[i].fail = failTo.child[i],
不成立:判断 failTo == null是否成立
成立: curr.child[i].fail == root
不成立:执行failTo = failTo.fail,继续执行2.2)
b.curr.child[i]入列,再次执行再次执行步骤2)
若队列为空:结束
//ch【】【】:第二维ascii码范围,开260.
//last数组,就是fail数组中节点是模式串结尾的点)
void getfail(){
queue<int> q;
fail[0] = 0;
for(int i = 0 ; i < ascs ; i++){
if(ch[0][i]){
fail[ch[0][i]] = 0;
q.push(ch[0][i]);
last[ch[0][i]] = 0;
}
}
while(!q.empty()){
int tmp = q.front();
q.pop();
for(int i = 0 ; i < ascs ; i++){
int u = ch[tmp][i];
if(u != 0){
q.push(u);
int v = fail[tmp];
while(v && ch[v][i]== 0) v = fail[v];
fail[u] = ch[v][i];
last[u] = val[fail[u]]?fail[u]:last[fail[u]];
}
}
}
}
查询:
void Find(){
int len = strlen(s),j = 0;
for(int i = 0 ; i < len ; i++){
// if(s[i] <'A' ||s[i] > 'Z') continue;
int tmp = s[i]-'A';
while(j && ch[j][tmp] == 0) j = fail[j];
j = ch[j][tmp];
if(val[j]) cal(j);
else if(last[j]) cal(last[j]);
}
}
Problem Description
小t非常感谢大家帮忙解决了他的上一个问题。然而病毒侵袭持续中。在小t的不懈努力下,他发现了网路中的“万恶之源”。这是一个庞大的病毒网站,他有着好多好多的病毒,但是这个网站包含的病毒很奇怪,这些病毒的特征码很短,而且只包含“英文大写字符”。当然小t好想好想为民除害,但是小t从来不打没有准备的战争。知己知彼,百战不殆,小t首先要做的是知道这个病毒网站特征:包含多少不同的病毒,每种病毒出现了多少次。大家能再帮帮他吗?
Input
第一行,一个整数N(1<=N<=1000),表示病毒特征码的个数。
接下来N行,每行表示一个病毒特征码,特征码字符串长度在1—50之间,并且只包含“英文大写字符”。任意两个病毒特征码,不会完全相同。
在这之后一行,表示“万恶之源”网站源码,源码字符串长度在2000000之内。字符串中字符都是ASCII码可见字符(不包括回车)。
Output
按以下格式每行一个,输出每个病毒出现次数。未出现的病毒不需要输出。
病毒特征码: 出现次数
冒号后有一个空格,按病毒特征码的输入顺序进行输出。
Sample Input
3
AA
BB
CC
ooxxCC%dAAAoen….END
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
int n,node;
const int maxn = 50*1000+10;
const int ascs = 200;
const int mm = 2e6+10;
char s[mm],ss[1010][51];
int ch[maxn][ascs],fail[maxn],val[maxn],last[maxn],cnt[mm];
void Insert(int v){
int u = 0,len = strlen(ss[v]);
for(int i = 0 ; i < len ; i++){
int tmp = ss[v][i]-'A';
if(!ch[u][tmp]){
memset(ch[node],0,sizeof(ch[node]));
val[node] = 0;
ch[u][tmp] = node++;
}
u = ch[u][tmp];
}
val[u] = v;
}
void getfail(){
queue<int> q;
fail[0] = 0;
for(int i = 0 ; i < ascs ; i++){
if(ch[0][i]){
fail[ch[0][i]] = 0;
q.push(ch[0][i]);
last[ch[0][i]] = 0;
}
}
while(!q.empty()){
int tmp = q.front();
q.pop();
for(int i = 0 ; i < ascs ; i++){
int u = ch[tmp][i];
if(u != 0){
q.push(u);
int v = fail[tmp];
while(v && ch[v][i]== 0) v = fail[v];
fail[u] = ch[v][i];
last[u] = val[fail[u]]?fail[u]:last[fail[u]];
}
}
}
}
void cal(int j){
if(j){
cnt[val[j]]++;
cal(last[j]);
}
}
void Find(){
int len = strlen(s),j = 0;
for(int i = 0 ; i < len ; i++){
// if(s[i] <'A' ||s[i] > 'Z') continue;
int tmp = s[i]-'A';
while(j && ch[j][tmp] == 0) j = fail[j];
j = ch[j][tmp];
if(val[j]) cal(j);
else if(last[j]) cal(last[j]);
}
}
void init(){
node = 1;
memset(ch[0],0,sizeof(ch[0]));
memset(cnt, 0, sizeof(cnt[0])*(n+2));
for(int i = 1 ;i <= n ; i++){
scanf("%s",ss[i]);
//cout << ss[i]<<endl;
Insert(i);
}
getfail();
}
void sov(){
scanf("%s",s);
Find();
for(int i = 1; i <= n ; i ++){
if(cnt[i] == 0) continue;
printf("%s: %d\n",ss[i],cnt[i]);
}
}
int main(){
while(~scanf("%d",&n)){
init();
sov();
}
return 0;
}