背景简介
TensorFlow Lite 是 Google 发布的,应用开发者可以在移动设备上部署人工智能,比如实现图像识别,人脸识别等。
部署步骤
1、将tflite模型拷贝到项目中
这里我在网上找了2个模型,一个是人脸关键点检测face_landmark.tflite和coco_ssd_mobilenet_v1_1.0_quant.tflite
2、添加TFLite远程依赖库和CameraX
def camerax_version = "1.1.0-beta01"
implementation "androidx.camera:camera-core:${camerax_version}"
implementation "androidx.camera:camera-camera2:${camerax_version}"
implementation "androidx.camera:camera-lifecycle:${camerax_version}"
implementation "androidx.camera:camera-view:${camerax_version}"
implementation "androidx.camera:camera-video:${camerax_version}"
// Tensorflow lite dependencies
implementation 'org.tensorflow:tensorflow-lite:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-gpu:2.9.0'
implementation 'org.tensorflow:tensorflow-lite-support:0.4.2'
implementation 'androidx.lifecycle:lifecycle-runtime-ktx:2.3.1'
这里在做图像识别的时候,借助了CameraX(对Camera2的封装)进行图片分析,然后送入TFLite进行识别
3、封装识别工具类DetectHelper
public class DetectHelper {
private final int OBJECT_COUNT = 10;
private Interpreter mTFLite;
private List<String> mLabels;
private float[][][] locations;
private float[][] labelIndices;
private float[][] scores;
public DetectHelper(Interpreter tflite, List<String> labels) {
this.mTFLite = tflite;
this.mLabels = labels;
}
public List<ObjectPrediction> recognizeImage(TensorImage image) {
locations = new float[1][OBJECT_COUNT][4];
labelIndices = new float[1][OBJECT_COUNT];
scores = new float[1][OBJECT_COUNT];
Map<Integer, Object> outputBuffer = new HashMap<>();
outputBuffer.put(0, locations);
outputBuffer.put(1, labelIndices);
outputBuffer.put(2, scores);
outputBuffer.put(3, new float[1]);
mTFLite.runForMultipleInputsOutputs(new ByteBuffer[]{image.getBuffer()}, outputBuffer);
List<ObjectPrediction> predictions = new ArrayList<>();
for (int i = 0; i < OBJECT_COUNT; i++) {
predictions.add(new ObjectPrediction(new RectF(locations[0][i][1], locations[0][i][0], locations[0][i][3], locations[0][i][2]), mLabels.get(1 + ((int) labelIndices[0][i])), scores[0][i]));
}
return predictions;
}
}
需要注意outputBuffer这里的Map,这是定义的输出集合,每个key代表的维度不一样,比如0指的是物体识别框的坐标,1指的是识别物体的名称。一定要和模型对应上,不然在识别的时候会报异常或者无法识别,具体根据模型而定
# [Cannot copy from a TensorFlowLite tensor (Identity_1) with shape [1, 3087, 2] to a Java object with shape [1, 3087]](https://stackoverflow.com/questions/63457635/cannot-copy-from-a-tensorflowlite-tensor-identity-1-with-shape-1-3087-2-to)
4、Activity层调用,全部代码贴出来
public class Camera2Activity extends AppCompatActivity {
private final String TAG = CameraActivity.class.getSimpleName();
private final float ACCURACY_THRESHOLD = 0.5f;
private final String MODEL_PATH = "coco_ssd_mobilenet_v1_1.0_quant.tflite";
private final String LABELS_PATH = "coco_ssd_mobilenet_v1_1.0_labels.txt";
private ActivityCamera2Binding mBinding;
private ExecutorService executor = Executors.newSingleThreadExecutor();
private String[] permissions = new String[]{Manifest.permission.CAMERA};
private int lensFacing = CameraSelector.LENS_FACING_BACK;
private boolean isFrontFacing = lensFacing == CameraSelector.LENS_FACING_FRONT;
private boolean pauseAnalysis = false;
private int imageRotationDegrees = 0;
private TensorImage tfImageBuffer = new TensorImage(DataType.UINT8);
private Bitmap bitmapBuffer;
private ImageProcessor tfImageProcessor;
private Size tfInputSize;
private Interpreter tflite;
private DetectHelper detector;
private NnApiDelegate nnApiDelegate = new NnApiDelegate();
private ProcessCameraProvider cameraProvider;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
mBinding = ActivityCamera2Binding.inflate(getLayoutInflater());
setContentView(mBinding.getRoot());
initTfLite();
initDetector();
initTfInputSize();
mBinding.cameraCaptureButton.setOnClickListener(view -> {
view.setEnabled(false);
if (pauseAnalysis) {
pauseAnalysis = false;
mBinding.imagePredicted.setVisibility(View.GONE);
} else {
pauseAnalysis = true;
Matrix matrix = new Matrix();
matrix.postRotate((float) imageRotationDegrees);
if (isFrontFacing) {
matrix.postScale(-1f, 1f);
}
Bitmap uprightImage = Bitmap.createBitmap(
bitmapBuffer, 0, 0, bitmapBuffer.getWidth(), bitmapBuffer.getHeight(), matrix, true
);
mBinding.imagePredicted.setImageBitmap(uprightImage);
mBinding.imagePredicted.setVisibility(View.VISIBLE);
}
view.setEnabled(true);
});
if (!PermissionsUtils.checkSelfPermission(this, permissions)) {
PermissionsUtils.requestPermissions(this, permissions);
} else {
setUpCamera();
}
}
private void initDetector() {
try {
detector = new DetectHelper(tflite, FileUtil.loadLabels(this, LABELS_PATH));
} catch (Exception e) {
e.printStackTrace();
}
}
private void initTfLite() {
try {
tflite = new Interpreter(FileUtil.loadMappedFile(this, MODEL_PATH), new Interpreter.Options().addDelegate(nnApiDelegate));
} catch (Exception e) {
e.printStackTrace();
}
}
private void initTfInputSize() {
int inputIndex = 0;
int[] inputShape = tflite.getInputTensor(inputIndex).shape();
tfInputSize = new Size(inputShape[2], inputShape[1]);
}
private void initImageProcessor() {
int cropSize = Math.min(bitmapBuffer.getWidth(), bitmapBuffer.getHeight());
tfImageProcessor = new ImageProcessor.Builder()
.add(new ResizeWithCropOrPadOp(cropSize, cropSize))
.add(new ResizeOp(tfInputSize.getHeight(), tfInputSize.getWidth(), ResizeOp.ResizeMethod.NEAREST_NEIGHBOR))
.add(new Rot90Op(-imageRotationDegrees / 90))
.add(new NormalizeOp(0f, 1f))
.build();
}
@Override
protected void onDestroy() {
super.onDestroy();
try {
executor.shutdown();
executor.awaitTermination(1000, TimeUnit.MILLISECONDS);
} catch (Exception e) {
e.printStackTrace();
}
tflite.close();
nnApiDelegate.close();
}
@Override
public void onRequestPermissionsResult(int requestCode, @NonNull String[] permissions, @NonNull int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (requestCode == PermissionsUtils.REQUEST_CODE_PERMISSIONS && PermissionsUtils.checkHasAllPermission(permissions, grantResults)) {
setUpCamera();
} else {
finish();
}
}
private void setUpCamera() {
Log.d("viewFinder", "bindCameraUseCases");
ListenableFuture<ProcessCameraProvider> cameraProviderFuture = ProcessCameraProvider.getInstance(Camera2Activity.this);
cameraProviderFuture.addListener(() -> {
Log.d("viewFinder", "cameraListener");
try {
cameraProvider = cameraProviderFuture.get();
bindCameraUseCases();
} catch (Exception e) {
e.printStackTrace();
}
}, ContextCompat.getMainExecutor(this));
}
private void bindCameraUseCases() {
Preview preview = new Preview.Builder()
.setTargetAspectRatio(AspectRatio.RATIO_4_3)
.setTargetRotation(mBinding.viewFinder.getDisplay().getRotation())
.build();
ImageAnalysis imageAnalysis = new ImageAnalysis.Builder()
.setTargetAspectRatio(AspectRatio.RATIO_4_3)
.setTargetRotation(mBinding.viewFinder.getDisplay().getRotation())
.setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)
.setOutputImageFormat(ImageAnalysis.OUTPUT_IMAGE_FORMAT_RGBA_8888)
.build();
final int[] frameCounter = {0};
final long[] lastFpsTimestamp = {System.currentTimeMillis()};
imageAnalysis.setAnalyzer(executor, image -> {
Log.d("viewFinder", "setAnalyzer");
if (bitmapBuffer == null) {
imageRotationDegrees = image.getImageInfo().getRotationDegrees();
bitmapBuffer = Bitmap.createBitmap(
image.getWidth(), image.getHeight(), Bitmap.Config.ARGB_8888
);
initImageProcessor();
}
if (pauseAnalysis) {
image.close();
return;
}
try {
// bitmapBuffer.copyPixelsFromBuffer(image.getPlanes()[0].getBuffer());
ImageProxy.PlaneProxy[] planes = image.getPlanes();
ByteBuffer buffer = planes[0].getBuffer();
bitmapBuffer.copyPixelsFromBuffer(buffer);
image.close();
} finally {
}
tfImageBuffer.load(bitmapBuffer);
TensorImage tfImage = tfImageProcessor.process(tfImageBuffer);
List<ObjectPrediction> predictions = detector.recognizeImage(tfImage);
Collections.sort(predictions, (o1, o2) -> {
if (o1.score > o2.score) {
return -1;
}
return 1;
});
reportPrediction(predictions.get(0));
int frameCount = 10;
if (++frameCounter[0] % frameCount == 0) {
frameCounter[0] = 0;
long now = System.currentTimeMillis();
long delta = now - lastFpsTimestamp[0];
float fps = (float) 1000 * frameCount / delta;
Log.d(
TAG,
"FPS: " + fps + " with tensorSize: " + tfImage.getWidth() + " x " + tfImage.getHeight()
);
lastFpsTimestamp[0] = now;
}
});
CameraSelector cameraSelector = new CameraSelector.Builder().requireLensFacing(lensFacing).build();
cameraProvider.unbindAll();
cameraProvider.bindToLifecycle(
this,
cameraSelector,
preview,
imageAnalysis
);
preview.setSurfaceProvider(mBinding.viewFinder.getSurfaceProvider());
}
private void reportPrediction(ObjectPrediction prediction) {
runOnUiThread(() -> {
if (prediction == null || prediction.score < ACCURACY_THRESHOLD) {
mBinding.boxPrediction.setVisibility(View.GONE);
mBinding.textPrediction.setVisibility(View.GONE);
return;
}
RectF location = mapOutputCoordinates(prediction.location);
mBinding.textPrediction.setText(prediction.score + " " + prediction.label);
ConstraintLayout.LayoutParams params = (ConstraintLayout.LayoutParams) mBinding.boxPrediction.getLayoutParams();
params.height = Math.min(
mBinding.viewFinder.getHeight(),
(int) location.bottom - (int) location.top
);
params.width = Math.min(
mBinding.viewFinder.getWidth(),
(int) location.right - (int) location.left
);
params.topMargin = (int) location.top;
params.leftMargin = (int) location.left;
mBinding.boxPrediction.setLayoutParams(params);
mBinding.boxPrediction.setVisibility(View.VISIBLE);
mBinding.textPrediction.setVisibility(View.VISIBLE);
});
}
private RectF mapOutputCoordinates(RectF location) {
RectF previewLocation = new RectF(
location.left * mBinding.viewFinder.getWidth(),
location.top * mBinding.viewFinder.getHeight(),
location.right * mBinding.viewFinder.getWidth(),
location.bottom * mBinding.viewFinder.getHeight()
);
boolean isFrontFacing = lensFacing == CameraSelector.LENS_FACING_FRONT;
RectF correctedLocation = null;
if (isFrontFacing) {
correctedLocation = new RectF(
mBinding.viewFinder.getWidth() - previewLocation.right,
previewLocation.top,
mBinding.viewFinder.getWidth() - previewLocation.left,
previewLocation.bottom
);
} else {
correctedLocation = previewLocation;
}
float margin = 0.1f;
float requestedRatio = 4f / 3f;
float midX = (correctedLocation.left + correctedLocation.right) / 2f;
float midY = (correctedLocation.top + correctedLocation.bottom) / 2f;
if (mBinding.viewFinder.getWidth() < mBinding.viewFinder.getHeight()) {
return new RectF(
midX - (1f + margin) * requestedRatio * correctedLocation.width() / 2f,
midY - (1f - margin) * correctedLocation.height() / 2f,
midX + (1f + margin) * requestedRatio * correctedLocation.width() / 2f,
midY + (1f - margin) * correctedLocation.height() / 2f
);
} else {
return new RectF(
midX - (1f - margin) * correctedLocation.width() / 2f,
midY - (1f + margin) * requestedRatio * correctedLocation.height() / 2f,
midX + (1f - margin) * correctedLocation.width() / 2f,
midY + (1f + margin) * requestedRatio * correctedLocation.height() / 2f
);
}
}
}
基本的全部逻辑都在这里了,别忘记在AndroidManifest.xml添加相机的权限
<!-- Declare features -->
<uses-feature android:name="android.hardware.camera" />
<!-- Declare permissions -->
<uses-permission android:name="android.permission.CAMERA" />
第一次集成,网上没有太多文档,希望对大家有帮助,然后也欢迎一起学习讨论~