import torch
from torch.utils.data import Dataset,DataLoader
import linecache
import random
from PIL import Image
class MyDataset(Dataset):
def __init__(self, txt_file, transform=None):
self.transform = transform
self.txt = txt_file
def __getitem__(self, index):
# 随机选择一个人脸
line = linecache.getline(self.txt, random.randint(1, self.__len__()))
line.strip('\n')
img0_list = line.split()
# 随机取0,1 0------不是同一个人, 1----同一个人脸
should_get_same_class = random.randint(0, 1)
if should_get_same_class:
while True:
img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
# 同一个人的脸
if img0_list[1] == img1_list[1]:
break
else:
img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
# 加载图像
img0 = Image.open(img0_list[0])
img1 = Image.open(img1_list[0])
# 变换
if self.transform:
img0 = self.transform(img0)
img1 = self.transform(img1)
# label
label = 1 if img1_list[1] == img0_list[1] else 0
return {'image': [img0, img1], 'label': torch.tensor(label)}
def __len__(self):
# 返回总行数
num = 0
with open(self.txt, 'r') as f:
num = len(f.readlines())
return num
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
import linecache
import random
from PIL import Image
import dataset
import matplotlib.pyplot as plt
import numpy as np
# https://www.cnblogs.com/king-lps/p/8342452.html
def show_image(sample):
image0 = sample['image'][0]
image1 = sample['image'][1]
image_transform = make_grid([image0, image1], pad_value=255)
image_transform = np.transpose(image_transform.numpy(), (1, 2, 0))
plt.imshow(image_transform)
plt.axis('off')
plt.title(sample['label'].numpy())
my_dataset = dataset.MyDataset('./data/att_faces/list.txt', transform=transforms.ToTensor())
plt.figure()
for i, sample in enumerate(my_dataset):
print(sample)
images = sample['image']
label = sample['label']
image0 = images[0]
image1 = images[1]
# 显示
show_image(sample)
plt.show()