transform = Compose([
transforms.RandomResizedCrop(int(image_size*1.2)),
# transforms.ToPILImage(),
transforms.RandomAffine(15),
# transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.RandomGrayscale(),
transforms.TenCrop(image_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
transforms.Lambda(lambda crops: torch.stack([transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) for crop in crops])),
])
- 注意:此时image部分得到的是一个5维的tensor(batch_size,10,channels,H,W),而我们一般训练的时候需要的是4维tensor(batch_size,channels,H,W),所以具体使用的时候还需要进行一波转换(融合batch中的原始图片和每个原始图片的crop出来的ten个图片变成一个新的大的batch)
transform = Compose([
TenCrop(size), # this is a list of PIL Images
Lambda(lambda crops: torch.stack([ToTensor()(crop) for crop in crops])) # returns a 4D tensor
])
#In your test loop you can do the following:
input, target = batch # input is a 5d tensor, target is 2d
bs, ncrops, c, h, w = input.size()
result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops
result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops