在深度学习最常用的卷积神经网络中,要求数据为具有空间局部性的多维矩阵或者说张量。这与广泛应用的三维模型格式例如STL这种保存三角面片的存储方式不一致。因此,采用体素化的方式对输入进行处理。
以VTK为例,在读入了vtkPolyData后,采用vtkPolyDataToImageStencil(Example)的方式对三维模型进行转换,类似的转换方法还有vtkVoxelModeller,但相比之下效率极低。
不过,这样的方法还是较为缓慢,尤其是当输出体素模型规模较大时(如128x128x128),在实际使用中,会使模型文件读取占据了大量开销。不过,由于这个转换本身是可以重复利用的,因此在定义数据集时,加入了cache模式,PyTorch样例代码如下:
class Dataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None, cache=False):
self.frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
if cache:
self.cache = [None for i in range(len(self.frame))]
for i in range(len(self.frame)):
print('Caching record #%d\r' % (count))
self.cache[count] = self.read(i)
else:
self.cache = None
def __len__(self):
return len(self.frame)
def read(self, idx):
"""Read your data here."""
return sample
def __getitem__(self, idx):
if self.cache:
sample = self.cache[idx]
else:
sample = self.read(idx)
if self.transform:
sample = self.transform(sample)
return sample
实践中发现这样建立缓存还是存在读取效率不足的问题,因此再次改写了一下,变成多线程的形式。
def __init__(self, csv_file, root_dir, transform=None, cache=False, thread=4):
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
if cache:
self.cache = [None for i in range(len(self.landmarks_frame))]
pool = multiprocessing.Pool(processes=thread)
irange = range(len(self.landmarks_frame))
count = 0
for sample in pool.imap(self.read, irange):
print('Caching record #%d\r' % (count))
self.cache[count] = sample
count += 1
else:
self.cache = None
可惜的是,这样的改写并不能成功,因为在multiprocessing中传递结果时用到了pickle进行数据的传递,而vtkImageData作为比较特殊的对象无法被pickle序列化。为了解决这个问题,简单调用了vtk.util.numpy_support
里的一些方法,完成vtkImageData与Numpy array之间的无损转换。
def voxel2array(self, img):
# Up to support for 3 dimensions for this line
rows, cols, _ = img.GetDimensions()
sc = img.GetPointData().GetScalars()
arr = numpy_support.vtk_to_numpy(sc)
arr = array.reshape(rows, cols, -1)
spacing = img.GetSpacing()
origin = img.GetOrigin()
return arr, spacing, origin
def array2voxel(self, arr, spacing, origin):
vtk_data = numpy_support.numpy_to_vtk(
arr.ravel(), array_type=vtk.VTK_UNSIGNED_CHAR)
img = vtk.vtkImageData()
img.SetDimensions(array.shape)
img.SetSpacing(spacing)
img.SetOrigin(origin)
img.GetPointData().SetScalars(vtk_data)
return img
重点是vtkImageData中还留存着其体素的spacing信息和图像的整体坐标信息。
突然想到,在体素化前利用一些三维模型降采样方法对牙齿模型进行降采样,是否能够大大加速体素化。