编码器和解码器

(一)编码器和解码器

这是最近几年比较新的概念。

首先我们重新考察一下CNN,我们之前说,我们经过很多神经网络层,对原始图片进行特征抽取压缩,最终进行分类。那么我们换一种角度进行思考,也就是说机器通过卷积层对特征提取,变成适合机器理解的变量,整个过程我们可以将其抽象成对图片进行编码的过程。然而解码过程就是再从高维特征转变成人类能理解的含义。

  • 编码器:将输入变成中间表达式(特征)
  • 解码器:将中间表示解码成输出

那我们再看看RNN是怎么样的。

  • 编码器:将文本表示成向量
  • 解码器:向量表示成输出
(1)编码器-解码器架构

也就是说,一个模型可以被分为两块:

  • 编码器处理输入
  • 解码器生成输出

后面的网络架构都会采取这样的形式来定义。

(二)代码实现

我们来看一下重新定义这样的架构之后,神经网络怎么实现。下面是伪代码:

from torch import nn


# 编码器
class Encoder(nn.Module):
    """编码器-解码器架构的基本编码器接口"""
    def __init__(self, **kwargs):
        super(Encoder, self).__init__(**kwargs)

    def forward(self, X, *args):
        raise NotImplementedError
# 解码器
class Decoder(nn.Module):
    """编码器-解码器架构的基本解码器接口"""
    def __init__(self, **kwargs):
        super(Decoder, self).__init__(**kwargs)

    # 初始化状态
    def init_state(self, enc_outputs, *args):
        # 这里的输入是编码器的输出,这里是对上面的输入要做一些处理。
        raise NotImplementedError

    def forward(self, X, state):
        # 解码器也会有自己的输入,他的输出是编码器输出和解码器输入共同的结果。
        raise NotImplementedError
# 编码器解码器架构
class EncoderDecoder(nn.Module):
    """编码器-解码器架构的基类"""
    def __init__(self, encoder, decoder, **kwargs):
        super(EncoderDecoder, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, enc_X, dec_X, *args):
        enc_outputs = self.encoder(enc_X, *args)
        dec_state = self.decoder.init_state(enc_outputs, *args) # 初始化状态
        return self.decoder(dec_X, dec_state)
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

相关阅读更多精彩内容

友情链接更多精彩内容