update readme
Showing
5 changed files
with
67 additions
and
23 deletions
ReadMe.md
0 → 100644
... | @@ -18,15 +18,18 @@ def path_to_file(file_path): | ... | @@ -18,15 +18,18 @@ def path_to_file(file_path): |
18 | return f | 18 | return f |
19 | 19 | ||
20 | 20 | ||
21 | # 流水OCR接口 | ||
21 | def bill_ocr(image): | 22 | def bill_ocr(image): |
22 | f = image_to_base64(image) | 23 | f = image_to_base64(image) |
23 | resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f}) | 24 | resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f}) |
24 | results = resp.json() | 25 | results = resp.json() |
25 | ocr_results = results['ocr_results'] | 26 | ocr_results = results['ocr_results'] |
26 | return ocr_results | 27 | return ocr_results |
27 | 28 | ||
28 | 29 | ||
30 | # 提取民生银行信息 | ||
29 | def extract_minsheng_info(ocr_results): | 31 | def extract_minsheng_info(ocr_results): |
32 | |||
30 | name_prefix = '客户姓名:' | 33 | name_prefix = '客户姓名:' |
31 | account_prefix = '客户账号:' | 34 | account_prefix = '客户账号:' |
32 | results = [] | 35 | results = [] |
... | @@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results): | ... | @@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results): |
71 | results.append([value[1], value[0]]) | 74 | results.append([value[1], value[0]]) |
72 | return results | 75 | return results |
73 | 76 | ||
74 | 77 | # 提取工商银行信息 | |
75 | def extract_gongshang_info(ocr_results): | 78 | def extract_gongshang_info(ocr_results): |
76 | name_prefix = '户名:' | 79 | name_prefix = '户名:' |
77 | account_prefix = '卡号:' | 80 | account_prefix = '卡号:' |
... | @@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results): | ... | @@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results): |
117 | results.append([value[1], value[0]]) | 120 | results.append([value[1], value[0]]) |
118 | return results | 121 | return results |
119 | 122 | ||
120 | 123 | # 提取中国银行信息 | |
121 | def extract_zhongguo_info(ocr_results): | 124 | def extract_zhongguo_info(ocr_results): |
122 | name_prefix = '客户姓名:' | 125 | name_prefix = '客户姓名:' |
123 | account_prefix = '借记卡号:' | 126 | account_prefix = '借记卡号:' |
... | @@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results): | ... | @@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results): |
163 | results.append([value[1], value[0]]) | 166 | results.append([value[1], value[0]]) |
164 | return results | 167 | return results |
165 | 168 | ||
166 | 169 | # 提取建设银行信息 | |
167 | def extract_jianshe_info(ocr_results): | 170 | def extract_jianshe_info(ocr_results): |
168 | name_prefixes = ['客户名称:', '户名:'] | 171 | name_prefixes = ['客户名称:', '户名:'] |
169 | account_prefixes = ['卡号/账号:', '卡号:'] | 172 | account_prefixes = ['卡号/账号:', '卡号:'] |
... | @@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results): | ... | @@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results): |
215 | break | 218 | break |
216 | return results | 219 | return results |
217 | 220 | ||
218 | 221 | # 提取农业银行信息(比较复杂,目前训练的版式都支持) | |
219 | def extract_nongye_info(ocr_results): | 222 | def extract_nongye_info(ocr_results): |
220 | name_prefixes = ['客户名:', '户名:'] | 223 | name_prefixes = ['客户名:', '户名:'] |
221 | account_prefixes = ['账号:'] | 224 | account_prefixes = ['账号:'] |
... | @@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results): | ... | @@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results): |
318 | break | 321 | break |
319 | return results | 322 | return results |
320 | 323 | ||
321 | 324 | # 提取银行流水信息总接口 | |
322 | def extract_bank_info(ocr_results): | 325 | def extract_bank_info(ocr_results): |
323 | results = [] | 326 | results = [] |
324 | for value in ocr_results.values(): | 327 | for value in ocr_results.values(): |
... | @@ -341,6 +344,7 @@ def extract_bank_info(ocr_results): | ... | @@ -341,6 +344,7 @@ def extract_bank_info(ocr_results): |
341 | 344 | ||
342 | 345 | ||
343 | if __name__ == '__main__': | 346 | if __name__ == '__main__': |
347 | |||
344 | path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' | 348 | path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' |
345 | save_path='/data/situ_invoice_bill_data/new_data/results' | 349 | save_path='/data/situ_invoice_bill_data/new_data/results' |
346 | bank='minsheng' | 350 | bank='minsheng' |
... | @@ -362,3 +366,4 @@ if __name__ == '__main__': | ... | @@ -362,3 +366,4 @@ if __name__ == '__main__': |
362 | # cv2.waitKey(0) | 366 | # cv2.waitKey(0) |
363 | cv2.imwrite(os.path.join(save_path,j),img) | 367 | cv2.imwrite(os.path.join(save_path,j),img) |
364 | print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) | 368 | print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) |
369 | # | ||
... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
... | @@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False): | ... | @@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False): |
54 | return result | 54 | return result |
55 | 55 | ||
56 | 56 | ||
57 | def keep_resize_padding(image): | ||
58 | h, w, c = image.shape | ||
59 | if h >= w: | ||
60 | pad1 = (h - w) // 2 | ||
61 | pad2 = h - w - pad1 | ||
62 | p1 = np.ones((h, pad1, 3)) * 114.0 | ||
63 | p2 = np.ones((h, pad2, 3)) * 114.0 | ||
64 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
65 | new_image = np.hstack((p1, image, p2)) | ||
66 | else: | ||
67 | pad1 = (w - h) // 2 | ||
68 | pad2 = w - h - pad1 | ||
69 | p1 = np.ones((pad1, w, 3)) * 114.0 | ||
70 | p2 = np.ones((pad2, w, 3)) * 114.0 | ||
71 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
72 | new_image = np.vstack((p1, image, p2)) | ||
73 | new_image = cv2.resize(new_image, (640, 640)) | ||
74 | return new_image | ||
75 | |||
76 | |||
57 | class Yolov5: | 77 | class Yolov5: |
58 | def __init__(self, cfg=None): | 78 | def __init__(self, cfg=None): |
59 | self.cfg = cfg | 79 | self.cfg = cfg |
... | @@ -66,7 +86,16 @@ class Yolov5: | ... | @@ -66,7 +86,16 @@ class Yolov5: |
66 | imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size | 86 | imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size |
67 | # Dataloader | 87 | # Dataloader |
68 | bs = 1 # batch_size | 88 | bs = 1 # batch_size |
69 | im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize | 89 | # im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize |
90 | # hh, ww, cc = im.shape | ||
91 | # tlen1 = (640 - hh) // 2 | ||
92 | # tlen2 = 640 - hh - tlen1 | ||
93 | # t1 = np.zeros((tlen1, ww, cc)) | ||
94 | # t2 = np.zeros((tlen2, ww, cc)) | ||
95 | # im = np.vstack((t1, im, t2)) | ||
96 | im = keep_resize_padding(image) | ||
97 | |||
98 | # print(im.shape) | ||
70 | im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | 99 | im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB |
71 | im = np.ascontiguousarray(im) # contiguous | 100 | im = np.ascontiguousarray(im) # contiguous |
72 | # Run inference | 101 | # Run inference |
... | @@ -74,13 +103,15 @@ class Yolov5: | ... | @@ -74,13 +103,15 @@ class Yolov5: |
74 | im = torch.from_numpy(im).to(self.model.device) | 103 | im = torch.from_numpy(im).to(self.model.device) |
75 | im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 | 104 | im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 |
76 | im /= 255 # 0 - 255 to 0.0 - 1.0 | 105 | im /= 255 # 0 - 255 to 0.0 - 1.0 |
106 | |||
77 | if len(im.shape) == 3: | 107 | if len(im.shape) == 3: |
78 | im = im[None] # expand for batch dim | 108 | im = im[None] # expand for batch dim |
79 | # Inference | 109 | # Inference |
80 | pred = self.model(im, augment=False, visualize=False) | 110 | pred = self.model(im, augment=False, visualize=False) |
111 | # print(pred[0].shape) | ||
112 | # exit(0) | ||
81 | # NMS | 113 | # NMS |
82 | pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det) | 114 | pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det) |
83 | |||
84 | det = pred[0] | 115 | det = pred[0] |
85 | # if len(det): | 116 | # if len(det): |
86 | det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round() | 117 | det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round() |
... | @@ -95,13 +126,14 @@ class Yolov5: | ... | @@ -95,13 +126,14 @@ class Yolov5: |
95 | 126 | ||
96 | if __name__ == "__main__": | 127 | if __name__ == "__main__": |
97 | img = cv2.imread( | 128 | img = cv2.imread( |
98 | '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png') | 129 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg') |
99 | detector = Yolov5(config) | 130 | detector = Yolov5(config) |
100 | result = detector.detect(img) | 131 | result = detector.detect(img) |
101 | for i in result['result']: | 132 | for i in result['result']: |
102 | position=list(i.values())[2:] | 133 | position = list(i.values())[2:] |
103 | print(position) | 134 | print(position) |
104 | cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255)) | 135 | cv2.rectangle(img, (position[0], position[1]), (position[0] + position[2], position[1] + position[3]), |
105 | cv2.imshow('w',img) | 136 | (0, 0, 255)) |
137 | cv2.imshow('w', img) | ||
106 | cv2.waitKey(0) | 138 | cv2.waitKey(0) |
107 | print(result) | 139 | print(result) | ... | ... |
... | @@ -2,7 +2,7 @@ from easydict import EasyDict as edict | ... | @@ -2,7 +2,7 @@ from easydict import EasyDict as edict |
2 | 2 | ||
3 | config = edict( | 3 | config = edict( |
4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL | 4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL |
5 | weights='runs/train/exp/weights/best.pt', # model path or triton URL | 5 | weights='runs/train/exp/weights/best.onnx', # model path or triton URL |
6 | data='data/VOC.yaml', # dataset.yaml path | 6 | data='data/VOC.yaml', # dataset.yaml path |
7 | imgsz=(640, 640), # inference size (height, width) | 7 | imgsz=(640, 640), # inference size (height, width) |
8 | conf_thres=0.2, # confidence threshold | 8 | conf_thres=0.2, # confidence threshold | ... | ... |
1 | import time | 1 | import time |
2 | |||
3 | import cv2 | 2 | import cv2 |
4 | |||
5 | from bank_ocr_inference import bill_ocr, extract_bank_info | 3 | from bank_ocr_inference import bill_ocr, extract_bank_info |
6 | from inference import Yolov5 | 4 | from inference import Yolov5 |
7 | from models.yolov5_config import config | 5 | from models.yolov5_config import config |
... | @@ -18,10 +16,9 @@ def enlarge_position(box): | ... | @@ -18,10 +16,9 @@ def enlarge_position(box): |
18 | def tamper_detect(image): | 16 | def tamper_detect(image): |
19 | st = time.time() | 17 | st = time.time() |
20 | ocr_results = bill_ocr(image) | 18 | ocr_results = bill_ocr(image) |
21 | et1=time.time() | 19 | et1 = time.time() |
22 | info_results = extract_bank_info(ocr_results) | 20 | info_results = extract_bank_info(ocr_results) |
23 | et2=time.time() | 21 | et2 = time.time() |
24 | print(info_results) | ||
25 | tamper_results = [] | 22 | tamper_results = [] |
26 | if len(info_results) != 0: | 23 | if len(info_results) != 0: |
27 | for info_result in info_results: | 24 | for info_result in info_results: |
... | @@ -29,21 +26,21 @@ def tamper_detect(image): | ... | @@ -29,21 +26,21 @@ def tamper_detect(image): |
29 | x1, y1, x2, y2 = enlarge_position(box) | 26 | x1, y1, x2, y2 = enlarge_position(box) |
30 | # x1, y1, x2, y2 = box | 27 | # x1, y1, x2, y2 = box |
31 | info_image = image[y1:y2, x1:x2, :] | 28 | info_image = image[y1:y2, x1:x2, :] |
32 | cv2.imshow('info_image',info_image) | 29 | cv2.imshow('info_image', info_image) |
33 | results = detector.detect(info_image) | 30 | results = detector.detect(info_image) |
34 | print(results) | 31 | print(results) |
35 | if len(results['result'])!=0: | 32 | if len(results['result']) != 0: |
36 | for res in results['result']: | 33 | for res in results['result']: |
37 | left = int(res['left']) | 34 | left = int(res['left']) |
38 | top = int(res['top']) | 35 | top = int(res['top']) |
39 | width = int(res['width']) | 36 | width = int(res['width']) |
40 | height = int(res['height']) | 37 | height = int(res['height']) |
41 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | 38 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] |
42 | tamper_results.append(absolute_position) | 39 | tamper_results .append(absolute_position) |
43 | print(tamper_results) | 40 | print(tamper_results) |
44 | et3 = time.time() | 41 | et3 = time.time() |
45 | 42 | ||
46 | print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}') | 43 | print(f'all:{et3 - st} ocr:{et1 - st} extract:{et2 - et1} yolo:{et3 - et2}') |
47 | for i in tamper_results: | 44 | for i in tamper_results: |
48 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | 45 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) |
49 | cv2.imshow('info', image) | 46 | cv2.imshow('info', image) |
... | @@ -53,5 +50,5 @@ def tamper_detect(image): | ... | @@ -53,5 +50,5 @@ def tamper_detect(image): |
53 | if __name__ == '__main__': | 50 | if __name__ == '__main__': |
54 | detector = Yolov5(config) | 51 | detector = Yolov5(config) |
55 | image = cv2.imread( | 52 | image = cv2.imread( |
56 | "/home/situ/下载/_1597378020.731796page_33_img_0.jpg") | 53 | "/data/situ_invoice_bill_data/new_data/qfs_bank_bill_person_ps/gongshang/tampered/images/val/ps3/CH-B006369332_page_67_img_0.jpg") |
57 | tamper_detect(image) | 54 | tamper_detect(image) | ... | ... |
-
Please register or sign in to post a comment