pytorch实现一个CVAE对话系统

CVAE模型结构

CVAE模型结构

如上图所示,CVAE模型在seq2seq的基础上多了一个先验网络,一个识别网络。在训练时,从识别网络中采样隐变量用于解码,而测试时从先验网络采样隐变量。这里不考虑图中的dialog act和bow预测,即图中蓝色和黄色部分。
所以基本模块主要包括Embedding,Encoder,PriorNet,RecognizeNet,Decoder。

Embedding

import torch.nn as nn


class Embedding(nn.Module):
    def __init__(self, num_vocab,
                 embedding_size,
                 pad_id=0,
                 dropout=0.1):
        super(Embedding, self).__init__()
        self.embedding = nn.Embedding(num_vocab, embedding_size, padding_idx=pad_id)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):  # [batch, seq]
        return self.dropout(self.embedding(x))  # [batch, seq, embedding_size]

参数分别是词汇表大小,词嵌入维度,用于pad句子的符号在词汇表中的id和dropout的概率。主要就是封装了nn.Embedding模块。

Encoder

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence


class Encoder(nn.Module):
    def __init__(self, cell_type,  # rnn类型
                 input_size,  # 输入维度
                 output_size,  # 输出维度
                 num_layers,  # rnn层数
                 bidirectional=False,  # 是否双向
                 dropout=0.1):  # dropout
        super(Encoder, self).__init__()
        assert cell_type in ['GRU', 'LSTM']  # 限定rnn类型

        if bidirectional:  # 如果双向
            assert output_size % 2 == 0
            cell_size = output_size // 2  # rnn维度
        else:
            cell_size = output_size

        self.bidirectional = bidirectional
        self.cell_type = cell_type
        self.rnn_cell = getattr(nn, cell_type)(input_size=input_size,
                                               hidden_size=cell_size,
                                               num_layers=num_layers,
                                               bidirectional=bidirectional,
                                               dropout=dropout)

    def forward(self, x,  # [seq, batch, dim]
                length):  # [batch]
        x = pack_padded_sequence(x, length, enforce_sorted=False)

        # output: [seq, batch, dim*directions] 每个时间步的输出
        # final_state = [layers*directions, batch, dim] 每一层的最终状态
        output, final_state = self.rnn_cell(x)
        output = pad_packed_sequence(output)[0]

        if self.bidirectional:  # 如果是双向的,对双向进行拼接作为每层的最终状态
            if self.cell_type == 'GRU':
                final_state_forward = final_state[0::2, :, :]  # [layers, batch, dim]
                final_state_back = final_state[1::2, :, :]  # [layers, batch, dim]
                final_state = torch.cat([final_state_forward, final_state_back], 2)  # [layers, batch, dim*2]
            else:
                final_state_h, final_state_c = final_state
                final_state_h = torch.cat([final_state_h[0::2, :, :], final_state_h[1::2, :, :]], 2)
                final_state_c = torch.cat([final_state_c[0::2, :, :], final_state_c[1::2, :, :]], 2)
                final_state = (final_state_h, final_state_c)

        # output = [seq, batch, dim]
        # final_state = [layers, batch, dim]
        return output, final_state

具体参数都写在了注释中,其他值得注意的就是pack_padded_sequence和pad_packed_sequence的作用。pack_padded_sequence是将句子中的pad压缩,因为数据是按batch封装的,所有的输入都会用pad补齐到这个batch中最长句子的长度,这部分是没必要计算的,只需计算到pad之前的最后一个字符就可以了,通过pack_padded_sequence就会忽略这个的影响。传入的参数length就是这个batch中每句句子的长度,告诉每句句子需要计算多少的长度。另外,如果这个batch的所有句子都按长度进行排序(好像是逆序),enforce_sorted这个参数就可以设置为True来加快计算速度,否则就会报错,如果没排序直接设置False。pad_packed_sequence是一个反向的操作,返回值是一个包含2个值的元组,第一个就是需要的输出,第二个是句子的长度,也就是之前传进去的length又传了回来。通常取第一个值就可以了。
如果是双向的编码器,需要将正向的状态和反向的状态做一个拼接作为最终的状态输出。

先验网络

import torch.nn as nn


class PriorNet(nn.Module):
    r""" 计算先验概率p(z|x)的网络,x为解码器最后一步的输出 """
    def __init__(self, x_size,  # post编码维度
                 latent_size,  # 隐变量维度
                 dims):  # 隐藏层维度
        super(PriorNet, self).__init__()
        assert len(dims) >= 1  # 至少两层感知机

        dims = [x_size] + dims + [latent_size*2]
        dims_input = dims[:-1]
        dims_output = dims[1:]

        self.latent_size = latent_size
        self.mlp = nn.Sequential()
        for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
            self.mlp.add_module(f'linear{idx}', nn.Linear(x, y))  # 线性层
            self.mlp.add_module(f'activate{idx}', nn.Tanh())  # 激活层
        self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))

    def forward(self, x):  # [batch, x_size]
        predict = self.mlp(x)  # [batch, latent_size*2]
        mu, logvar = predict.split([self.latent_size]*2, 1)
        return mu, logvar

先验网络本质上就是一个多层感知机目的是计算先验概率p(z|x)z的均值和log方差,因为先验分布通常假设为一个高斯分布\mathcal N(\mu,\sigma^2),包含两个参数\mu\sigma^2。那为啥不直接预测方差呢,那是因为从高斯分布中采样的操作是不可微的,需要通过重参数化实现,即采样\mathcal N(\mu,\sigma^2)等于从\mathcal N(0, 1)采样\varepsilon并计算\mu+\varepsilon\sigma,所以通常预测log方差比较方便计算。

识别网络

import torch
import torch.nn as nn


class RecognizeNet(nn.Module):
    r""" 计算后验概率p(z|x,y)的网络;x,y为编码器最后一步的输出 """
    def __init__(self, x_size,  # post编码维度
                 y_size,  # response编码维度
                 latent_size,  # 隐变量维度
                 dims):  # 隐藏层维度
        super(RecognizeNet, self).__init__()
        assert len(dims) >= 1  # 至少两层感知机

        dims = [x_size+y_size] + dims + [latent_size*2]
        dims_input = dims[:-1]
        dims_output = dims[1:]

        self.latent_size = latent_size
        self.mlp = nn.Sequential()
        for idx, (x, y) in enumerate(zip(dims_input[:-1], dims_output[:-1])):
            self.mlp.add_module(f'linear{idx}', nn.Linear(x, y))  # 线性层
            self.mlp.add_module(f'activate{idx}', nn.Tanh())  # 激活层
        self.mlp.add_module('output', nn.Linear(dims_input[-1], dims_output[-1]))

    def forward(self, x,  # [batch, x_size]
                y):  # [batch, y_size]
        x = torch.cat([x, y], 1)  # [batch, x_size+y_size]
        predict = self.mlp(x)  # [batch, latent_size*2]
        mu, logvar = predict.split([self.latent_size]*2, 1)
        return mu, logvar

识别网络本质上也是一个多层感知机,只是多加了回复编码产生的后验信息,和先验网络一起作为一个模块其实都是可以的。

解码器

import torch.nn as nn


class Decoder(nn.Module):
    def __init__(self, cell_type,  # rnn类型
                 input_size,  # 输入维度
                 output_size,  # 输出维度
                 num_layer,  # rnn层数
                 dropout=0.1):  # dropout
        super(Decoder, self).__init__()
        assert cell_type in ['GRU', 'LSTM']  # 限定rnn类型

        self.cell_type = cell_type
        self.rnn_cell = getattr(nn, cell_type)(
            input_size=input_size,
            hidden_size=output_size,
            num_layers=num_layer,
            dropout=dropout)

    def forward(self, x,  # 输入 [seq, batch, dim] 或者单步输入 [1, batch, dim]
                state):  # 初始状态 [layers*directions, batch, dim]
        # output: [seq, batch, dim*directions] 每个时间步的输出
        # final_state: [layers*directions, batch, dim] 每一层的最终状态
        output, final_state = self.rnn_cell(x, state)
        return output, final_state

和编码器没什么区别。

其余模块

使用编码器最终状态和隐变量z初始化解码器初始状态

import torch.nn as nn


class PrepareState(nn.Module):
    r""" 准备解码器的初始状态,使用隐变量和编码器输入进行初始化 """
    def __init__(self, input_size,  # 用于初始化状态的向量维度
                 decoder_cell_type,  # 解码器类型
                 decoder_output_size,  # 解码器隐藏层大小
                 decoder_num_layers):  # 解码器层数
        super(PrepareState, self).__init__()
        assert decoder_cell_type in ['GRU', 'LSTM']

        self.decoder_cell_type = decoder_cell_type
        self.num_layers = decoder_num_layers
        self.linear = nn.Linear(input_size, decoder_output_size)

    def forward(self, x):  # [batch, dim]
        if self.num_layers > 1:
            states = self.linear(x).unsqueeze(0).repeat(self.num_layers, 1, 1)  # [num_layers, batch, output_size]
        else:
            states = self.linear(x).unsqueeze(0)
        if self.decoder_cell_type == 'LSTM':
            return states, states  # (h, c)
        else:
            return states

就是将编码器的最后一层的最终状态和隐变量拼接传入一个线性网络,根据解码器的状态的维度进行多次的复制。

整个模型

各个部分的初始化

    def __init__(self, config):
        super(Model, self).__init__()
        self.config = config

        # 定义嵌入层
        self.embedding = Embedding(config.num_vocab,  # 词汇表大小
                                   config.embedding_size,  # 嵌入层维度
                                   config.pad_id,  # pad_id
                                   config.dropout)

        # post编码器
        self.post_encoder = Encoder(config.post_encoder_cell_type,  # rnn类型
                                    config.embedding_size,  # 输入维度
                                    config.post_encoder_output_size,  # 输出维度
                                    config.post_encoder_num_layers,  # rnn层数
                                    config.post_encoder_bidirectional,  # 是否双向
                                    config.dropout)  # dropout概率

        # response编码器
        self.response_encoder = Encoder(config.response_encoder_cell_type,
                                        config.embedding_size,  # 输入维度
                                        config.response_encoder_output_size,  # 输出维度
                                        config.response_encoder_num_layers,  # rnn层数
                                        config.response_encoder_bidirectional,  # 是否双向
                                        config.dropout)  # dropout概率

        # 先验网络
        self.prior_net = PriorNet(config.post_encoder_output_size,  # post输入维度
                                  config.latent_size,  # 隐变量维度
                                  config.dims_prior)  # 隐藏层维度

        # 识别网络
        self.recognize_net = RecognizeNet(config.post_encoder_output_size,  # post输入维度
                                          config.response_encoder_output_size,  # response输入维度
                                          config.latent_size,  # 隐变量维度
                                          config.dims_recognize)  # 隐藏层维度

        # 初始化解码器状态
        self.prepare_state = PrepareState(config.post_encoder_output_size+config.latent_size,
                                          config.decoder_cell_type,
                                          config.decoder_output_size,
                                          config.decoder_num_layers)

        # 解码器
        self.decoder = Decoder(config.decoder_cell_type,  # rnn类型
                               config.embedding_size,  # 输入维度
                               config.decoder_output_size,  # 输出维度
                               config.decoder_num_layers,  # rnn层数
                               config.dropout)  # dropout概率

        # 输出层
        self.projector = nn.Sequential(
            nn.Linear(config.decoder_output_size, config.num_vocab),
            nn.Softmax(-1)
        )

需要注意的就是最后有个projector层,将解码器输出映射到词汇表维度,用于预测每个单词概率。

定义前向传播

    def forward(self, inputs, inference=False, max_len=60, gpu=True):
        if not inference:  # 训练
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            id_responses = inputs['responses']  # [batch, seq]
            len_responses = inputs['len_responses']  # [batch, seq]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            len_decoder = id_responses.size(1) - 1

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            embed_responses = self.embedding(id_responses)  # [batch, seq, embed_size]
            # state: [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
            _, state_responses = self.response_encoder(embed_responses.transpose(0, 1), len_responses)
            if isinstance(state_posts, tuple):
                state_posts = state_posts[0]
            if isinstance(state_responses, tuple):
                state_responses = state_responses[0]
            x = state_posts[-1, :, :]  # [batch, dim]
            y = state_responses[-1, :, :]  # [batch, dim]

            # p(z|x)
            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            # p(z|x,y)
            mu, logvar = self.recognize_net(x, y)  # [batch, latent]
            # 重参数化
            z = mu + (0.5 * logvar).exp() * sampled_latents  # [batch, latent]

            # 解码器的输入为回复去掉end_id
            decoder_inputs = embed_responses[:, :-1, :].transpose(0, 1)  # [seq-1, batch, embed_size]
            decoder_inputs = decoder_inputs.split([1] * len_decoder, 0)  # 解码器每一步的输入 seq-1个[1, batch, embed_size]
            first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]

            outputs = []
            for idx in range(len_decoder):
                if idx == 0:
                    state = first_state  # 解码器初始状态
                decoder_input = decoder_inputs[idx]  # 当前时间步输入 [1, batch, embed_size]
                # output: [1, batch, dim_out]
                # state: [num_layer, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

            outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq-1, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq-1, num_vocab]

            return output_vocab, _mu, _logvar, mu, logvar
        else:  # 测试
            id_posts = inputs['posts']  # [batch, seq]
            len_posts = inputs['len_posts']  # [batch]
            sampled_latents = inputs['sampled_latents']  # [batch, latent_size]
            batch_size = id_posts.size(0)

            embed_posts = self.embedding(id_posts)  # [batch, seq, embed_size]
            # state = [layers, batch, dim]
            _, state_posts = self.post_encoder(embed_posts.transpose(0, 1), len_posts)
            if isinstance(state_posts, tuple):  # 如果是lstm则取h
                state_posts = state_posts[0]  # [layers, batch, dim]
            x = state_posts[-1, :, :]  # 取最后一层 [batch, dim]

            # p(z|x)
            _mu, _logvar = self.prior_net(x)  # [batch, latent]
            # 重参数化
            z = _mu + (0.5 * _logvar).exp() * sampled_latents  # [batch, latent]

            first_state = self.prepare_state(torch.cat([z, x], 1))  # [num_layer, batch, dim_out]
            done = torch.tensor([0] * batch_size).bool()
            first_input_id = (torch.ones((1, batch_size)) * self.config.start_id).long()
            if gpu:
                done = done.cuda()
                first_input_id = first_input_id.cuda()

            outputs = []
            for idx in range(max_len):
                if idx == 0:  # 第一个时间步
                    state = first_state  # 解码器初始状态
                    decoder_input = self.embedding(first_input_id)  # 解码器初始输入 [1, batch, embed_size]
                else:
                    decoder_input = self.embedding(next_input_id)  # [1, batch, embed_size]
                # output: [1, batch, dim_out]
                # state: [num_layers, batch, dim_out]
                output, state = self.decoder(decoder_input, state)
                outputs.append(output)

                vocab_prob = self.projector(output)  # [1, batch, num_vocab]
                next_input_id = torch.argmax(vocab_prob, 2)  # 选择概率最大的词作为下个时间步的输入 [1, batch]

                _done = next_input_id.squeeze(0) == self.config.end_id  # 当前时间步完成解码的 [batch]
                done = done | _done  # 所有完成解码的
                if done.sum() == batch_size:  # 如果全部解码完成则提前停止
                    break

            outputs = torch.cat(outputs, 0).transpose(0, 1)  # [batch, seq, dim_out]
            output_vocab = self.projector(outputs)  # [batch, seq, num_vocab]

            return output_vocab, _mu, _logvar, None, None

id_posts是输入的id表示,len_posts是每个输入的长度,sampled_latents 是从标准正态分布中采样的隐变量。
需要注意的有,输入和回复的编码表示x和y都是采用的lstm的短时记忆h(而不是长时记忆c)或gru的h,并且是最后一层的。这是采用的开头那张图片的论文里源码的做法。

计算损失

def compute_loss(outputs, labels, masks, global_step):
    def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar):  # [batch, latent]
        """ 两个高斯分布之间的kl散度公式 """
        kld = 0.5 * torch.sum(prior_logvar - recog_logvar - 1
                              + recog_logvar.exp() / prior_logvar.exp()
                              + (prior_mu - recog_mu).pow(2) / prior_logvar.exp(), 1)
        return kld  # [batch]

    # output_vocab: [batch, len_decoder, num_vocab] 对每个单词的softmax概率
    output_vocab, _mu, _logvar, mu, logvar = outputs  # 先验的均值、log方差,后验的均值、log方差

    token_per_batch = masks.sum(1)  # 每个样本要计算损失的token数 [batch]
    len_decoder = masks.size(1)  # 解码长度

    output_vocab = output_vocab.reshape(-1, config.num_vocab)  # [batch*len_decoder, num_vocab]
    labels = labels.reshape(-1)  # [batch*len_decoder]
    masks = masks.reshape(-1)  # [batch*len_decoder]

    # nll_loss需要自己求log,它只是把label指定下标的损失取负并拿出来,reduction='none'代表只是拿出来,而不需要求和或者求均值
    _nll_loss = F.nll_loss(output_vocab.clamp_min(1e-12).log(), labels, reduction='none')  # 每个token的-log似然 [batch*len_decoder]
    _nll_loss = _nll_loss * masks  # 忽略掉不需要计算损失的token [batch*len_decoder]

    nll_loss = _nll_loss.reshape(-1, len_decoder).sum(1)  # 每个batch的nll损失 [batch]

    ppl = nll_loss / token_per_batch.clamp_min(1e-12)  # ppl的计算需要平均到每个有效的token上 [batch]

    # kl散度损失 [batch]
    kld_loss = gaussian_kld(mu, logvar, _mu, _logvar)

    # kl退火
    kld_weight = min(1.0 * (global_step % (2*config.kl_step)) / config.kl_step, 1)  # 周期性退火

    # 损失
    loss = nll_loss + kld_weight * kld_loss

    return loss, nll_loss, kld_loss, ppl, kld_weight

需要注意的是nll损失只要计算一个batch的nll损失,而ppl的计算是要将一个batch的nll损失平均到每个需要计算字符上的。另外不要忘了乘上mask,忽略不要计算损失的字符例如pad的损失。
github:https://github.com/Kirito0918/cvae-dialog
参考论文: Learning Discourse-level Diversity for Neural Dialog Models using Conditional Variational Autoencoders

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,589评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,615评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,933评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,976评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,999评论 6 393
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,775评论 1 307
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,474评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,359评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,854评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,007评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,146评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,826评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,484评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,029评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,153评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,420评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,107评论 2 356

推荐阅读更多精彩内容