python充分利用多核性能预处理ImageNet数据集

TL;DR:用multiprocessing库解决python单线程处理大量图片缓慢的问题。

最近想试试HSI色彩空间的图片对卷积网络有没有帮助,就在每次加载数据的时候对每张图片做RGB到HSI的色彩空间变换。跑了几个epoch之后寻思着不对头,网络训练速度比原来慢了不少。这应该是因为数据预处理太占用CPU,感觉很不爽,于是想把整个ImageNet数据集提前处理好存下来,一劳永逸。

于是简单用python写了个脚本,遍历数据集,然后每张图片做好变换后按照原来的目录结构保存到新的根目录下。

实现很简单,但是跑了下一看,整个数据集跑完一遍竟然要17个小时......显然是因为python的GIL而无法充分利用CPU的多核性能。解决思路自然是利用真“多线程”来让程序跑起来。

python的thread是假线程,适合用在IO密集型的场景,对这种计算密集型的任务毫无帮助,而另一个multiprocessing自然就是解决方案了。运行机制简单地说就是产生一个进程池pool,pool提供一个map接口,把处理数据的函数接口和待处理的数据迭代器丢进去,进程池会自动分配多个进程执行,达到多进程的目的。当然multiprocessing库不止这么简单,还有更复杂的用法,这里并不需要所以不再深入。

最后附上代码:

import os
import tqdm
import itertools
import numpy as np
import multiprocessing as mp

from PIL import Image


def rgb2hsi(rgb):
    rgb /= 255.
    r, g, b = list(map(np.squeeze, np.split(rgb, 3, 2)))
    hsi = np.zeros_like(rgb)
    theta = np.arccos(((r - g) + (r - b)) / (2 * np.sqrt((r - g) ** 2 + (r - b) * (g - b))))
    pi_2 = 2 * np.pi
    hsi[:, :, 0] = np.where(g >= b, theta, pi_2 - theta) / pi_2
    hsi[:, :, 1] = 1 - 3 * np.min(rgb, 2) / np.sum(rgb, 2)
    hsi[:, :, 2] = np.sum(rgb, 2) / 3.

    return hsi * 255.


def resize_create_hsi_img(dir_pair):
    src_path, target_path_rgb, target_path_hsi = dir_pair
    rgb_not_exist = not os.path.exists(target_path_rgb)
    hsi_not_exist = not os.path.exists(target_path_hsi)
    try:
        if rgb_not_exist or hsi_not_exist:
            org_pic = Image.open(src_path)
            new_size = int(org_pic.size[0] * 0.7), int(org_pic.size[1] * 0.7)
            pic = org_pic.resize(new_size, Image.ANTIALIAS)
            if rgb_not_exist:
                pic.save(target_path_rgb, quality=75)
            if hsi_not_exist:
                if pic.mode == 'RGB':
                    rgb_img = np.asarray(pic, np.float32)
                    hsi_img = rgb2hsi(rgb_img).astype(np.uint8)
                    pic = Image.fromarray(hsi_img)
                pic.save(target_path_hsi, quality=75)
        return None
    except Exception as exc:
        print(exc)
        return src_path


def walk_all_pic():
    root = 'D:\Datasets\ImageNet\ILSVRC2017_CLS-LOC\ILSVRC\Data\CLS-LOC'
    targets = ['val', 'train', ]
    root_new1 = 'E:\Imagenet\cls_rgb'
    root_new2 = 'E:\Imagenet\cls_hsi'
    if not os.path.exists(root_new1):
        os.mkdir(root_new1)

    if not os.path.exists(root_new2):
        os.mkdir(root_new2)

    for t in targets:
        sub1 = os.path.join(root, t)
        sub1_new1 = os.path.join(root_new1, t)
        sub1_new2 = os.path.join(root_new2, t)
        folders = os.listdir(sub1)

        if not os.path.exists(sub1_new1):
            os.mkdir(sub1_new1)
        if not os.path.exists(sub1_new2):
            os.mkdir(sub1_new2)

        for subfolder in folders:
            sub2 = os.path.join(sub1, subfolder)
            sub2_new1 = os.path.join(sub1_new1, subfolder)
            sub2_new2 = os.path.join(sub1_new2, subfolder)
            if os.path.isdir(sub2):
                if not os.path.exists(sub2_new1):
                    os.mkdir(sub2_new1)
            if os.path.isdir(sub2):
                if not os.path.exists(sub2_new2):
                    os.mkdir(sub2_new2)

                files = os.listdir(sub2)
                for file in files:
                    fpath = os.path.join(sub2, file)
                    fpath_new_rgb = os.path.join(sub2_new1, file)
                    fpath_new_hsi = os.path.join(sub2_new2, file)
                    if os.path.isfile(fpath):
                        yield fpath, fpath_new_rgb, fpath_new_hsi

# Multiple process version
def run_multiprocess():

    print('Processing pictures with multiple processors...')
    error_pics = []
    with mp.Pool(processes=mp.cpu_count()) as pool:
        for ep in pool.imap_unordered(resize_create_hsi_img, tqdm.tqdm(walk_all_pic(), total=1331167, ncols=65)):
            error_pics.append(ep)
    with open('./error_pics.log', mode='w') as f:
        if error_pics is not None:
            f.writelines(error_pics)
    print('All pictures cannot be processed have been writen into \'error_pics.log\'')

# Single process version
def run():
    for f, fnew1, fnew2 in tqdm.tqdm(walk_all_pic()):
        resize_create_hsi_img((f, fnew1, fnew2))


if __name__ == '__main__':
    run_multiprocess()

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 一. 操作系统概念 操作系统位于底层硬件与应用软件之间的一层.工作方式: 向下管理硬件,向上提供接口.操作系统进行...
    月亮是我踢弯得阅读 11,218评论 3 28
  • 1.进程和线程 1.1系统多任务机制 多任务操作机制的引入主要是在相同的硬件资源下怎么提高任务处理效率的!多任务的...
    _宁采臣阅读 4,668评论 0 6
  • 进程、进程的使用、进程注意点、进程间通信-Queue、进程池Pool、进程与线程对比、文件夹拷贝器-多任务 1.进...
    Cestine阅读 4,601评论 0 0
  • 多进程 要让python程序实现多进程,我们先了解操作系统的相关知识。 Unix、Linux操作系统提供了一个fo...
    蓓蓓的万能男友阅读 3,775评论 0 1
  • 周五的晚上校区里前台边上固定坐着两个家长,一个是大班男生妈妈,一个是大班女生妈妈。今天讨论的话题是幼儿园的教育和外...
    许你一个诺言阅读 2,523评论 0 1

友情链接更多精彩内容