djl入门-迁移学习实战之车型识别(附源码)

一句话解释什么是迁移学习: 即小样本下我们也可以搞图片分类识别, 让你站在巨人的肩膀上去做你想干的事。
很多情况 下我们使用深度学习模型是要解决实际问题的,比如我们接下来要用图片分类模型去识别车型,如果从0去训练模型,不仅仅要解决大量样本标注问题,还要进行各种模型调参以使模型最优,那么有没有比较简单的方式将已有 的模型快速应用到我们的业务数据上进行预测呢?

介绍

现在我们将通过使用迁移学习的方式做一个图片分类模型,迁移学习是一种流行的深度学习技术,它可以快速将已有的高精度模型应用在其它的业务场景里,与从头训练一个模型相比,这种方式能让你快速实现一个健壮的、准确的模型
接下来我们会用ResNet18做迁移学习 去预测10款汽车,resnet是一个非常强悍的模型,它包含18层神经网络,使用ImageNet数据集经过120万张图片训练得到,支持1000个类别的识别预测

数据准备

这里将通过Jsoup爬取爱卡汽车图片数据,将爬取的数据分别放到 /data/cars/车名 目录下;将会爬取如下10款汽车

"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"

---data
   ---cars
       ---奔驰
            ---bc1.jpg
            ---bc2.jpg
            ....
       ---宝马
            ---bmw1.jpg
            ---bmw2.jpg
            ....

爬取代码:


  public static void main(String[] args) throws Exception {
        String search="http://sou.xcar.com.cn/XcarSearch/car/find/keyword/%s/pbid/none/chexiLevel/none/priceLevel/70_/sort/down/pageNO/1/limit/50?rand=1591882856094";

        String basePath = "/data/cars/";
        String[] cars = new String[]{"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"};

        for(String car: cars) {
            Document doc = Jsoup.connect(String.format(search,car))
                    .timeout(5000).get();

            //获取搜索结果页Json列表
            String bd =  doc.body().text().replace("findcar(","");
            bd = bd.substring(0,bd.length()-1);
            JSONObject json= JSON.parseObject(bd);
            
            JSONArray jsonCars = json.getJSONArray("spserList");
            for(int i=0; i<jsonCars.size();i++){
                JSONObject jsonCar = jsonCars.getJSONObject(i);
                String imgUrl = jsonCar.getString("purl");
                String persid = jsonCar.getString("persid");

                //下载图片
                File f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                FileUtil.mkParentDirs(f);
                ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);

//                ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));

                //解析详情页 获取更多外观图片数据
                doc = Jsoup.connect(String.format("http://newcar.xcar.com.cn/photo/ps%s-s_1/",persid)).timeout(4000).get();
                HtmlCleaner cleaner = new HtmlCleaner();
                //转化成TagNode
                TagNode node = cleaner.clean(doc.html());
                //通过XPath解析出图片地址
                Object[] ns2 = node.evaluateXPath("//div[@class='pic-wrap']/div[@class='pic-con']/dl/dt/a/img");
                for (Object on : ns2) {
                    TagNode n = (TagNode) on;
                    imgUrl = "http:"+n.getAttributeByName("src");
//                    ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                    f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
                    ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);
                }
            }
            Thread.sleep(3000);

        }
    }

数据地址:链接: https://pan.baidu.com/s/16AHf-zJzcjWuGwokECKXYg 提取码: tqvv

下载ResNet模型

百度网盘下载resnet18,并将下载好的模型解压到/data/models/resnet


扫码下载
resnet18

重构模型

1、加载已有模型
2、移除原有模型分类输出层,并添加新的分类输出层

//加载resnet18
private Model getModel() throws IOException, MalformedModelException {
        Path modelDir = Paths.get("/data/models/resnet");
        Model model = Model.newInstance(Device.cpu(),"MXNet");
        model.load(modelDir, "resnet18_v1");
        return model;
 }
//删除resnet18的全连接层,并根据自己需要添加新的分类输出层
 private void prepareModel(Model old){
        SequentialBlock newBlock = new SequentialBlock();
        SymbolBlock block = (SymbolBlock) old.getBlock();
        block.removeLastBlock();
        newBlock.add(block);
        newBlock.add(x -> new NDList(x.singletonOrThrow().squeeze()));
        newBlock.add(Linear.builder().setOutChannels(10).build());
        newBlock.add(Blocks.batchFlattenBlock());
        old.setBlock(newBlock);
 }

模型训练

1、指定GPU个数、配置模型参数 如优化器和评估函数等
2、读取图片数据集
3、设置迭代次数进行模型训练
4、模型评估

private void train(Model model ,int epoch) throws IOException {
        DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
                .addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
                .optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
                .addTrainingListeners(TrainingListener.Defaults.logging());

        Trainer trainer = model.newTrainer(config);

        for (int i = 0; i < epoch; ++i) {
            int index = 0;
            for (Batch batch : trainer.iterateDataset(getImgDataSet("train",dataPath))) {
                trainer.trainBatch(batch);
                trainer.step();
                batch.close();
            }
            // reset training and validation evaluators at end of epoch
            trainer.endEpoch();
        }
}

根据下图可以看出使用迁移学习的方式进行模型训练,仅仅执行两轮训练准度已接近70%


image.png

模型测试

1、设置Translator,使的预测数据和训练数据处理方式保持一致,并将 lable Id 映射为分类
2、将刚才训练好的模型加载到应用 程序中
3、构造预测器,并对单一或批量图片进行预测,输出分类结果

public static void predic(String imagePath) throws IOException, MalformedModelException, TranslateException {
        BufferedImage image;
        if (imagePath.startsWith("http")) {
            image = BufferedImageUtils.fromUrl(new URL(imagePath));
        } else {
            image = BufferedImageUtils.fromFile(Paths.get(imagePath));
        }

        Pipeline pipeline = new Pipeline()
                .add(new CenterCrop())
                .add(new Resize(224))
                .add(new ToTensor())
                .add(new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));
        //对图片数据进行预处理
        ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
                .setPipeline(pipeline)
                .setSynsetArtifactName("synset.txt")
                .optApplySoftmax(true)
                .build();

        Path modelDir = Paths.get(modelPath);
        Model model = Model.newInstance(Device.cpu(),"MXNet");
        model.load(modelDir, modelName);
        Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);
        Classifications classifications = predictor.predict(image);
        System.out.println(classifications);
}

线上部署

这里我使用spark web框架快速开发了一个图片预测Api,让大家直观感受一下迁移学习在未知场景的泛化效果

   public static void main(String[] args) {

        String repTemp = "<!DOCTYPE html><html lang=\"en\"><head>    <meta charset=\"utf-8\">    <style type=\"text/css\">        .content {            color: #ffffff;            font-size: 40px;        }        .bg {            background: url('${img}');            background-repeat: no-repeat;            background-position: center;            background-size: cover;            height:600px;            text-align: center;            line-height: 600px;        }    </style></head><body><div class=\"bg\">    <div class=\"content\">${txt}</div></div></body></html>";

        port(8899);
         
        get("/img_classes/cars/predict", (request, response) -> {
           
            return repTemp.replace("${img}",request.queryParams("img_url")).replace("${txt}",TransferLearning.predict(request.queryParams("img_url")));
            
        });
    }

查看效果

image.png

谁知道下面这辆是什么车? 下载模型跑跑试试看看

image.png

http://localhost:8899/img_classes/cars/predict?img_url=https://car3.autoimg.cn/cardfs/product/g24/M06/89/22/1024x0_1_q95_autohomecar__ChwFjl6zaqOAEtu7AAa92Uw57ys354.jpg

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 216,744评论 6 502
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 92,505评论 3 392
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 163,105评论 0 353
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,242评论 1 292
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 67,269评论 6 389
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,215评论 1 299
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,096评论 3 418
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 38,939评论 0 274
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,354评论 1 311
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 37,573评论 2 333
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 39,745评论 1 348
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,448评论 5 344
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,048评论 3 327
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 31,683评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,838评论 1 269
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 47,776评论 2 369
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 44,652评论 2 354