解码就是输入音频,利用声学模型、构建好的WFST解码网络,输出最优状态序列的过程。以Kaldi中LatticeFasterOnlineDecoder为例,解析解码代码。
示例程序:
online2-wav-nnet3-latgen-faster --do-endpointing=false --online=false --frame-subsampling-factor=3
--config=conf/online.conf --max-active=7000 --beam=15.0 --frames-per-chunk=50 --lattice-beam=6.0
--acoustic-scale=1.0 --word-symbol-table=words.txt final.mdl HCLG.fst ark:spk2utt.txt scp:test.scp ark,t:lat.debug.txt
声学模型:final.mdl Kaldi Chain model 文件解析
WFST:HCLG.fst
spk2utt.txt 内容如下:
wav10 wav10
wav9 wav9
test.scp 内容如下:
wav10 data/wav/00030/2017_03_07_16.57.22_1175.wav
wav9 data/wav/00030/2017_03_07_16.57.40_2562.wav
主要数据结构:
- Token
struct Token {
BaseFloat tot_cost; // 到该状态的累计最优cost
BaseFloat extra_cost; //token所有ForwardLinks中和最优路径的cost差的最小值,PruneActiveTokens 用到
ForwardLink *links; // 链表,表示现在时刻到下一时刻的那条跳转边
Token *next; // 指向同一时刻的下一个token
Token *backpointer; // 指向上一时刻的最佳token,相当于一个回溯指针
};
- ForwardLink
struct ForwardLink {
Token *next_tok; // 这条链接指向的token
Label ilabel; // 这下面的四个量取自解码图中的跳转/弧/边,因为每一个状态
Label olabel; // 维护一个token,那么token到token之间的连接信息和状态到状态之间的信息
BaseFloat graph_cost; // 应该保持一致,所以会有输入(tid),输出,权值(就是graph_cost)
BaseFloat acoustic_cost; // acoustic_cost就是tid对应的pdf_id的在声学模型中的后验
ForwardLink *next; // 链表结构,指向下一个
};
- TokenList
struct TokenList {
Token *toks; // 同一时刻的token链表头
bool must_prune_forward_links; // 这两个是Lattice剪枝标记,起始默认设置为true
bool must_prune_tokens;
};
- HashList
template<class I, class T> class HashList {
struct Elem {
I key; // state
T val; // Token
Elem *tail;
};
struct HashBucket {
size_t prev_bucket; // 指向下一个桶,最后一个指向-1
Elem *last_elem; // 指向挂在桶上的最后一个元素,空桶指向NULL
};
Elem *list_head_; // 链表头
size_t bucket_list_tail_; // 当前活跃桶最后一个下标
size_t hash_size_; // 当前活跃桶个数
std::vector<HashBucket> buckets_; //存储实际活跃的桶
Elem *freed_head_; // head of list of currently freed elements. [ready for allocation]
std::vector<Elem*> allocated_; // list of allocated blocks.
};
解码过程中上述数据结构对应的一些重要变量如下(来自decoder/lattice-faster-online-decoder.h)
HashList<StateId, Token*> toks_;
std::vector<TokenList> active_toks_; // 每一帧对应其中一个TokenList,等于frame+1,
std::vector<StateId> queue_; // 临时变量,用于ProcessNonemitting,保存的是下一时刻state
std::vector<BaseFloat> tmp_array_; // used in GetCutoff.
解码整体流程:
- 模型、文件加载,配置生成;
- 三层循环
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) { //循环speaker
...
const std::vector<std::string> &uttlist = spk2utt_reader.Value();
for (size_t i = 0; i < uttlist.size(); i++) { //循环某个speaker的所有wav
SingleUtteranceNnet3Decoder decoder(decoder_opts, trans_model, decodable_info, *decode_fst, &feature_pipeline); //构造函数中调用InitDecoding()
//循环某个wav的chunk,比如说一帧一帧,online=false的时候一次加载整个wav
while (samp_offset < data.Dim()) {
decoder.AdvanceDecoding();
}
decoder.FinalizeDecoding();
decoder.GetLattice(end_of_utterance, &clat);
GetDiagnosticsAndPrintOutput(utt, word_syms, clat,&num_frames, &tot_like);
}
}
对于单个wav,最主要流程就是三个函数:
void InitDecoding();
void LatticeFasterOnlineDecoder::AdvanceDecoding(DecodableInterface *decodable, int32 max_num_frames=-1)
void FinalizeDecoding();
其中AdvanceDecoding主流程如下图,每帧数据处理流程包括:
BaseFloat ProcessEmittingWrapper(DecodableInterface *decodable);
实际调用LatticeFasterOnlineDecoder::ProcessEmitting<fst::VectorFst<Arc>>(decodable);
处理输入非空跳转(ilabel != 0),主体两层循环,外层循环现在时刻所有Token,内层循环每个现在时刻的state能够跳转的下一时刻所有state。
ProcessEmitting 函数中vector active_toks_ 加1(active_toks_.resize(active_toks_.size() + 1);),另外,NumFramesDecoded() 返回值等于active_toks_.size() - 1。void ProcessNonemittingWrapper(BaseFloat cost_cutoff);
实际调用LatticeFasterOnlineDecoder::ProcessNonemitting<fst::VectorFst<Arc>>(cost_cutoff);
处理输入空跳转(ilabel == 0),主体两层循环,外层循环下一时刻所有Token,内层循环每个下一时刻的state能够跳转到的的state。可以这样理解,下一时刻的空跳转还是现在时刻通过一帧能够到达的时刻。void PruneActiveTokens(BaseFloat delta);
lattice beam 剪枝,默认25帧一次,包括两部分:剪枝ForwardLinks(PruneForwardLinks函数),剪枝Tokens(PruneTokensForFrame函数)
- 打印统计信息
主要函数解析:
- ProcessEmitting (decoder/lattice-faster-online-decoder.cc)
template <typename FstType>
BaseFloat LatticeFasterOnlineDecoder::ProcessEmitting(
DecodableInterface *decodable) {
KALDI_ASSERT(active_toks_.size() > 0);
int32 frame = active_toks_.size() - 1;
active_toks_.resize(active_toks_.size() + 1); //每帧+1,外层调用的while循环也是
Elem *final_toks = toks_.Clear(); // 此处clear的是bucket,返回链表头,遍历可得现在时刻所有state的链表
Elem *best_elem = NULL;
BaseFloat adaptive_beam;
size_t tok_cnt;
// Beam prune 参数获取,包括cur_cutoff,adaptive_beam, best_elem。 后两者用来确定next_cutoff
// 主要是两个条件,默认是best_weight + config_.beam,同时用config_.max_active、config_.min_active 做了加强,希望state数目在[config_.min_active, config_.max_active]之间
BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
PossiblyResizeHash(tok_cnt); // This makes sure the hash is always big enough.
BaseFloat next_cutoff = std::numeric_limits<BaseFloat>::infinity();
BaseFloat cost_offset = 0.0; // Used to keep probabilities in a good
const FstType &fst = dynamic_cast<const FstType&>(fst_);
// 下面这个块只是为了得到next_cutoff and cost_offset.
// next_cutoff 用于下一时刻state的beam prune。等于现在时刻最优state到下一时刻对应所有state中最优的tot_cost
// cost_offset 只是为了计算方面的考虑,相当于同时减了一个最小数。
if (best_elem) {
StateId state = best_elem->key;
Token *tok = best_elem->val;
cost_offset = - tok->tot_cost;
for (fst::ArcIterator<FstType> aiter(fst, state);
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // propagate..
BaseFloat new_weight = arc.weight.Value() + cost_offset -
decodable->LogLikelihood(frame, arc.ilabel) + tok->tot_cost; // 这一步cost_offset + tok_tot_cost === 0,可以不要
if (new_weight + adaptive_beam < next_cutoff)
next_cutoff = new_weight + adaptive_beam;
}
}
}
...
// the tokens are now owned here, in final_toks, and the hash is empty.
// 'owned' is a complex thing here; the point is we need to call DeleteElem
// on each elem 'e' to let toks_ know we're done with them.
for (Elem *e = final_toks, *e_tail; e != NULL; e = e_tail) { //外层循环,遍历现在时刻state
// loop this way because we delete "e" as we go.
StateId state = e->key;
Token *tok = e->val;
if (tok->tot_cost <= cur_cutoff) { // 现在时刻beam prune,tot_cost控制在cur_cutoff阈值以内,cur_cutoff=现在时刻最优state tot_cost+beam
for (fst::ArcIterator<FstType> aiter(fst, state); // 内层循环,遍历现在时刻某个state的所有跳转
!aiter.Done();
aiter.Next()) {
const Arc &arc = aiter.Value();
if (arc.ilabel != 0) { // 输入非空跳转
BaseFloat ac_cost = cost_offset -
decodable->LogLikelihood(frame, arc.ilabel),
graph_cost = arc.weight.Value(),
cur_cost = tok->tot_cost,
tot_cost = cur_cost + ac_cost + graph_cost;
if (tot_cost > next_cutoff) continue;
// 下一时刻beam prune,下一时刻tot_cost控制在阈值next_cutoff之内。
// next_cutoff,初始值为:现在时刻最优state到下一时刻所有state中最优cost+adaptive_beam。注意不是下一时刻所有state中最优cost+adaptive_beam,后面再动态调整。
else if (tot_cost + adaptive_beam < next_cutoff)
next_cutoff = tot_cost + adaptive_beam;
//扩展下一时刻token,存取在toks_中,这一帧的ProcessNonemitting就是在toks_对应的list中循环。所以说ProcessNonemitting循环的是下一时刻的state以及下一时刻state的扩展跳转。
Token *next_tok = FindOrAddToken(arc.nextstate,frame + 1, tot_cost, tok, NULL);
// 加边。Add ForwardLink from tok to next_tok (put on head of list tok->links)
tok->links = new ForwardLink(next_tok, arc.ilabel, arc.olabel,
graph_cost, ac_cost, tok->links);
}
} // for all arcs
}
e_tail = e->tail;
toks_.Delete(e); // delete Elem
}
return next_cutoff;
}
主体流程是双层循环,也就是Viterbi解码,外层循环现在时刻所有state,内层循环每个state对应的每个跳转,确定下一时刻所有state。过程中生成state对应的Token以及ForwardLink。同时用到了Beam Prune,现在时刻和下一时刻都有应用。
ProcessNonemitting(BaseFloat cutoff) (decoder/lattice-faster-online-decoder.cc)
首先遍历前面ProcessEmitting函数生成的HashList,得到现在时刻state 队列 queue_
然后两层遍历:外层遍历queue_,内层遍历stata的空跳转;
注意一点的是:frame = static_cast<int32>(active_toks_.size()) - 2 ,这个如果不注意,理解内循环中的FindOrAddToken函数会出现偏差。FindOrAddToken
构造Token,插入到active_toks_[frame_plus_one].toks指向的Token list中,插入到HashList toks_中
inline LatticeFasterOnlineDecoder::Token *LatticeFasterOnlineDecoder::FindOrAddToken(
StateId state, int32 frame_plus_one, BaseFloat tot_cost,
Token *backpointer, bool *changed) {
// Returns the Token pointer. Sets "changed" (if non-NULL) to true
// if the token was newly created or the cost changed.
KALDI_ASSERT(frame_plus_one < active_toks_.size());
Token *&toks = active_toks_[frame_plus_one].toks; // 引用,注意后面的改变其实改变了右边的值
Elem *e_found = toks_.Find(state); //HashList中查找
if (e_found == NULL) { // no such token presently.
const BaseFloat extra_cost = 0.0;
Token *new_tok = new Token (tot_cost, extra_cost, NULL, toks, backpointer); //构造Token,头插
toks = new_tok;
num_toks_++;
toks_.Insert(state, new_tok); //toks_是一个HashList,ProcessNonemitting函数或者下一帧会用到
if (changed) *changed = true;
return new_tok;
} else {
Token *tok = e_found->val; // There is an existing Token for this state.
if (tok->tot_cost > tot_cost) { // replace old token
tok->tot_cost = tot_cost;
tok->backpointer = backpointer;
if (changed) *changed = true;
} else {
if (changed) *changed = false;
}
return tok;
}
}
- GetCutoff(Elem *list_head, size_t *tok_count, BaseFloat *adaptive_beam, Elem **best_elem)
Viterbi解码中涉及到现在时刻state数目以及下一时刻state数目,如果我们想要提高解码速度,需要对这两个数值都做缩减。实际做法是设置阈值,减少语音识别中现在时刻以及下一时刻状态数目,具体做法是:** 首先求现在时刻最优路径得分,加上beam,得到现在时刻得分阈值;然后求下一时刻最优路径得分,加上beam,得到下一时刻得分阈值**;具体步骤是:
- 对所有状态排序,最优状态放最前面,最优状态得分=best_weight
- 设置一个beam,设置阈值1=cur_cutoff,cur_cutoff=best_weight+beam,所有得分在cur_cutoff以内的,保留,反之丢弃,现在时刻的state数目减少。
- 计算到下一时刻的最优路径得分new_weight。
- 设置一个adaptive_beam, 设置阈值2=next_cutoff,next_cutoff=new_weight+adaptive_beam,所有得分在next_cutoff以内的,保留,反之丢弃,下一时刻的state数目减少。
注意上述步骤中的beam不是参数传递进去的config_.beam;因为我们如果直接用config_.beam,有可能卡出的state数目太多(大于config_.max_active)或者太少(少于config_.min_active)。所以需要分类讨论,确定最终的beam值,adaptive_beam类似。
cur_cutoff,adaptive_beam 都是来自GetCutoff函数:
// BaseFloat cur_cutoff = GetCutoff(final_toks, &tok_cnt, &adaptive_beam, &best_elem);
// 输入final_toks,HashList对应的list,toks_.Clear() 操作后的得头结点指向
// 输出 cur_cutoff,返回值,用于现在时刻Beam Prune
// 输出 adaptive_beam, best_elem 得到next_cutoff,用于下一时刻Beam Prune
// 输出 tok_cnt 用于重置HashList toks_大小 ,足够大,减少内存分配时间
PossiblyResizeHash(tok_cnt)
BaseFloat LatticeFasterOnlineDecoder::GetCutoff(Elem *list_head, size_t *tok_count,
BaseFloat *adaptive_beam, Elem **best_elem)
- PruneActiveTokens
从后向前,主要做两步操作:
PruneForwardLinks,删减Token的ForwordLinks,
PruneTokensForFrame,删减Token本身,如果该Token对应的所有的ForwardLinks 都没有了,那Token本身也可以删除,判断条件tok->extra_cost == std::numeric_limits<BaseFloat>::infinity(),extra_cost代表该tok所有ForwardLinks到的next state 的tot_cost和到达该next state最优路径的tot_cost差的最小值,如果是无穷大(最小值都是无穷大)代表所有ForwordLinks都删除了。
Reference
http://www.funcwj.cn/2017/08/02/kaldi-online-decoder/
https://blog.csdn.net/u013677156/article/details/78930532