Tensorflow2单机多GPU数据准备与训练说明

前言

能看到这篇文章的,都是富贵让我们相遇。
现在这光景,单GPU都困难,何况多GPU训练。。。

几个需要注意的点

  1. 模型生成部分需要使用tf.distribute.MirroredStrategy
  2. 为了将batch size的数据均等分配给各个GPU的显存,需要通过tf.data.Dataset.from_generator托管数据,从迭代器加载,同时显式关闭AutoShardPolicy。如果不做这一步,显存分配可能会出问题,不仅显存会爆,还可能过程中的validation loss计算会出问题。
  3. 为了避免触发tensorflow2在完成以上步骤,训练过程中metrics的计算bug,需要做到如下几点!这个地方是痛点,如果不仔细跟踪,是很难发现的!
    metrics一定设置为binary_accuracy,或者sparse_categorical_accuracy
    不能简单设置为acc
    否则之后会报:as_list() is not defined on an unknown TensorShape的错误
  4. 之所以使用生成器动态产生训练数据,不仅仅是为了避免一次性加载训练数据,直接吃爆显存,还因为需要实时对训练数据做数据增强与变换,增加模型的鲁棒性。

代码部分

模型生成与编译部分

直接看tf.distribute.MirroredStrategy的用法,损失函数,优化函数的根据自己习惯来。但是metrics一定不能选择acc!

gpus = tf.config.list_physical_devices('GPU')
batchsize = 8
print('apply: Adam + weighted_bce_dice_loss_v1_7_3')
if len(gpus) > 1:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(device=gpu, enable=True)
    batchsize *= len(gpus)
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        model = table_line.get_model(input_shape=(512, 512, 3),
                                     is_resnest_unet=is_resnest_unet,
                                     is_swin_unet=is_swin_unet,
                                     resnest_pretrain_model=resnest_pretrain_model)
        # apply custom loss
        model.compile(
            optimizer=Adam(
                lr=0.0001),
            loss=weighted_bce_dice_loss_v1_7_3,
            metrics=['binary_accuracy'])
else:
    model = table_line.get_model(input_shape=(512, 512, 3),
                                 is_resnest_unet=is_resnest_unet,
                                 is_swin_unet=is_swin_unet,
                                 resnest_pretrain_model=resnest_pretrain_model)
    model.compile(
        optimizer=Adam(
            lr=0.0001),
        loss=weighted_bce_dice_loss_v1_7_3,
        metrics=['binary_accuracy'])
print('batch size: {0}, GPUs: {1}'.format(batchsize, gpus))

数据迭代器生成部分

def makeDataset(generator_func,
                data_list,
                line_path,
                batchsize,
                draw_line,
                is_raw,
                need_rotate,
                only_flip,
                is_wide_line,
                strategy=None):
    # Get amount of files
    ds = tf.data.Dataset.from_generator(generator_func,
                                        args=[data_list, line_path, batchsize,
                                              draw_line, is_raw, need_rotate,
                                              only_flip, is_wide_line],
                                        output_types=(tf.float64, tf.float64))
    # Make a dataset from the generator. MAKE SURE TO SPECIFY THE DATA TYPE!!!
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
    ds = ds.with_options(options)

    # Optional: Make it a distributed dataset if you're using a strategy
    if strategy is not None:
        ds = strategy.experimental_distribute_dataset(ds)

    return ds

获取training与validation数据获取的迭代器
其中gen是生成数据的方程,其余参数, 除了最后一个strategy参数,都是生成数据方程所需的参数

training_ds = makeDataset(gen,
                          data_list=trainP,
                          line_path=line_path,
                          batchsize=batchsize,
                          draw_line=False,
                          is_raw=is_raw,
                          need_rotate=need_rotate,
                          only_flip=only_flip,
                          is_wide_line=is_wide_line,
                          strategy=None)
validation_ds = makeDataset(gen,
                            data_list=testP,
                            line_path=line_path,
                            batchsize=batchsize,
                            draw_line=False,
                            is_raw=is_raw,
                            need_rotate=need_rotate,
                            only_flip=only_flip,
                            is_wide_line=is_wide_line,
                            strategy=None)

生成数据方程的示例,学过iterate的都明白在说啥

def gen(paths,
        line_path,
        batchsize=2,
        draw_line=True,
        is_raw=False,
        need_rotate=False,
        only_flip: bool = True,
        is_wide_line=False):
    num = len(paths)
    i = 0
    while True:
        # sizes = [512,512,512,512,640,1024] ##多尺度训练
        # size = np.random.choice(sizes,1)[0]
        size = 512
        X = np.zeros((batchsize, size, size, 3))
        Y = np.zeros((batchsize, size, size, 2))
        print(i)
        for j in range(batchsize):
            while True:
                if i >= num:
                    i = 0
                    np.random.shuffle(paths)
                p = paths[i]
                i += 1
                try:
                    if is_raw:
                        img, lines, labelImg = get_img_label_raw(p,
                                                                 line_path,
                                                                 size=(size, size),
                                                                 draw_line=draw_line,
                                                                 is_wide_line=is_wide_line)
                    else:
                        img, lines, labelImg = get_img_label_transform(p,
                                                                       line_path,
                                                                       size=(size, size),
                                                                       draw_line=draw_line,
                                                                       need_rotate=need_rotate,
                                                                       only_flip=only_flip,
                                                                       is_wide_line=is_wide_line)
                    break
                except Exception as e:
                    print(e)
            X[j] = img
            Y[j] = labelImg
        yield X, Y

模型训练部分的代码

训练方法:fit

之前调用数据生成器的训练方法是fit_generator,TF2之后统一用fit方程了

steps参数的写法,重点!

注意steps_per_epoch与validation_steps的写法,batchsize必须与调用makeDataset时,传入的batchsize的值相同,否则无法计算出正确的steps

model.fit(training_ds,
          callbacks=[checkpointer, earlyStopping],
          steps_per_epoch=max(1, len(trainP) // batchsize),
          validation_data=validation_ds,
          validation_steps=max(1, len(testP) // batchsize),
          epochs=300)
最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 203,937评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 85,503评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 150,712评论 0 337
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,668评论 1 276
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,677评论 5 366
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,601评论 1 281
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 37,975评论 3 396
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,637评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 40,881评论 1 298
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,621评论 2 321
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,710评论 1 329
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,387评论 4 319
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 38,971评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,947评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,189评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 44,805评论 2 349
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,449评论 2 342

推荐阅读更多精彩内容