Pytorch中使用SubsetRandomSampler做数据集划分

实现了一个小demo演示SubsetRandomSampler的用法

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler


dataset = TensorDataset(torch.tensor(list(range(20))))  # 构造一个数据集(0到19)
idx = list(range(len(dataset)))  # 创建索引,SubsetRandomSampler会自动乱序
# idx = torch.zeros(len(dataset)).long()  # 传入相同的索引,SubsetRandomSampler只会采样相同结果
n = len(dataset)
split = n//5
train_sampler = SubsetRandomSampler(idx[split::])  # 随机取80%的数据做训练集
test_sampler = SubsetRandomSampler(idx[::split])  # 随机取20%的数据做测试集
train_loader = DataLoader(dataset, sampler=train_sampler)
test_loader = DataLoader(dataset, sampler=test_sampler)

print('data for training:')
for i in train_loader:
    print(i)
print('data for testing:')
for i in test_loader:
    print(i)

注意train_loader和test_loader的dataset都是一样的,比如要获取loader的样本总数,应该len(sampler)而不是len(dataset)

len(train_loader.sampler)
len(test_loader.sampler)

关于pytorch的其他sampler的文档:
https://blog.csdn.net/aiwanghuan5017/article/details/102147825

©著作权归作者所有,转载或内容合作请联系作者
平台声明:文章内容(如有图片或视频亦包括在内)由作者上传并发布,文章内容仅代表作者本人观点,简书系信息发布平台,仅提供信息存储服务。

推荐阅读更多精彩内容