add pipeline inference
Showing
6 changed files
with
90 additions
and
26 deletions
bank_ocr_inference.py
0 → 100644
This diff is collapsed.
Click to expand it.
... | @@ -576,8 +576,8 @@ def run( | ... | @@ -576,8 +576,8 @@ def run( |
576 | 576 | ||
577 | def parse_opt(): | 577 | def parse_opt(): |
578 | parser = argparse.ArgumentParser() | 578 | parser = argparse.ArgumentParser() |
579 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') | 579 | parser.add_argument('--data', type=str, default=ROOT / 'data/VOC.yaml', help='dataset.yaml path') |
580 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') | 580 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/exp/weights/best.pt', help='model.pt path(s)') |
581 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') | 581 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') |
582 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') | 582 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') |
583 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | 583 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | ... | ... |
... | @@ -95,7 +95,13 @@ class Yolov5: | ... | @@ -95,7 +95,13 @@ class Yolov5: |
95 | 95 | ||
96 | if __name__ == "__main__": | 96 | if __name__ == "__main__": |
97 | img = cv2.imread( | 97 | img = cv2.imread( |
98 | '/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/data/images/crop_img/_1594890230.8032346page_10_img_0_hname.jpg') | 98 | '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png') |
99 | detector = Yolov5(config) | 99 | detector = Yolov5(config) |
100 | result = detector.detect(img) | 100 | result = detector.detect(img) |
101 | for i in result['result']: | ||
102 | position=list(i.values())[2:] | ||
103 | print(position) | ||
104 | cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255)) | ||
105 | cv2.imshow('w',img) | ||
106 | cv2.waitKey(0) | ||
101 | print(result) | 107 | print(result) | ... | ... |
1 | from easydict import EasyDict as edict | 1 | from easydict import EasyDict as edict |
2 | 2 | ||
3 | config = edict( | 3 | config = edict( |
4 | # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL | ||
4 | weights='runs/train/exp/weights/best.pt', # model path or triton URL | 5 | weights='runs/train/exp/weights/best.pt', # model path or triton URL |
5 | data='data/VOC.yaml', # dataset.yaml path | 6 | data='data/VOC.yaml', # dataset.yaml path |
6 | imgsz=(640, 640), # inference size (height, width) | 7 | imgsz=(640, 640), # inference size (height, width) |
7 | conf_thres=0.5, # confidence threshold | 8 | conf_thres=0.2, # confidence threshold |
8 | iou_thres=0.45, # NMS IOU threshold | 9 | iou_thres=0.45, # NMS IOU threshold |
9 | max_det=1000, # maximum detections per image | 10 | max_det=1000, # maximum detections per image |
10 | device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu | 11 | device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu | ... | ... |
1 | import time | ||
2 | |||
3 | import cv2 | ||
4 | |||
5 | from bank_ocr_inference import bill_ocr, extract_bank_info | ||
6 | from inference import Yolov5 | ||
7 | from models.yolov5_config import config | ||
8 | |||
9 | |||
10 | def enlarge_position(box): | ||
11 | x1, y1, x2, y2 = box | ||
12 | w, h = abs(x2 - x1), abs(y2 - y1) | ||
13 | y1, y2 = max(y1 - h // 3, 0), y2 + h // 3 | ||
14 | x1, x2 = max(x1 - w // 8, 0), x2 + w // 8 | ||
15 | return [x1, y1, x2, y2] | ||
16 | |||
17 | |||
18 | def tamper_detect(image): | ||
19 | st = time.time() | ||
20 | ocr_results = bill_ocr(image) | ||
21 | et1=time.time() | ||
22 | info_results = extract_bank_info(ocr_results) | ||
23 | et2=time.time() | ||
24 | print(info_results) | ||
25 | tamper_results = [] | ||
26 | if len(info_results) != 0: | ||
27 | for info_result in info_results: | ||
28 | box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]] | ||
29 | x1, y1, x2, y2 = enlarge_position(box) | ||
30 | # x1, y1, x2, y2 = box | ||
31 | info_image = image[y1:y2, x1:x2, :] | ||
32 | cv2.imshow('info_image',info_image) | ||
33 | results = detector.detect(info_image) | ||
34 | print(results) | ||
35 | if len(results['result'])!=0: | ||
36 | for res in results['result']: | ||
37 | left = int(res['left']) | ||
38 | top = int(res['top']) | ||
39 | width = int(res['width']) | ||
40 | height = int(res['height']) | ||
41 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | ||
42 | tamper_results.append(absolute_position) | ||
43 | print(tamper_results) | ||
44 | et3 = time.time() | ||
45 | |||
46 | print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}') | ||
47 | for i in tamper_results: | ||
48 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | ||
49 | cv2.imshow('info', image) | ||
50 | cv2.waitKey(0) | ||
51 | |||
52 | |||
53 | if __name__ == '__main__': | ||
54 | detector = Yolov5(config) | ||
55 | image = cv2.imread( | ||
56 | "/home/situ/下载/_1597378020.731796page_33_img_0.jpg") | ||
57 | tamper_detect(image) | ... | ... |
... | @@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions): | ... | @@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions): |
10 | result = [] | 10 | result = [] |
11 | x1, y1, x2, y2 = crop_position | 11 | x1, y1, x2, y2 = crop_position |
12 | for p in predict_positions: | 12 | for p in predict_positions: |
13 | px1, py1, px2, py2,score = p | 13 | px1, py1, px2, py2, score = p |
14 | w, h = px2 - px1, py2 - py1 | 14 | w, h = px2 - px1, py2 - py1 |
15 | result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h,score]) | 15 | result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h, score]) |
16 | return result | 16 | return result |
17 | 17 | ||
18 | 18 | ||
... | @@ -22,9 +22,9 @@ def decode_label(image, label_path): | ... | @@ -22,9 +22,9 @@ def decode_label(image, label_path): |
22 | result = [] | 22 | result = [] |
23 | for d in data: | 23 | for d in data: |
24 | d = [float(i) for i in d.strip().split(' ')] | 24 | d = [float(i) for i in d.strip().split(' ')] |
25 | cls, cx, cy, cw, ch,score = d | 25 | cls, cx, cy, cw, ch, score = d |
26 | cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h | 26 | cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h |
27 | result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2),score]) | 27 | result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2), score]) |
28 | return result | 28 | return result |
29 | 29 | ||
30 | 30 | ||
... | @@ -38,28 +38,28 @@ if __name__ == '__main__': | ... | @@ -38,28 +38,28 @@ if __name__ == '__main__': |
38 | data = pd.read_csv(crop_csv_path) | 38 | data = pd.read_csv(crop_csv_path) |
39 | img_name = data.loc[:, 'img_name'].tolist() | 39 | img_name = data.loc[:, 'img_name'].tolist() |
40 | crop_position1 = data.loc[:, 'name_crop_coord'].tolist() | 40 | crop_position1 = data.loc[:, 'name_crop_coord'].tolist() |
41 | crop_position2 = data.loc[:,'number_crop_coord'].tolist() | 41 | crop_position2 = data.loc[:, 'number_crop_coord'].tolist() |
42 | cc='/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3' | 42 | cc = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3' |
43 | for im in os.listdir(cc): | 43 | for im in os.listdir(cc): |
44 | print(im) | 44 | print(im) |
45 | img = cv2.imread(os.path.join(cc,im)) | 45 | img = cv2.imread(os.path.join(cc, im)) |
46 | img_=img.copy() | 46 | img_ = img.copy() |
47 | id = img_name.index(im) | 47 | id = img_name.index(im) |
48 | name_crop_position=[int(i) for i in crop_position1[id].split(',')] | 48 | name_crop_position = [int(i) for i in crop_position1[id].split(',')] |
49 | number_crop_position=[int(i) for i in crop_position2[id].split(',')] | 49 | number_crop_position = [int(i) for i in crop_position2[id].split(',')] |
50 | nx1,ny1,nx2,ny2=name_crop_position | 50 | nx1, ny1, nx2, ny2 = name_crop_position |
51 | nux1,nuy1,nux2,nuy2=number_crop_position | 51 | nux1, nuy1, nux2, nuy2 = number_crop_position |
52 | if im[:-4]+'_hname.txt' in predict_labels: | 52 | if im[:-4] + '_hname.txt' in predict_labels: |
53 | 53 | ||
54 | h, w, c = img[ny1:ny2, nx1:nx2, :].shape | 54 | h, w, c = img[ny1:ny2, nx1:nx2, :].shape |
55 | data = open(os.path.join(predict_label_path,im[:-4]+'_hname.txt')).readlines() | 55 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() |
56 | for d in data: | 56 | for d in data: |
57 | cls,cx,cy,cw,ch,score = [float(i) for i in d.strip().split(' ')] | 57 | cls, cx, cy, cw, ch, score = [float(i) for i in d.strip().split(' ')] |
58 | cx,cy,cw,ch=int(cx*w),int(cy*h),int(cw*w),int(ch*h) | 58 | cx, cy, cw, ch = int(cx * w), int(cy * h), int(cw * w), int(ch * h) |
59 | cx1,cy1=cx-cw//2,cy-ch//2 | 59 | cx1, cy1 = cx - cw // 2, cy - ch // 2 |
60 | x1,y1,x2,y2=nx1+cx1,ny1+cy1,nx1+cx1+cw,ny1+cy1+ch | 60 | x1, y1, x2, y2 = nx1 + cx1, ny1 + cy1, nx1 + cx1 + cw, ny1 + cy1 + ch |
61 | cv2.rectangle(img,(x1,y1),(x2,y2),(0,0,255),2) | 61 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) |
62 | cv2.putText(img,f'tampered:{score}',(x1,y1-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),1) | 62 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) |
63 | if im[:-4] + '_hnumber.txt' in predict_labels: | 63 | if im[:-4] + '_hnumber.txt' in predict_labels: |
64 | h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape | 64 | h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape |
65 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() | 65 | data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() |
... | @@ -70,5 +70,5 @@ if __name__ == '__main__': | ... | @@ -70,5 +70,5 @@ if __name__ == '__main__': |
70 | x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch | 70 | x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch |
71 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) | 71 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) |
72 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) | 72 | cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) |
73 | result = np.vstack((img_,img)) | 73 | result = np.vstack((img_, img)) |
74 | cv2.imwrite(f'z/{im}',result) | 74 | cv2.imwrite(f'z/{im}', result) | ... | ... |
-
Please register or sign in to post a comment