实现 AutoEncoder 模型

最近在 kaggle 上学习些 keras 的使用方法,这里总结下 AutoEncoder 使用方式

模型定义

对于 AutoEncoder 模型定义有两种方式:

  • Encoder 和 Decoder 分开定义,然后通过 Model 进行合并
  • Encoder 和 Decoder 同一个 Model 进行定义,在 Encoder 最后一层设置特定名称,然后在取出直接使用即可

分开定义

from operator import mul   
from functools import reduce     

def product(dset):
    return reduce(mul, dset)

def encoder_model(x_shape, y_shape):
    """
    定义 Encoder 部分模型,这里将最后一层的数据维度记录下来,作为 Decoder 输入层后的接下来的一层大小
    """
    inp = Input(shape=x_shape)
    m = Conv2D(16, (3, 3), activation='relu', padding='same')(inp)
    m = MaxPooling2D((2, 2), padding='same')(m)
    m = Conv2D(8, (3, 3), activation='relu', padding='same')(m)
    m = MaxPooling2D((2, 2), padding='same')(m)
    m = Conv2D(8, (3, 3), activation='relu', padding='same')(m)
    shape = m.shape
    m = Flatten()(m)
    outp = Dense(y_shape)(m)
    return inp, outp, shape[1:]

def decoder_model(x_shape, x_shape_2d):
    """
    定义 Decoder 部分模型
    """
    inp = Input(shape=x_shape)
    m = Dense(product(x_shape_2d))(inp)
    # 数据维度的转换 1D 转为 2D
    m = Reshape(x_shape_2d)(m)
    m = Conv2D(8, (3, 3), activation='relu', padding='same')(m)
    m = UpSampling2D((2, 2))(m)
    m = Conv2D(8, (3, 3), activation='relu', padding='same')(m)
    m = UpSampling2D((2, 2))(m)
    outp = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(m)
    return inp, outp

# 定义 encoder 模型
encoder_inp, encoder_outp, shape = encoder_model((28, 28, 1), 100)
encoder = Model(inputs=encoder_inp, outputs=encoder_outp)
encoder.summary()

print(shape)

# 定义 decoder 模型
decoder_inp, decoder_outp = decoder_model(encoder_outp.shape[1:], shape)
decoder = Model(inputs=decoder_inp, outputs=decoder_outp)
decoder.summary()

# 定义 autoencoder 模型
autoencoder = Model(inputs=encoder_inp, outputs=decoder(encoder(encoder_inp)), name='autoencoder')
autoencoder.compile(loss='binary_crossentropy', optimizer='adam')
autoencoder.summary()

noise = np.random.normal(loc=0.5, scale=0.5, size=x_train.shape)
x_train_noisy = x_train + noise

noise = np.random.normal(loc=0.5, scale=0.5, size=x_cv.shape)
x_cv_noisy = x_cv + noise
autoencoder.fit(x_train_noisy, x_train, validation_data=(x_cv_noisy, x_cv),
                epochs=100, batch_size=128)

同时定义

input_img = Input(shape=(28, 28, 1))  # adapt this if using `channels_first` image data format

x = Conv2D(16, (3, 3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same', name='encoder')(x)

# at this point the representation is (4, 4, 8) i.e. 128-dimensional

x = Conv2D(8, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_img, decoded)
autoencoder.compile(loss='binary_crossentropy', optimizer='adam')
autoencoder.summary()
autoencoder.fit(x_train, x_train, batch_size=64, epochs=100, verbose=1, validation_data=(x_cv, x_cv))

encoder = Model(inputs=autoencoder.input,
                        outputs=autoencoder.get_layer('encoder').output)

这样通过 AutoEncoder 模型的训练来训练 Encoder 模型

使用 Encoder

x_encoded = encoder.predict(x_cv_noisy)
print(x_encoded[:1])

常见问题

  1. Loss 高居不下

可以通过如下几个方面进行改善:

  • 对输入数据进行归一化,输出的数据很容易拟合
x_train = normalize(x_train).reshape(-1, 28, 28, 1).astype(float32) / 255

# 在 Decoder 输出的内容需要展示为图片时候,简单的乘以 255 即可
x_decoded = autoencoder.predict(x_cv_noisy)

imgs = np.concatenate([x_test[:num], x_cv_noisy[:num], x_decoded[:num]])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
# 将数据进行还原即可
imgs = (imgs * 255).astype(np.uint8)
  • 降低学习率,归一化之后,数据基本都是小数,学习率高了会不容易拟合
sgd = SGD(lr=1e-4, decay=1e-6, momentum=0.4, nesterov=True)
  • 修改权重初始化的方式
    m = Conv2D(16, 3, activation=LeakyReLU(alpha=0.2), padding='same', kernel_initializer='glorot_normal')(m)

·「参考」
Why my training and validation loss is not changing?
Building Autoencoders in Keras

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 215,923评论 6 498
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,154评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 161,775评论 0 351
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,960评论 1 290
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,976评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,972评论 1 295
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,893评论 3 416
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,709评论 0 271
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,159评论 1 308
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,400评论 2 331
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,552评论 1 346
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,265评论 5 341
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,876评论 3 325
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,528评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,701评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,552评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,451评论 2 352

推荐阅读更多精彩内容

  • 摘要:在深度学习之前已经有很多生成模型,但苦于生成模型难以描述难以建模,科研人员遇到了很多挑战,而深度学习的出现帮...
    肆虐的悲傷阅读 11,277评论 1 21
  • Spring Cloud为开发人员提供了快速构建分布式系统中一些常见模式的工具(例如配置管理,服务发现,断路器,智...
    卡卡罗2017阅读 134,649评论 18 139
  • 今日,再一次折服于一部印度神剧——《功夫小蝇》。为此,在坐轻轨的时候还不小心坐过了站点,虽然我又很淡定地坐回去了,...
    右月阅读 800评论 0 0
  • 有两种商业模式,第一种是将单位时间售价提高,第二种是将单位时间卖出很多份 不过一般来说,有能力将自己单位时间售价提...
    周书恒阅读 410评论 0 0
  • 姑娘,累吗? 累。 姑娘,苦吗? 苦。 既然这么累,这么苦,那你想不想要一场不为爱情的爱情?比如说和一位作家,呃,...
    梦萦春秋阅读 7,120评论 25 33