keras-二分类、多分类

旨在使用keras构建出二分类和多分类模型,给出相关代码。

机器学习问题中,二分类和多分类问题是最为常见,下面使用keras在imdb和newswires数据上进行相应的实验。

[code: https://github.com/zylhub/More_Python/blob/master/keras_TOT/simple-network-imdb.py ]

imdb 二分类

  1. 文本获取
  2. 文本预处理
  3. 定义模型/神经网络
  4. 定义评价指标和优化方法
  5. 划分数据集,进行训练
  6. 测试集对模型进行测试
# coding=utf-8
# @Time   : 17-12-22 下午11:01
# @Author : knight
# @File   : simple-network-imdb.py

from keras import layers
from keras import models
from keras.datasets import imdb
import numpy as np

# 1. 加载数据
(train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)

# 2. 定义模型
model = models.Sequential()
model.add(layers.Dense(32, activation='relu', input_shape=(10000,)))
model.add(layers.Dense(16, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

######################
# input_tensor = layers.Input(shape=(784,))
# x = layers.Dense(32, activation='relu')(input_tensor)
# output_tensor = layers.Dense(10, activation='softmax')(x)
# model = models.Model(input=input_tensor, output=output_tensor)

# 3.1  文本数据预处理
word_index = imdb.get_word_index()  # 获取词索引  id to word
reverse_word_index = dict((value, key) for key, value in word_index.items())  # word to id
decoded_review = ' '.join([reverse_word_index.get(i-3, '?') for i in train_data[0]])

# 3.2 特征向量化
def vectorize_sequences(sequences, dim=10000):
    """
    词袋模型,获取词向量
    :param sequences: [[sentence1], [sentence2], [...]]
    :param dim: 词袋大小,词特征维度
    :return:
    """
    results = np.zeros((len(sequences), dim))
    for i, sequence in enumerate(sequences):
        results[i, sequence] = 1.0
    return results

x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

y_train = np.asarray(train_labels).astype('float32')
y_test = np.asarray(test_labels).astype('float32')

from keras import metrics, losses
model.compile(optimizer='rmsprop',
              loss=losses.binary_crossentropy,
              metrics=[metrics.binary_accuracy])  # metrics 传入list,可以使用多种评价方式

# 划分验证集
x_val = x_train[:10000]
partial_x_train = x_train[10000:]

y_val = y_train[:10000]
partial_y_train = y_train[10000:]

# 4. 训练模型
history = model.fit(partial_x_train, partial_y_train,
                      epochs=20,
                      batch_size=512,
                      validation_data=(x_val, y_val))  # 验证集

# history 获取训练过程的acc loss val_acc val_loss


history_dict = history.history
print(history_dict.keys())

# 5. Plotting the training and validation loss

import matplotlib.pyplot as plt

# 画出训练集和验证集的损失和精度变化,分析模型状态

acc = history.history['binary_accuracy']  # 训练集acc
val_acc = history.history['val_binary_accuracy']  # 验证集 acc
loss = history.history['loss']  # 训练损失
val_loss = history.history['val_loss']  # 验证损失

epochs = range(1, len(acc)+1)  # 迭代次数

plt.plot(epochs, loss, 'bo', label='Training loss')  # bo for blue dot 蓝色点
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

plt.clf()  # clar figure


plt.plot(epochs, acc, 'bo', label='Training acc')  # bo for blue dot 蓝色点
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()
# ----------------------------------------使用model 定义网络
input_tensor = layers.Input(shape=(10000,))
x = layers.Dense(32, activation='relu')(input_tensor)
x = layers.Dense(16, activation='relu')(x)
x = layers.Dense(16, activation='relu')(x)
output_tensor = layers.Dense(1, activation='sigmoid')(x)
network = models.Model(inputs=input_tensor, outputs=output_tensor)

network.compile(optimizer='rmsprop',
              loss='binary_crossentropy',
              metrics=['accuracy'])
network.fit(x_train, y_train, epochs=4, batch_size=512)

print("test data evaluate, epochs=20", model.evaluate(x_test, y_test))
print("test data evaluate, epochs=4 ", network.evaluate(x_test, y_test))
epochs=20_acc.png
epochs=20_loss.png

根据训练集和验证集的loss和acc,可以判断出模型过拟合了,此时应该尝试早停止,比如在Epochs=4的时候(迭代次数到4之后,训练集的loss继续降低,验证集loss增加了,同时验证集的acc开始下降,模型学习的目的是对未来的数据有较好的泛化能力,因此选择验证集较好的时候作为可用的模型,会得到较好的结果 )。


newswires 多标签多分类

构建一个基本的神经网络来解决文本分类问题,先给出网络结构

Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 64)                640064    
_________________________________________________________________
dense_2 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_3 (Dense)              (None, 46)                2990      
=================================================================
Total params: 647,214
Trainable params: 647,214
Non-trainable params: 0
_________________________________________________________________
# coding=utf-8
# ---------------------------
# @Time   : 17-12-23 下午10:05
# @Author : knight
# @File   : simple-network-newswires.py       
# ---------------------------
from __future__ import absolute_import

from keras.datasets import reuters
from keras_TOT.utils_text import vectorize_sequences, to_one_hot

(train_data, train_labels), (test_data, test_labels) = reuters.load_data(num_words=10000)

# 8,982 training examples and 2,246 test examples:
print('train len :', len(train_data))
print('test len: ', len(test_data))

word_index = reuters.get_word_index()  # 获取词袋
revserse_word_index = dict((value, key) for key, value in word_index.items())
decoded_newswire = ' '.join(revserse_word_index.get(i-3, '?') for i in train_data[0])

print('decoded_newswire', decoded_newswire)

x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)

###------------------------------------###
one_hot_train_labels = to_one_hot(train_labels)
one_hot_test_labels = to_one_hot(test_labels)

# 也可以直接使用keras提供的one-hot 方法
# from keras.utils.np_utils import to_categorical
# one_hot_train_labels = to_categorical(train_labels)
# one_hot_test_labels = to_categorical(test_labels)

###------------------------------------###


from keras import models
from keras import layers

network = models.Sequential()  # 序列方式构建模型
network.add(layers.Dense(64, activation='relu', input_shape=(10000,)))
network.add(layers.Dense(64, activation='relu'))
network.add(layers.Dense(46, activation='softmax'))  # 多分类问题常用的激活函数softmax

network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',  # 交叉熵
                metrics=['accuracy'])

x_val = x_train[:1000]
partial_x_train = x_train[1000:]

y_val = one_hot_train_labels[:1000]
partial_y_train = one_hot_train_labels[1000:]

history = network.fit(partial_x_train, partial_y_train, epochs=20, validation_data=(x_val, y_val))


import matplotlib.pyplot as plt

loss = history.history['loss']
val_loss = history.history['val_loss']

epoch = range(1, len(loss)+1)

plt.plot(epoch, loss, 'bo', label='Training loss')
plt.plot(epoch, val_loss, 'b', label='Training val_loss')
plt.title("Training and validation loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.show()

plt.clf()

acc = history.history['acc']
val_acc = history.history['val_acc']

plt.plot(epoch, acc, 'bo', label='Training acc')
plt.plot(epoch, val_loss, 'b', label='Training val_loss')
plt.title('Training and validation acc')
plt.xlabel('Epochs')
plt.ylabel('acc')
plt.legend()
plt.show()
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,490评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,581评论 3 395
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 165,830评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,957评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,974评论 6 393
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,754评论 1 307
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,464评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,357评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,847评论 1 317
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,995评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,137评论 1 351
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,819评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,482评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,023评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,149评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,409评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,086评论 2 355

推荐阅读更多精彩内容