循环主体(209-410行)
1.顺序读取特征,和相应的target
while(!feature_reader.Done){}
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
typedef SequentialTableReader<KaldiObjectHolder<Matrix<BaseFloat> > > SequrentialBaseFloatMatrixReader;
template<class Holder> class SequentialTableReader (util/kaldi-table.h 中276行)
成员变量:SequentialTableReaderImplBase<Holder> *impl_;
SequentialTableReaderScriptImpl和SequentialTableReaderArchiveImpl继承SequentialTableReaderImplBase(在util/kaldi-table-inl.h)
template<class KaldiType> class KaldiObjectHolder{ typedef KaldiType T;}
2.randomizer中添加数据AddData函数
- Randomizer中所有成员变量都是Cu形式的,也就是都在显存中
- 在AddData函数中data_begin_被置为0;逐渐添加数据的过程调整data_end_
- (1)isFull函数的结束条件要求data_end_超过randomizer_size,kaldi会读入完整的句子,所以实际大小可以略微超出randomizer_size
(2)data_begin_在逐步从randomizer中读取数据后会增加,判断data_begin_!=0且data_end_>size是读取完minibatch的情况。当重新添加时data_begin_会被置为0IsFull() { return ((data_begin_ == 0) && (data_end_ > conf_.randomizer_size )); }
- 最终一轮数据添加结束的时:data_中为具体数据;data_begin_=0;data_end_为结束位置。
- 循环添加直至等于或刚超过randomizer_size
3.添加randomizer的顺序(209-310行)
- 判断是否randomizer.IsFull()
- utt=feature_reader.Key()
- num_no_tgt_mat记录无target的量;num_other_error记录无frame_weights、keep_frames的量
- 获取feature和target pair(weights默认为1.0)
Matrix<BaseFloat> mat = feature_reader.Value(); Posterior targets = targets_reader.Value(utt); weights.Resize(mat.NumRows()).Set(1.0);
- 可能会处理某些长度的mismatch
- 如果有拼帧等处理或者特征变换,利用的是nnet_transf.
nnet_transf.Feedforward(CuMatrix<BaseFloat>(mat),&feats_transf);
- 获取相应的spkid
std::vector<int32> spkid; if (utt2spk != "") { if (map_utt2spk.find(utt) != map_utt2spk.end()) { spkid.resize(feats_transf.NumRows(), map_utt2spk[utt]); } else { KALDI_WARN << utt << ", spkid is unknown"; continue; } } else { spkid.resize(feats_transf.NumRows(), 0); }
- 最后向randomizer中加入数据,准备进行混合,这样算添加完1句,num_done用于计数,每5000句会报告一次速度
KALDI_ASSERT(feats_transf.NumRows() == targets.size()); feature_randomizer.AddData(feats_transf); targets_randomizer.AddData(targets); weights_randomizer.AddData(weights); spkids_randomizer.AddData(spkid); num_done++;