add pipeline inference
Showing
6 changed files
with
454 additions
and
26 deletions
bank_ocr_inference.py
0 → 100644
| 1 | import base64 | ||
| 2 | import os | ||
| 3 | import time | ||
| 4 | |||
| 5 | import cv2 | ||
| 6 | import numpy as np | ||
| 7 | import requests | ||
| 8 | import tqdm | ||
| 9 | |||
| 10 | |||
| 11 | def image_to_base64(image): | ||
| 12 | image = cv2.imencode('.png', image)[1] | ||
| 13 | return image | ||
| 14 | |||
| 15 | |||
| 16 | def path_to_file(file_path): | ||
| 17 | f = open(file_path, 'rb') | ||
| 18 | return f | ||
| 19 | |||
| 20 | |||
| 21 | def bill_ocr(image): | ||
| 22 | f = image_to_base64(image) | ||
| 23 | resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f}) | ||
| 24 | results = resp.json() | ||
| 25 | ocr_results = results['ocr_results'] | ||
| 26 | return ocr_results | ||
| 27 | |||
| 28 | |||
| 29 | def extract_minsheng_info(ocr_results): | ||
| 30 | name_prefix = '客户姓名:' | ||
| 31 | account_prefix = '客户账号:' | ||
| 32 | results = [] | ||
| 33 | for value in ocr_results.values(): | ||
| 34 | if name_prefix in value[1]: | ||
| 35 | if name_prefix == value[1]: | ||
| 36 | tmp_value, max_dis = [], 999999 | ||
| 37 | top_right_x = value[0][2] | ||
| 38 | top_right_y = value[0][3] | ||
| 39 | for tmp in ocr_results.values(): | ||
| 40 | if tmp[1] != name_prefix: | ||
| 41 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 42 | tmp[0][0] - top_right_x) < max_dis: | ||
| 43 | tmp_value = tmp | ||
| 44 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 45 | else: | ||
| 46 | continue | ||
| 47 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 48 | tmp_value[0][5], | ||
| 49 | value[0][6], value[0][7]] | ||
| 50 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 51 | else: | ||
| 52 | results.append([value[1], value[0]]) | ||
| 53 | if account_prefix in value[1]: | ||
| 54 | if account_prefix == value[1]: | ||
| 55 | tmp_value, max_dis = [], 999999 | ||
| 56 | top_right_x = value[0][2] | ||
| 57 | top_right_y = value[0][3] | ||
| 58 | for tmp in ocr_results.values(): | ||
| 59 | if tmp[1] != account_prefix: | ||
| 60 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 61 | tmp[0][0] - top_right_x) < max_dis: | ||
| 62 | tmp_value = tmp | ||
| 63 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 64 | else: | ||
| 65 | continue | ||
| 66 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 67 | tmp_value[0][5], | ||
| 68 | value[0][6], value[0][7]] | ||
| 69 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 70 | else: | ||
| 71 | results.append([value[1], value[0]]) | ||
| 72 | return results | ||
| 73 | |||
| 74 | |||
| 75 | def extract_gongshang_info(ocr_results): | ||
| 76 | name_prefix = '户名:' | ||
| 77 | account_prefix = '卡号:' | ||
| 78 | results = [] | ||
| 79 | for value in ocr_results.values(): | ||
| 80 | if name_prefix in value[1]: | ||
| 81 | if name_prefix == value[1]: | ||
| 82 | tmp_value, max_dis = [], 999999 | ||
| 83 | top_right_x = value[0][2] | ||
| 84 | top_right_y = value[0][3] | ||
| 85 | for tmp in ocr_results.values(): | ||
| 86 | if tmp[1] != name_prefix: | ||
| 87 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 88 | tmp[0][0] - top_right_x) < max_dis: | ||
| 89 | tmp_value = tmp | ||
| 90 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 91 | else: | ||
| 92 | continue | ||
| 93 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 94 | tmp_value[0][5], | ||
| 95 | value[0][6], value[0][7]] | ||
| 96 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 97 | else: | ||
| 98 | results.append([value[1], value[0]]) | ||
| 99 | if account_prefix in value[1]: | ||
| 100 | if account_prefix == value[1]: | ||
| 101 | tmp_value, max_dis = [], 999999 | ||
| 102 | top_right_x = value[0][2] | ||
| 103 | top_right_y = value[0][3] | ||
| 104 | for tmp in ocr_results.values(): | ||
| 105 | if tmp[1] != account_prefix: | ||
| 106 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 107 | tmp[0][0] - top_right_x) < max_dis: | ||
| 108 | tmp_value = tmp | ||
| 109 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 110 | else: | ||
| 111 | continue | ||
| 112 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 113 | tmp_value[0][5], | ||
| 114 | value[0][6], value[0][7]] | ||
| 115 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 116 | else: | ||
| 117 | results.append([value[1], value[0]]) | ||
| 118 | return results | ||
| 119 | |||
| 120 | |||
| 121 | def extract_zhongguo_info(ocr_results): | ||
| 122 | name_prefix = '客户姓名:' | ||
| 123 | account_prefix = '借记卡号:' | ||
| 124 | results = [] | ||
| 125 | for value in ocr_results.values(): | ||
| 126 | if name_prefix in value[1]: | ||
| 127 | if name_prefix == value[1]: | ||
| 128 | tmp_value, max_dis = [], 999999 | ||
| 129 | top_right_x = value[0][2] | ||
| 130 | top_right_y = value[0][3] | ||
| 131 | for tmp in ocr_results.values(): | ||
| 132 | if tmp[1] != name_prefix: | ||
| 133 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 134 | tmp[0][0] - top_right_x) < max_dis: | ||
| 135 | tmp_value = tmp | ||
| 136 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 137 | else: | ||
| 138 | continue | ||
| 139 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 140 | tmp_value[0][5], | ||
| 141 | value[0][6], value[0][7]] | ||
| 142 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 143 | else: | ||
| 144 | results.append([value[1], value[0]]) | ||
| 145 | if account_prefix in value[1]: | ||
| 146 | if account_prefix == value[1]: | ||
| 147 | tmp_value, max_dis = [], 999999 | ||
| 148 | top_right_x = value[0][2] | ||
| 149 | top_right_y = value[0][3] | ||
| 150 | for tmp in ocr_results.values(): | ||
| 151 | if tmp[1] != account_prefix: | ||
| 152 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 153 | tmp[0][0] - top_right_x) < max_dis: | ||
| 154 | tmp_value = tmp | ||
| 155 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 156 | else: | ||
| 157 | continue | ||
| 158 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 159 | tmp_value[0][5], | ||
| 160 | value[0][6], value[0][7]] | ||
| 161 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 162 | else: | ||
| 163 | results.append([value[1], value[0]]) | ||
| 164 | return results | ||
| 165 | |||
| 166 | |||
| 167 | def extract_jianshe_info(ocr_results): | ||
| 168 | name_prefixes = ['客户名称:', '户名:'] | ||
| 169 | account_prefixes = ['卡号/账号:', '卡号:'] | ||
| 170 | results = [] | ||
| 171 | for value in ocr_results.values(): | ||
| 172 | for name_prefix in name_prefixes: | ||
| 173 | if name_prefix in value[1]: | ||
| 174 | if name_prefix == value[1]: | ||
| 175 | tmp_value, max_dis = [], 999999 | ||
| 176 | top_right_x = value[0][2] | ||
| 177 | top_right_y = value[0][3] | ||
| 178 | for tmp in ocr_results.values(): | ||
| 179 | if tmp[1] != name_prefix: | ||
| 180 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 181 | tmp[0][0] - top_right_x) < max_dis: | ||
| 182 | tmp_value = tmp | ||
| 183 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 184 | else: | ||
| 185 | continue | ||
| 186 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 187 | tmp_value[0][5], | ||
| 188 | value[0][6], value[0][7]] | ||
| 189 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 190 | break | ||
| 191 | else: | ||
| 192 | results.append([value[1], value[0]]) | ||
| 193 | break | ||
| 194 | for account_prefix in account_prefixes: | ||
| 195 | if account_prefix in value[1]: | ||
| 196 | if account_prefix == value[1]: | ||
| 197 | tmp_value, max_dis = [], 999999 | ||
| 198 | top_right_x = value[0][2] | ||
| 199 | top_right_y = value[0][3] | ||
| 200 | for tmp in ocr_results.values(): | ||
| 201 | if tmp[1] != account_prefix: | ||
| 202 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 203 | tmp[0][0] - top_right_x) < max_dis: | ||
| 204 | tmp_value = tmp | ||
| 205 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 206 | else: | ||
| 207 | continue | ||
| 208 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 209 | tmp_value[0][5], | ||
| 210 | value[0][6], value[0][7]] | ||
| 211 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 212 | break | ||
| 213 | else: | ||
| 214 | results.append([value[1], value[0]]) | ||
| 215 | break | ||
| 216 | return results | ||
| 217 | |||
| 218 | |||
| 219 | def extract_nongye_info(ocr_results): | ||
| 220 | name_prefixes = ['客户名:', '户名:'] | ||
| 221 | account_prefixes = ['账号:'] | ||
| 222 | results = [] | ||
| 223 | is_account = True | ||
| 224 | for value in ocr_results.values(): | ||
| 225 | for name_prefix in name_prefixes: | ||
| 226 | if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]: | ||
| 227 | if name_prefix == value[1]: | ||
| 228 | tmp_value, max_dis = [], 999999 | ||
| 229 | top_right_x = value[0][2] | ||
| 230 | top_right_y = value[0][3] | ||
| 231 | for tmp in ocr_results.values(): | ||
| 232 | if tmp[1] != name_prefix: | ||
| 233 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 234 | tmp[0][0] - top_right_x) < max_dis: | ||
| 235 | tmp_value = tmp | ||
| 236 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 237 | else: | ||
| 238 | continue | ||
| 239 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 240 | tmp_value[0][5], | ||
| 241 | value[0][6], value[0][7]] | ||
| 242 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 243 | break | ||
| 244 | else: | ||
| 245 | results.append([value[1], value[0]]) | ||
| 246 | break | ||
| 247 | if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5: | ||
| 248 | is_account = False | ||
| 249 | if len(value[1]) == 5: | ||
| 250 | tmp_value, max_dis = [], 999999 | ||
| 251 | top_right_x = value[0][2] | ||
| 252 | top_right_y = value[0][3] | ||
| 253 | tmp_info = {} | ||
| 254 | for tmp in ocr_results.values(): | ||
| 255 | if tmp[1] != value[1]: | ||
| 256 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: | ||
| 257 | tmp_info[abs(tmp[0][0] - top_right_x)] = tmp | ||
| 258 | else: | ||
| 259 | continue | ||
| 260 | tmp_info_id = sorted(tmp_info.keys()) | ||
| 261 | if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19: | ||
| 262 | tmp_value = tmp_info[tmp_info_id[0]] | ||
| 263 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 264 | tmp_value[0][5], | ||
| 265 | value[0][6], value[0][7]] | ||
| 266 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 267 | if tmp_info[tmp_info_id[0]][1].isdigit(): | ||
| 268 | tmp_value = tmp_info[tmp_info_id[1]] | ||
| 269 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 270 | tmp_value[0][5], | ||
| 271 | value[0][6], value[0][7]] | ||
| 272 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 273 | break | ||
| 274 | elif len(value[1]) < 25: | ||
| 275 | tmp_info = {} | ||
| 276 | top_right_x = value[0][2] | ||
| 277 | top_right_y = value[0][3] | ||
| 278 | for tmp in ocr_results.values(): | ||
| 279 | if tmp[1] != value[1]: | ||
| 280 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: | ||
| 281 | tmp_info[abs(tmp[0][0] - top_right_x)] = tmp | ||
| 282 | else: | ||
| 283 | continue | ||
| 284 | tmp_info_id = sorted(tmp_info.keys()) | ||
| 285 | tmp_value = tmp_info[tmp_info_id[0]] | ||
| 286 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 287 | tmp_value[0][5], | ||
| 288 | value[0][6], value[0][7]] | ||
| 289 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 290 | break | ||
| 291 | else: | ||
| 292 | results.append([value[1], value[0]]) | ||
| 293 | break | ||
| 294 | if is_account: | ||
| 295 | for account_prefix in account_prefixes: | ||
| 296 | if account_prefix in value[1]: | ||
| 297 | if account_prefix == value[1]: | ||
| 298 | tmp_value, max_dis = [], 999999 | ||
| 299 | top_right_x = value[0][2] | ||
| 300 | top_right_y = value[0][3] | ||
| 301 | for tmp in ocr_results.values(): | ||
| 302 | if tmp[1] != account_prefix: | ||
| 303 | if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( | ||
| 304 | tmp[0][0] - top_right_x) < max_dis: | ||
| 305 | tmp_value = tmp | ||
| 306 | max_dis = abs(tmp[0][0] - top_right_x) | ||
| 307 | else: | ||
| 308 | continue | ||
| 309 | new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], | ||
| 310 | tmp_value[0][5], | ||
| 311 | value[0][6], value[0][7]] | ||
| 312 | results.append([value[1] + tmp_value[1], new_position]) | ||
| 313 | break | ||
| 314 | else: | ||
| 315 | results.append([value[1], value[0]]) | ||
| 316 | break | ||
| 317 | else: | ||
| 318 | break | ||
| 319 | return results | ||
| 320 | |||
| 321 | |||
| 322 | def extract_bank_info(ocr_results): | ||
| 323 | results = [] | ||
| 324 | for value in ocr_results.values(): | ||
| 325 | if value[1].__contains__('建设'): | ||
| 326 | results = extract_jianshe_info(ocr_results) | ||
| 327 | break | ||
| 328 | elif value[1].__contains__('民生'): | ||
| 329 | results = extract_minsheng_info(ocr_results) | ||
| 330 | break | ||
| 331 | elif value[1].__contains__('农业'): | ||
| 332 | results = extract_nongye_info(ocr_results) | ||
| 333 | break | ||
| 334 | elif value[1].__contains__('中国银行'): | ||
| 335 | results = extract_zhongguo_info(ocr_results) | ||
| 336 | break | ||
| 337 | if len(results) == 0: | ||
| 338 | results = extract_gongshang_info(ocr_results) | ||
| 339 | |||
| 340 | return results | ||
| 341 | |||
| 342 | |||
| 343 | if __name__ == '__main__': | ||
| 344 | path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' | ||
| 345 | save_path='/data/situ_invoice_bill_data/new_data/results' | ||
| 346 | bank='minsheng' | ||
| 347 | if not os.path.exists(os.path.join(save_path,bank)): | ||
| 348 | os.makedirs(os.path.join(save_path,bank)) | ||
| 349 | save_path=os.path.join(save_path,bank) | ||
| 350 | for j in tqdm.tqdm(os.listdir(path)): | ||
| 351 | # if True: | ||
| 352 | img=cv2.imread(os.path.join(path,j)) | ||
| 353 | # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg') | ||
| 354 | st = time.time() | ||
| 355 | ocr_result = bill_ocr(img) | ||
| 356 | et1 = time.time() | ||
| 357 | result = extract_bank_info(ocr_result) | ||
| 358 | et2 = time.time() | ||
| 359 | for i in range(len(result)): | ||
| 360 | cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2) | ||
| 361 | # cv2.imshow('img',img) | ||
| 362 | # cv2.waitKey(0) | ||
| 363 | cv2.imwrite(os.path.join(save_path,j),img) | ||
| 364 | print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) | 
| ... | @@ -576,8 +576,8 @@ def run( | ... | @@ -576,8 +576,8 @@ def run( | 
| 576 | 576 | ||
| 577 | def parse_opt(): | 577 | def parse_opt(): | 
| 578 | parser = argparse.ArgumentParser() | 578 | parser = argparse.ArgumentParser() | 
| 579 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') | 579 | parser.add_argument('--data', type=str, default=ROOT / 'data/VOC.yaml', help='dataset.yaml path') | 
| 580 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') | 580 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/exp/weights/best.pt', help='model.pt path(s)') | 
| 581 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') | 581 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') | 
| 582 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') | 582 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') | 
| 583 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | 583 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | ... | ... | 
| ... | @@ -95,7 +95,13 @@ class Yolov5: | ... | @@ -95,7 +95,13 @@ class Yolov5: | 
| 95 | 95 | ||
| 96 | if __name__ == "__main__": | 96 | if __name__ == "__main__": | 
| 97 | img = cv2.imread( | 97 | img = cv2.imread( | 
| 98 | '/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/data/images/crop_img/_1594890230.8032346page_10_img_0_hname.jpg') | 98 | '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png') | 
| 99 | detector = Yolov5(config) | 99 | detector = Yolov5(config) | 
| 100 | result = detector.detect(img) | 100 | result = detector.detect(img) | 
| 101 | for i in result['result']: | ||
| 102 | position=list(i.values())[2:] | ||
| 103 | print(position) | ||
| 104 | cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255)) | ||
| 105 | cv2.imshow('w',img) | ||
| 106 | cv2.waitKey(0) | ||
| 101 | print(result) | 107 | print(result) | ... | ... | 
| 1 | from easydict import EasyDict as edict | 1 | from easydict import EasyDict as edict | 
| 2 | 2 | ||
| 3 | config = edict( | 3 | config = edict( | 
| 4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL | ||
| 4 | weights='runs/train/exp/weights/best.pt', # model path or triton URL | 5 | weights='runs/train/exp/weights/best.pt', # model path or triton URL | 
| 5 | data='data/VOC.yaml', # dataset.yaml path | 6 | data='data/VOC.yaml', # dataset.yaml path | 
| 6 | imgsz=(640, 640), # inference size (height, width) | 7 | imgsz=(640, 640), # inference size (height, width) | 
| 7 | conf_thres=0.5, # confidence threshold | 8 | conf_thres=0.2, # confidence threshold | 
| 8 | iou_thres=0.45, # NMS IOU threshold | 9 | iou_thres=0.45, # NMS IOU threshold | 
| 9 | max_det=1000, # maximum detections per image | 10 | max_det=1000, # maximum detections per image | 
| 10 | device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu | 11 | device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu | ... | ... | 
| 1 | import time | ||
| 2 | |||
| 3 | import cv2 | ||
| 4 | |||
| 5 | from bank_ocr_inference import bill_ocr, extract_bank_info | ||
| 6 | from inference import Yolov5 | ||
| 7 | from models.yolov5_config import config | ||
| 8 | |||
| 9 | |||
| 10 | def enlarge_position(box): | ||
| 11 | x1, y1, x2, y2 = box | ||
| 12 | w, h = abs(x2 - x1), abs(y2 - y1) | ||
| 13 | y1, y2 = max(y1 - h // 3, 0), y2 + h // 3 | ||
| 14 | x1, x2 = max(x1 - w // 8, 0), x2 + w // 8 | ||
| 15 | return [x1, y1, x2, y2] | ||
| 16 | |||
| 17 | |||
| 18 | def tamper_detect(image): | ||
| 19 | st = time.time() | ||
| 20 | ocr_results = bill_ocr(image) | ||
| 21 | et1=time.time() | ||
| 22 | info_results = extract_bank_info(ocr_results) | ||
| 23 | et2=time.time() | ||
| 24 | print(info_results) | ||
| 25 | tamper_results = [] | ||
| 26 | if len(info_results) != 0: | ||
| 27 | for info_result in info_results: | ||
| 28 | box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]] | ||
| 29 | x1, y1, x2, y2 = enlarge_position(box) | ||
| 30 | # x1, y1, x2, y2 = box | ||
| 31 | info_image = image[y1:y2, x1:x2, :] | ||
| 32 | cv2.imshow('info_image',info_image) | ||
| 33 | results = detector.detect(info_image) | ||
| 34 | print(results) | ||
| 35 | if len(results['result'])!=0: | ||
| 36 | for res in results['result']: | ||
| 37 | left = int(res['left']) | ||
| 38 | top = int(res['top']) | ||
| 39 | width = int(res['width']) | ||
| 40 | height = int(res['height']) | ||
| 41 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | ||
| 42 | tamper_results.append(absolute_position) | ||
| 43 | print(tamper_results) | ||
| 44 | et3 = time.time() | ||
| 45 | |||
| 46 | print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}') | ||
| 47 | for i in tamper_results: | ||
| 48 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | ||
| 49 | cv2.imshow('info', image) | ||
| 50 | cv2.waitKey(0) | ||
| 51 | |||
| 52 | |||
| 53 | if __name__ == '__main__': | ||
| 54 | detector = Yolov5(config) | ||
| 55 | image = cv2.imread( | ||
| 56 | "/home/situ/下载/_1597378020.731796page_33_img_0.jpg") | ||
| 57 | tamper_detect(image) | ... | ... | 
| ... | @@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions): | ... | @@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions): | 
| 10 | result = [] | 10 | result = [] | 
| 11 | x1, y1, x2, y2 = crop_position | 11 | x1, y1, x2, y2 = crop_position | 
| 12 | for p in predict_positions: | 12 | for p in predict_positions: | 
| 13 | px1, py1, px2, py2,score = p | 13 | px1, py1, px2, py2, score = p | 
| 14 | w, h = px2 - px1, py2 - py1 | 14 | w, h = px2 - px1, py2 - py1 | 
| 15 | result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h,score]) | 15 | result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h, score]) | 
| 16 | return result | 16 | return result | 
| 17 | 17 | ||
| 18 | 18 | ||
| ... | @@ -22,9 +22,9 @@ def decode_label(image, label_path): | ... | @@ -22,9 +22,9 @@ def decode_label(image, label_path): | 
| 22 | result = [] | 22 | result = [] | 
| 23 | for d in data: | 23 | for d in data: | 
| 24 | d = [float(i) for i in d.strip().split(' ')] | 24 | d = [float(i) for i in d.strip().split(' ')] | 
| 25 | cls, cx, cy, cw, ch,score = d | 25 | cls, cx, cy, cw, ch, score = d | 
| 26 | cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h | 26 | cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h | 
| 27 | result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2),score]) | 27 | result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2), score]) | 
| 28 | return result | 28 | return result | 
| 29 | 29 | ||
| 30 | 30 | ||
| ... | @@ -38,28 +38,28 @@ if __name__ == '__main__': | ... | @@ -38,28 +38,28 @@ if __name__ == '__main__': | 
| 38 | data = pd.read_csv(crop_csv_path) | 38 | data = pd.read_csv(crop_csv_path) | 
| 39 | img_name = data.loc[:, 'img_name'].tolist() | 39 | img_name = data.loc[:, 'img_name'].tolist() | 
| 40 | crop_position1 = data.loc[:, 'name_crop_coord'].tolist() | 40 | crop_position1 = data.loc[:, 'name_crop_coord'].tolist() | 
| 41 | crop_position2 = data.loc[:,'number_crop_coord'].tolist() | 41 | crop_position2 = data.loc[:, 'number_crop_coord'].tolist() | 
| 42 | cc='/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3' | 42 | cc = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3' | 
| 43 | for im in os.listdir(cc): | 43 | for im in os.listdir(cc): | 
| 44 | print(im) | 44 | print(im) | 
| 45 | img = cv2.imread(os.path.join(cc,im)) | 45 | img = cv2.imread(os.path.join(cc, im)) | 
| 46 | img_=img.copy() | 46 | img_ = img.copy() | 
| 47 | id = img_name.index(im) | 47 | id = img_name.index(im) | 
| 48 | name_crop_position=[int(i) for i in crop_position1[id].split(',')] | 48 | name_crop_position = [int(i) for i in crop_position1[id].split(',')] | 
| 49 | number_crop_position=[int(i) for i in crop_position2[id].split(',')] | 49 | number_crop_position = [int(i) for i in crop_position2[id].split(',')] | 
| 50 | nx1,ny1,nx2,ny2=name_crop_position | 50 | nx1, ny1, nx2, ny2 = name_crop_position | 
| 51 | nux1,nuy1,nux2,nuy2=number_crop_position | 51 | nux1, nuy1, nux2, nuy2 = number_crop_position | 
| 52 | if im[:-4]+'_hname.txt' in predict_labels: | 52 | if im[:-4] + '_hname.txt' in predict_labels: | 
| 53 | 53 | ||
| 54 | h, w, c = img[ny1:ny2, nx1:nx2, :].shape | 54 | h, w, c = img[ny1:ny2, nx1:nx2, :].shape | 
| 55 | data = open(os.path.join(predict_label_path,im[:-4]+'_hname.txt')).readlines() | 55 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() | 
| 56 | for d in data: | 56 | for d in data: | 
| 57 | cls,cx,cy,cw,ch,score = [float(i) for i in d.strip().split(' ')] | 57 | cls, cx, cy, cw, ch, score = [float(i) for i in d.strip().split(' ')] | 
| 58 | cx,cy,cw,ch=int(cx*w),int(cy*h),int(cw*w),int(ch*h) | 58 | cx, cy, cw, ch = int(cx * w), int(cy * h), int(cw * w), int(ch * h) | 
| 59 | cx1,cy1=cx-cw//2,cy-ch//2 | 59 | cx1, cy1 = cx - cw // 2, cy - ch // 2 | 
| 60 | x1,y1,x2,y2=nx1+cx1,ny1+cy1,nx1+cx1+cw,ny1+cy1+ch | 60 | x1, y1, x2, y2 = nx1 + cx1, ny1 + cy1, nx1 + cx1 + cw, ny1 + cy1 + ch | 
| 61 | cv2.rectangle(img,(x1,y1),(x2,y2),(0,0,255),2) | 61 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | 
| 62 | cv2.putText(img,f'tampered:{score}',(x1,y1-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),1) | 62 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | 
| 63 | if im[:-4] + '_hnumber.txt' in predict_labels: | 63 | if im[:-4] + '_hnumber.txt' in predict_labels: | 
| 64 | h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape | 64 | h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape | 
| 65 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() | 65 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() | 
| ... | @@ -70,5 +70,5 @@ if __name__ == '__main__': | ... | @@ -70,5 +70,5 @@ if __name__ == '__main__': | 
| 70 | x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch | 70 | x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch | 
| 71 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | 71 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | 
| 72 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | 72 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | 
| 73 | result = np.vstack((img_,img)) | 73 | result = np.vstack((img_, img)) | 
| 74 | cv2.imwrite(f'z/{im}',result) | 74 | cv2.imwrite(f'z/{im}', result) | ... | ... | 
- 
Please register or sign in to post a comment