import traceback
import base64
import tensorflow as tf
import cv2
import numpy as np
from flask import Flask, request, jsonify
from tensorflow.python.saved_model import tag_constants
import collections
# initiate
app = Flask(__name__)
iou_thresh = 0.5
score_thresh = 0.3
#load_model
model = tf.saved_model.load("./futi/yolov4-416", tags=[tag_constants.SERVING]) #用yolov4的模型,将模型转换成pb格式
infer = model.signatures['serving_default']
def bbox2points(bbox):
ymin, xmin, ymax, xmax = bbox
return ymin,xmin, ymax, xmax
def detector(image):
boxes_filtered = []
scores_filtered = []
box_ids_filtered = []
image = cv2.resize(image, (608, 608))
image = image/255.
image_data = []
image_data.append(image)
image_data = tf.constant(np.asarray(image_data).astype(np.float32))
pred_bbox = infer(image_data)
for key, value in pred_bbox.items():
boxes = value[:, :, 0:4]
pred_conf = value[:, :, 4:]
boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
scores=tf.reshape(
pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
max_output_size_per_class=100,
max_total_size=200,
iou_threshold=iou_thresh,
score_threshold=score_thresh
)
# print(scores,boxes,NUM_CLASS)
pred_bbox = [boxes.numpy(), scores.numpy(), classes.numpy(), valid_detections.numpy()]
boxes, scores, classes, count = pred_bbox
classes = classes[0].tolist()
scores = scores[0].tolist()
boxes = boxes[0].tolist()
count = count[0]
for i in range(count):
boxes_filtered.append(bbox2points(boxes[i]))
scores_filtered.append(scores[i])
box_ids_filtered.append(int(classes[i])+1)
res = {"outputs":{"detection_boxes": [boxes_filtered],
"detection_score:": [scores_filtered], "detection_classes": [box_ids_filtered]}}
return res
#predict & post
@app.route('/predict',methods = ['POST'])
def post_data_ladder():
# get data
try:
print("[INFO]: Enacting api..")
base64img = None
#get_json(force=True)设置force=true解决 Nonetype can not suscripble的问题
data = request.get_json(force=True) if request.method == 'POST' else request.arg.get('name')
base64img = data["inputs"][0]["b64"]
print('[INFO]: Image receive successfully')
if base64img is None:
print("[ERROR]: Image cant be None")
return jsonify(None)
# Inference
img = base64.b64decode(base64img)
np_array = np.fromstring(img, np.uint8)
img_np = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
detect_img = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
result = detector(detect_img)
except Exception as e:
traceback.print_exc()
print(repr(e))
print('Analysis.Detect fail')
return jsonify(None)
return jsonify(result)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8501) #设置host127.0.0.1或者0.0.0.0
参考资料:
https://pytorch.apachecn.org/docs/1.4/28.html
https://blog.csdn.net/weixin_42902669/article/details/84590282
https://blog.csdn.net/newbeixue/article/details/103482399