swin-transformer 图片分类 win10

win10 部署

代码和部署方式来自官方:
https://github.com/microsoft/Swin-Transformer

克隆到本地

git clone https://github.com/microsoft/Swin-Transformer.git
cd Swin-Transformer

建立conda 环境
这里因为要使用分布式的代码,虽然只有一个GPU,当然修改代码的分布式支持,但是相对麻烦,而且要改的很多,所以这里直接用新版的pytorch。

conda install pytorch==1.8.1 torchvision==0.9.1

安装其他库

pip install timm==0.3.2
 
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8

安装apex
下载

git clone https://github.com/NVIDIA/apex
cd apex
python setup.py install

数据集

对transformer的优势和应用很多都有介绍,这里主要是一个需要很多数据训练。一般是基于imagenet数据集,
https://www.bilibili.com/video/BV1qv411n7gN?from=search&seid=8702502621110627235&spm_id_from=333.337.0.0
ImageNet拥有用于分类、定位和检测任务评估的数据。
与分类数据类似,定位任务有1000个类别。准确率是根据最高五项检测结果计算出来的。
所有图像中至少有一个边框。对200个目标的检测问题有470000个图像,平均每个图像有1.1个目标。
http://image-net.org/
官方的模型都是基于imagenet训练,完整的有1T,所以这里部署和验证用的是别人分享的部分子集,200多M
下载后

image.png

image.png

分类数据集主要分为train,val文件夹, 文件夹里有各个类的图片的文件夹,


image.png

转换代码参考

import glob
import os
from shutil import move
from os import rmdir

target_folder = './imagenet/val/'

val_dict = {}
with open('./imagenet/val/val_annotations.txt', 'r') as f:
    for line in f.readlines():
        split_line = line.split('\t')
        val_dict[split_line[0]] = split_line[1]
# print(val_dict)
# print(val_dict.keys())

paths = glob.glob('E:\\workspace\\Swin-Transformer\\imagenet\\val\\images\\*')
for path in paths:
    file = path.split('\\')[-1]
    # print(file)
    folder = val_dict[file]
    if not os.path.exists(target_folder + str(folder)):
        os.mkdir(target_folder + str(folder))
        os.mkdir(target_folder + str(folder) + '/images')

for path in paths:
    file = path.split('\\')[-1]
    folder = val_dict[file]
    dest = target_folder + str(folder) + '/images/' + str(file)
    move(path, dest)

rmdir('./imagenet/val/images')

训练验证

准备好就可以跑通代码了
训练

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py  --cfg .\configs\swin_tiny_patch4_window7_224.yaml --resume .\pth\swin_tiny_patch4_window7_224.pth --data-path .\imagenet\  --local_rank 0  --batch-size 16  --output output

其中 nproc_per_node 为GPU数量,这里就一个,不修改的话也可以,忽略警告即可;master_port 为端口号,如果此处报错 RuntimeError: Address already in use ,可以更改为其他空闲端口,如23467,88888等。
--resume 后面就是官网的预训练模型, 这里跑tiny模型
--data-path 就是数据集文件夹
测试

python -m torch.distributed.launch --nproc_per_node  1  --master_port 12345 main.py --eval \
--cfg  .\configs\swin_tiny_patch4_window7_224.yaml  --resume   .\pth\swin_tiny_patch4_window7_224.pth  --data-path  .\imagenet\  local_rank 0 

训练自定义数据集

数据处理

这里数据来自比赛
http://data.sd.gov.cn/cmpt/cmptDetail.html?id=61
主要是对烟,火,云,霓虹灯四个类别的图片进行分类识别,分别为0,1,2,3 ,训练数据分别为 157,200,223,194
这里只是跑通训练测试,不做增强和调优,划分80%为训练集,20为验证集。
将图片按照imagenet格式划分文件夹
E:\WORKSPACE\SWIN-TRANSFORMER\YHYW

├─train
│  ├─0
          | 1.jpg
            2.jpg
            ........
│  ├─1
          | 12.jpg
            22.jpg
            ........
│  ├─2
│  └─3
└─val
    ├─0
    ├─1
    ├─2
    └─3

修改代码和配置
config.py 文件

_C.DATA.BATCH_SIZE = 16 ,修改batch的大小
_C.MODEL.NUM_CLASSES = 4 ,修改类别数
_C.SAVE_FREQ = 10 ,每多少个epoch保存一次模型
_C.TRAIN.EPOCHS = 300 总共训练多少个epoch
修改 data/build.py中第69行,num_classes=..自己的类别数

    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
            prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
            label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
def build_dataset(is_train, config):
    transform = build_transform(is_train, config)
    if config.DATA.DATASET == 'imagenet':
        prefix = 'train' if is_train else 'val'
        if config.DATA.ZIP_MODE:
            ann_file = prefix + "_map.txt"
            prefix = prefix + ".zip@/"
            dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform,
                                        cache_mode=config.DATA.CACHE_MODE if is_train else 'part')
        else:
            root = os.path.join(config.DATA.DATA_PATH, prefix)
            dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 4 # 改为自己的类数

由于原始代码验证是基于top5的检测结果来测试,而这里只有4个类别,所以容易超出报错
可以注释掉,或者改为3


@torch.no_grad()
def validate(config, data_loader, model):
    criterion = torch.nn.CrossEntropyLoss()
    model.eval()

    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    acc3_meter = AverageMeter()

    end = time.time()
    for idx, (images, target) in enumerate(data_loader):
        images = images.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        # compute output
        output = model(images)

        # measure accuracy and record loss
        loss = criterion(output, target)
        acc1,acc3 = accuracy(output,target,topk=(1,3))

        acc1 = reduce_tensor(acc1)
        acc3 = reduce_tensor(acc3)
        loss = reduce_tensor(loss)

        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        acc3_meter.update(acc3.item(),target.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if idx % config.PRINT_FREQ == 0:
            memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
            logger.info(
                f'Test: [{idx}/{len(data_loader)}]\t'
                f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
                f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
                f'Acc@3 {acc3_meter.val:.3f} ({acc3_meter.avg:.3f})\t'
                f'Mem {memory_used:.0f}MB')
    logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@3 {acc3_meter.avg:3f}')
    return acc1_meter.avg,acc3_meter.avg,loss_meter.avg

可以跑了

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py  --cfg .\configs\swin_tiny_patch4_window7_224.yaml   --data-path .\yhyw\  --local_rank 0  --batch-size 16  --output out1

测试
config.py 里增加

def just_config():
    config = _C.clone()
    return config

infer.py

import os

import cv2
import numpy as np
import torch.nn as nn
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torchvision.transforms import transforms
from config import just_config
from models.swin_transformer import SwinTransformer
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
import PIL.Image as Im
cudnn.benchmark = True
device = 'cuda' if torch.cuda.is_available() else  'cpu'
config = just_config()
model = SwinTransformer(img_size=config.DATA.IMG_SIZE,
                        patch_size=config.MODEL.SWIN.PATCH_SIZE,
                        in_chans=config.MODEL.SWIN.IN_CHANS,
                        num_classes=config.MODEL.NUM_CLASSES,
                        embed_dim=config.MODEL.SWIN.EMBED_DIM,
                        depths=config.MODEL.SWIN.DEPTHS,
                        num_heads=config.MODEL.SWIN.NUM_HEADS,
                        window_size=config.MODEL.SWIN.WINDOW_SIZE,
                        mlp_ratio=config.MODEL.SWIN.MLP_RATIO,
                        qkv_bias=config.MODEL.SWIN.QKV_BIAS,
                        qk_scale=config.MODEL.SWIN.QK_SCALE,
                        drop_rate=config.MODEL.DROP_RATE,
                        drop_path_rate=config.MODEL.DROP_PATH_RATE,
                        ape=config.MODEL.SWIN.APE,
                        patch_norm=config.MODEL.SWIN.PATCH_NORM,
                        use_checkpoint=config.TRAIN.USE_CHECKPOINT)


state_dicts = torch.load('E:\workspace\Swin-Transformer\out1\swin_tiny_patch4_window7_224\default\ckpt_epoch_399.pth')['model']
model.load_state_dict(state_dicts)
# print(model)
model.to(device)
process_img = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN,IMAGENET_DEFAULT_STD),
])
classes_name = {0:'0', 1:'1', 2:'2', 3:'3' }

img_dir = 'E:\workspace\Swin-Transformer\yanhuo\\testA\images'

testres = pd.DataFrame(columns=('image_name','label'))

model.eval()
with torch.no_grad():
    for img_name in os.listdir(img_dir):
        img_path = os.path.join(img_dir,img_name)
        img = cv2.imread(img_path)
        img = Im.fromarray(img)
        img = process_img(img)
        img = img.unsqueeze(0).to(device)
        out = model(img).cpu().numpy()
        # print(out.shape)
        index = np.argmax(out)
        print(img_name)
        print('predict the class of this pic is:{}'.format(classes_name[index]))
        testres = testres.append(pd.DataFrame({'image_name': [img_name], 'label': [classes_name[index]]}),ignore_index=True)

 

可以开始炼丹啦


image.png
tr-775.jpg
tr-282.jpg
tr-8.jpg
tr-562.jpg

参考

https://github.com/yihui8776/Swin-transformer/blob/main/get_started.md
https://blog.csdn.net/u014515463/article/details/80748125
https://blog.csdn.net/qq_33932782/article/details/117013133
https://blog.csdn.net/hi_gril/article/details/118486070?spm=1001.2014.3001.5501
https://blog.csdn.net/qq_42067064/article/details/115597482?spm=1001.2101.3001.6650.14&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Edefault-14.queryctr&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7EBlogCommendFromBaidu%7Edefault-14.queryctr&utm_relevant_index=20
https://blog.csdn.net/qq_36622589/article/details/117913064
http://data.sd.gov.cn/cmpt/cmptDetail.html?id=61

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,544评论 6 501
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,430评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 162,764评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,193评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,216评论 6 388
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,182评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,063评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,917评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,329评论 1 310
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,543评论 2 332
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,722评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,425评论 5 343
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,019评论 3 326
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,671评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,825评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,729评论 2 368
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,614评论 2 353

推荐阅读更多精彩内容