三维深度学习-多线程读取vtkImageData

在深度学习最常用的卷积神经网络中,要求数据为具有空间局部性的多维矩阵或者说张量。这与广泛应用的三维模型格式例如STL这种保存三角面片的存储方式不一致。因此,采用体素化的方式对输入进行处理。

以VTK为例,在读入了vtkPolyData后,采用vtkPolyDataToImageStencilExample)的方式对三维模型进行转换,类似的转换方法还有vtkVoxelModeller,但相比之下效率极低。

不过,这样的方法还是较为缓慢,尤其是当输出体素模型规模较大时(如128x128x128),在实际使用中,会使模型文件读取占据了大量开销。不过,由于这个转换本身是可以重复利用的,因此在定义数据集时,加入了cache模式,PyTorch样例代码如下:

class Dataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None, cache=False):
        self.frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform
        if cache:
            self.cache = [None for i in range(len(self.frame))]
            for i in range(len(self.frame)):
                print('Caching record #%d\r' % (count))
                self.cache[count] = self.read(i)
        else:
            self.cache = None

    def __len__(self):
        return len(self.frame)

    def read(self, idx):
        """Read your data here."""
        return sample

    def __getitem__(self, idx):
        if self.cache:
            sample = self.cache[idx]
        else:
            sample = self.read(idx)
        if self.transform:
            sample = self.transform(sample)
        return sample

实践中发现这样建立缓存还是存在读取效率不足的问题,因此再次改写了一下,变成多线程的形式。

def __init__(self, csv_file, root_dir, transform=None, cache=False, thread=4):
    self.landmarks_frame = pd.read_csv(csv_file)
    self.root_dir = root_dir
    self.transform = transform
    if cache:
        self.cache = [None for i in range(len(self.landmarks_frame))]
        pool = multiprocessing.Pool(processes=thread)
        irange = range(len(self.landmarks_frame))
        count = 0
        for sample in pool.imap(self.read, irange):
            print('Caching record #%d\r' % (count))
            self.cache[count] = sample
            count += 1
    else:
        self.cache = None

可惜的是,这样的改写并不能成功,因为在multiprocessing中传递结果时用到了pickle进行数据的传递,而vtkImageData作为比较特殊的对象无法被pickle序列化。为了解决这个问题,简单调用了vtk.util.numpy_support里的一些方法,完成vtkImageData与Numpy array之间的无损转换。

def voxel2array(self, img):
    # Up to support for 3 dimensions for this line
    rows, cols, _ = img.GetDimensions()

    sc = img.GetPointData().GetScalars()
    arr = numpy_support.vtk_to_numpy(sc)
    arr = array.reshape(rows, cols, -1)
    spacing = img.GetSpacing()
    origin = img.GetOrigin()

    return arr, spacing, origin

def array2voxel(self, arr, spacing, origin):

    vtk_data = numpy_support.numpy_to_vtk(
        arr.ravel(), array_type=vtk.VTK_UNSIGNED_CHAR)
    img = vtk.vtkImageData()
    img.SetDimensions(array.shape)
    img.SetSpacing(spacing)
    img.SetOrigin(origin)
    img.GetPointData().SetScalars(vtk_data)

    return img

重点是vtkImageData中还留存着其体素的spacing信息和图像的整体坐标信息。
突然想到,在体素化前利用一些三维模型降采样方法对牙齿模型进行降采样,是否能够大大加速体素化。

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容

  • 文章主要分为:一、深度学习概念;二、国内外研究现状;三、深度学习模型结构;四、深度学习训练算法;五、深度学习的优点...
    艾剪疏阅读 21,943评论 0 58
  • (第一部分 机器学习基础)第01章 机器学习概览第02章 一个完整的机器学习项目(上)第02章 一个完整的机器学习...
    SeanCheney阅读 19,803评论 20 62
  • 在银河系里头,有一个古怪的传说。传说, 在银河系里有不死人的家族。 就是这样,不死人族,就在这一刻,诞生了。 不死...
    王密亮阅读 604评论 0 2
  • 一卷诗书,一回梦,椅案埋头,睡意袭袭。 梅雨时节,行人纷纷,一把纸伞,一袭白衣。 长安城,朱雀街,繁华盛世,谁...
    翛娛阅读 273评论 0 1
  • 榴莲这个水果之王,我是一直想尝试却没机会尝试。碰巧我表姐来看我给我以及我室友带了榴莲,因为她最爱吃榴莲。可惜了我捏...
    冰雅乐阅读 206评论 0 1