from torch.utils.data import DataLoader
batch_size = 16
data = {
'train':
WheatDataset(
'/home/yons/data/kaggle/global-wheat-detection/global-wheat-detection-coco-format/instances_train2017.json',
'/home/yons/data/kaggle/global-wheat-detection/raw/train',
get_train_transforms()
),
'val':
WheatDataset(
'/home/yons/data/kaggle/global-wheat-detection/global-wheat-detection-coco-format/instances_val2017.json',
'/home/yons/data/kaggle/global-wheat-detection/raw/train',
get_valid_transforms()
)
}
dataloaders = {
'train': DataLoader(data['train'], batch_size=batch_size, shuffle=True),
'val': DataLoader(data['val'], batch_size=batch_size, shuffle=False),
}
test_iter = iter(dataloaders['train'])
inputs, targets = next(test_iter)
inputs.shape, targets.shape
data loader
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。