d479b4ec by 乔峰昇

ocr_yolo triton-inference-server

0 parents
1 model_repository/
2 .idea/
OCR_Engine @ 3dddc11a
1 Subproject commit 3dddc11a8a1d369ca4fbd0b69e4e21e6af81cc4c
1 ## OCR+yolov5 triton-inference-server服务
2
3 1.使用docker启动triton服务
4
5 sudo docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/situ/qfs/triton_inference_server/demo/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models
6
7 2.分别启动OCR和yolov5的web服务
8
9 cd OCR_Engine/api
10 python ocr_engine_server.py
11
12 cd yolov5_onnx_demo/api
13 python yolov5_onnx_server.py
14
15 3.pipeline测试
16
17 python triton_pipeline.py
18
1 import base64
2 import json
3 from bank_ocr_inference import *
4
5
6 def enlarge_position(box):
7 x1, y1, x2, y2 = box
8 w, h = abs(x2 - x1), abs(y2 - y1)
9 y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
10 x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
11 return [x1, y1, x2, y2]
12
13
14 def path_base64(file_path):
15 f = open(file_path, 'rb')
16 file64 = base64.b64encode(f.read()) # image 64 bytes 类型
17 file64 = file64.decode('utf-8')
18 return file64
19
20
21 def bgr_base64(image):
22 _, img64 = cv2.imencode('.jpg', image)
23 img64 = base64.b64encode(img64)
24 return img64.decode('utf-8')
25
26
27 def base64_bgr(img64):
28 str_img64 = base64.b64decode(img64)
29 image = np.frombuffer(str_img64, np.uint8)
30 image = cv2.imdecode(image, cv2.IMREAD_COLOR)
31 return image
32
33
34 def tamper_detect_(image):
35 img64 = bgr_base64(image)
36 resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64}))
37 results = resp.json()
38 return results
39
40
41 if __name__ == '__main__':
42 image = cv2.imread(
43 '/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg')
44 st = time.time()
45 ocr_results = bill_ocr(image)
46 et1 = time.time()
47 info_results = extract_bank_info(ocr_results)
48 et2 = time.time()
49 tamper_results = []
50 if len(info_results) != 0:
51 for info_result in info_results:
52 box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
53 x1, y1, x2, y2 = enlarge_position(box)
54 # x1, y1, x2, y2 = box
55 info_image = image[y1:y2, x1:x2, :]
56 results = tamper_detect_(info_image)
57 print(results)
58 if len(results['results']) != 0:
59 for res in results['results']:
60 cx = int(res[0])
61 cy = int(res[1])
62 width = int(res[2])
63 height = int(res[3])
64 left = cx - width // 2
65 top = cy - height // 2
66 absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
67 # absolute_position = [x1+left, y1+top, x2, y2]
68 tamper_results.append(absolute_position)
69 et3 = time.time()
70 print(tamper_results)
71
72 print(f'all time:{et3 - st} ocr time:{et1 - st} extract info time:{et2 - et1} yolo time:{et3 - et2}')
73 for i in tamper_results:
74 cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
75 cv2.imshow('info', image)
76 cv2.waitKey(0)
1 import base64
2
3 import cv2
4 import numpy as np
5 from sanic import Sanic
6 from sanic.response import json
7 from yolov5_onnx_demo.model.yolov5_infer import *
8
9
10 def base64_to_bgr(bs64):
11 img_data = base64.b64decode(bs64)
12 img_arr = np.fromstring(img_data, np.uint8)
13 img_np = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
14 return img_np
15
16
17 app = Sanic('tamper_det')
18
19
20 @app.post('/tamper_det')
21 def hello(request):
22 d = request.json
23 print(d['img'])
24 img = base64_to_bgr(d['img'])
25 result = grpc_detect(img)
26
27 return json({'results': result})
28
29
30 if __name__ == '__main__':
31 app.run(host='192.168.10.11', port=8009,workers=10)
1 import base64
2
3 import requests
4 import json
5 from yolov5_onnx_demo.model.yolov5_infer import *
6
7 def path_base64(file_path):
8 f = open(file_path, 'rb')
9 file64 = base64.b64encode(f.read()) # image 64 bytes 类型
10 file64 = file64.decode('utf-8')
11 return file64
12
13
14 res = requests.post('http://192.168.10.11:8009/tamper_det', data=json.dumps(
15 {'img': path_base64('/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')}))
16 results = res.json()
17 img = cv2.imread(
18 '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')
19 print(res)
20 plot_label(img,results['keys'])
1 import cv2
2 import numpy as np
3 import tritonclient.grpc as grpcclient
4
5
6 def keep_resize_padding(image):
7 '''
8 注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行
9 导致输入的尺寸不固定,重写resize方法,添加padding到640*640
10 '''
11 h, w, c = image.shape
12 if h >= w:
13 pad1 = (h - w) // 2
14 pad2 = h - w - pad1
15 p1 = np.ones((h, pad1, 3)) * 114.0
16 p2 = np.ones((h, pad2, 3)) * 114.0
17 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
18 new_image = np.hstack((p1, image, p2))
19 padding_info = [pad1, pad2, 0]
20 else:
21 pad1 = (w - h) // 2
22 pad2 = w - h - pad1
23 p1 = np.ones((pad1, w, 3)) * 114.0
24 p2 = np.ones((pad2, w, 3)) * 114.0
25 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
26 new_image = np.vstack((p1, image, p2))
27 padding_info = [pad1, pad2, 1]
28 new_image = cv2.resize(new_image, (640, 640))
29 return new_image, padding_info
30
31
32 # remove padding
33 def extract_authentic_bboxes(image, padding_info, bboxes):
34 '''
35 反算坐标信息
36 '''
37 pad1, pad2, pad_type = padding_info
38 h, w, c = image.shape
39 bboxes = np.array(bboxes)
40 max_slide = max(h, w)
41 scale = max_slide / 640
42 bboxes[:, :4] = bboxes[:, :4] * scale
43 if pad_type == 0:
44 bboxes[:, 0] = bboxes[:, 0] - pad1
45 else:
46 bboxes[:, 1] = bboxes[:, 1] - pad1
47 return bboxes.tolist()
48
49
50 # NMS
51 def py_nms_cpu(
52 prediction,
53 conf_thres=0.25,
54 iou_thres=0.45,
55 ):
56 """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
57
58 Returns:
59 list of detections, on (n,6) tensor per image [xyxy, conf, cls]
60 """
61 xc = prediction[..., 4] > conf_thres # candidates
62 prediction = prediction[xc]
63
64 # MNS
65 x1 = prediction[..., 0] - prediction[..., 2] / 2
66 y1 = prediction[..., 1] - prediction[..., 3] / 2
67 x2 = prediction[..., 0] + prediction[..., 2] / 2
68 y2 = prediction[..., 1] + prediction[..., 3] / 2
69
70 areas = (x2 - x1 + 1) * (y2 - y1 + 1)
71 score = prediction[..., 5]
72 order = np.argsort(score)
73 keep = []
74 while order.size > 0:
75 i = order[0]
76 keep.append(i)
77
78 xx1 = np.maximum(x1[i], x1[order[1:]])
79 yy1 = np.maximum(y1[i], y1[order[1:]])
80 xx2 = np.minimum(x2[i], x2[order[1:]])
81 yy2 = np.minimum(y2[i], y2[order[1:]])
82
83 ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1)
84 inter = ww * hh
85
86 over = inter / (areas[i] + areas[order[1:]] - inter)
87
88 idx = np.where(over < iou_thres)[0]
89 order = order[idx + 1]
90
91 return prediction[keep]
92
93
94 def client_init(url='localhost:8001',
95 ssl=False,
96 private_key=None,
97 root_certificates=None,
98 certificate_chain=None,
99 verbose=False):
100 triton_client = grpcclient.InferenceServerClient(
101 url=url,
102 verbose=verbose, # 详细输出 默认是False
103 ssl=ssl,
104 root_certificates=root_certificates,
105 private_key=private_key,
106 certificate_chain=certificate_chain,
107 )
108 return triton_client
109
110
111 triton_client = client_init('localhost:8001')
112 compression_algorithm = None
113 input_name = 'images'
114 output_name = 'output0'
115 model_name = 'yolov5'
116
117
118 def grpc_detect(img):
119 image, padding_info = keep_resize_padding(img)
120 image = image.transpose((2, 0, 1))[::-1]
121 image = image.astype(np.float32)
122 image = image / 255.0
123 if len(image.shape) == 3:
124 image = image[None]
125
126 outputs, inputs = [], []
127
128 # 动态输入
129 input_shape = image.shape
130 inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32'))
131 outputs.append(grpcclient.InferRequestedOutput(output_name))
132
133 inputs[0].set_data_from_numpy(image.astype(np.float32))
134
135 pred = triton_client.infer(
136 model_name=model_name,
137 inputs=inputs, outputs=outputs,
138 compression_algorithm=compression_algorithm
139 )
140 pred = pred.as_numpy(output_name).copy()
141 result_bboxes = py_nms_cpu(pred)
142 result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes)
143 return result_bboxes
144
145
146 def plot_label(img, result_bboxes):
147 print(result_bboxes)
148 for bbox in result_bboxes:
149 x, y, w, h, conf, cls = bbox
150 cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2)
151 cv2.imshow('im', img)
152 cv2.waitKey(0)
153
154
155 if __name__ == '__main__':
156 img = cv2.imread(
157 '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
158
159 result_bboxes = grpc_detect(img)
160 plot_label(result_bboxes)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!