利用keras从实例掌握深度学习图像分类

0. 背景

之前写过一个 keras 进行图像分类的教程,同时也便于自己使用,进行了开源。经过一段时间的学习,虽然已不再使用 keras 和 tensorflow 作为深度学习框架进行项目开发,但是 keras 的简洁性还是值得新手选择使用的。这里完善一下教程和代码。建议:框架只是工具而已,还是多学理论和论文的好

完整代码:keras_image_classifier

完整教程:利用 keras 从实例掌握深度学习图像分类

个人博客:超杰

要求:

keras==2.2.0 tensorflow==1.8.0 (尽量保持一致,版本不同带来问题,请自行谷歌)

1. 更新

2018年12月29日 第一次更新:更新全部文档和代码

2. 声明

开源只是帮助一些需要帮助的人,如果有疑问欢迎咨询,但是代码已做过线下调试,确认无误才进行发布的,在使用过程中如果遇到一些 bug 请自行谷歌或百度,如果仍有疑问欢迎联系我,联系方式:zhuchaojie@buaa.edu.cn

另外深度学习也好机器学习也好,请学点编程吧。。。。

3. 数据格式

为了便于大家使用相同的数据进行训练,从而熟悉整个过程,然后再转而到自己的实际项目上,这里了提供公开数据集交通标志数据集 ,下载链接 traffic-sign
数据存储格式:

  • data/
    • train/
      • 00000/
      • 00001/
      • 00002/
      • ...
    • test/
      • 00000/
      • 00001/
      • 00002/

关于数据集的划分:为了验证模型的效果,我们需要再另设验证集,以便在训练过程中验证模型效果,但是最终的评测结果,需要在测试集上进行,所以我们使用 sklearn.model_selection.train_test_split() 函数对训练集进行划分,比例采用 训练集:验证集 = 7:3

4. 项目结构介绍

  • checkpoints/
  • config.py
  • model.py
  • data.py
  • main.py

4.1 checkpoints

主要用来存放训练好的模型权重,默认只保存模型权重不保存整个网络结构。

4.2 config.py

超参数文件,本文件定义了整个项目要用到的超参数,具体如下:

### define global configs  ###

class DefaultConfigs(object):
    #1. string configs
    train_data = "../data/train/"
    val_data = "../data/val/"  # if exists else use train_test_split to generate val dataset
    test_data = "../data/all/traffic-sign/test/00003/"  # for competitions
    model_name = "NASNetMobile"
    weights_path = "./checkpoints/model.h5"#save weights for predict
    
    #2. numerical configs
    lr = 0.001
    epochs = 50
    num_classes = 62
    image_size = 224
    batch_size = 16
    channels = 3
    gpu = "0"
    
config = DefaultConfigs()

4.3 model.py

模型搭建文件 。由于日常任务中常使用预训练好的模型并进行 finetune ,这里只提供使用该版本,至于自己搭建模型,不是本教程的重点所在,如果有需要的请自行谷歌。

提醒:这里为了方便进行更复杂的分类任务,提供了一个 ensemble 版本的模型搭建过程。


from keras.applications import NASNetMobile
from keras.applications import resnet50
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D,Dense,Flatten,Input,Concatenate,Dropout
from keras.models import Model
from keras.losses import mae, sparse_categorical_crossentropy, binary_crossentropy,categorical_crossentropy
from keras.optimizers import Adam
from config import config

def get_model():
    inputs = Input((config.image_size, config.image_size, config.channels))
    base_model = NASNetMobile(include_top=False, input_shape=(config.image_size, config.image_size, config.channels))#, weights=None
    x = base_model(inputs)
    out1 = GlobalMaxPooling2D()(x)  # GMP feature
    out2 = GlobalAveragePooling2D()(x) # GAP feature
    out3 = Flatten()(x)                # Flatten feature
    out = Concatenate(axis=-1)([out1, out2, out3])  #concate all feature
    out = Dropout(0.5)(out)
    out = Dense(config.num_classes, activation="softmax", name="classifier")(out)
    model = Model(inputs, out)
    model.compile(optimizer=Adam(0.0001), loss=categorical_crossentropy, metrics=['acc'])

    return model

4.4 data.py

自定义的data generator模块。在之前,个人喜欢将所有数据直接加载到内存中,然后再进训练,这样做的好处就是训练过程减弱了大量数据的频繁调度问题,但是如果数据量过大,内存吃不消就行不通了,而且在调试过程中也很麻烦,这里使用 python 的 yield 机制,能够避免小机无法运行的问题。

本模块共包含三个部分:get_files() augument() create_train()

  1. get_files() 详情如下:

主要功能是循环读取每个文件夹下的图片,并根据路径信息提取类别信息。

提醒:windows系统请自行修改类别获取方法,如果不会,请百度或谷歌

修改位置 labels.append(int(file.split("/")[-2]))

def get_files(root,mode):
    #for test
    if mode == "test":
        files = []
        for img in os.listdir(root):
            files.append(root + img)
        files = pd.DataFrame({"filename":files})
        return files
    elif mode != "test":
        #for train and val       
        all_data_path,labels = [],[]
        image_folders = list(map(lambda x:root+x,os.listdir(root)))
        all_images = list(chain.from_iterable(list(map(lambda x:glob(x+"/*"),image_folders))))
        print("loading train dataset")
        for file in tqdm(all_images):
            all_data_path.append(file)
            labels.append(int(file.split("/")[-2]))
        all_files = pd.DataFrame({"filename":all_data_path,"label":labels})
        return all_files
    else:
        print("check the mode please!")

  1. augument()

使用的线上数据增强方式,见开源 imgaug,此处不做重复介绍。

  1. create_train()

使用python的yield机制对数据进行加载,函数定义如下:


class data_generator:
    def __init__(self,data_lists,mode,augument=True):
        self.mode = mode
        self.augment = augument
        self.all_files = data_lists
    def create_train(self):
        images = []
        dataset_info = self.all_files.values
        #embed()
        """
        if not self.mode == "test":
            for index,row in all_files.iterrows():
                images.append((row["filename"],row["label"]))
        else:
            for index,row in all_files.iterrows():
                images.append((row["filename"]))
        """
        while 1:
            shuffle(dataset_info)
            #print(dataset_info)
            for start in range(0,len(dataset_info),config.batch_size):
                end = min(start + config.batch_size,len(dataset_info))
                batch_images = []
                X_train_batch = dataset_info[start:end]
                batch_labels = np.zeros((len(X_train_batch),config.num_classes))
                for i in range(len(X_train_batch)):
                    #print(X_train_batch[i])
                    image = cv2.imread(X_train_batch[i][0])
                    image = cv2.resize(image,(config.image_size,config.image_size),interpolation=cv2.INTER_NEAREST)

                    if self.augument:
                        image = self.augument(image)
                    batch_images.append(image / 255.)
                    if not self.mode == "test":
                        batch_labels[i][X_train_batch[i][1]] = 1
                        #print(np.array(batch_images).shape)
                yield np.array(batch_images, np.float32), batch_labels

4.5 main.py

包含功能:

  • create_callbacks(),主要实现提前停止训练权重保存学习率衰减
  • train(),训练函数 ,执行训练
  • test(),测试函数,返回每个图片的预测类别
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from keras.preprocessing.image import img_to_array
from model import get_model
from config import config
from data import data_generator,get_files
import os
import warnings
import numpy as np
from IPython import embed
warnings.filterwarnings("ignore")
os.environ["CUDA_VISIBLE_DEVICES"] = config.gpu

def create_callbacks():

    early_stop = EarlyStopping(
        monitor         =       "val_acc",
        mode            =       "auto",
        patience        =       30,
        verbose         =       1
    )

    checkpoint = ModelCheckpoint(
        filepath            =       config.weights_path,
        monitor             =       "val_acc",
        save_best_only      =       True,
        save_weights_only   =       True,
        mode                =       "max",
        verbose             =        1
    )

    lr_reducer = ReduceLROnPlateau(
        monitor         =       "val_acc",
        mode            =       "max",
        epsilon         =       0.01,
        factor          =       0.1,
        patience        =       5,
    )
    return [early_stop,checkpoint,lr_reducer]

def train(callbacks):
    #1. compile
    print("--> Compiling the model...")
    model = get_model()
    # load raw train data
    raw_train_data_lists = get_files(config.train_data,"train")
    #split raw train data to train and val
    train_data_lists,val_data_lists = train_test_split(raw_train_data_lists,test_size=0.3)
    # for train
    train_datagen = data_generator(train_data_lists,"train",augument=True).create_train()
    #embed()
    # val data
    val_datagen = data_generator(val_data_lists,"val",augument=True).create_train()  # if model can predict better on augumented data ,the model should be more reboust
    history = model.fit_generator(
        train_datagen,
        validation_data = val_datagen,
        epochs = config.epochs,
        verbose = 1,
        callbacks = callbacks,
        steps_per_epoch=len(train_data_lists) // config.batch_size,
        validation_steps=len(val_data_lists) // config.batch_size
    )
def test():
    test_data_lists = get_files(config.test_data,"test")
    test_datagen = data_generator(test_data_lists,"test",augument=False).create_train()
    model = get_model()
    model.load_weights(config.weights_path)
    predicted_labels = np.argmax(model.predict_generator(test_datagen,steps=len(test_data_lists) / 16),axis=-1)  
    print(predicted_labels) 
if __name__ == "__main__":
    if not os.path.exists("./checkpoints/"):
        os.mkdir("./checkpoints/")
    callbacks = create_callbacks()
    mode = "test"
    if mode == "train":
        train(callbacks)
    elif mode =="test":
        test()
    else:
        print("check mode!")

5 使用方法

  1. 下载数据集并解压 traffic-sign
  2. 修改 config.pytrain_datatest_data 路径,其中 test_data 示例:test_data = "../data/all/traffic-sign/test/00003/"
  3. 训练:python main.py
  4. 预测:修改 main.py 中 76 行的 mode = "test" 并执行 python main.py
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 214,504评论 6 496
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 91,434评论 3 389
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 160,089评论 0 349
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 57,378评论 1 288
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 66,472评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,506评论 1 292
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,519评论 3 413
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,292评论 0 270
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,738评论 1 307
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,022评论 2 329
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,194评论 1 342
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,873评论 5 338
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,536评论 3 322
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,162评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,413评论 1 268
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,075评论 2 365
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,080评论 2 352