一、数据来源
python模块
#模块
import torch
from torchvision import datasets, transforms
#方法
datasets.ImageFolder(filepath, transform=transform)
注意:自带父目录的标签(建立清晰的分类目录)二、数据整理 transforms()
-
transforms.Compose([ ])
输入为transforms.操作()
的列表:多个transform组合起来使用
- 尺寸:比例缩放
transforms.Resize()
- 二维输入:(height, width)
- 一维输入:较小的边匹配这个输入size,比如height>width时,尺度调整为 (size * height / width, size)
torchvision.transforms.Scale()
不推荐使用:可能有畸变
- 尺寸:裁剪
transforms.CenterCrop()
基于中心裁剪:输出正方形(size, size)
transforms.RandomResizedCrop()
根据比例,裁剪为随机大小
- 位移:旋转,翻转
transforms.RandomRotation()
旋转
transforms.RandomHorizontalFlip(p=0.5)
水平翻转 (概率为0.5)
- 数据转换
transforms.ToTensor()
图像变为pytorch张量;适用情境:灰度值变为彩色图像
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
RGB三个通道的归一化(均值0.5, 标准差0.5)
三、针对模型调整输入数据的格式
-
明确操作对象
如图: