【技术博客】对抗性域适应

域适应简介

域适应是迁移学习中最常见的问题之一,域不同但任务相同,且源域数据有标签,目标域数据没有标签或者很少数据有标签。
域适应通过将源域和目标域的特征投影到相似的特征空间,这样就可以拿源域的分类器对目标域进行分类了

下面拿二分类做说明,如下图:


域.PNG
域.PNG

图中红圈是源域,蓝圈是目标域,圆圈和叉是不同特征的数据,源域的分类器将源域的数据分为两类,即虚线所示。
此时如果拿源域的分类器在目标域上分类,从图中可以看到,效果很差。

那怎么办呢,有一种方法就是把源域和目标域的分布对齐,如图片右边所示,源域目标域的分布相似(即相似特征的数据分布在相近的位置),这样就可以直接拿源域的分类器对目标域进行分类了。

训练过程域对抗生成网络 GAN 相似
同时训练两个模型:一个用来提取目标域特征 MT,和一个用来判断特征来自源域还是目标域的域辨别器 D,MT 的训练过程是最大化 D 产生错误的过程,即MT提取的特征让 D 分辨不出来是来自源域还是目标域。

目标域特征提取器 MT 和域判别器 D 互为对手:D 学习去判别特征是来自源域还是目标域,MT 学习让自己提取的特征更接近源域提取出的特征。目标域特征提取器 MT 可以被认为是一个伪造团队,试图产生假货并在不被发现的情况下使用它,而域判别器 D 类似于警察,试图检测假币。在这个游戏中的竞争驱使两个团队改进他们的方法,直到真假难分为止。

对抗性域适应

数据的选取

为了效果好,训练简单,我选取 mnist 数据集中 0、1 的数据作为源域,2、3 的数据作为目标域。源域和目标域的数据各 10000 个。
在训练时,源域可获得数据和标签,而目标域只能获得数据,没有标签,来模拟域适应的背景。目标域的标签仅在测试精度时使用。

网络

1.源域特征提取器 MS、目标域特征提取器 MT。所谓特征提取器,实际上就是将识别 mnist 的网络去掉最后一层分类层。

        (encoder): Sequential (
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (2): ReLU ()
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): ReLU ()
    )
    (fc1): Linear (64 * 4 * 4 -> 512)

把这个网络的输出看作是提取出的特征

2.分类器C。实际就是识别 mnist 的网络最后一层分类层,一个简单的全连接网络。

        Classifier (
    (fc2): Linear (512 -> 2)
    )

3.域识别器 D。根据特征提取器的输出来判别数据来自源域还是目标域,输出 0 代表来自源域,输出 1 代表来自目标域。

        Discriminator (
     (layer): Sequential (
    (0): Linear (512 -> 512)
    (1): Linear (512 -> 512)
    (2): Linear (512 -> 2)
    ))

训练过程

训练MS、C

首先,在源域上训练特征提取器 MS 和分类器 C


过程1.PNG
过程1.PNG

训练过程和一般训练过程相似,只不过把整个网络分成了两部分来训练、优化。

def train_MS_C(loader_ms):
    # 模型
    MS = Encoder()
    C = Classifier()
    # 优化器
    o_ms = optim.SGD(MS.parameters(), lr=0.03)
    o_c = optim.SGD(C.parameters(), lr=0.03)
    criterion = nn.CrossEntropyLoss()  # 计算损失
    for j in range(1):
        print(j)
        # 训练
        for i, (images, labels) in enumerate(loader_ms):
            o_ms.zero_grad()
            o_c.zero_grad()
            outputs_mid = MS(images)
            outputs = C(outputs_mid)

            loss = criterion(outputs, labels)
            loss.backward()

            o_ms.step()  # 优化参数
            o_c.step()

            if i % 100 == 0:
                print(i)
                print('current loss : %.5f' % loss.data.item())
    # 保存模型
    np.save(params.MS_save_dir, MS.get_w())
    np.save(params.C_save_dir, C.get_w())

训练完成后,在源域的精确度为 0.9985
如果直接拿源域的特征提取器和分类器对目标域进行分类的话,精确度只有 0.5840


acc1.PNG
acc1.PNG

固定MS和C,训练MT和D

接着,固定 MS 和 C 不变,即不改变它们的网络权重,在源域和目标域上对抗式学习目标域特征提取器 MT 和域识别器 D
1.用 MS 初始化 MT,这样开始目标域会获得一个不错的精度 0.5840,接着在这个基础上训练,更容易收敛到好的方向,并且收敛过程也快了。

MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())
过程2.PNG
过程2.PNG
def train_MT_D(loader_ms, loader_mt):
    # 模型
    MS = Encoder()
    MT = Encoder()
    D = Discriminator()
    # 加载模型
    MS.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())

    if params.first_train:
        params.first_train = False
        # 第一次训练
        # MT用MS的权重初始化
        MT.update_w(np.load(params.MS_save_dir, encoding='bytes', allow_pickle=True).item())
    else:
        MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())
        D.update_w(np.load(params.D_save_dir, encoding='bytes', allow_pickle=True).item())

    # 优化器
    o_mt = optim.SGD(MT.parameters(), lr=0.00001)
    o_d = optim.SGD(D.parameters(), lr=0.00001)
    criterion = nn.CrossEntropyLoss()  # 计算损失
    # 训练
    for j in range(1):
        print(j)
        # 训练D 域辨别器
        data_zip = zip(loader_ms, loader_mt)
        for i, ((images_s, labels_s), (images_t, labels_t)) in enumerate(data_zip):
            ################对域辨别器D的训练
            # 提取的特征
            f_s = MS(images_s)
            f_t = MT(images_t)
            f_cat = torch.cat((f_s, f_t), 0)
            # 域辨别器辨别结果
            out_D = D(f_cat.detach())

            predicts_D = torch.max(out_D.data, 1)[1]
            if i == 0:
                print('域辨别器的辨别结果')
                print(predicts_D)

            # 构造损失对比用的标签
            len_s = len(labels_s)
            len_t = len(labels_t)

            temp1 = torch.zeros(len_s)
            temp2 = torch.ones(len_t)

            lab_D = torch.cat((temp1, temp2), 0).long()

            # 梯度置0
            o_d.zero_grad()
            # 计算loss
            loss_D = criterion(out_D, lab_D)
            # 反向传播
            loss_D.backward()
            # 优化网络
            o_d.step()
            ##############################对目标域特征提取器MT的训练
            # 提取的特征
            f_t = MT(images_t)
            # 域辨别器辨别结果
            d_t = D(f_t)
            # 构造计算损失的outputs、labels
            out_MT = d_t

            predicts_MT = torch.max(out_MT.data, 1)[1]

            lab_MT = torch.zeros(len_t).long()
            # 梯度置0
            o_mt.zero_grad()
            # 计算loss
            loss_MT = criterion(out_MT, lab_MT)
            # 反向传播
            loss_MT.backward()
            # 优化网络
            o_mt.step()

            if i % 100 == 0:
                print(i)
                print('current loss_D : %.5f' % loss_D.data.item())
                print('current loss_MT : %.5f' % loss_MT.data.item())
    # 保存模型
    np.save(params.MT_save_dir, MT.get_w())
    np.save(params.D_save_dir, D.get_w())

用MT和C在目标域上分类

最后用训练好的目标域特征提取器 MT 和分类器 C 来在目标域上分类


过程3.PNG
过程3.PNG
def test_MT_C(loader_mt):
    MT = Encoder()
    C = Classifier()
    # 加载模型
    MT.update_w(np.load(params.MT_save_dir, encoding='bytes', allow_pickle=True).item())
    C.update_w(np.load(params.C_save_dir, encoding='bytes', allow_pickle=True).item())
    correct = 0
    for images, labels in loader_mt:
        outputs_mid = MT(images)
        outputs = C(outputs_mid)
        _, predicts = torch.max(outputs.data, 1)
        correct += (predicts == labels).sum()
    total = len(loader_mt.dataset)
    print('MT+C  Accuracy: %.4f' % (1.0 * correct / total))

实验结果

拿源域的特征提取器和分类器对目标域进行分类的话,精确度只有 0.5840


acc1.PNG
acc1.PNG

下图是域辨别器 D 的结果,前半部分的输入是源域的特征,后半部分的输入是目标域的特征,现在 D 大部分都能判断正确。


捕获.PNG
捕获.PNG

训练几轮后,精确度上升了一点


acc2.PNG
acc2.PNG

D 对域的分辨能力下降了,大部分目标域的输入都判断为源域的。


捕获2.PNG
捕获2.PNG

在训练 40 轮后,精确度在 0.9 附近波动,与开始的 0.5840 相比,精确度提升了很多


acc3.PNG
acc3.PNG

D 无法分辨源域和目标域了,将所有输入都识别为源域的。


捕获3.PNG
捕获3.PNG

代码地址

https://momodel.cn/explore/5f1574360a2fac574eb9c3f6?type=app

参考

Adversarial Discriminative Domain Adaptation
https://blog.csdn.net/sinat_29381299/article/details/73504196
https://github.com/corenel/pytorch-adda

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 213,254评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,875评论 3 387
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,682评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,896评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,015评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,152评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,208评论 3 412
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,962评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,388评论 1 304
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,700评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,867评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,551评论 4 335
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,186评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,901评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,142评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,689评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,757评论 2 351