基于现有 TensorFlow 模型构建 Android 应用

在之前写的一篇文章 TensorFlow,从一个 Android Demo 开始 中通过编译官方的 Demo 接触到了 TensorFlow 实际使用场景。这篇文章打算从一个Android 开发者的角度切入,看看构建一个基于 TensorFlow 的 Android 应用的完整流程。

相关代码可查看:GitHub 项目地址

通过 TensorFlow 用已有模型构建 Android 应用

在 Google 的 TensorFlow examples project 中,有一个 Sample 叫作 TF Classify,它通过使用 Google Inception 模型对实时的相机图像帧进行分类,并显示展示当前图像的分类推断结果。

TF Classify

下面我们就基于这个现有模型,在 Android 平台上实现一个可以对物品进行分类的图像识别应用。

获取数据模型

这里可以直接下载 Google 提供的一个数据模型 inception5h.zip ,其中 .pb 后缀的文件是已经训练好的模型,而 .txt 对应的是训练数据包含的所有标签。

这个模型可对 1008 种物品识别分类,具体有哪些类可以查看标签信息,至于每个类别到底训练了多少张图片就不得而知了。

在 Android 项目中引入 TensorFlow

跟在项目中集成其他第三库一样,先在 build.gradle 中添加对 TensorFlow 的依赖。

compile 'org.tensorflow:tensorflow-android:1.6.0'

这里我们直接使用了 Google 为我们编译好的 TensorFlow 现成库了,如果你想自行对 TensorFlow 进行 NDK 交叉编译得到库文件也可以。

图像识别功能的实现

复制模型文件到项目 assets 文件夹:
如下图所示,我们在项目 assets 文件夹下创建一个 model 文件夹,并把之前下载的 inception5h.zip 解压后的全部文件复制到该文件夹下。

添加模型调用的相关类
因为我们要实现的功能和官方 demo 相似,只是训练的有所模型不同。既然对模型的使用方式是一样的,那这里就直接使用 Google demo 项目中提供的 Classifier.javaTensorFlowImageClassifier.java 这两个类来实现。

我们可以先跳过这部分内容的具体实现,等到对整体流程有个大致认识后再回过头来消化掉,这样可以更好地去理解。

这里我们重点关注下面两个方法,一个是 TensorFlowImageClassifier 的静态方法 create 方法:

   /**
     * Initializes a native TensorFlow session for classifying images.
     *
     * @param assetManager The asset manager to be used to load assets.
     * @param modelFilename The filepath of the model GraphDef protocol buffer.
     * @param labelFilename The filepath of label file for classes.
     * @param inputSize The input size. A square image of inputSize x inputSize is assumed.
     * @param imageMean The assumed mean of the image values.
     * @param imageStd The assumed std of the image values.
     * @param inputName The label of the image input node.
     * @param outputName The label of the output node.
     * @throws IOException
     */
    public static Classifier create(AssetManager assetManager, String modelFilename, String labelFilename,
            int inputSize, int imageMean, float imageStd, String inputName, String outputName) 

该方法需要传入模型相关的参数进行初始化,完成后返回一个 Classifier 实例。

通过 Classifier 对象,我们可以调用其 recognizeImage 方法来识别我们传入的 bitmap 图像数据,该方法会返回图像类别后对物品类别进行推断的标签结果:

/**
 * 进行图片识别
 */
 public List<Recognition> recognizeImage(final Bitmap bitmap) 

相关主要功能代码的实现:
相关代码可查看:GitHub 项目地址

public class MainActivity extends AppCompatActivity implements View.OnClickListener {
    ...
    
    // 模型相关配置
    private static final int INPUT_SIZE = 224;
    private static final int IMAGE_MEAN = 117;
    private static final float IMAGE_STD = 1;
    private static final String INPUT_NAME = "input";
    private static final String OUTPUT_NAME = "output";
    private static final String MODEL_FILE = "file:///android_asset/model/tensorflow_inception_graph.pb";
    private static final String LABEL_FILE = "file:///android_asset/model/imagenet_comp_graph_label_strings.txt";

    private Executor executor;
    private Uri currentTakePhotoUri;

    private TextView result;
    private ImageView ivPicture;
    private Classifier classifier;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        if (!isTaskRoot()) {
            finish();
        }

        setContentView(R.layout.activity_main);

        findViewById(R.id.iv_choose_picture).setOnClickListener(this);
        findViewById(R.id.iv_take_photo).setOnClickListener(this);

        ivPicture = findViewById(R.id.iv_picture);
        result = findViewById(R.id.tv_classifier_info);

        // 避免耗时任务占用 CPU 时间片造成UI绘制卡顿,提升启动页面加载速度
        Looper.myQueue().addIdleHandler(idleHandler);

    }

    /**
     *  主线程消息队列空闲时(视图第一帧绘制完成时)处理耗时事件
     */
    MessageQueue.IdleHandler idleHandler = new MessageQueue.IdleHandler() {
        @Override
        public boolean queueIdle() {
            // 初始化 Classifier
            if (classifier == null) {
                // 创建 TensorFlowImageClassifier
               classifier = TensorFlowImageClassifier.create(MainActivity.this.getAssets(),
                       MODEL_FILE, LABEL_FILE, INPUT_SIZE, IMAGE_MEAN, IMAGE_STD, INPUT_NAME, OUTPUT_NAME);
            }

            // 初始化线程池
            executor = new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
                @Override
                public Thread newThread(@NonNull Runnable r) {
                    Thread thread = new Thread(r);
                    thread.setDaemon(true);
                    thread.setName("ThreadPool-ImageClassifier");
                    return thread;
                }
            });
            // 请求权限
            requestMultiplePermissions();
            // 返回 false 时只会回调一次
            return false;
        }
    };

    @Override
    public void onClick(View view) {
        switch (view.getId()) {
            case R.id.iv_choose_picture :
                choosePicture();
                break;
            case R.id.iv_take_photo :
                takePhoto();
                break;
            default:break;
        }
    }

    /**
     * 选择一张图片并裁剪获得一个小图
     */
    private void choosePicture() {
        Intent intent = new Intent(Intent.ACTION_GET_CONTENT);
        intent.setType("image/*");
        startActivityForResult(intent, PICTURE_REQUEST_CODE);
    }

    /**
     * 使用系统相机拍照
     */
    private void takePhoto() {
        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, CAMERA_PERMISSIONS_REQUEST_CODE);
        } else {
            openSystemCamera();
        }
    }

    /**
     * 打开系统相机
     */
    private void openSystemCamera() {
        //调用系统相机
        Intent takePhotoIntent = new Intent();
        takePhotoIntent.setAction(MediaStore.ACTION_IMAGE_CAPTURE);

        //这句作用是如果没有相机则该应用不会闪退,要是不加这句则当系统没有相机应用的时候该应用会闪退
        if (takePhotoIntent.resolveActivity(getPackageManager()) == null) {
            Toast.makeText(this, "当前系统没有可用的相机应用", Toast.LENGTH_SHORT).show();
            return;
        }

        String fileName = "TF_" + System.currentTimeMillis() + ".jpg";
        File photoFile = new File(FileUtil.getPhotoCacheFolder(), fileName);

        if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.N) {
            //通过FileProvider创建一个content类型的Uri
            currentTakePhotoUri = FileProvider.getUriForFile(this, "gdut.bsx.tensorflowtraining.fileprovider", photoFile);
            //对目标应用临时授权该 Uri 所代表的文件
            takePhotoIntent.addFlags(Intent.FLAG_GRANT_READ_URI_PERMISSION);
        } else {
            currentTakePhotoUri = Uri.fromFile(photoFile);
        }

        //将拍照结果保存至 outputFile 的Uri中,不保留在相册中
        takePhotoIntent.putExtra(MediaStore.EXTRA_OUTPUT, currentTakePhotoUri);
        startActivityForResult(takePhotoIntent, TAKE_PHOTO_REQUEST_CODE);
    }

    /**
     * 处理图片
     * @param imageUri
     */
    private void handleInputPhoto(Uri imageUri) {
        // 加载图片
        GlideApp.with(MainActivity.this).asBitmap().listener(new RequestListener<Bitmap>() {

            @Override
            public boolean onLoadFailed(@Nullable GlideException e, Object model, Target<Bitmap> target, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onLoadFailed");
                Toast.makeText(MainActivity.this, "图片加载失败", Toast.LENGTH_SHORT).show();
                return false;
            }

            @Override
            public boolean onResourceReady(Bitmap resource, Object model, Target<Bitmap> target, DataSource dataSource, boolean isFirstResource) {
                Log.d(TAG,"handleInputPhoto onResourceReady");
                startImageClassifier(resource);
                return false;
            }
        }).load(imageUri).into(ivPicture);

        result.setText("Processing...");
    }

    /**
     * 开始图片识别匹配
     * @param bitmap
     */
    private void startImageClassifier(final Bitmap bitmap) {
        executor.execute(new Runnable() {
            @Override
            public void run() {
                try {
                    Log.i(TAG, Thread.currentThread().getName() + " startImageClassifier");
                    Bitmap croppedBitmap = getScaleBitmap(bitmap, INPUT_SIZE);

                    final List<Classifier.Recognition> results = classifier.recognizeImage(croppedBitmap);
                    Log.i(TAG, "startImageClassifier results: " + results);
                    runOnUiThread(new Runnable() {
                        @Override
                        public void run() {
                            result.setText(String.format("results: %s", results));
                        }
                    });
                } catch (IOException e) {
                    Log.e(TAG, "startImageClassifier getScaleBitmap " + e.getMessage());
                }
            }
        });
    }

   /**
     * 请求相机和外部存储权限
     */
    private void requestMultiplePermissions() {

        String storagePermission = Manifest.permission.WRITE_EXTERNAL_STORAGE;
        String cameraPermission = Manifest.permission.CAMERA;

        int hasStoragePermission = ActivityCompat.checkSelfPermission(this, storagePermission);
        int hasCameraPermission = ActivityCompat.checkSelfPermission(this, cameraPermission);

        List<String> permissions = new ArrayList<>();
        if (hasStoragePermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(storagePermission);
        }

        if (hasCameraPermission != PackageManager.PERMISSION_GRANTED) {
            permissions.add(cameraPermission);
        }
        
        if (!permissions.isEmpty()) {
            String[] params = permissions.toArray(new String[permissions.size()]);
            ActivityCompat.requestPermissions(this, params, PERMISSIONS_REQUEST);
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);

        if (resultCode == RESULT_OK) {
            if (requestCode == PICTURE_REQUEST_CODE) {
                // 处理选择的图片
                handleInputPhoto(data.getData());
            } else if (requestCode == OPEN_SETTING_REQUEST_COED){
                requestMultiplePermissions();
            } else if (requestCode == TAKE_PHOTO_REQUEST_CODE) {
                // 如果拍照成功,加载图片并识别
                handleInputPhoto(currentTakePhotoUri);
            }
        }
    }

    /**
     * 对图片进行缩放
     * @param bitmap
     * @param size
     * @return
     * @throws IOException
     */
    private static Bitmap getScaleBitmap(Bitmap bitmap, int size) throws IOException {
        int width = bitmap.getWidth();
        int height = bitmap.getHeight();
        float scaleWidth = ((float) size) / width;
        float scaleHeight = ((float) size) / height;
        Matrix matrix = new Matrix();
        matrix.postScale(scaleWidth, scaleHeight);
        return Bitmap.createBitmap(bitmap, 0, 0, width, height, matrix, true);
    }
}

运行效果

图片选择和拍照获取界面:

物品识别结果展示界面:

相关代码可查看:GitHub 项目地址

是不是觉得通过 TensorFlow 在现有的数据模型基础下,我们可以很简单就完成了一个简单的图像识别应用。

在使用这个模型来推断物品类型的过程中,发现好像有时候准确率不是那么高,这时候改怎么办。如果说只是想识别一些特定种类的物品,哪有又该怎么办?

在之前一篇文章中我有提到过,机器学习是依靠对大量有标签的样本数据进行反复训练后才逐步得到的最佳模型。对未知无标签样本的推断依赖这个模型的准确程度。所以我们可以通过对现有模型进行迁移训练(retrain)来定制我们自己的模型。

下面就通过对现有的 Google Inception-V3 模型进行 retrain ,对 5 种花朵样本数据的进行训练,来完成一个可以识别五种花朵的模型。

具体实现方式可以参考我的另外一篇文章:通过迁移训练来定制 TensorFlow 模型

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 219,701评论 6 508
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 93,649评论 3 396
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 166,037评论 0 356
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 58,994评论 1 295
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 68,018评论 6 395
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 51,796评论 1 308
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 40,481评论 3 420
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 39,370评论 0 276
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 45,868评论 1 319
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 38,014评论 3 338
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 40,153评论 1 352
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 35,832评论 5 346
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 41,494评论 3 331
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 32,039评论 0 22
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 33,156评论 1 272
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 48,437评论 3 373
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 45,131评论 2 356

推荐阅读更多精彩内容