data loader

from torch.utils.data import DataLoader
batch_size = 16
data = {
    'train':
    WheatDataset(
    '/home/yons/data/kaggle/global-wheat-detection/global-wheat-detection-coco-format/instances_train2017.json',
    '/home/yons/data/kaggle/global-wheat-detection/raw/train',
    get_train_transforms()
    ),
    'val':
    WheatDataset(
    '/home/yons/data/kaggle/global-wheat-detection/global-wheat-detection-coco-format/instances_val2017.json',
    '/home/yons/data/kaggle/global-wheat-detection/raw/train',
    get_valid_transforms()
    )
}
dataloaders = {
    'train': DataLoader(data['train'], batch_size=batch_size, shuffle=True),
    'val': DataLoader(data['val'], batch_size=batch_size, shuffle=False),
}
test_iter = iter(dataloaders['train'])
inputs, targets = next(test_iter)
inputs.shape, targets.shape
©著作权归作者所有,转载或内容合作请联系作者
【社区内容提示】社区部分内容疑似由AI辅助生成,浏览时请结合常识与多方信息审慎甄别。
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

友情链接更多精彩内容