一句话解释什么是迁移学习: 即小样本下我们也可以搞图片分类识别, 让你站在巨人的肩膀上去做你想干的事。
很多情况 下我们使用深度学习模型是要解决实际问题的,比如我们接下来要用图片分类模型去识别车型,如果从0去训练模型,不仅仅要解决大量样本标注问题,还要进行各种模型调参以使模型最优,那么有没有比较简单的方式将已有 的模型快速应用到我们的业务数据上进行预测呢?
介绍
现在我们将通过使用迁移学习的方式做一个图片分类模型,迁移学习是一种流行的深度学习技术,它可以快速将已有的高精度模型应用在其它的业务场景里,与从头训练一个模型相比,这种方式能让你快速实现一个健壮的、准确的模型
接下来我们会用ResNet18做迁移学习 去预测10款汽车,resnet是一个非常强悍的模型,它包含18层神经网络,使用ImageNet数据集经过120万张图片训练得到,支持1000个类别的识别预测
数据准备
这里将通过Jsoup爬取爱卡汽车图片数据,将爬取的数据分别放到 /data/cars/车名 目录下;将会爬取如下10款汽车
"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"
---data
---cars
---奔驰
---bc1.jpg
---bc2.jpg
....
---宝马
---bmw1.jpg
---bmw2.jpg
....
爬取代码:
public static void main(String[] args) throws Exception {
String search="http://sou.xcar.com.cn/XcarSearch/car/find/keyword/%s/pbid/none/chexiLevel/none/priceLevel/70_/sort/down/pageNO/1/limit/50?rand=1591882856094";
String basePath = "/data/cars/";
String[] cars = new String[]{"奔驰","宝马","奥迪","别克","日产","大众","福特","红旗","丰田","本田"};
for(String car: cars) {
Document doc = Jsoup.connect(String.format(search,car))
.timeout(5000).get();
//获取搜索结果页Json列表
String bd = doc.body().text().replace("findcar(","");
bd = bd.substring(0,bd.length()-1);
JSONObject json= JSON.parseObject(bd);
JSONArray jsonCars = json.getJSONArray("spserList");
for(int i=0; i<jsonCars.size();i++){
JSONObject jsonCar = jsonCars.getJSONObject(i);
String imgUrl = jsonCar.getString("purl");
String persid = jsonCar.getString("persid");
//下载图片
File f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
FileUtil.mkParentDirs(f);
ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);
// ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
//解析详情页 获取更多外观图片数据
doc = Jsoup.connect(String.format("http://newcar.xcar.com.cn/photo/ps%s-s_1/",persid)).timeout(4000).get();
HtmlCleaner cleaner = new HtmlCleaner();
//转化成TagNode
TagNode node = cleaner.clean(doc.html());
//通过XPath解析出图片地址
Object[] ns2 = node.evaluateXPath("//div[@class='pic-wrap']/div[@class='pic-con']/dl/dt/a/img");
for (Object on : ns2) {
TagNode n = (TagNode) on;
imgUrl = "http:"+n.getAttributeByName("src");
// ImageDownloaderUtil.downLoadImage(imgUrl,basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
f = new File(basePath+car+imgUrl.substring(imgUrl.lastIndexOf("/")));
ImgUtil.write(ImgUtil.read(new URL(imgUrl)), f);
}
}
Thread.sleep(3000);
}
}
数据地址:链接: https://pan.baidu.com/s/16AHf-zJzcjWuGwokECKXYg 提取码: tqvv
下载ResNet模型
百度网盘下载resnet18,并将下载好的模型解压到/data/models/resnet
重构模型
1、加载已有模型
2、移除原有模型分类输出层,并添加新的分类输出层
//加载resnet18
private Model getModel() throws IOException, MalformedModelException {
Path modelDir = Paths.get("/data/models/resnet");
Model model = Model.newInstance(Device.cpu(),"MXNet");
model.load(modelDir, "resnet18_v1");
return model;
}
//删除resnet18的全连接层,并根据自己需要添加新的分类输出层
private void prepareModel(Model old){
SequentialBlock newBlock = new SequentialBlock();
SymbolBlock block = (SymbolBlock) old.getBlock();
block.removeLastBlock();
newBlock.add(block);
newBlock.add(x -> new NDList(x.singletonOrThrow().squeeze()));
newBlock.add(Linear.builder().setOutChannels(10).build());
newBlock.add(Blocks.batchFlattenBlock());
old.setBlock(newBlock);
}
模型训练
1、指定GPU个数、配置模型参数 如优化器和评估函数等
2、读取图片数据集
3、设置迭代次数进行模型训练
4、模型评估
private void train(Model model ,int epoch) throws IOException {
DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy()) // Use accuracy so we humans can understand how accurate the model is
.optDevices(Device.getDevices(1)) // Limit your GPU, using more GPU actually will slow down coverging
.addTrainingListeners(TrainingListener.Defaults.logging());
Trainer trainer = model.newTrainer(config);
for (int i = 0; i < epoch; ++i) {
int index = 0;
for (Batch batch : trainer.iterateDataset(getImgDataSet("train",dataPath))) {
trainer.trainBatch(batch);
trainer.step();
batch.close();
}
// reset training and validation evaluators at end of epoch
trainer.endEpoch();
}
}
根据下图可以看出使用迁移学习的方式进行模型训练,仅仅执行两轮训练准度已接近70%
模型测试
1、设置Translator,使的预测数据和训练数据处理方式保持一致,并将 lable Id 映射为分类
2、将刚才训练好的模型加载到应用 程序中
3、构造预测器,并对单一或批量图片进行预测,输出分类结果
public static void predic(String imagePath) throws IOException, MalformedModelException, TranslateException {
BufferedImage image;
if (imagePath.startsWith("http")) {
image = BufferedImageUtils.fromUrl(new URL(imagePath));
} else {
image = BufferedImageUtils.fromFile(Paths.get(imagePath));
}
Pipeline pipeline = new Pipeline()
.add(new CenterCrop())
.add(new Resize(224))
.add(new ToTensor())
.add(new Normalize(new float[] {0.4914f, 0.4822f, 0.4465f}, new float[] {0.2023f, 0.1994f, 0.2010f}));
//对图片数据进行预处理
ImageClassificationTranslator translator = ImageClassificationTranslator.builder()
.setPipeline(pipeline)
.setSynsetArtifactName("synset.txt")
.optApplySoftmax(true)
.build();
Path modelDir = Paths.get(modelPath);
Model model = Model.newInstance(Device.cpu(),"MXNet");
model.load(modelDir, modelName);
Predictor<BufferedImage, Classifications> predictor = model.newPredictor(translator);
Classifications classifications = predictor.predict(image);
System.out.println(classifications);
}
线上部署
这里我使用spark web框架快速开发了一个图片预测Api,让大家直观感受一下迁移学习在未知场景的泛化效果
public static void main(String[] args) {
String repTemp = "<!DOCTYPE html><html lang=\"en\"><head> <meta charset=\"utf-8\"> <style type=\"text/css\"> .content { color: #ffffff; font-size: 40px; } .bg { background: url('${img}'); background-repeat: no-repeat; background-position: center; background-size: cover; height:600px; text-align: center; line-height: 600px; } </style></head><body><div class=\"bg\"> <div class=\"content\">${txt}</div></div></body></html>";
port(8899);
get("/img_classes/cars/predict", (request, response) -> {
return repTemp.replace("${img}",request.queryParams("img_url")).replace("${txt}",TransferLearning.predict(request.queryParams("img_url")));
});
}
查看效果
谁知道下面这辆是什么车? 下载模型跑跑试试看看