ocr_yolo triton-inference-server
0 parents
Showing
11 changed files
with
737 additions
and
0 deletions
.gitignore
0 → 100644
OCR_Engine @ 3dddc11a
| 1 | Subproject commit 3dddc11a8a1d369ca4fbd0b69e4e21e6af81cc4c | 
README.md
0 → 100644
| 1 | ## OCR+yolov5 triton-inference-server服务 | ||
| 2 | |||
| 3 | 1.使用docker启动triton服务 | ||
| 4 | |||
| 5 | sudo docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/situ/qfs/triton_inference_server/demo/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
| 6 | |||
| 7 | 2.分别启动OCR和yolov5的web服务 | ||
| 8 | |||
| 9 | cd OCR_Engine/api | ||
| 10 | python ocr_engine_server.py | ||
| 11 | |||
| 12 | cd yolov5_onnx_demo/api | ||
| 13 | python yolov5_onnx_server.py | ||
| 14 | |||
| 15 | 3.pipeline测试 | ||
| 16 | |||
| 17 | python triton_pipeline.py | ||
| 18 | 
bank_ocr_inference.py
0 → 100644
| 1 | import base64 | ||
| 2 | import os | ||
| 3 | import time | ||
| 4 | |||
| 5 | import cv2 | ||
| 6 | import numpy as np | ||
| 7 | import requests | ||
| 8 | import tqdm | ||
| 9 | |||
| 10 | |||
| 11 | def image_to_base64(image): | ||
| 12 | image = cv2.imencode('.png', image)[1] | ||
| 13 | return image | ||
| 14 | |||
| 15 | |||
| 16 | def path_to_file(file_path): | ||
| 17 | f = open(file_path, 'rb') | ||
| 18 | return f | ||
| 19 | |||
| 20 | |||
| 21 | # 流水OCR接口 | ||
| 22 | def bill_ocr(image): | ||
| 23 | f = image_to_base64(image) | ||
| 24 | resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f}) | ||
| 25 | results = resp.json() | ||
| 26 | ocr_results = results['ocr_results'] | ||
| 27 | return ocr_results | ||
| 28 | |||
| 29 | |||
| 30 | # 提取民生银行信息 | ||
| 31 | def extract_minsheng_info(ocr_results): | ||
| 32 | name_prefix = '客户姓名:' | ||
| 33 | account_prefix = '客户账号:' | ||
| 34 | results = [] | ||
| 35 | for value in ocr_results.values(): | ||
| 36 | if name_prefix in value[1]: | ||
| 37 | if name_prefix == value[1]: | ||
| 38 | tmp_value, max_dis = [], 999999 | ||
| 39 | top_right_x = value[0][2] | ||
| 40 | top_right_y = value[0][3] | ||
| 41 | for tmp in ocr_results.values(): | ||
| 42 | if tmp[1] != name_prefix: | ||
| 43 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 44 | tmp[0][0] - top_right_x) < max_dis: | ||
| 45 | tmp_value = tmp | ||
| 46 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 47 | else: | ||
| 48 | continue | ||
| 49 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 50 | tmp_value[0][5], | ||
| 51 | value[0][6], value[0][7]] | ||
| 52 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 53 | else: | ||
| 54 | results.append([value[1], value[0]]) | ||
| 55 | if account_prefix in value[1]: | ||
| 56 | if account_prefix == value[1]: | ||
| 57 | tmp_value, max_dis = [], 999999 | ||
| 58 | top_right_x = value[0][2] | ||
| 59 | top_right_y = value[0][3] | ||
| 60 | for tmp in ocr_results.values(): | ||
| 61 | if tmp[1] != account_prefix: | ||
| 62 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 63 | tmp[0][0] - top_right_x) < max_dis: | ||
| 64 | tmp_value = tmp | ||
| 65 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 66 | else: | ||
| 67 | continue | ||
| 68 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 69 | tmp_value[0][5], | ||
| 70 | value[0][6], value[0][7]] | ||
| 71 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 72 | else: | ||
| 73 | results.append([value[1], value[0]]) | ||
| 74 | return results | ||
| 75 | |||
| 76 | |||
| 77 | # 提取工商银行信息 | ||
| 78 | def extract_gongshang_info(ocr_results): | ||
| 79 | name_prefix = '户名:' | ||
| 80 | account_prefix = '卡号:' | ||
| 81 | results = [] | ||
| 82 | for value in ocr_results.values(): | ||
| 83 | if name_prefix in value[1]: | ||
| 84 | if name_prefix == value[1]: | ||
| 85 | tmp_value, max_dis = [], 999999 | ||
| 86 | top_right_x = value[0][2] | ||
| 87 | top_right_y = value[0][3] | ||
| 88 | for tmp in ocr_results.values(): | ||
| 89 | if tmp[1] != name_prefix: | ||
| 90 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 91 | tmp[0][0] - top_right_x) < max_dis: | ||
| 92 | tmp_value = tmp | ||
| 93 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 94 | else: | ||
| 95 | continue | ||
| 96 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 97 | tmp_value[0][5], | ||
| 98 | value[0][6], value[0][7]] | ||
| 99 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 100 | else: | ||
| 101 | results.append([value[1], value[0]]) | ||
| 102 | if account_prefix in value[1]: | ||
| 103 | if account_prefix == value[1]: | ||
| 104 | tmp_value, max_dis = [], 999999 | ||
| 105 | top_right_x = value[0][2] | ||
| 106 | top_right_y = value[0][3] | ||
| 107 | for tmp in ocr_results.values(): | ||
| 108 | if tmp[1] != account_prefix: | ||
| 109 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 110 | tmp[0][0] - top_right_x) < max_dis: | ||
| 111 | tmp_value = tmp | ||
| 112 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 113 | else: | ||
| 114 | continue | ||
| 115 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 116 | tmp_value[0][5], | ||
| 117 | value[0][6], value[0][7]] | ||
| 118 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 119 | else: | ||
| 120 | results.append([value[1], value[0]]) | ||
| 121 | return results | ||
| 122 | |||
| 123 | |||
| 124 | # 提取中国银行信息 | ||
| 125 | def extract_zhongguo_info(ocr_results): | ||
| 126 | name_prefix = '客户姓名:' | ||
| 127 | account_prefix = '借记卡号:' | ||
| 128 | results = [] | ||
| 129 | for value in ocr_results.values(): | ||
| 130 | if name_prefix in value[1]: | ||
| 131 | if name_prefix == value[1]: | ||
| 132 | tmp_value, max_dis = [], 999999 | ||
| 133 | top_right_x = value[0][2] | ||
| 134 | top_right_y = value[0][3] | ||
| 135 | for tmp in ocr_results.values(): | ||
| 136 | if tmp[1] != name_prefix: | ||
| 137 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 138 | tmp[0][0] - top_right_x) < max_dis: | ||
| 139 | tmp_value = tmp | ||
| 140 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 141 | else: | ||
| 142 | continue | ||
| 143 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 144 | tmp_value[0][5], | ||
| 145 | value[0][6], value[0][7]] | ||
| 146 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 147 | else: | ||
| 148 | results.append([value[1], value[0]]) | ||
| 149 | if account_prefix in value[1]: | ||
| 150 | if account_prefix == value[1]: | ||
| 151 | tmp_value, max_dis = [], 999999 | ||
| 152 | top_right_x = value[0][2] | ||
| 153 | top_right_y = value[0][3] | ||
| 154 | for tmp in ocr_results.values(): | ||
| 155 | if tmp[1] != account_prefix: | ||
| 156 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 157 | tmp[0][0] - top_right_x) < max_dis: | ||
| 158 | tmp_value = tmp | ||
| 159 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 160 | else: | ||
| 161 | continue | ||
| 162 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 163 | tmp_value[0][5], | ||
| 164 | value[0][6], value[0][7]] | ||
| 165 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 166 | else: | ||
| 167 | results.append([value[1], value[0]]) | ||
| 168 | return results | ||
| 169 | |||
| 170 | |||
| 171 | # 提取建设银行信息 | ||
| 172 | def extract_jianshe_info(ocr_results): | ||
| 173 | name_prefixes = ['客户名称:', '户名:'] | ||
| 174 | account_prefixes = ['卡号/账号:', '卡号:'] | ||
| 175 | results = [] | ||
| 176 | for value in ocr_results.values(): | ||
| 177 | for name_prefix in name_prefixes: | ||
| 178 | if name_prefix in value[1]: | ||
| 179 | if name_prefix == value[1]: | ||
| 180 | tmp_value, max_dis = [], 999999 | ||
| 181 | top_right_x = value[0][2] | ||
| 182 | top_right_y = value[0][3] | ||
| 183 | for tmp in ocr_results.values(): | ||
| 184 | if tmp[1] != name_prefix: | ||
| 185 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 186 | tmp[0][0] - top_right_x) < max_dis: | ||
| 187 | tmp_value = tmp | ||
| 188 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 189 | else: | ||
| 190 | continue | ||
| 191 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 192 | tmp_value[0][5], | ||
| 193 | value[0][6], value[0][7]] | ||
| 194 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 195 | break | ||
| 196 | else: | ||
| 197 | results.append([value[1], value[0]]) | ||
| 198 | break | ||
| 199 | for account_prefix in account_prefixes: | ||
| 200 | if account_prefix in value[1]: | ||
| 201 | if account_prefix == value[1]: | ||
| 202 | tmp_value, max_dis = [], 999999 | ||
| 203 | top_right_x = value[0][2] | ||
| 204 | top_right_y = value[0][3] | ||
| 205 | for tmp in ocr_results.values(): | ||
| 206 | if tmp[1] != account_prefix: | ||
| 207 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 208 | tmp[0][0] - top_right_x) < max_dis: | ||
| 209 | tmp_value = tmp | ||
| 210 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 211 | else: | ||
| 212 | continue | ||
| 213 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 214 | tmp_value[0][5], | ||
| 215 | value[0][6], value[0][7]] | ||
| 216 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 217 | break | ||
| 218 | else: | ||
| 219 | results.append([value[1], value[0]]) | ||
| 220 | break | ||
| 221 | return results | ||
| 222 | |||
| 223 | |||
| 224 | # 提取农业银行信息(比较复杂,目前训练的版式都支持) | ||
| 225 | def extract_nongye_info(ocr_results): | ||
| 226 | name_prefixes = ['客户名:', '户名:'] | ||
| 227 | account_prefixes = ['账号:'] | ||
| 228 | results = [] | ||
| 229 | is_account = True | ||
| 230 | for value in ocr_results.values(): | ||
| 231 | for name_prefix in name_prefixes: | ||
| 232 | if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]: | ||
| 233 | if name_prefix == value[1]: | ||
| 234 | tmp_value, max_dis = [], 999999 | ||
| 235 | top_right_x = value[0][2] | ||
| 236 | top_right_y = value[0][3] | ||
| 237 | for tmp in ocr_results.values(): | ||
| 238 | if tmp[1] != name_prefix: | ||
| 239 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 240 | tmp[0][0] - top_right_x) < max_dis: | ||
| 241 | tmp_value = tmp | ||
| 242 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 243 | else: | ||
| 244 | continue | ||
| 245 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 246 | tmp_value[0][5], | ||
| 247 | value[0][6], value[0][7]] | ||
| 248 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 249 | break | ||
| 250 | else: | ||
| 251 | results.append([value[1], value[0]]) | ||
| 252 | break | ||
| 253 | if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5: | ||
| 254 | is_account = False | ||
| 255 | if len(value[1]) == 5: | ||
| 256 | tmp_value, max_dis = [], 999999 | ||
| 257 | top_right_x = value[0][2] | ||
| 258 | top_right_y = value[0][3] | ||
| 259 | tmp_info = {} | ||
| 260 | for tmp in ocr_results.values(): | ||
| 261 | if tmp[1] != value[1]: | ||
| 262 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: | ||
| 263 | tmp_info[abs(tmp[0][0] - top_right_x)] = tmp | ||
| 264 | else: | ||
| 265 | continue | ||
| 266 | tmp_info_id = sorted(tmp_info.keys()) | ||
| 267 | if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19: | ||
| 268 | tmp_value = tmp_info[tmp_info_id[0]] | ||
| 269 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 270 | tmp_value[0][5], | ||
| 271 | value[0][6], value[0][7]] | ||
| 272 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 273 | if tmp_info[tmp_info_id[0]][1].isdigit(): | ||
| 274 | tmp_value = tmp_info[tmp_info_id[1]] | ||
| 275 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 276 | tmp_value[0][5], | ||
| 277 | value[0][6], value[0][7]] | ||
| 278 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 279 | break | ||
| 280 | elif len(value[1]) < 25: | ||
| 281 | tmp_info = {} | ||
| 282 | top_right_x = value[0][2] | ||
| 283 | top_right_y = value[0][3] | ||
| 284 | for tmp in ocr_results.values(): | ||
| 285 | if tmp[1] != value[1]: | ||
| 286 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: | ||
| 287 | tmp_info[abs(tmp[0][0] - top_right_x)] = tmp | ||
| 288 | else: | ||
| 289 | continue | ||
| 290 | tmp_info_id = sorted(tmp_info.keys()) | ||
| 291 | tmp_value = tmp_info[tmp_info_id[0]] | ||
| 292 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 293 | tmp_value[0][5], | ||
| 294 | value[0][6], value[0][7]] | ||
| 295 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 296 | break | ||
| 297 | else: | ||
| 298 | results.append([value[1], value[0]]) | ||
| 299 | break | ||
| 300 | if is_account: | ||
| 301 | for account_prefix in account_prefixes: | ||
| 302 | if account_prefix in value[1]: | ||
| 303 | if account_prefix == value[1]: | ||
| 304 | tmp_value, max_dis = [], 999999 | ||
| 305 | top_right_x = value[0][2] | ||
| 306 | top_right_y = value[0][3] | ||
| 307 | for tmp in ocr_results.values(): | ||
| 308 | if tmp[1] != account_prefix: | ||
| 309 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 310 | tmp[0][0] - top_right_x) < max_dis: | ||
| 311 | tmp_value = tmp | ||
| 312 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 313 | else: | ||
| 314 | continue | ||
| 315 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 316 | tmp_value[0][5], | ||
| 317 | value[0][6], value[0][7]] | ||
| 318 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 319 | break | ||
| 320 | else: | ||
| 321 | results.append([value[1], value[0]]) | ||
| 322 | break | ||
| 323 | else: | ||
| 324 | break | ||
| 325 | return results | ||
| 326 | |||
| 327 | |||
| 328 | # 提取银行流水信息总接口 | ||
| 329 | def extract_bank_info(ocr_results): | ||
| 330 | results = [] | ||
| 331 | for value in ocr_results.values(): | ||
| 332 | if value[1].__contains__('建设'): | ||
| 333 | results = extract_jianshe_info(ocr_results) | ||
| 334 | break | ||
| 335 | elif value[1].__contains__('民生'): | ||
| 336 | results = extract_minsheng_info(ocr_results) | ||
| 337 | break | ||
| 338 | elif value[1].__contains__('农业'): | ||
| 339 | results = extract_nongye_info(ocr_results) | ||
| 340 | break | ||
| 341 | elif value[1].__contains__('中国银行'): | ||
| 342 | results = extract_zhongguo_info(ocr_results) | ||
| 343 | break | ||
| 344 | elif value[1].__contains__('邮政'): | ||
| 345 | results = extract_youchu_info(ocr_results) | ||
| 346 | if len(results) == 0: | ||
| 347 | results = extract_gongshang_info(ocr_results) | ||
| 348 | |||
| 349 | return results | ||
| 350 | |||
| 351 | |||
| 352 | def extract_youchu_info(ocr_results): | ||
| 353 | name_prefixes = ['户名:'] | ||
| 354 | account_prefixes = ['账号:', '卡号:'] | ||
| 355 | results = [] | ||
| 356 | for value in ocr_results.values(): | ||
| 357 | for name_prefix in name_prefixes: | ||
| 358 | if name_prefix in value[1]: | ||
| 359 | if name_prefix == value[1]: | ||
| 360 | tmp_value, max_dis = [], 999999 | ||
| 361 | top_right_x = value[0][2] | ||
| 362 | top_right_y = value[0][3] | ||
| 363 | for tmp in ocr_results.values(): | ||
| 364 | if tmp[1] != name_prefix: | ||
| 365 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 366 | tmp[0][0] - top_right_x) < max_dis: | ||
| 367 | tmp_value = tmp | ||
| 368 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 369 | else: | ||
| 370 | continue | ||
| 371 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 372 | tmp_value[0][5], | ||
| 373 | value[0][6], value[0][7]] | ||
| 374 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 375 | break | ||
| 376 | else: | ||
| 377 | results.append([value[1], value[0]]) | ||
| 378 | break | ||
| 379 | for account_prefix in account_prefixes: | ||
| 380 | if account_prefix in value[1]: | ||
| 381 | if account_prefix == value[1]: | ||
| 382 | tmp_value, max_dis = [], 999999 | ||
| 383 | top_right_x = value[0][2] | ||
| 384 | top_right_y = value[0][3] | ||
| 385 | for tmp in ocr_results.values(): | ||
| 386 | if tmp[1] != account_prefix: | ||
| 387 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 388 | tmp[0][0] - top_right_x) < max_dis: | ||
| 389 | tmp_value = tmp | ||
| 390 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 391 | else: | ||
| 392 | continue | ||
| 393 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 394 | tmp_value[0][5], | ||
| 395 | value[0][6], value[0][7]] | ||
| 396 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 397 | break | ||
| 398 | else: | ||
| 399 | results.append([value[1], value[0]]) | ||
| 400 | break | ||
| 401 | return results | ||
| 402 | |||
| 403 | |||
| 404 | if __name__ == '__main__': | ||
| 405 | img = cv2.imread('/home/situ/下载/邮储对账单/飞书20221020-155202.jpg') | ||
| 406 | ocr_results = bill_ocr(img) | ||
| 407 | results = extract_youchu_info(ocr_results) | ||
| 408 | print(results) | ||
| 409 | # path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' | ||
| 410 | # save_path='/data/situ_invoice_bill_data/new_data/results' | ||
| 411 | # bank='minsheng' | ||
| 412 | # if not os.path.exists(os.path.join(save_path,bank)): | ||
| 413 | # os.makedirs(os.path.join(save_path,bank)) | ||
| 414 | # save_path=os.path.join(save_path,bank) | ||
| 415 | # for j in tqdm.tqdm(os.listdir(path)): | ||
| 416 | # # if True: | ||
| 417 | # img=cv2.imread(os.path.join(path,j)) | ||
| 418 | # # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg') | ||
| 419 | # st = time.time() | ||
| 420 | # ocr_result = bill_ocr(img) | ||
| 421 | # et1 = time.time() | ||
| 422 | # result = extract_bank_info(ocr_result) | ||
| 423 | # et2 = time.time() | ||
| 424 | # for i in range(len(result)): | ||
| 425 | # cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2) | ||
| 426 | # # cv2.imshow('img',img) | ||
| 427 | # # cv2.waitKey(0) | ||
| 428 | # cv2.imwrite(os.path.join(save_path,j),img) | ||
| 429 | # print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) | 
triton_pipeline.py
0 → 100644
| 1 | import base64 | ||
| 2 | import json | ||
| 3 | from bank_ocr_inference import * | ||
| 4 | |||
| 5 | |||
| 6 | def enlarge_position(box): | ||
| 7 | x1, y1, x2, y2 = box | ||
| 8 | w, h = abs(x2 - x1), abs(y2 - y1) | ||
| 9 | y1, y2 = max(y1 - h // 3, 0), y2 + h // 3 | ||
| 10 | x1, x2 = max(x1 - w // 8, 0), x2 + w // 8 | ||
| 11 | return [x1, y1, x2, y2] | ||
| 12 | |||
| 13 | |||
| 14 | def path_base64(file_path): | ||
| 15 | f = open(file_path, 'rb') | ||
| 16 | file64 = base64.b64encode(f.read()) # image 64 bytes 类型 | ||
| 17 | file64 = file64.decode('utf-8') | ||
| 18 | return file64 | ||
| 19 | |||
| 20 | |||
| 21 | def bgr_base64(image): | ||
| 22 | _, img64 = cv2.imencode('.jpg', image) | ||
| 23 | img64 = base64.b64encode(img64) | ||
| 24 | return img64.decode('utf-8') | ||
| 25 | |||
| 26 | |||
| 27 | def base64_bgr(img64): | ||
| 28 | str_img64 = base64.b64decode(img64) | ||
| 29 | image = np.frombuffer(str_img64, np.uint8) | ||
| 30 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) | ||
| 31 | return image | ||
| 32 | |||
| 33 | |||
| 34 | def tamper_detect_(image): | ||
| 35 | img64 = bgr_base64(image) | ||
| 36 | resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64})) | ||
| 37 | results = resp.json() | ||
| 38 | return results | ||
| 39 | |||
| 40 | |||
| 41 | if __name__ == '__main__': | ||
| 42 | image = cv2.imread( | ||
| 43 | '/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg') | ||
| 44 | st = time.time() | ||
| 45 | ocr_results = bill_ocr(image) | ||
| 46 | et1 = time.time() | ||
| 47 | info_results = extract_bank_info(ocr_results) | ||
| 48 | et2 = time.time() | ||
| 49 | tamper_results = [] | ||
| 50 | if len(info_results) != 0: | ||
| 51 | for info_result in info_results: | ||
| 52 | box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]] | ||
| 53 | x1, y1, x2, y2 = enlarge_position(box) | ||
| 54 | # x1, y1, x2, y2 = box | ||
| 55 | info_image = image[y1:y2, x1:x2, :] | ||
| 56 | results = tamper_detect_(info_image) | ||
| 57 | print(results) | ||
| 58 | if len(results['results']) != 0: | ||
| 59 | for res in results['results']: | ||
| 60 | cx = int(res[0]) | ||
| 61 | cy = int(res[1]) | ||
| 62 | width = int(res[2]) | ||
| 63 | height = int(res[3]) | ||
| 64 | left = cx - width // 2 | ||
| 65 | top = cy - height // 2 | ||
| 66 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | ||
| 67 | # absolute_position = [x1+left, y1+top, x2, y2] | ||
| 68 | tamper_results.append(absolute_position) | ||
| 69 | et3 = time.time() | ||
| 70 | print(tamper_results) | ||
| 71 | |||
| 72 | print(f'all time:{et3 - st} ocr time:{et1 - st} extract info time:{et2 - et1} yolo time:{et3 - et2}') | ||
| 73 | for i in tamper_results: | ||
| 74 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | ||
| 75 | cv2.imshow('info', image) | ||
| 76 | cv2.waitKey(0) | 
yolov5_onnx_demo/api/yolov5_onnx_server.py
0 → 100644
| 1 | import base64 | ||
| 2 | |||
| 3 | import cv2 | ||
| 4 | import numpy as np | ||
| 5 | from sanic import Sanic | ||
| 6 | from sanic.response import json | ||
| 7 | from yolov5_onnx_demo.model.yolov5_infer import * | ||
| 8 | |||
| 9 | |||
| 10 | def base64_to_bgr(bs64): | ||
| 11 | img_data = base64.b64decode(bs64) | ||
| 12 | img_arr = np.fromstring(img_data, np.uint8) | ||
| 13 | img_np = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) | ||
| 14 | return img_np | ||
| 15 | |||
| 16 | |||
| 17 | app = Sanic('tamper_det') | ||
| 18 | |||
| 19 | |||
| 20 | @app.post('/tamper_det') | ||
| 21 | def hello(request): | ||
| 22 | d = request.json | ||
| 23 | print(d['img']) | ||
| 24 | img = base64_to_bgr(d['img']) | ||
| 25 | result = grpc_detect(img) | ||
| 26 | |||
| 27 | return json({'results': result}) | ||
| 28 | |||
| 29 | |||
| 30 | if __name__ == '__main__': | ||
| 31 | app.run(host='192.168.10.11', port=8009,workers=10) | 
yolov5_onnx_demo/api_test.py
0 → 100644
| 1 | import base64 | ||
| 2 | |||
| 3 | import requests | ||
| 4 | import json | ||
| 5 | from yolov5_onnx_demo.model.yolov5_infer import * | ||
| 6 | |||
| 7 | def path_base64(file_path): | ||
| 8 | f = open(file_path, 'rb') | ||
| 9 | file64 = base64.b64encode(f.read()) # image 64 bytes 类型 | ||
| 10 | file64 = file64.decode('utf-8') | ||
| 11 | return file64 | ||
| 12 | |||
| 13 | |||
| 14 | res = requests.post('http://192.168.10.11:8009/tamper_det', data=json.dumps( | ||
| 15 | {'img': path_base64('/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')})) | ||
| 16 | results = res.json() | ||
| 17 | img = cv2.imread( | ||
| 18 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg') | ||
| 19 | print(res) | ||
| 20 | plot_label(img,results['keys']) | 
yolov5_onnx_demo/model/__init__.py
0 → 100644
File mode changed
No preview for this file type
No preview for this file type
yolov5_onnx_demo/model/yolov5_infer.py
0 → 100644
| 1 | import cv2 | ||
| 2 | import numpy as np | ||
| 3 | import tritonclient.grpc as grpcclient | ||
| 4 | |||
| 5 | |||
| 6 | def keep_resize_padding(image): | ||
| 7 | ''' | ||
| 8 | 注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行 | ||
| 9 | 导致输入的尺寸不固定,重写resize方法,添加padding到640*640 | ||
| 10 | ''' | ||
| 11 | h, w, c = image.shape | ||
| 12 | if h >= w: | ||
| 13 | pad1 = (h - w) // 2 | ||
| 14 | pad2 = h - w - pad1 | ||
| 15 | p1 = np.ones((h, pad1, 3)) * 114.0 | ||
| 16 | p2 = np.ones((h, pad2, 3)) * 114.0 | ||
| 17 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
| 18 | new_image = np.hstack((p1, image, p2)) | ||
| 19 | padding_info = [pad1, pad2, 0] | ||
| 20 | else: | ||
| 21 | pad1 = (w - h) // 2 | ||
| 22 | pad2 = w - h - pad1 | ||
| 23 | p1 = np.ones((pad1, w, 3)) * 114.0 | ||
| 24 | p2 = np.ones((pad2, w, 3)) * 114.0 | ||
| 25 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
| 26 | new_image = np.vstack((p1, image, p2)) | ||
| 27 | padding_info = [pad1, pad2, 1] | ||
| 28 | new_image = cv2.resize(new_image, (640, 640)) | ||
| 29 | return new_image, padding_info | ||
| 30 | |||
| 31 | |||
| 32 | # remove padding | ||
| 33 | def extract_authentic_bboxes(image, padding_info, bboxes): | ||
| 34 | ''' | ||
| 35 | 反算坐标信息 | ||
| 36 | ''' | ||
| 37 | pad1, pad2, pad_type = padding_info | ||
| 38 | h, w, c = image.shape | ||
| 39 | bboxes = np.array(bboxes) | ||
| 40 | max_slide = max(h, w) | ||
| 41 | scale = max_slide / 640 | ||
| 42 | bboxes[:, :4] = bboxes[:, :4] * scale | ||
| 43 | if pad_type == 0: | ||
| 44 | bboxes[:, 0] = bboxes[:, 0] - pad1 | ||
| 45 | else: | ||
| 46 | bboxes[:, 1] = bboxes[:, 1] - pad1 | ||
| 47 | return bboxes.tolist() | ||
| 48 | |||
| 49 | |||
| 50 | # NMS | ||
| 51 | def py_nms_cpu( | ||
| 52 | prediction, | ||
| 53 | conf_thres=0.25, | ||
| 54 | iou_thres=0.45, | ||
| 55 | ): | ||
| 56 | """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections | ||
| 57 | |||
| 58 | Returns: | ||
| 59 | list of detections, on (n,6) tensor per image [xyxy, conf, cls] | ||
| 60 | """ | ||
| 61 | xc = prediction[..., 4] > conf_thres # candidates | ||
| 62 | prediction = prediction[xc] | ||
| 63 | |||
| 64 | # MNS | ||
| 65 | x1 = prediction[..., 0] - prediction[..., 2] / 2 | ||
| 66 | y1 = prediction[..., 1] - prediction[..., 3] / 2 | ||
| 67 | x2 = prediction[..., 0] + prediction[..., 2] / 2 | ||
| 68 | y2 = prediction[..., 1] + prediction[..., 3] / 2 | ||
| 69 | |||
| 70 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) | ||
| 71 | score = prediction[..., 5] | ||
| 72 | order = np.argsort(score) | ||
| 73 | keep = [] | ||
| 74 | while order.size > 0: | ||
| 75 | i = order[0] | ||
| 76 | keep.append(i) | ||
| 77 | |||
| 78 | xx1 = np.maximum(x1[i], x1[order[1:]]) | ||
| 79 | yy1 = np.maximum(y1[i], y1[order[1:]]) | ||
| 80 | xx2 = np.minimum(x2[i], x2[order[1:]]) | ||
| 81 | yy2 = np.minimum(y2[i], y2[order[1:]]) | ||
| 82 | |||
| 83 | ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1) | ||
| 84 | inter = ww * hh | ||
| 85 | |||
| 86 | over = inter / (areas[i] + areas[order[1:]] - inter) | ||
| 87 | |||
| 88 | idx = np.where(over < iou_thres)[0] | ||
| 89 | order = order[idx + 1] | ||
| 90 | |||
| 91 | return prediction[keep] | ||
| 92 | |||
| 93 | |||
| 94 | def client_init(url='localhost:8001', | ||
| 95 | ssl=False, | ||
| 96 | private_key=None, | ||
| 97 | root_certificates=None, | ||
| 98 | certificate_chain=None, | ||
| 99 | verbose=False): | ||
| 100 | triton_client = grpcclient.InferenceServerClient( | ||
| 101 | url=url, | ||
| 102 | verbose=verbose, # 详细输出 默认是False | ||
| 103 | ssl=ssl, | ||
| 104 | root_certificates=root_certificates, | ||
| 105 | private_key=private_key, | ||
| 106 | certificate_chain=certificate_chain, | ||
| 107 | ) | ||
| 108 | return triton_client | ||
| 109 | |||
| 110 | |||
| 111 | triton_client = client_init('localhost:8001') | ||
| 112 | compression_algorithm = None | ||
| 113 | input_name = 'images' | ||
| 114 | output_name = 'output0' | ||
| 115 | model_name = 'yolov5' | ||
| 116 | |||
| 117 | |||
| 118 | def grpc_detect(img): | ||
| 119 | image, padding_info = keep_resize_padding(img) | ||
| 120 | image = image.transpose((2, 0, 1))[::-1] | ||
| 121 | image = image.astype(np.float32) | ||
| 122 | image = image / 255.0 | ||
| 123 | if len(image.shape) == 3: | ||
| 124 | image = image[None] | ||
| 125 | |||
| 126 | outputs, inputs = [], [] | ||
| 127 | |||
| 128 | # 动态输入 | ||
| 129 | input_shape = image.shape | ||
| 130 | inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32')) | ||
| 131 | outputs.append(grpcclient.InferRequestedOutput(output_name)) | ||
| 132 | |||
| 133 | inputs[0].set_data_from_numpy(image.astype(np.float32)) | ||
| 134 | |||
| 135 | pred = triton_client.infer( | ||
| 136 | model_name=model_name, | ||
| 137 | inputs=inputs, outputs=outputs, | ||
| 138 | compression_algorithm=compression_algorithm | ||
| 139 | ) | ||
| 140 | pred = pred.as_numpy(output_name).copy() | ||
| 141 | result_bboxes = py_nms_cpu(pred) | ||
| 142 | result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes) | ||
| 143 | return result_bboxes | ||
| 144 | |||
| 145 | |||
| 146 | def plot_label(img, result_bboxes): | ||
| 147 | print(result_bboxes) | ||
| 148 | for bbox in result_bboxes: | ||
| 149 | x, y, w, h, conf, cls = bbox | ||
| 150 | cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2) | ||
| 151 | cv2.imshow('im', img) | ||
| 152 | cv2.waitKey(0) | ||
| 153 | |||
| 154 | |||
| 155 | if __name__ == '__main__': | ||
| 156 | img = cv2.imread( | ||
| 157 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg') | ||
| 158 | |||
| 159 | result_bboxes = grpc_detect(img) | ||
| 160 | plot_label(result_bboxes) | 
- 
Please register or sign in to post a comment