win10 部署
代码和部署方式来自官方:
https://github.com/microsoft/Swin-Transformer
克隆到本地
git clone https://github.com/microsoft/Swin-Transformer.git
cd Swin-Transformer
建立conda 环境
这里因为要使用分布式的代码,虽然只有一个GPU,当然修改代码的分布式支持,但是相对麻烦,而且要改的很多,所以这里直接用新版的pytorch。
conda install pytorch==1.8.1 torchvision==0.9.1
安装其他库
pip install timm==0.3.2
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
安装apex
下载
git clone https://github.com/NVIDIA/apex
cd apex
python setup.py install
数据集
对transformer的优势和应用很多都有介绍,这里主要是一个需要很多数据训练。一般是基于imagenet数据集,
https://www.bilibili.com/video/BV1qv411n7gN?from=search&seid=8702502621110627235&spm_id_from=333.337.0.0
ImageNet拥有用于分类、定位和检测任务评估的数据。
与分类数据类似,定位任务有1000个类别。准确率是根据最高五项检测结果计算出来的。
所有图像中至少有一个边框。对200个目标的检测问题有470000个图像,平均每个图像有1.1个目标。
http://image-net.org/
官方的模型都是基于imagenet训练,完整的有1T,所以这里部署和验证用的是别人分享的部分子集,200多M
下载后
分类数据集主要分为train,val文件夹, 文件夹里有各个类的图片的文件夹,
转换代码参考
import glob
import os
from shutil import move
from os import rmdir
target_folder = './imagenet/val/'
val_dict = {}
with open('./imagenet/val/val_annotations.txt', 'r') as f:
for line in f.readlines():
split_line = line.split('\t')
val_dict[split_line[0]] = split_line[1]
# print(val_dict)
# print(val_dict.keys())
paths = glob.glob('E:\\workspace\\Swin-Transformer\\imagenet\\val\\images\\*')
for path in paths:
file = path.split('\\')[-1]
# print(file)
folder = val_dict[file]
if not os.path.exists(target_folder + str(folder)):
os.mkdir(target_folder + str(folder))
os.mkdir(target_folder + str(folder) + '/images')
for path in paths:
file = path.split('\\')[-1]
folder = val_dict[file]
dest = target_folder + str(folder) + '/images/' + str(file)
move(path, dest)
rmdir('./imagenet/val/images')
训练验证
准备好就可以跑通代码了
训练
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg .\configs\swin_tiny_patch4_window7_224.yaml --resume .\pth\swin_tiny_patch4_window7_224.pth --data-path .\imagenet\ --local_rank 0 --batch-size 16 --output output
其中 nproc_per_node 为GPU数量,这里就一个,不修改的话也可以,忽略警告即可;master_port 为端口号,如果此处报错 RuntimeError: Address already in use ,可以更改为其他空闲端口,如23467,88888等。
--resume 后面就是官网的预训练模型, 这里跑tiny模型
--data-path 就是数据集文件夹
测试
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg .\configs\swin_tiny_patch4_window7_224.yaml --resume .\pth\swin_tiny_patch4_window7_224.pth --data-path .\imagenet\ local_rank 0
训练自定义数据集
数据处理
这里数据来自比赛
http://data.sd.gov.cn/cmpt/cmptDetail.html?id=61
主要是对烟,火,云,霓虹灯四个类别的图片进行分类识别,分别为0,1,2,3 ,训练数据分别为 157,200,223,194
这里只是跑通训练测试,不做增强和调优,划分80%为训练集,20为验证集。
将图片按照imagenet格式划分文件夹
E:\WORKSPACE\SWIN-TRANSFORMER\YHYW
├─train
│ ├─0
| 1.jpg
2.jpg
........
│ ├─1
| 12.jpg
22.jpg
........
│ ├─2
│ └─3
└─val
├─0
├─1
├─2
└─3
修改代码和配置
config.py 文件
_C.DATA.BATCH_SIZE = 16 ,修改batch的大小
_C.MODEL.NUM_CLASSES = 4 ,修改类别数
_C.SAVE_FREQ = 10 ,每多少个epoch保存一次模型
_C.TRAIN.EPOCHS = 300 总共训练多少个epoch
修改 data/build.py中第69行,num_classes=..自己的类别数
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
def build_dataset(is_train, config):
transform = build_transform(is_train, config)
if config.DATA.DATASET == 'imagenet':
prefix = 'train' if is_train else 'val'
if config.DATA.ZIP_MODE:
ann_file = prefix + "_map.txt"
prefix = prefix + ".zip@/"
dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
else:
root = os.path.join(config.DATA.DATA_PATH, prefix)
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 4 # 改为自己的类数
由于原始代码验证是基于top5的检测结果来测试,而这里只有4个类别,所以容易超出报错
可以注释掉,或者改为3
@torch.no_grad()
def validate(config, data_loader, model):
criterion = torch.nn.CrossEntropyLoss()
model.eval()
batch_time = AverageMeter()
loss_meter = AverageMeter()
acc1_meter = AverageMeter()
acc3_meter = AverageMeter()
end = time.time()
for idx, (images, target) in enumerate(data_loader):
images = images.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(images)
# measure accuracy and record loss
loss = criterion(output, target)
acc1,acc3 = accuracy(output,target,topk=(1,3))
acc1 = reduce_tensor(acc1)
acc3 = reduce_tensor(acc3)
loss = reduce_tensor(loss)
loss_meter.update(loss.item(), target.size(0))
acc1_meter.update(acc1.item(), target.size(0))
acc3_meter.update(acc3.item(),target.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if idx % config.PRINT_FREQ == 0:
memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
logger.info(
f'Test: [{idx}/{len(data_loader)}]\t'
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
f'Acc@3 {acc3_meter.val:.3f} ({acc3_meter.avg:.3f})\t'
f'Mem {memory_used:.0f}MB')
logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@3 {acc3_meter.avg:3f}')
return acc1_meter.avg,acc3_meter.avg,loss_meter.avg
可以跑了
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --cfg .\configs\swin_tiny_patch4_window7_224.yaml --data-path .\yhyw\ --local_rank 0 --batch-size 16 --output out1
测试
config.py 里增加
def just_config():
config = _C.clone()
return config
infer.py
import os
import cv2
import numpy as np
import torch.nn as nn
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torchvision.transforms import transforms
from config import just_config
from models.swin_transformer import SwinTransformer
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import PIL.Image as Im
cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = just_config()
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
patch_size=config.MODEL.SWIN.PATCH_SIZE,
in_chans=config.MODEL.SWIN.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
embed_dim=config.MODEL.SWIN.EMBED_DIM,
depths=config.MODEL.SWIN.DEPTHS,
num_heads=config.MODEL.SWIN.NUM_HEADS,
window_size=config.MODEL.SWIN.WINDOW_SIZE,
mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
qkv_bias=config.MODEL.SWIN.QKV_BIAS,
qk_scale=config.MODEL.SWIN.QK_SCALE,
drop_rate=config.MODEL.DROP_RATE,
drop_path_rate=config.MODEL.DROP_PATH_RATE,
ape=config.MODEL.SWIN.APE,
patch_norm=config.MODEL.SWIN.PATCH_NORM,
use_checkpoint=config.TRAIN.USE_CHECKPOINT)
state_dicts = torch.load('E:\workspace\Swin-Transformer\out1\swin_tiny_patch4_window7_224\default\ckpt_epoch_399.pth')['model']
model.load_state_dict(state_dicts)
# print(model)
model.to(device)
process_img = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((224,224)),
transforms.Normalize(IMAGENET_DEFAULT_MEAN,IMAGENET_DEFAULT_STD),
])
classes_name = {0:'0', 1:'1', 2:'2', 3:'3' }
img_dir = 'E:\workspace\Swin-Transformer\yanhuo\\testA\images'
testres = pd.DataFrame(columns=('image_name','label'))
model.eval()
with torch.no_grad():
for img_name in os.listdir(img_dir):
img_path = os.path.join(img_dir,img_name)
img = cv2.imread(img_path)
img = Im.fromarray(img)
img = process_img(img)
img = img.unsqueeze(0).to(device)
out = model(img).cpu().numpy()
# print(out.shape)
index = np.argmax(out)
print(img_name)
print('predict the class of this pic is:{}'.format(classes_name[index]))
testres = testres.append(pd.DataFrame({'image_name': [img_name], 'label': [classes_name[index]]}),ignore_index=True)
可以开始炼丹啦
参考
https://github.com/yihui8776/Swin-transformer/blob/main/get_started.md
https://blog.csdn.net/u014515463/article/details/80748125
https://blog.csdn.net/qq_33932782/article/details/117013133
https://blog.csdn.net/hi_gril/article/details/118486070?spm=1001.2014.3001.5501
https://blog.csdn.net/qq_42067064/article/details/115597482?spm=1001.2101.3001.6650.14&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Edefault-14.queryctr&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Edefault-14.queryctr&utm_relevant_index=20
https://blog.csdn.net/qq_36622589/article/details/117913064
http://data.sd.gov.cn/cmpt/cmptDetail.html?id=61