tensor2tensor 1.10 框架简介

问题

深度学习炼丹师们大都在面对某项任务是都会在github上搜索SOTA的模型实现,clone下来,尝试魔改一番以适应当前任务,评测指标达标可能就准备上线了,然后遇到下一个任务就再来一遍。这样会遇到两个问题:

  • github的模型实现往往都是基于贡献者的喜好选择自微分框架,不同任务之间不能共用模块,例如优化器,学习率策略等,说的玄幻点儿就是没有技术沉淀。
  • 多人协作更是不太可能,只能各自为战,由于没有统一的构建准则,有了bug也只能自己硬gang了。

介绍

tensor2tensor为以上两个问题提供了很好的解决方案。Attention Is All You Need所提出的Transformer的官方实现就是基于tensor2tensor(T2T)的。T2T将一个深度学习任务抽象成一个T2TExperiment,其中包括ProblemT2TModelhparams

  • Problem主要负责预处理原始数据和输入、输出的数据格式定义。一方面根据tf.Example协议将原始数据转写成TFRecord,通过problem.input_fn为模型train、eval阶段提供dataset。另一方面,利用problem.hparams中预设的Modality转化模型输入与输出和计算损失值。
  • T2TModel通过bottombodytop构建模型的核心运算,其中bottomtopModality转化数据阶段。
  • hparams主要设置模型的超参,包括层数、层宽、优化器、学习率策略等,common_hparams提供了一个基本配置。
  • T2TExperiment是对tf.estimator.Estimator一个封装,根据模型的不同阶段(train, eval, predict)通过T2TModel.make_estimator_model_fn中获取不同的tf.estimator.EstimatorSpec,所以模型的训练是借助Estimatortrain方法。create_run_config设置了模型训练的参数,包括训练步数,分布式训练策略、早听策略、模型保存策略等。

其中T2T是通过工厂模式管理ProblemT2TModelhparams的,自定义的模块可以借助registry注册到相应工厂。同样的方式之后在基于pytorch的fairseq中也得到了应用。

优点

  • 解耦深度学习任务,每个阶段只需要关注特定问题。例如针对一个新的任务,只需要构建相应的Problem,复用已有的Model就可以直接训练了。
  • 通过继承T2T中抽象的很多基类,例如Problem(Text2TextProblemText2ClassProblem),Model(TransformerResnet),来快速构建自定义的任务。
  • 自动管理模型训练中的可视化监控、验证集指标、混合精度训练等,加快任务版本迭代,将有限的精力用在构建模型主干上。
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。