实体命名识别详解(十九)

终于又回转到train()函数里了,剩下的就简单些。

train.py -> model.train(train, dev) -> base_model.py -> train(self, train, dev)

    def train(self, train, dev):
        """Performs training with early stopping and lr exponential decay

        Args:
            train: dataset that yields tuple of (sentences, tags)
            dev: dataset

        """
        best_score = 0
        nepoch_no_imprv = 0 # for early stopping
        self.add_summary() # tensorboard

        for epoch in range(self.config.nepochs):
            self.logger.info("Epoch {:} out of {:}".format(epoch + 1,
                        self.config.nepochs))

            score = self.run_epoch(train, dev, epoch)
            self.config.lr *= self.config.lr_decay # decay learning rate

            # early stopping and saving best parameters
            if score >= best_score:
                nepoch_no_imprv = 0
                self.save_session()
                best_score = score
                self.logger.info("- new best score!")
            else:
                nepoch_no_imprv += 1
                if nepoch_no_imprv >= self.config.nepoch_no_imprv:
                    self.logger.info("- early stopping {} epochs without "\
                            "improvement".format(nepoch_no_imprv))
                    break

这里一个for循环遍历每个epoch,在循环里将分数返回给score,接下来进行学习率衰减,这里

            self.config.lr *= self.config.lr_decay # decay learning rate

也就是每一个epoch后学习率变为原先的90%,为什么要进行学习率衰减呢?就好比进行冲刺跑,前面我们当然要加足马力,但是如果我们到后面不知道逐渐减慢速度,反倒会冲出终点线好远,同样的,到了后面,模型在最优点附近震荡,学习率太高的话,震荡幅度太大,无法进行接下来更合理的优化。
接下来一个判断函数,用于程序提前停止:如果当前score大于best_score,首先将nepoch_no_imprv置0,这个是计算已经连续多少轮没有提升了,然后呢保存模型,使用了save_session()函数。

    def save_session(self):
        """Saves session = weights"""
        if not os.path.exists(self.config.dir_model):
            os.makedirs(self.config.dir_model)
        self.saver.save(self.sess, self.config.dir_model)

以下是Config()类下设置的参数。

    # general config
    dir_output = "results/test/"
    dir_model  = dir_output + "model.weights/"
    path_log   = dir_output + "log.txt"

之后更新best_score并打印。
否则的话,如果新的score并不比best_score好,那么nepoch_no_imprv自增1,如果nepoch_no_imprv大于了config之前预设的停止值(这里是3),那么函数退出。

    nepoch_no_imprv  = 3

至此,train()函数就聊完了,train.py也告一段落,我们准备看最后的evaluate.py

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。