问题
深度学习炼丹师们大都在面对某项任务是都会在github上搜索SOTA的模型实现,clone下来,尝试魔改一番以适应当前任务,评测指标达标可能就准备上线了,然后遇到下一个任务就再来一遍。这样会遇到两个问题:
- github的模型实现往往都是基于贡献者的喜好选择自微分框架,不同任务之间不能共用模块,例如优化器,学习率策略等,说的玄幻点儿就是没有技术沉淀。
- 多人协作更是不太可能,只能各自为战,由于没有统一的构建准则,有了bug也只能自己硬gang了。
介绍
tensor2tensor为以上两个问题提供了很好的解决方案。Attention Is All You Need所提出的Transformer的官方实现就是基于tensor2tensor(T2T)的。T2T将一个深度学习任务抽象成一个T2TExperiment
,其中包括Problem
、T2TModel
、hparams
。
-
Problem
主要负责预处理原始数据和输入、输出的数据格式定义。一方面根据tf.Example
协议将原始数据转写成TFRecord
,通过problem.input_fn
为模型train、eval阶段提供dataset
。另一方面,利用problem.hparams
中预设的Modality
转化模型输入与输出和计算损失值。 -
T2TModel
通过bottom
、body
、top
构建模型的核心运算,其中bottom
和top
是Modality
转化数据阶段。 -
hparams
主要设置模型的超参,包括层数、层宽、优化器、学习率策略等,common_hparams
提供了一个基本配置。 -
T2TExperiment
是对tf.estimator.Estimator
一个封装,根据模型的不同阶段(train, eval, predict)通过T2TModel.make_estimator_model_fn
中获取不同的tf.estimator.EstimatorSpec
,所以模型的训练是借助Estimator
的train
方法。create_run_config
设置了模型训练的参数,包括训练步数,分布式训练策略、早听策略、模型保存策略等。
其中T2T
是通过工厂模式管理Problem
、T2TModel
和hparams
的,自定义的模块可以借助registry
注册到相应工厂。同样的方式之后在基于pytorch的fairseq中也得到了应用。
优点
- 解耦深度学习任务,每个阶段只需要关注特定问题。例如针对一个新的任务,只需要构建相应的
Problem
,复用已有的Model
就可以直接训练了。 - 通过继承
T2T
中抽象的很多基类,例如Problem
(Text2TextProblem
、Text2ClassProblem
),Model
(Transformer
、Resnet
),来快速构建自定义的任务。 - 自动管理模型训练中的可视化监控、验证集指标、混合精度训练等,加快任务版本迭代,将有限的精力用在构建模型主干上。