d479b4ec by 乔峰昇

ocr_yolo triton-inference-server

0 parents
1 model_repository/
2 .idea/
OCR_Engine @ 3dddc11a
1 Subproject commit 3dddc11a8a1d369ca4fbd0b69e4e21e6af81cc4c
1 ## OCR+yolov5 triton-inference-server服务
2
3 1.使用docker启动triton服务
4
5 sudo docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/situ/qfs/triton_inference_server/demo/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models
6
7 2.分别启动OCR和yolov5的web服务
8
9 cd OCR_Engine/api
10 python ocr_engine_server.py
11
12 cd yolov5_onnx_demo/api
13 python yolov5_onnx_server.py
14
15 3.pipeline测试
16
17 python triton_pipeline.py
18
1 import base64
2 import os
3 import time
4
5 import cv2
6 import numpy as np
7 import requests
8 import tqdm
9
10
11 def image_to_base64(image):
12 image = cv2.imencode('.png', image)[1]
13 return image
14
15
16 def path_to_file(file_path):
17 f = open(file_path, 'rb')
18 return f
19
20
21 # 流水OCR接口
22 def bill_ocr(image):
23 f = image_to_base64(image)
24 resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f})
25 results = resp.json()
26 ocr_results = results['ocr_results']
27 return ocr_results
28
29
30 # 提取民生银行信息
31 def extract_minsheng_info(ocr_results):
32 name_prefix = '客户姓名:'
33 account_prefix = '客户账号:'
34 results = []
35 for value in ocr_results.values():
36 if name_prefix in value[1]:
37 if name_prefix == value[1]:
38 tmp_value, max_dis = [], 999999
39 top_right_x = value[0][2]
40 top_right_y = value[0][3]
41 for tmp in ocr_results.values():
42 if tmp[1] != name_prefix:
43 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
44 tmp[0][0] - top_right_x) < max_dis:
45 tmp_value = tmp
46 max_dis = abs(tmp[0][0] - top_right_x)
47 else:
48 continue
49 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
50 tmp_value[0][5],
51 value[0][6], value[0][7]]
52 results.append([value[1] + tmp_value[1], new_position])
53 else:
54 results.append([value[1], value[0]])
55 if account_prefix in value[1]:
56 if account_prefix == value[1]:
57 tmp_value, max_dis = [], 999999
58 top_right_x = value[0][2]
59 top_right_y = value[0][3]
60 for tmp in ocr_results.values():
61 if tmp[1] != account_prefix:
62 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
63 tmp[0][0] - top_right_x) < max_dis:
64 tmp_value = tmp
65 max_dis = abs(tmp[0][0] - top_right_x)
66 else:
67 continue
68 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
69 tmp_value[0][5],
70 value[0][6], value[0][7]]
71 results.append([value[1] + tmp_value[1], new_position])
72 else:
73 results.append([value[1], value[0]])
74 return results
75
76
77 # 提取工商银行信息
78 def extract_gongshang_info(ocr_results):
79 name_prefix = '户名:'
80 account_prefix = '卡号:'
81 results = []
82 for value in ocr_results.values():
83 if name_prefix in value[1]:
84 if name_prefix == value[1]:
85 tmp_value, max_dis = [], 999999
86 top_right_x = value[0][2]
87 top_right_y = value[0][3]
88 for tmp in ocr_results.values():
89 if tmp[1] != name_prefix:
90 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
91 tmp[0][0] - top_right_x) < max_dis:
92 tmp_value = tmp
93 max_dis = abs(tmp[0][0] - top_right_x)
94 else:
95 continue
96 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
97 tmp_value[0][5],
98 value[0][6], value[0][7]]
99 results.append([value[1] + tmp_value[1], new_position])
100 else:
101 results.append([value[1], value[0]])
102 if account_prefix in value[1]:
103 if account_prefix == value[1]:
104 tmp_value, max_dis = [], 999999
105 top_right_x = value[0][2]
106 top_right_y = value[0][3]
107 for tmp in ocr_results.values():
108 if tmp[1] != account_prefix:
109 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
110 tmp[0][0] - top_right_x) < max_dis:
111 tmp_value = tmp
112 max_dis = abs(tmp[0][0] - top_right_x)
113 else:
114 continue
115 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
116 tmp_value[0][5],
117 value[0][6], value[0][7]]
118 results.append([value[1] + tmp_value[1], new_position])
119 else:
120 results.append([value[1], value[0]])
121 return results
122
123
124 # 提取中国银行信息
125 def extract_zhongguo_info(ocr_results):
126 name_prefix = '客户姓名:'
127 account_prefix = '借记卡号:'
128 results = []
129 for value in ocr_results.values():
130 if name_prefix in value[1]:
131 if name_prefix == value[1]:
132 tmp_value, max_dis = [], 999999
133 top_right_x = value[0][2]
134 top_right_y = value[0][3]
135 for tmp in ocr_results.values():
136 if tmp[1] != name_prefix:
137 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
138 tmp[0][0] - top_right_x) < max_dis:
139 tmp_value = tmp
140 max_dis = abs(tmp[0][0] - top_right_x)
141 else:
142 continue
143 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
144 tmp_value[0][5],
145 value[0][6], value[0][7]]
146 results.append([value[1] + tmp_value[1], new_position])
147 else:
148 results.append([value[1], value[0]])
149 if account_prefix in value[1]:
150 if account_prefix == value[1]:
151 tmp_value, max_dis = [], 999999
152 top_right_x = value[0][2]
153 top_right_y = value[0][3]
154 for tmp in ocr_results.values():
155 if tmp[1] != account_prefix:
156 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
157 tmp[0][0] - top_right_x) < max_dis:
158 tmp_value = tmp
159 max_dis = abs(tmp[0][0] - top_right_x)
160 else:
161 continue
162 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
163 tmp_value[0][5],
164 value[0][6], value[0][7]]
165 results.append([value[1] + tmp_value[1], new_position])
166 else:
167 results.append([value[1], value[0]])
168 return results
169
170
171 # 提取建设银行信息
172 def extract_jianshe_info(ocr_results):
173 name_prefixes = ['客户名称:', '户名:']
174 account_prefixes = ['卡号/账号:', '卡号:']
175 results = []
176 for value in ocr_results.values():
177 for name_prefix in name_prefixes:
178 if name_prefix in value[1]:
179 if name_prefix == value[1]:
180 tmp_value, max_dis = [], 999999
181 top_right_x = value[0][2]
182 top_right_y = value[0][3]
183 for tmp in ocr_results.values():
184 if tmp[1] != name_prefix:
185 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
186 tmp[0][0] - top_right_x) < max_dis:
187 tmp_value = tmp
188 max_dis = abs(tmp[0][0] - top_right_x)
189 else:
190 continue
191 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
192 tmp_value[0][5],
193 value[0][6], value[0][7]]
194 results.append([value[1] + tmp_value[1], new_position])
195 break
196 else:
197 results.append([value[1], value[0]])
198 break
199 for account_prefix in account_prefixes:
200 if account_prefix in value[1]:
201 if account_prefix == value[1]:
202 tmp_value, max_dis = [], 999999
203 top_right_x = value[0][2]
204 top_right_y = value[0][3]
205 for tmp in ocr_results.values():
206 if tmp[1] != account_prefix:
207 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
208 tmp[0][0] - top_right_x) < max_dis:
209 tmp_value = tmp
210 max_dis = abs(tmp[0][0] - top_right_x)
211 else:
212 continue
213 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
214 tmp_value[0][5],
215 value[0][6], value[0][7]]
216 results.append([value[1] + tmp_value[1], new_position])
217 break
218 else:
219 results.append([value[1], value[0]])
220 break
221 return results
222
223
224 # 提取农业银行信息(比较复杂,目前训练的版式都支持)
225 def extract_nongye_info(ocr_results):
226 name_prefixes = ['客户名:', '户名:']
227 account_prefixes = ['账号:']
228 results = []
229 is_account = True
230 for value in ocr_results.values():
231 for name_prefix in name_prefixes:
232 if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]:
233 if name_prefix == value[1]:
234 tmp_value, max_dis = [], 999999
235 top_right_x = value[0][2]
236 top_right_y = value[0][3]
237 for tmp in ocr_results.values():
238 if tmp[1] != name_prefix:
239 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
240 tmp[0][0] - top_right_x) < max_dis:
241 tmp_value = tmp
242 max_dis = abs(tmp[0][0] - top_right_x)
243 else:
244 continue
245 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
246 tmp_value[0][5],
247 value[0][6], value[0][7]]
248 results.append([value[1] + tmp_value[1], new_position])
249 break
250 else:
251 results.append([value[1], value[0]])
252 break
253 if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5:
254 is_account = False
255 if len(value[1]) == 5:
256 tmp_value, max_dis = [], 999999
257 top_right_x = value[0][2]
258 top_right_y = value[0][3]
259 tmp_info = {}
260 for tmp in ocr_results.values():
261 if tmp[1] != value[1]:
262 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
263 tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
264 else:
265 continue
266 tmp_info_id = sorted(tmp_info.keys())
267 if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19:
268 tmp_value = tmp_info[tmp_info_id[0]]
269 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
270 tmp_value[0][5],
271 value[0][6], value[0][7]]
272 results.append([value[1] + tmp_value[1], new_position])
273 if tmp_info[tmp_info_id[0]][1].isdigit():
274 tmp_value = tmp_info[tmp_info_id[1]]
275 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
276 tmp_value[0][5],
277 value[0][6], value[0][7]]
278 results.append([value[1] + tmp_value[1], new_position])
279 break
280 elif len(value[1]) < 25:
281 tmp_info = {}
282 top_right_x = value[0][2]
283 top_right_y = value[0][3]
284 for tmp in ocr_results.values():
285 if tmp[1] != value[1]:
286 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
287 tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
288 else:
289 continue
290 tmp_info_id = sorted(tmp_info.keys())
291 tmp_value = tmp_info[tmp_info_id[0]]
292 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
293 tmp_value[0][5],
294 value[0][6], value[0][7]]
295 results.append([value[1] + tmp_value[1], new_position])
296 break
297 else:
298 results.append([value[1], value[0]])
299 break
300 if is_account:
301 for account_prefix in account_prefixes:
302 if account_prefix in value[1]:
303 if account_prefix == value[1]:
304 tmp_value, max_dis = [], 999999
305 top_right_x = value[0][2]
306 top_right_y = value[0][3]
307 for tmp in ocr_results.values():
308 if tmp[1] != account_prefix:
309 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
310 tmp[0][0] - top_right_x) < max_dis:
311 tmp_value = tmp
312 max_dis = abs(tmp[0][0] - top_right_x)
313 else:
314 continue
315 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
316 tmp_value[0][5],
317 value[0][6], value[0][7]]
318 results.append([value[1] + tmp_value[1], new_position])
319 break
320 else:
321 results.append([value[1], value[0]])
322 break
323 else:
324 break
325 return results
326
327
328 # 提取银行流水信息总接口
329 def extract_bank_info(ocr_results):
330 results = []
331 for value in ocr_results.values():
332 if value[1].__contains__('建设'):
333 results = extract_jianshe_info(ocr_results)
334 break
335 elif value[1].__contains__('民生'):
336 results = extract_minsheng_info(ocr_results)
337 break
338 elif value[1].__contains__('农业'):
339 results = extract_nongye_info(ocr_results)
340 break
341 elif value[1].__contains__('中国银行'):
342 results = extract_zhongguo_info(ocr_results)
343 break
344 elif value[1].__contains__('邮政'):
345 results = extract_youchu_info(ocr_results)
346 if len(results) == 0:
347 results = extract_gongshang_info(ocr_results)
348
349 return results
350
351
352 def extract_youchu_info(ocr_results):
353 name_prefixes = ['户名:']
354 account_prefixes = ['账号:', '卡号:']
355 results = []
356 for value in ocr_results.values():
357 for name_prefix in name_prefixes:
358 if name_prefix in value[1]:
359 if name_prefix == value[1]:
360 tmp_value, max_dis = [], 999999
361 top_right_x = value[0][2]
362 top_right_y = value[0][3]
363 for tmp in ocr_results.values():
364 if tmp[1] != name_prefix:
365 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
366 tmp[0][0] - top_right_x) < max_dis:
367 tmp_value = tmp
368 max_dis = abs(tmp[0][0] - top_right_x)
369 else:
370 continue
371 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
372 tmp_value[0][5],
373 value[0][6], value[0][7]]
374 results.append([value[1] + tmp_value[1], new_position])
375 break
376 else:
377 results.append([value[1], value[0]])
378 break
379 for account_prefix in account_prefixes:
380 if account_prefix in value[1]:
381 if account_prefix == value[1]:
382 tmp_value, max_dis = [], 999999
383 top_right_x = value[0][2]
384 top_right_y = value[0][3]
385 for tmp in ocr_results.values():
386 if tmp[1] != account_prefix:
387 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
388 tmp[0][0] - top_right_x) < max_dis:
389 tmp_value = tmp
390 max_dis = abs(tmp[0][0] - top_right_x)
391 else:
392 continue
393 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
394 tmp_value[0][5],
395 value[0][6], value[0][7]]
396 results.append([value[1] + tmp_value[1], new_position])
397 break
398 else:
399 results.append([value[1], value[0]])
400 break
401 return results
402
403
404 if __name__ == '__main__':
405 img = cv2.imread('/home/situ/下载/邮储对账单/飞书20221020-155202.jpg')
406 ocr_results = bill_ocr(img)
407 results = extract_youchu_info(ocr_results)
408 print(results)
409 # path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
410 # save_path='/data/situ_invoice_bill_data/new_data/results'
411 # bank='minsheng'
412 # if not os.path.exists(os.path.join(save_path,bank)):
413 # os.makedirs(os.path.join(save_path,bank))
414 # save_path=os.path.join(save_path,bank)
415 # for j in tqdm.tqdm(os.listdir(path)):
416 # # if True:
417 # img=cv2.imread(os.path.join(path,j))
418 # # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg')
419 # st = time.time()
420 # ocr_result = bill_ocr(img)
421 # et1 = time.time()
422 # result = extract_bank_info(ocr_result)
423 # et2 = time.time()
424 # for i in range(len(result)):
425 # cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2)
426 # # cv2.imshow('img',img)
427 # # cv2.waitKey(0)
428 # cv2.imwrite(os.path.join(save_path,j),img)
429 # print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
1 import base64
2 import json
3 from bank_ocr_inference import *
4
5
6 def enlarge_position(box):
7 x1, y1, x2, y2 = box
8 w, h = abs(x2 - x1), abs(y2 - y1)
9 y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
10 x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
11 return [x1, y1, x2, y2]
12
13
14 def path_base64(file_path):
15 f = open(file_path, 'rb')
16 file64 = base64.b64encode(f.read()) # image 64 bytes 类型
17 file64 = file64.decode('utf-8')
18 return file64
19
20
21 def bgr_base64(image):
22 _, img64 = cv2.imencode('.jpg', image)
23 img64 = base64.b64encode(img64)
24 return img64.decode('utf-8')
25
26
27 def base64_bgr(img64):
28 str_img64 = base64.b64decode(img64)
29 image = np.frombuffer(str_img64, np.uint8)
30 image = cv2.imdecode(image, cv2.IMREAD_COLOR)
31 return image
32
33
34 def tamper_detect_(image):
35 img64 = bgr_base64(image)
36 resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64}))
37 results = resp.json()
38 return results
39
40
41 if __name__ == '__main__':
42 image = cv2.imread(
43 '/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg')
44 st = time.time()
45 ocr_results = bill_ocr(image)
46 et1 = time.time()
47 info_results = extract_bank_info(ocr_results)
48 et2 = time.time()
49 tamper_results = []
50 if len(info_results) != 0:
51 for info_result in info_results:
52 box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
53 x1, y1, x2, y2 = enlarge_position(box)
54 # x1, y1, x2, y2 = box
55 info_image = image[y1:y2, x1:x2, :]
56 results = tamper_detect_(info_image)
57 print(results)
58 if len(results['results']) != 0:
59 for res in results['results']:
60 cx = int(res[0])
61 cy = int(res[1])
62 width = int(res[2])
63 height = int(res[3])
64 left = cx - width // 2
65 top = cy - height // 2
66 absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
67 # absolute_position = [x1+left, y1+top, x2, y2]
68 tamper_results.append(absolute_position)
69 et3 = time.time()
70 print(tamper_results)
71
72 print(f'all time:{et3 - st} ocr time:{et1 - st} extract info time:{et2 - et1} yolo time:{et3 - et2}')
73 for i in tamper_results:
74 cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
75 cv2.imshow('info', image)
76 cv2.waitKey(0)
1 import base64
2
3 import cv2
4 import numpy as np
5 from sanic import Sanic
6 from sanic.response import json
7 from yolov5_onnx_demo.model.yolov5_infer import *
8
9
10 def base64_to_bgr(bs64):
11 img_data = base64.b64decode(bs64)
12 img_arr = np.fromstring(img_data, np.uint8)
13 img_np = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
14 return img_np
15
16
17 app = Sanic('tamper_det')
18
19
20 @app.post('/tamper_det')
21 def hello(request):
22 d = request.json
23 print(d['img'])
24 img = base64_to_bgr(d['img'])
25 result = grpc_detect(img)
26
27 return json({'results': result})
28
29
30 if __name__ == '__main__':
31 app.run(host='192.168.10.11', port=8009,workers=10)
1 import base64
2
3 import requests
4 import json
5 from yolov5_onnx_demo.model.yolov5_infer import *
6
7 def path_base64(file_path):
8 f = open(file_path, 'rb')
9 file64 = base64.b64encode(f.read()) # image 64 bytes 类型
10 file64 = file64.decode('utf-8')
11 return file64
12
13
14 res = requests.post('http://192.168.10.11:8009/tamper_det', data=json.dumps(
15 {'img': path_base64('/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')}))
16 results = res.json()
17 img = cv2.imread(
18 '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')
19 print(res)
20 plot_label(img,results['keys'])
1 import cv2
2 import numpy as np
3 import tritonclient.grpc as grpcclient
4
5
6 def keep_resize_padding(image):
7 '''
8 注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行
9 导致输入的尺寸不固定,重写resize方法,添加padding到640*640
10 '''
11 h, w, c = image.shape
12 if h >= w:
13 pad1 = (h - w) // 2
14 pad2 = h - w - pad1
15 p1 = np.ones((h, pad1, 3)) * 114.0
16 p2 = np.ones((h, pad2, 3)) * 114.0
17 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
18 new_image = np.hstack((p1, image, p2))
19 padding_info = [pad1, pad2, 0]
20 else:
21 pad1 = (w - h) // 2
22 pad2 = w - h - pad1
23 p1 = np.ones((pad1, w, 3)) * 114.0
24 p2 = np.ones((pad2, w, 3)) * 114.0
25 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
26 new_image = np.vstack((p1, image, p2))
27 padding_info = [pad1, pad2, 1]
28 new_image = cv2.resize(new_image, (640, 640))
29 return new_image, padding_info
30
31
32 # remove padding
33 def extract_authentic_bboxes(image, padding_info, bboxes):
34 '''
35 反算坐标信息
36 '''
37 pad1, pad2, pad_type = padding_info
38 h, w, c = image.shape
39 bboxes = np.array(bboxes)
40 max_slide = max(h, w)
41 scale = max_slide / 640
42 bboxes[:, :4] = bboxes[:, :4] * scale
43 if pad_type == 0:
44 bboxes[:, 0] = bboxes[:, 0] - pad1
45 else:
46 bboxes[:, 1] = bboxes[:, 1] - pad1
47 return bboxes.tolist()
48
49
50 # NMS
51 def py_nms_cpu(
52 prediction,
53 conf_thres=0.25,
54 iou_thres=0.45,
55 ):
56 """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
57
58 Returns:
59 list of detections, on (n,6) tensor per image [xyxy, conf, cls]
60 """
61 xc = prediction[..., 4] > conf_thres # candidates
62 prediction = prediction[xc]
63
64 # MNS
65 x1 = prediction[..., 0] - prediction[..., 2] / 2
66 y1 = prediction[..., 1] - prediction[..., 3] / 2
67 x2 = prediction[..., 0] + prediction[..., 2] / 2
68 y2 = prediction[..., 1] + prediction[..., 3] / 2
69
70 areas = (x2 - x1 + 1) * (y2 - y1 + 1)
71 score = prediction[..., 5]
72 order = np.argsort(score)
73 keep = []
74 while order.size > 0:
75 i = order[0]
76 keep.append(i)
77
78 xx1 = np.maximum(x1[i], x1[order[1:]])
79 yy1 = np.maximum(y1[i], y1[order[1:]])
80 xx2 = np.minimum(x2[i], x2[order[1:]])
81 yy2 = np.minimum(y2[i], y2[order[1:]])
82
83 ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1)
84 inter = ww * hh
85
86 over = inter / (areas[i] + areas[order[1:]] - inter)
87
88 idx = np.where(over < iou_thres)[0]
89 order = order[idx + 1]
90
91 return prediction[keep]
92
93
94 def client_init(url='localhost:8001',
95 ssl=False,
96 private_key=None,
97 root_certificates=None,
98 certificate_chain=None,
99 verbose=False):
100 triton_client = grpcclient.InferenceServerClient(
101 url=url,
102 verbose=verbose, # 详细输出 默认是False
103 ssl=ssl,
104 root_certificates=root_certificates,
105 private_key=private_key,
106 certificate_chain=certificate_chain,
107 )
108 return triton_client
109
110
111 triton_client = client_init('localhost:8001')
112 compression_algorithm = None
113 input_name = 'images'
114 output_name = 'output0'
115 model_name = 'yolov5'
116
117
118 def grpc_detect(img):
119 image, padding_info = keep_resize_padding(img)
120 image = image.transpose((2, 0, 1))[::-1]
121 image = image.astype(np.float32)
122 image = image / 255.0
123 if len(image.shape) == 3:
124 image = image[None]
125
126 outputs, inputs = [], []
127
128 # 动态输入
129 input_shape = image.shape
130 inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32'))
131 outputs.append(grpcclient.InferRequestedOutput(output_name))
132
133 inputs[0].set_data_from_numpy(image.astype(np.float32))
134
135 pred = triton_client.infer(
136 model_name=model_name,
137 inputs=inputs, outputs=outputs,
138 compression_algorithm=compression_algorithm
139 )
140 pred = pred.as_numpy(output_name).copy()
141 result_bboxes = py_nms_cpu(pred)
142 result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes)
143 return result_bboxes
144
145
146 def plot_label(img, result_bboxes):
147 print(result_bboxes)
148 for bbox in result_bboxes:
149 x, y, w, h, conf, cls = bbox
150 cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2)
151 cv2.imshow('im', img)
152 cv2.waitKey(0)
153
154
155 if __name__ == '__main__':
156 img = cv2.imread(
157 '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
158
159 result_bboxes = grpc_detect(img)
160 plot_label(result_bboxes)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!