Android简单集成TFLite完成图像识别

背景简介

TensorFlow Lite 是 Google 发布的,应用开发者可以在移动设备上部署人工智能,比如实现图像识别,人脸识别等。

部署步骤

1、将tflite模型拷贝到项目中
image.png

这里我在网上找了2个模型,一个是人脸关键点检测face_landmark.tflite和coco_ssd_mobilenet_v1_1.0_quant.tflite

2、添加TFLite远程依赖库和CameraX
    def camerax_version = "1.1.0-beta01"
    implementation "androidx.camera:camera-core:${camerax_version}"
    implementation "androidx.camera:camera-camera2:${camerax_version}"
    implementation "androidx.camera:camera-lifecycle:${camerax_version}"
    implementation "androidx.camera:camera-view:${camerax_version}"
    implementation "androidx.camera:camera-video:${camerax_version}"


    // Tensorflow lite dependencies
    implementation 'org.tensorflow:tensorflow-lite:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
    implementation 'org.tensorflow:tensorflow-lite-support:0.4.2'

    implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.3.1'

这里在做图像识别的时候,借助了CameraX(对Camera2的封装)进行图片分析,然后送入TFLite进行识别

3、封装识别工具类DetectHelper
public class DetectHelper {

    private final int OBJECT_COUNT = 10;

    private Interpreter mTFLite;
    private List<String> mLabels;
    private float[][][] locations;
    private float[][] labelIndices;
    private float[][] scores;

    public DetectHelper(Interpreter tflite, List<String> labels) {
        this.mTFLite = tflite;
        this.mLabels = labels;
    }

    public List<ObjectPrediction> recognizeImage(TensorImage image) {
        locations = new float[1][OBJECT_COUNT][4];
        labelIndices = new float[1][OBJECT_COUNT];
        scores = new float[1][OBJECT_COUNT];
        Map<Integer, Object> outputBuffer = new HashMap<>();
        outputBuffer.put(0, locations);
        outputBuffer.put(1, labelIndices);
        outputBuffer.put(2, scores);
        outputBuffer.put(3, new float[1]);
        mTFLite.runForMultipleInputsOutputs(new ByteBuffer[]{image.getBuffer()}, outputBuffer);

        List<ObjectPrediction> predictions = new ArrayList<>();
        for (int i = 0; i < OBJECT_COUNT; i++) {
            predictions.add(new ObjectPrediction(new RectF(locations[0][i][1], locations[0][i][0], locations[0][i][3], locations[0][i][2]), mLabels.get(1 + ((int) labelIndices[0][i])), scores[0][i]));
        }
        return predictions;
    }

}

需要注意outputBuffer这里的Map,这是定义的输出集合,每个key代表的维度不一样,比如0指的是物体识别框的坐标,1指的是识别物体的名称。一定要和模型对应上,不然在识别的时候会报异常或者无法识别,具体根据模型而定

# [Cannot copy from a TensorFlowLite tensor (Identity_1) with shape [1, 3087, 2] to a Java object with shape [1, 3087]](https://stackoverflow.com/questions/63457635/cannot-copy-from-a-tensorflowlite-tensor-identity-1-with-shape-1-3087-2-to)

4、Activity层调用,全部代码贴出来
public class Camera2Activity extends AppCompatActivity {

    private final String TAG = CameraActivity.class.getSimpleName();

    private final float ACCURACY_THRESHOLD = 0.5f;
    private final String MODEL_PATH = "coco_ssd_mobilenet_v1_1.0_quant.tflite";
    private final String LABELS_PATH = "coco_ssd_mobilenet_v1_1.0_labels.txt";

    private ActivityCamera2Binding mBinding;

    private ExecutorService executor = Executors.newSingleThreadExecutor();
    private String[] permissions = new String[]{Manifest.permission.CAMERA};
    private int lensFacing = CameraSelector.LENS_FACING_BACK;
    private boolean isFrontFacing = lensFacing == CameraSelector.LENS_FACING_FRONT;
    private boolean pauseAnalysis = false;
    private int imageRotationDegrees = 0;
    private TensorImage tfImageBuffer = new TensorImage(DataType.UINT8);

    private Bitmap bitmapBuffer;

    private ImageProcessor tfImageProcessor;
    private Size tfInputSize;
    private Interpreter tflite;
    private DetectHelper detector;

    private NnApiDelegate nnApiDelegate = new NnApiDelegate();
    private ProcessCameraProvider cameraProvider;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        mBinding = ActivityCamera2Binding.inflate(getLayoutInflater());
        setContentView(mBinding.getRoot());
        initTfLite();
        initDetector();
        initTfInputSize();
        mBinding.cameraCaptureButton.setOnClickListener(view -> {
            view.setEnabled(false);

            if (pauseAnalysis) {
                pauseAnalysis = false;
                mBinding.imagePredicted.setVisibility(View.GONE);
            } else {
                pauseAnalysis = true;
                Matrix matrix = new Matrix();
                matrix.postRotate((float) imageRotationDegrees);
                if (isFrontFacing) {
                    matrix.postScale(-1f, 1f);
                }

                Bitmap uprightImage = Bitmap.createBitmap(
                        bitmapBuffer, 0, 0, bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), matrix, true
                );
                mBinding.imagePredicted.setImageBitmap(uprightImage);
                mBinding.imagePredicted.setVisibility(View.VISIBLE);
            }

            view.setEnabled(true);


        });

        if (!PermissionsUtils.checkSelfPermission(this, permissions)) {
            PermissionsUtils.requestPermissions(this, permissions);
        } else {
            setUpCamera();
        }
    }

    private void initDetector() {
        try {
            detector = new DetectHelper(tflite, FileUtil.loadLabels(this, LABELS_PATH));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void initTfLite() {
        try {
            tflite = new Interpreter(FileUtil.loadMappedFile(this, MODEL_PATH), new Interpreter.Options().addDelegate(nnApiDelegate));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void initTfInputSize() {
        int inputIndex = 0;
        int[] inputShape = tflite.getInputTensor(inputIndex).shape();
        tfInputSize = new Size(inputShape[2], inputShape[1]);
    }

    private void initImageProcessor() {
        int cropSize = Math.min(bitmapBuffer.getWidth(), bitmapBuffer.getHeight());
        tfImageProcessor = new ImageProcessor.Builder()
                .add(new ResizeWithCropOrPadOp(cropSize, cropSize))
                .add(new ResizeOp(tfInputSize.getHeight(), tfInputSize.getWidth(), ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
                .add(new Rot90Op(-imageRotationDegrees / 90))
                .add(new NormalizeOp(0f, 1f))
                .build();
    }

    @Override
    protected void onDestroy() {
        super.onDestroy();
        try {
            executor.shutdown();
            executor.awaitTermination(1000, TimeUnit.MILLISECONDS);
        } catch (Exception e) {
            e.printStackTrace();
        }
        tflite.close();
        nnApiDelegate.close();
    }

    @Override
    public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
        super.onRequestPermissionsResult(requestCode, permissions, grantResults);
        if (requestCode == PermissionsUtils.REQUEST_CODE_PERMISSIONS && PermissionsUtils.checkHasAllPermission(permissions, grantResults)) {
            setUpCamera();
        } else {
            finish();
        }

    }

    private void setUpCamera() {
        Log.d("viewFinder", "bindCameraUseCases");
        ListenableFuture<ProcessCameraProvider> cameraProviderFuture = ProcessCameraProvider.getInstance(Camera2Activity.this);
        cameraProviderFuture.addListener(() -> {
            Log.d("viewFinder", "cameraListener");
            try {
                cameraProvider = cameraProviderFuture.get();
                bindCameraUseCases();
            } catch (Exception e) {
                e.printStackTrace();
            }

        }, ContextCompat.getMainExecutor(this));
    }

    private void bindCameraUseCases() {
        Preview preview = new Preview.Builder()
                .setTargetAspectRatio(AspectRatio.RATIO_4_3)
                .setTargetRotation(mBinding.viewFinder.getDisplay().getRotation())
                .build();

        ImageAnalysis imageAnalysis = new ImageAnalysis.Builder()
                .setTargetAspectRatio(AspectRatio.RATIO_4_3)
                .setTargetRotation(mBinding.viewFinder.getDisplay().getRotation())
                .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
                .setOutputImageFormat(ImageAnalysis.OUTPUT_IMAGE_FORMAT_RGBA_8888)
                .build();

        final int[] frameCounter = {0};
        final long[] lastFpsTimestamp = {System.currentTimeMillis()};

        imageAnalysis.setAnalyzer(executor, image -> {

            Log.d("viewFinder", "setAnalyzer");
            if (bitmapBuffer == null) {
                imageRotationDegrees = image.getImageInfo().getRotationDegrees();
                bitmapBuffer = Bitmap.createBitmap(
                        image.getWidth(), image.getHeight(), Bitmap.Config.ARGB_8888
                );
                initImageProcessor();
            }

            if (pauseAnalysis) {
                image.close();
                return;
            }

            try {
//                bitmapBuffer.copyPixelsFromBuffer(image.getPlanes()[0].getBuffer());
                ImageProxy.PlaneProxy[] planes = image.getPlanes();
                ByteBuffer buffer = planes[0].getBuffer();
                bitmapBuffer.copyPixelsFromBuffer(buffer);
                image.close();
            } finally {

            }
            tfImageBuffer.load(bitmapBuffer);
            TensorImage tfImage = tfImageProcessor.process(tfImageBuffer);
            List<ObjectPrediction> predictions = detector.recognizeImage(tfImage);
            Collections.sort(predictions, (o1, o2) -> {
                if (o1.score > o2.score) {
                    return -1;
                }
                return 1;
            });
            reportPrediction(predictions.get(0));
            int frameCount = 10;
            if (++frameCounter[0] % frameCount == 0) {
                frameCounter[0] = 0;
                long now = System.currentTimeMillis();
                long delta = now - lastFpsTimestamp[0];
                float fps = (float) 1000 * frameCount / delta;
                Log.d(
                        TAG,
                        "FPS: " + fps + " with tensorSize: " + tfImage.getWidth() + " x " + tfImage.getHeight()
                );
                lastFpsTimestamp[0] = now;
            }
        });


        CameraSelector cameraSelector = new CameraSelector.Builder().requireLensFacing(lensFacing).build();

        cameraProvider.unbindAll();
        cameraProvider.bindToLifecycle(
                this,
                cameraSelector,
                preview,
                imageAnalysis
        );
        preview.setSurfaceProvider(mBinding.viewFinder.getSurfaceProvider());
    }

    private void reportPrediction(ObjectPrediction prediction) {
        runOnUiThread(() -> {
            if (prediction == null || prediction.score < ACCURACY_THRESHOLD) {
                mBinding.boxPrediction.setVisibility(View.GONE);
                mBinding.textPrediction.setVisibility(View.GONE);
                return;
            }
            RectF location = mapOutputCoordinates(prediction.location);

            mBinding.textPrediction.setText(prediction.score + " " + prediction.label);
            ConstraintLayout.LayoutParams params = (ConstraintLayout.LayoutParams) mBinding.boxPrediction.getLayoutParams();
            params.height = Math.min(
                    mBinding.viewFinder.getHeight(),
                    (int) location.bottom - (int) location.top
            );
            params.width = Math.min(
                    mBinding.viewFinder.getWidth(),
                    (int) location.right - (int) location.left
            );
            params.topMargin = (int) location.top;
            params.leftMargin = (int) location.left;
            mBinding.boxPrediction.setLayoutParams(params);

            mBinding.boxPrediction.setVisibility(View.VISIBLE);
            mBinding.textPrediction.setVisibility(View.VISIBLE);
        });
    }

    private RectF mapOutputCoordinates(RectF location) {
        RectF previewLocation = new RectF(
                location.left * mBinding.viewFinder.getWidth(),
                location.top * mBinding.viewFinder.getHeight(),
                location.right * mBinding.viewFinder.getWidth(),
                location.bottom * mBinding.viewFinder.getHeight()
        );
        boolean isFrontFacing = lensFacing == CameraSelector.LENS_FACING_FRONT;
        RectF correctedLocation = null;
        if (isFrontFacing) {
            correctedLocation = new RectF(
                    mBinding.viewFinder.getWidth() - previewLocation.right,
                    previewLocation.top,
                    mBinding.viewFinder.getWidth() - previewLocation.left,
                    previewLocation.bottom
            );
        } else {
            correctedLocation = previewLocation;
        }
        float margin = 0.1f;
        float requestedRatio = 4f / 3f;
        float midX = (correctedLocation.left + correctedLocation.right) / 2f;
        float midY = (correctedLocation.top + correctedLocation.bottom) / 2f;
        if (mBinding.viewFinder.getWidth() < mBinding.viewFinder.getHeight()) {
            return new RectF(
                    midX - (1f + margin) * requestedRatio * correctedLocation.width() / 2f,
                    midY - (1f - margin) * correctedLocation.height() / 2f,
                    midX + (1f + margin) * requestedRatio * correctedLocation.width() / 2f,
                    midY + (1f - margin) * correctedLocation.height() / 2f
            );
        } else {
            return new RectF(
                    midX - (1f - margin) * correctedLocation.width() / 2f,
                    midY - (1f + margin) * requestedRatio * correctedLocation.height() / 2f,
                    midX + (1f - margin) * correctedLocation.width() / 2f,
                    midY + (1f + margin) * requestedRatio * correctedLocation.height() / 2f
            );
        }
    }


}

基本的全部逻辑都在这里了,别忘记在AndroidManifest.xml添加相机的权限

    <!-- Declare features -->
    <uses-feature android:name="android.hardware.camera" />

    <!-- Declare permissions -->
    <uses-permission android:name="android.permission.CAMERA" />

第一次集成,网上没有太多文档,希望对大家有帮助,然后也欢迎一起学习讨论~

©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 205,033评论 6 478
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 87,725评论 2 381
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 151,473评论 0 338
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 54,846评论 1 277
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 63,848评论 5 368
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 48,691评论 1 282
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 38,053评论 3 399
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 36,700评论 0 258
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 42,856评论 1 300
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 35,676评论 2 323
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 37,787评论 1 333
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 33,430评论 4 321
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 39,034评论 3 307
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 29,990评论 0 19
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 31,218评论 1 260
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 45,174评论 2 352
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 42,526评论 2 343

推荐阅读更多精彩内容