update readme
Showing
5 changed files
with
67 additions
and
23 deletions
ReadMe.md
0 → 100644
| ... | @@ -18,15 +18,18 @@ def path_to_file(file_path): | ... | @@ -18,15 +18,18 @@ def path_to_file(file_path): |
| 18 | return f | 18 | return f |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | # 流水OCR接口 | ||
| 21 | def bill_ocr(image): | 22 | def bill_ocr(image): |
| 22 | f = image_to_base64(image) | 23 | f = image_to_base64(image) |
| 23 | resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f}) | 24 | resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f}) |
| 24 | results = resp.json() | 25 | results = resp.json() |
| 25 | ocr_results = results['ocr_results'] | 26 | ocr_results = results['ocr_results'] |
| 26 | return ocr_results | 27 | return ocr_results |
| 27 | 28 | ||
| 28 | 29 | ||
| 30 | # 提取民生银行信息 | ||
| 29 | def extract_minsheng_info(ocr_results): | 31 | def extract_minsheng_info(ocr_results): |
| 32 | |||
| 30 | name_prefix = '客户姓名:' | 33 | name_prefix = '客户姓名:' |
| 31 | account_prefix = '客户账号:' | 34 | account_prefix = '客户账号:' |
| 32 | results = [] | 35 | results = [] |
| ... | @@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results): | ... | @@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results): |
| 71 | results.append([value[1], value[0]]) | 74 | results.append([value[1], value[0]]) |
| 72 | return results | 75 | return results |
| 73 | 76 | ||
| 74 | 77 | # 提取工商银行信息 | |
| 75 | def extract_gongshang_info(ocr_results): | 78 | def extract_gongshang_info(ocr_results): |
| 76 | name_prefix = '户名:' | 79 | name_prefix = '户名:' |
| 77 | account_prefix = '卡号:' | 80 | account_prefix = '卡号:' |
| ... | @@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results): | ... | @@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results): |
| 117 | results.append([value[1], value[0]]) | 120 | results.append([value[1], value[0]]) |
| 118 | return results | 121 | return results |
| 119 | 122 | ||
| 120 | 123 | # 提取中国银行信息 | |
| 121 | def extract_zhongguo_info(ocr_results): | 124 | def extract_zhongguo_info(ocr_results): |
| 122 | name_prefix = '客户姓名:' | 125 | name_prefix = '客户姓名:' |
| 123 | account_prefix = '借记卡号:' | 126 | account_prefix = '借记卡号:' |
| ... | @@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results): | ... | @@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results): |
| 163 | results.append([value[1], value[0]]) | 166 | results.append([value[1], value[0]]) |
| 164 | return results | 167 | return results |
| 165 | 168 | ||
| 166 | 169 | # 提取建设银行信息 | |
| 167 | def extract_jianshe_info(ocr_results): | 170 | def extract_jianshe_info(ocr_results): |
| 168 | name_prefixes = ['客户名称:', '户名:'] | 171 | name_prefixes = ['客户名称:', '户名:'] |
| 169 | account_prefixes = ['卡号/账号:', '卡号:'] | 172 | account_prefixes = ['卡号/账号:', '卡号:'] |
| ... | @@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results): | ... | @@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results): |
| 215 | break | 218 | break |
| 216 | return results | 219 | return results |
| 217 | 220 | ||
| 218 | 221 | # 提取农业银行信息(比较复杂,目前训练的版式都支持) | |
| 219 | def extract_nongye_info(ocr_results): | 222 | def extract_nongye_info(ocr_results): |
| 220 | name_prefixes = ['客户名:', '户名:'] | 223 | name_prefixes = ['客户名:', '户名:'] |
| 221 | account_prefixes = ['账号:'] | 224 | account_prefixes = ['账号:'] |
| ... | @@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results): | ... | @@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results): |
| 318 | break | 321 | break |
| 319 | return results | 322 | return results |
| 320 | 323 | ||
| 321 | 324 | # 提取银行流水信息总接口 | |
| 322 | def extract_bank_info(ocr_results): | 325 | def extract_bank_info(ocr_results): |
| 323 | results = [] | 326 | results = [] |
| 324 | for value in ocr_results.values(): | 327 | for value in ocr_results.values(): |
| ... | @@ -341,6 +344,7 @@ def extract_bank_info(ocr_results): | ... | @@ -341,6 +344,7 @@ def extract_bank_info(ocr_results): |
| 341 | 344 | ||
| 342 | 345 | ||
| 343 | if __name__ == '__main__': | 346 | if __name__ == '__main__': |
| 347 | |||
| 344 | path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' | 348 | path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' |
| 345 | save_path='/data/situ_invoice_bill_data/new_data/results' | 349 | save_path='/data/situ_invoice_bill_data/new_data/results' |
| 346 | bank='minsheng' | 350 | bank='minsheng' |
| ... | @@ -362,3 +366,4 @@ if __name__ == '__main__': | ... | @@ -362,3 +366,4 @@ if __name__ == '__main__': |
| 362 | # cv2.waitKey(0) | 366 | # cv2.waitKey(0) |
| 363 | cv2.imwrite(os.path.join(save_path,j),img) | 367 | cv2.imwrite(os.path.join(save_path,j),img) |
| 364 | print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) | 368 | print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1)) |
| 369 | # | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file | ... | ... |
| ... | @@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False): | ... | @@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False): |
| 54 | return result | 54 | return result |
| 55 | 55 | ||
| 56 | 56 | ||
| 57 | def keep_resize_padding(image): | ||
| 58 | h, w, c = image.shape | ||
| 59 | if h >= w: | ||
| 60 | pad1 = (h - w) // 2 | ||
| 61 | pad2 = h - w - pad1 | ||
| 62 | p1 = np.ones((h, pad1, 3)) * 114.0 | ||
| 63 | p2 = np.ones((h, pad2, 3)) * 114.0 | ||
| 64 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
| 65 | new_image = np.hstack((p1, image, p2)) | ||
| 66 | else: | ||
| 67 | pad1 = (w - h) // 2 | ||
| 68 | pad2 = w - h - pad1 | ||
| 69 | p1 = np.ones((pad1, w, 3)) * 114.0 | ||
| 70 | p2 = np.ones((pad2, w, 3)) * 114.0 | ||
| 71 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
| 72 | new_image = np.vstack((p1, image, p2)) | ||
| 73 | new_image = cv2.resize(new_image, (640, 640)) | ||
| 74 | return new_image | ||
| 75 | |||
| 76 | |||
| 57 | class Yolov5: | 77 | class Yolov5: |
| 58 | def __init__(self, cfg=None): | 78 | def __init__(self, cfg=None): |
| 59 | self.cfg = cfg | 79 | self.cfg = cfg |
| ... | @@ -66,7 +86,16 @@ class Yolov5: | ... | @@ -66,7 +86,16 @@ class Yolov5: |
| 66 | imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size | 86 | imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size |
| 67 | # Dataloader | 87 | # Dataloader |
| 68 | bs = 1 # batch_size | 88 | bs = 1 # batch_size |
| 69 | im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize | 89 | # im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize |
| 90 | # hh, ww, cc = im.shape | ||
| 91 | # tlen1 = (640 - hh) // 2 | ||
| 92 | # tlen2 = 640 - hh - tlen1 | ||
| 93 | # t1 = np.zeros((tlen1, ww, cc)) | ||
| 94 | # t2 = np.zeros((tlen2, ww, cc)) | ||
| 95 | # im = np.vstack((t1, im, t2)) | ||
| 96 | im = keep_resize_padding(image) | ||
| 97 | |||
| 98 | # print(im.shape) | ||
| 70 | im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | 99 | im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB |
| 71 | im = np.ascontiguousarray(im) # contiguous | 100 | im = np.ascontiguousarray(im) # contiguous |
| 72 | # Run inference | 101 | # Run inference |
| ... | @@ -74,13 +103,15 @@ class Yolov5: | ... | @@ -74,13 +103,15 @@ class Yolov5: |
| 74 | im = torch.from_numpy(im).to(self.model.device) | 103 | im = torch.from_numpy(im).to(self.model.device) |
| 75 | im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 | 104 | im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32 |
| 76 | im /= 255 # 0 - 255 to 0.0 - 1.0 | 105 | im /= 255 # 0 - 255 to 0.0 - 1.0 |
| 106 | |||
| 77 | if len(im.shape) == 3: | 107 | if len(im.shape) == 3: |
| 78 | im = im[None] # expand for batch dim | 108 | im = im[None] # expand for batch dim |
| 79 | # Inference | 109 | # Inference |
| 80 | pred = self.model(im, augment=False, visualize=False) | 110 | pred = self.model(im, augment=False, visualize=False) |
| 111 | # print(pred[0].shape) | ||
| 112 | # exit(0) | ||
| 81 | # NMS | 113 | # NMS |
| 82 | pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det) | 114 | pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det) |
| 83 | |||
| 84 | det = pred[0] | 115 | det = pred[0] |
| 85 | # if len(det): | 116 | # if len(det): |
| 86 | det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round() | 117 | det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round() |
| ... | @@ -95,13 +126,14 @@ class Yolov5: | ... | @@ -95,13 +126,14 @@ class Yolov5: |
| 95 | 126 | ||
| 96 | if __name__ == "__main__": | 127 | if __name__ == "__main__": |
| 97 | img = cv2.imread( | 128 | img = cv2.imread( |
| 98 | '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png') | 129 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg') |
| 99 | detector = Yolov5(config) | 130 | detector = Yolov5(config) |
| 100 | result = detector.detect(img) | 131 | result = detector.detect(img) |
| 101 | for i in result['result']: | 132 | for i in result['result']: |
| 102 | position=list(i.values())[2:] | 133 | position = list(i.values())[2:] |
| 103 | print(position) | 134 | print(position) |
| 104 | cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255)) | 135 | cv2.rectangle(img, (position[0], position[1]), (position[0] + position[2], position[1] + position[3]), |
| 105 | cv2.imshow('w',img) | 136 | (0, 0, 255)) |
| 137 | cv2.imshow('w', img) | ||
| 106 | cv2.waitKey(0) | 138 | cv2.waitKey(0) |
| 107 | print(result) | 139 | print(result) | ... | ... |
| ... | @@ -2,7 +2,7 @@ from easydict import EasyDict as edict | ... | @@ -2,7 +2,7 @@ from easydict import EasyDict as edict |
| 2 | 2 | ||
| 3 | config = edict( | 3 | config = edict( |
| 4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL | 4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL |
| 5 | weights='runs/train/exp/weights/best.pt', # model path or triton URL | 5 | weights='runs/train/exp/weights/best.onnx', # model path or triton URL |
| 6 | data='data/VOC.yaml', # dataset.yaml path | 6 | data='data/VOC.yaml', # dataset.yaml path |
| 7 | imgsz=(640, 640), # inference size (height, width) | 7 | imgsz=(640, 640), # inference size (height, width) |
| 8 | conf_thres=0.2, # confidence threshold | 8 | conf_thres=0.2, # confidence threshold | ... | ... |
| 1 | import time | 1 | import time |
| 2 | |||
| 3 | import cv2 | 2 | import cv2 |
| 4 | |||
| 5 | from bank_ocr_inference import bill_ocr, extract_bank_info | 3 | from bank_ocr_inference import bill_ocr, extract_bank_info |
| 6 | from inference import Yolov5 | 4 | from inference import Yolov5 |
| 7 | from models.yolov5_config import config | 5 | from models.yolov5_config import config |
| ... | @@ -18,10 +16,9 @@ def enlarge_position(box): | ... | @@ -18,10 +16,9 @@ def enlarge_position(box): |
| 18 | def tamper_detect(image): | 16 | def tamper_detect(image): |
| 19 | st = time.time() | 17 | st = time.time() |
| 20 | ocr_results = bill_ocr(image) | 18 | ocr_results = bill_ocr(image) |
| 21 | et1=time.time() | 19 | et1 = time.time() |
| 22 | info_results = extract_bank_info(ocr_results) | 20 | info_results = extract_bank_info(ocr_results) |
| 23 | et2=time.time() | 21 | et2 = time.time() |
| 24 | print(info_results) | ||
| 25 | tamper_results = [] | 22 | tamper_results = [] |
| 26 | if len(info_results) != 0: | 23 | if len(info_results) != 0: |
| 27 | for info_result in info_results: | 24 | for info_result in info_results: |
| ... | @@ -29,21 +26,21 @@ def tamper_detect(image): | ... | @@ -29,21 +26,21 @@ def tamper_detect(image): |
| 29 | x1, y1, x2, y2 = enlarge_position(box) | 26 | x1, y1, x2, y2 = enlarge_position(box) |
| 30 | # x1, y1, x2, y2 = box | 27 | # x1, y1, x2, y2 = box |
| 31 | info_image = image[y1:y2, x1:x2, :] | 28 | info_image = image[y1:y2, x1:x2, :] |
| 32 | cv2.imshow('info_image',info_image) | 29 | cv2.imshow('info_image', info_image) |
| 33 | results = detector.detect(info_image) | 30 | results = detector.detect(info_image) |
| 34 | print(results) | 31 | print(results) |
| 35 | if len(results['result'])!=0: | 32 | if len(results['result']) != 0: |
| 36 | for res in results['result']: | 33 | for res in results['result']: |
| 37 | left = int(res['left']) | 34 | left = int(res['left']) |
| 38 | top = int(res['top']) | 35 | top = int(res['top']) |
| 39 | width = int(res['width']) | 36 | width = int(res['width']) |
| 40 | height = int(res['height']) | 37 | height = int(res['height']) |
| 41 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | 38 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] |
| 42 | tamper_results.append(absolute_position) | 39 | tamper_results .append(absolute_position) |
| 43 | print(tamper_results) | 40 | print(tamper_results) |
| 44 | et3 = time.time() | 41 | et3 = time.time() |
| 45 | 42 | ||
| 46 | print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}') | 43 | print(f'all:{et3 - st} ocr:{et1 - st} extract:{et2 - et1} yolo:{et3 - et2}') |
| 47 | for i in tamper_results: | 44 | for i in tamper_results: |
| 48 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | 45 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) |
| 49 | cv2.imshow('info', image) | 46 | cv2.imshow('info', image) |
| ... | @@ -53,5 +50,5 @@ def tamper_detect(image): | ... | @@ -53,5 +50,5 @@ def tamper_detect(image): |
| 53 | if __name__ == '__main__': | 50 | if __name__ == '__main__': |
| 54 | detector = Yolov5(config) | 51 | detector = Yolov5(config) |
| 55 | image = cv2.imread( | 52 | image = cv2.imread( |
| 56 | "/home/situ/下载/_1597378020.731796page_33_img_0.jpg") | 53 | "/data/situ_invoice_bill_data/new_data/qfs_bank_bill_person_ps/gongshang/tampered/images/val/ps3/CH-B006369332_page_67_img_0.jpg") |
| 57 | tamper_detect(image) | 54 | tamper_detect(image) | ... | ... |
-
Please register or sign in to post a comment