训练U-net网络对图像分割实例

模型搭建、数据预处理、模型训练
import keras
from keras import layers
import numpy as np
import cv2
import os

batch_size = 2
classe_nums = 2


def U_netModel(num_classes, input_shape=(512, 512, 1)):
    inputs = layers.Input(shape=input_shape)
    conv1_1 = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(inputs)
    conv1_2 = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(conv1_1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1_2)

    conv2_1 = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(pool1)
    conv2_2 = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(conv2_1)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2_2)

    conv3_1 = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(pool2)
    conv3_2 = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(conv3_1)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3_2)

    conv4_1 = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(pool3)
    conv4_2 = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(conv4_1)
    pool4 = layers.MaxPooling2D(pool_size=(2, 2))(conv4_2)

    conv5_1 = layers.Conv2D(filters=1024, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(pool4)
    conv5_2 = layers.Conv2D(filters=1024, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="relu")(conv5_1)

    deconv6_up = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                               activation="relu")(layers.UpSampling2D((2, 2))(conv5_2))
    merge6 = layers.concatenate([conv4_2, deconv6_up])
    deconv6_1 = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(merge6)
    deconv6_2 = layers.Conv2D(filters=512, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(deconv6_1)

    deconv7_up = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                               activation="relu")(layers.UpSampling2D((2, 2))(deconv6_2))
    merge7 = layers.concatenate([conv3_2, deconv7_up])
    deconv7_1 = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(merge7)
    deconv7_2 = layers.Conv2D(filters=256, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(deconv7_1)

    deconv8_up = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                               activation="relu")(layers.UpSampling2D((2, 2))(deconv7_2))
    merge8 = layers.concatenate([conv2_2, deconv8_up])
    deconv8_1 = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(merge8)
    deconv8_2 = layers.Conv2D(filters=128, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(deconv8_1)

    deconv9_up = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                               activation="relu")(layers.UpSampling2D((2, 2))(deconv8_2))
    merge9 = layers.concatenate([conv1_2, deconv9_up])
    deconv9_1 = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(merge9)
    deconv9_2 = layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                              activation="relu")(deconv9_1)
    ###########num_classes的值根据有多少类别决定
    ###########二分类激活函数sigmoid,labels是用one_hot编码
    outputs = layers.Conv2D(filters=num_classes, kernel_size=(3, 3), padding="same", kernel_initializer="he_normal",
                            activation="sigmoid")(deconv9_2)

    model = keras.models.Model(inputs=inputs, outputs=outputs)

    return model


###########读取图片的文件名
def read_file_names():
    dataSetNames = []
    with open("dataset2/train.txt") as f:
        for line in f:
            dataSetNames.append(line)
    return dataSetNames


############分割训练集验证集
def split_data_set(dataSet, ratio):
    assert type(dataSet) == list
    total_nums = len(dataSet)
    nums_train = int(total_nums * 0.8)
    nums_validation = total_nums - nums_train
    train_dataSet = dataSet[:nums_train]
    validation_dataSet = dataSet[nums_train:]
    return train_dataSet, validation_dataSet

def generate__data_from_file(fileNames, batch_size, height, width, class_nums):
    total = len(fileNames)
    i = 0
    while True:
        train_X = []
        train_Y = []
        for i in range(0, batch_size):
            if i == 0:
                np.random.permutation(fileNames)
            name = fileNames[i]
            x_name = name.split(";")[0]
            x_img = cv2.imread(r"./dataset2/jpg" + '/' + x_name, 0)
            x_img = np.resize(x_img, (height, width, 1))
            x_img = np.array(x_img)
            x_img = x_img / 255.
            train_X.append(x_img)

            y_name = name.split(";")[1].strip()
            y_img = cv2.imread(r"./dataset2/png" + '/' + y_name, 1)
            y_img = np.resize(y_img, (height, width, 3))
            b, g, r = cv2.split(y_img)
            y_img = cv2.merge([r, g, b])
            y_img = np.array(y_img)
            seg_labels = np.zeros((int(height), int(width), classes))
            for c in range(class_nums):
                seg_labels[:, :, c] = (y_img[:, :, 0] == c).astype(int)
            train_Y.append(seg_labels)
            i = (i + 1) % batch_size
        yield np.array(train_X), np.array(train_Y)

model = U_netModel(2, input_shape=(512, 512, 1))

##########编译模型
model.compile(
    optimizer=keras.optimizers.rmsprop(lr=1e-4),
    loss="binary_crossentropy",
    metrics=["acc"]
)

#########获取数据集文件名
dataSetNames = read_file_names()
###########分割训练集与验证集
train_dataSetNames, validation_dataSetNames = split_data_set(dataSetNames, 0.2)

model.fit_generator(
    generate__data_from_file(train_dataSetNames, batch_size, 512, 512, classe_nums),
    steps_per_epoch=len(train_dataSetNames) // batch_size,
    epochs=1,
    validation_data=generate__data_from_file(validation_dataSetNames, batch_size, 512, 512, classe_nums),
    validation_steps=len(validation_dataSetNames) // batch_size,
)

savedir = './savedir'
if not os.path.exists(savedir):
    os.mkdir(savedir)
model.save_weights(savedir + 'model01.h5')
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,753评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,668评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 166,090评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 59,010评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,054评论 6 395
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,806评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,484评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,380评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,873评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,021评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,158评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,838评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,499评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,044评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,159评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,449评论 3 374
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,136评论 2 356

推荐阅读更多精彩内容