64510cdb by 乔峰昇

update readme

1 parent 5e7dd86a
1 ## 五大银行OCR+关键字段信息提取
2
3 python bank_ocr_inference.py
4 其中函数extract_bank_info()为总的流水信息提取函数,参数为bill_ocr()返回的OCR所有识别结果results
5
6 ## yolov5推理
7 python inference.py
8
9 ## OCR+yolov5整体Pipeline
10 python pipeline.py
...\ No newline at end of file ...\ No newline at end of file
...@@ -18,15 +18,18 @@ def path_to_file(file_path): ...@@ -18,15 +18,18 @@ def path_to_file(file_path):
18 return f 18 return f
19 19
20 20
21 # 流水OCR接口
21 def bill_ocr(image): 22 def bill_ocr(image):
22 f = image_to_base64(image) 23 f = image_to_base64(image)
23 resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f}) 24 resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f})
24 results = resp.json() 25 results = resp.json()
25 ocr_results = results['ocr_results'] 26 ocr_results = results['ocr_results']
26 return ocr_results 27 return ocr_results
27 28
28 29
30 # 提取民生银行信息
29 def extract_minsheng_info(ocr_results): 31 def extract_minsheng_info(ocr_results):
32
30 name_prefix = '客户姓名:' 33 name_prefix = '客户姓名:'
31 account_prefix = '客户账号:' 34 account_prefix = '客户账号:'
32 results = [] 35 results = []
...@@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results): ...@@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results):
71 results.append([value[1], value[0]]) 74 results.append([value[1], value[0]])
72 return results 75 return results
73 76
74 77 # 提取工商银行信息
75 def extract_gongshang_info(ocr_results): 78 def extract_gongshang_info(ocr_results):
76 name_prefix = '户名:' 79 name_prefix = '户名:'
77 account_prefix = '卡号:' 80 account_prefix = '卡号:'
...@@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results): ...@@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results):
117 results.append([value[1], value[0]]) 120 results.append([value[1], value[0]])
118 return results 121 return results
119 122
120 123 # 提取中国银行信息
121 def extract_zhongguo_info(ocr_results): 124 def extract_zhongguo_info(ocr_results):
122 name_prefix = '客户姓名:' 125 name_prefix = '客户姓名:'
123 account_prefix = '借记卡号:' 126 account_prefix = '借记卡号:'
...@@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results): ...@@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results):
163 results.append([value[1], value[0]]) 166 results.append([value[1], value[0]])
164 return results 167 return results
165 168
166 169 # 提取建设银行信息
167 def extract_jianshe_info(ocr_results): 170 def extract_jianshe_info(ocr_results):
168 name_prefixes = ['客户名称:', '户名:'] 171 name_prefixes = ['客户名称:', '户名:']
169 account_prefixes = ['卡号/账号:', '卡号:'] 172 account_prefixes = ['卡号/账号:', '卡号:']
...@@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results): ...@@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results):
215 break 218 break
216 return results 219 return results
217 220
218 221 # 提取农业银行信息(比较复杂,目前训练的版式都支持)
219 def extract_nongye_info(ocr_results): 222 def extract_nongye_info(ocr_results):
220 name_prefixes = ['客户名:', '户名:'] 223 name_prefixes = ['客户名:', '户名:']
221 account_prefixes = ['账号:'] 224 account_prefixes = ['账号:']
...@@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results): ...@@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results):
318 break 321 break
319 return results 322 return results
320 323
321 324 # 提取银行流水信息总接口
322 def extract_bank_info(ocr_results): 325 def extract_bank_info(ocr_results):
323 results = [] 326 results = []
324 for value in ocr_results.values(): 327 for value in ocr_results.values():
...@@ -341,6 +344,7 @@ def extract_bank_info(ocr_results): ...@@ -341,6 +344,7 @@ def extract_bank_info(ocr_results):
341 344
342 345
343 if __name__ == '__main__': 346 if __name__ == '__main__':
347
344 path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' 348 path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
345 save_path='/data/situ_invoice_bill_data/new_data/results' 349 save_path='/data/situ_invoice_bill_data/new_data/results'
346 bank='minsheng' 350 bank='minsheng'
...@@ -362,3 +366,4 @@ if __name__ == '__main__': ...@@ -362,3 +366,4 @@ if __name__ == '__main__':
362 # cv2.waitKey(0) 366 # cv2.waitKey(0)
363 cv2.imwrite(os.path.join(save_path,j),img) 367 cv2.imwrite(os.path.join(save_path,j),img)
364 print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) 368 print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
369 #
...\ No newline at end of file ...\ No newline at end of file
......
...@@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False): ...@@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False):
54 return result 54 return result
55 55
56 56
57 def keep_resize_padding(image):
58 h, w, c = image.shape
59 if h >= w:
60 pad1 = (h - w) // 2
61 pad2 = h - w - pad1
62 p1 = np.ones((h, pad1, 3)) * 114.0
63 p2 = np.ones((h, pad2, 3)) * 114.0
64 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
65 new_image = np.hstack((p1, image, p2))
66 else:
67 pad1 = (w - h) // 2
68 pad2 = w - h - pad1
69 p1 = np.ones((pad1, w, 3)) * 114.0
70 p2 = np.ones((pad2, w, 3)) * 114.0
71 p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
72 new_image = np.vstack((p1, image, p2))
73 new_image = cv2.resize(new_image, (640, 640))
74 return new_image
75
76
57 class Yolov5: 77 class Yolov5:
58 def __init__(self, cfg=None): 78 def __init__(self, cfg=None):
59 self.cfg = cfg 79 self.cfg = cfg
...@@ -66,7 +86,16 @@ class Yolov5: ...@@ -66,7 +86,16 @@ class Yolov5:
66 imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size 86 imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size
67 # Dataloader 87 # Dataloader
68 bs = 1 # batch_size 88 bs = 1 # batch_size
69 im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize 89 # im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize
90 # hh, ww, cc = im.shape
91 # tlen1 = (640 - hh) // 2
92 # tlen2 = 640 - hh - tlen1
93 # t1 = np.zeros((tlen1, ww, cc))
94 # t2 = np.zeros((tlen2, ww, cc))
95 # im = np.vstack((t1, im, t2))
96 im = keep_resize_padding(image)
97
98 # print(im.shape)
70 im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB 99 im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
71 im = np.ascontiguousarray(im) # contiguous 100 im = np.ascontiguousarray(im) # contiguous
72 # Run inference 101 # Run inference
...@@ -74,13 +103,15 @@ class Yolov5: ...@@ -74,13 +103,15 @@ class Yolov5:
74 im = torch.from_numpy(im).to(self.model.device) 103 im = torch.from_numpy(im).to(self.model.device)
75 im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 104 im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
76 im /= 255 # 0 - 255 to 0.0 - 1.0 105 im /= 255 # 0 - 255 to 0.0 - 1.0
106
77 if len(im.shape) == 3: 107 if len(im.shape) == 3:
78 im = im[None] # expand for batch dim 108 im = im[None] # expand for batch dim
79 # Inference 109 # Inference
80 pred = self.model(im, augment=False, visualize=False) 110 pred = self.model(im, augment=False, visualize=False)
111 # print(pred[0].shape)
112 # exit(0)
81 # NMS 113 # NMS
82 pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det) 114 pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det)
83
84 det = pred[0] 115 det = pred[0]
85 # if len(det): 116 # if len(det):
86 det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round() 117 det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round()
...@@ -95,13 +126,14 @@ class Yolov5: ...@@ -95,13 +126,14 @@ class Yolov5:
95 126
96 if __name__ == "__main__": 127 if __name__ == "__main__":
97 img = cv2.imread( 128 img = cv2.imread(
98 '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png') 129 '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
99 detector = Yolov5(config) 130 detector = Yolov5(config)
100 result = detector.detect(img) 131 result = detector.detect(img)
101 for i in result['result']: 132 for i in result['result']:
102 position=list(i.values())[2:] 133 position = list(i.values())[2:]
103 print(position) 134 print(position)
104 cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255)) 135 cv2.rectangle(img, (position[0], position[1]), (position[0] + position[2], position[1] + position[3]),
105 cv2.imshow('w',img) 136 (0, 0, 255))
137 cv2.imshow('w', img)
106 cv2.waitKey(0) 138 cv2.waitKey(0)
107 print(result) 139 print(result)
......
...@@ -2,7 +2,7 @@ from easydict import EasyDict as edict ...@@ -2,7 +2,7 @@ from easydict import EasyDict as edict
2 2
3 config = edict( 3 config = edict(
4 # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL 4 # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL
5 weights='runs/train/exp/weights/best.pt', # model path or triton URL 5 weights='runs/train/exp/weights/best.onnx', # model path or triton URL
6 data='data/VOC.yaml', # dataset.yaml path 6 data='data/VOC.yaml', # dataset.yaml path
7 imgsz=(640, 640), # inference size (height, width) 7 imgsz=(640, 640), # inference size (height, width)
8 conf_thres=0.2, # confidence threshold 8 conf_thres=0.2, # confidence threshold
......
1 import time 1 import time
2
3 import cv2 2 import cv2
4
5 from bank_ocr_inference import bill_ocr, extract_bank_info 3 from bank_ocr_inference import bill_ocr, extract_bank_info
6 from inference import Yolov5 4 from inference import Yolov5
7 from models.yolov5_config import config 5 from models.yolov5_config import config
...@@ -18,10 +16,9 @@ def enlarge_position(box): ...@@ -18,10 +16,9 @@ def enlarge_position(box):
18 def tamper_detect(image): 16 def tamper_detect(image):
19 st = time.time() 17 st = time.time()
20 ocr_results = bill_ocr(image) 18 ocr_results = bill_ocr(image)
21 et1=time.time() 19 et1 = time.time()
22 info_results = extract_bank_info(ocr_results) 20 info_results = extract_bank_info(ocr_results)
23 et2=time.time() 21 et2 = time.time()
24 print(info_results)
25 tamper_results = [] 22 tamper_results = []
26 if len(info_results) != 0: 23 if len(info_results) != 0:
27 for info_result in info_results: 24 for info_result in info_results:
...@@ -29,21 +26,21 @@ def tamper_detect(image): ...@@ -29,21 +26,21 @@ def tamper_detect(image):
29 x1, y1, x2, y2 = enlarge_position(box) 26 x1, y1, x2, y2 = enlarge_position(box)
30 # x1, y1, x2, y2 = box 27 # x1, y1, x2, y2 = box
31 info_image = image[y1:y2, x1:x2, :] 28 info_image = image[y1:y2, x1:x2, :]
32 cv2.imshow('info_image',info_image) 29 cv2.imshow('info_image', info_image)
33 results = detector.detect(info_image) 30 results = detector.detect(info_image)
34 print(results) 31 print(results)
35 if len(results['result'])!=0: 32 if len(results['result']) != 0:
36 for res in results['result']: 33 for res in results['result']:
37 left = int(res['left']) 34 left = int(res['left'])
38 top = int(res['top']) 35 top = int(res['top'])
39 width = int(res['width']) 36 width = int(res['width'])
40 height = int(res['height']) 37 height = int(res['height'])
41 absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] 38 absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
42 tamper_results.append(absolute_position) 39 tamper_results .append(absolute_position)
43 print(tamper_results) 40 print(tamper_results)
44 et3 = time.time() 41 et3 = time.time()
45 42
46 print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}') 43 print(f'all:{et3 - st} ocr:{et1 - st} extract:{et2 - et1} yolo:{et3 - et2}')
47 for i in tamper_results: 44 for i in tamper_results:
48 cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) 45 cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
49 cv2.imshow('info', image) 46 cv2.imshow('info', image)
...@@ -53,5 +50,5 @@ def tamper_detect(image): ...@@ -53,5 +50,5 @@ def tamper_detect(image):
53 if __name__ == '__main__': 50 if __name__ == '__main__':
54 detector = Yolov5(config) 51 detector = Yolov5(config)
55 image = cv2.imread( 52 image = cv2.imread(
56 "/home/situ/下载/_1597378020.731796page_33_img_0.jpg") 53 "/data/situ_invoice_bill_data/new_data/qfs_bank_bill_person_ps/gongshang/tampered/images/val/ps3/CH-B006369332_page_67_img_0.jpg")
57 tamper_detect(image) 54 tamper_detect(image)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!