ocr_yolo triton-inference-server
0 parents
Showing
11 changed files
with
308 additions
and
0 deletions
.gitignore
0 → 100644
OCR_Engine @ 3dddc11a
1 | Subproject commit 3dddc11a8a1d369ca4fbd0b69e4e21e6af81cc4c |
README.md
0 → 100644
1 | ## OCR+yolov5 triton-inference-server服务 | ||
2 | |||
3 | 1.使用docker启动triton服务 | ||
4 | |||
5 | sudo docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/situ/qfs/triton_inference_server/demo/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
6 | |||
7 | 2.分别启动OCR和yolov5的web服务 | ||
8 | |||
9 | cd OCR_Engine/api | ||
10 | python ocr_engine_server.py | ||
11 | |||
12 | cd yolov5_onnx_demo/api | ||
13 | python yolov5_onnx_server.py | ||
14 | |||
15 | 3.pipeline测试 | ||
16 | |||
17 | python triton_pipeline.py | ||
18 |
bank_ocr_inference.py
0 → 100644
This diff is collapsed.
Click to expand it.
triton_pipeline.py
0 → 100644
1 | import base64 | ||
2 | import json | ||
3 | from bank_ocr_inference import * | ||
4 | |||
5 | |||
6 | def enlarge_position(box): | ||
7 | x1, y1, x2, y2 = box | ||
8 | w, h = abs(x2 - x1), abs(y2 - y1) | ||
9 | y1, y2 = max(y1 - h // 3, 0), y2 + h // 3 | ||
10 | x1, x2 = max(x1 - w // 8, 0), x2 + w // 8 | ||
11 | return [x1, y1, x2, y2] | ||
12 | |||
13 | |||
14 | def path_base64(file_path): | ||
15 | f = open(file_path, 'rb') | ||
16 | file64 = base64.b64encode(f.read()) # image 64 bytes 类型 | ||
17 | file64 = file64.decode('utf-8') | ||
18 | return file64 | ||
19 | |||
20 | |||
21 | def bgr_base64(image): | ||
22 | _, img64 = cv2.imencode('.jpg', image) | ||
23 | img64 = base64.b64encode(img64) | ||
24 | return img64.decode('utf-8') | ||
25 | |||
26 | |||
27 | def base64_bgr(img64): | ||
28 | str_img64 = base64.b64decode(img64) | ||
29 | image = np.frombuffer(str_img64, np.uint8) | ||
30 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) | ||
31 | return image | ||
32 | |||
33 | |||
34 | def tamper_detect_(image): | ||
35 | img64 = bgr_base64(image) | ||
36 | resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64})) | ||
37 | results = resp.json() | ||
38 | return results | ||
39 | |||
40 | |||
41 | if __name__ == '__main__': | ||
42 | image = cv2.imread( | ||
43 | '/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg') | ||
44 | st = time.time() | ||
45 | ocr_results = bill_ocr(image) | ||
46 | et1 = time.time() | ||
47 | info_results = extract_bank_info(ocr_results) | ||
48 | et2 = time.time() | ||
49 | tamper_results = [] | ||
50 | if len(info_results) != 0: | ||
51 | for info_result in info_results: | ||
52 | box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]] | ||
53 | x1, y1, x2, y2 = enlarge_position(box) | ||
54 | # x1, y1, x2, y2 = box | ||
55 | info_image = image[y1:y2, x1:x2, :] | ||
56 | results = tamper_detect_(info_image) | ||
57 | print(results) | ||
58 | if len(results['results']) != 0: | ||
59 | for res in results['results']: | ||
60 | cx = int(res[0]) | ||
61 | cy = int(res[1]) | ||
62 | width = int(res[2]) | ||
63 | height = int(res[3]) | ||
64 | left = cx - width // 2 | ||
65 | top = cy - height // 2 | ||
66 | absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height] | ||
67 | # absolute_position = [x1+left, y1+top, x2, y2] | ||
68 | tamper_results.append(absolute_position) | ||
69 | et3 = time.time() | ||
70 | print(tamper_results) | ||
71 | |||
72 | print(f'all time:{et3 - st} ocr time:{et1 - st} extract info time:{et2 - et1} yolo time:{et3 - et2}') | ||
73 | for i in tamper_results: | ||
74 | cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2) | ||
75 | cv2.imshow('info', image) | ||
76 | cv2.waitKey(0) |
yolov5_onnx_demo/api/yolov5_onnx_server.py
0 → 100644
1 | import base64 | ||
2 | |||
3 | import cv2 | ||
4 | import numpy as np | ||
5 | from sanic import Sanic | ||
6 | from sanic.response import json | ||
7 | from yolov5_onnx_demo.model.yolov5_infer import * | ||
8 | |||
9 | |||
10 | def base64_to_bgr(bs64): | ||
11 | img_data = base64.b64decode(bs64) | ||
12 | img_arr = np.fromstring(img_data, np.uint8) | ||
13 | img_np = cv2.imdecode(img_arr, cv2.IMREAD_COLOR) | ||
14 | return img_np | ||
15 | |||
16 | |||
17 | app = Sanic('tamper_det') | ||
18 | |||
19 | |||
20 | @app.post('/tamper_det') | ||
21 | def hello(request): | ||
22 | d = request.json | ||
23 | print(d['img']) | ||
24 | img = base64_to_bgr(d['img']) | ||
25 | result = grpc_detect(img) | ||
26 | |||
27 | return json({'results': result}) | ||
28 | |||
29 | |||
30 | if __name__ == '__main__': | ||
31 | app.run(host='192.168.10.11', port=8009,workers=10) |
yolov5_onnx_demo/api_test.py
0 → 100644
1 | import base64 | ||
2 | |||
3 | import requests | ||
4 | import json | ||
5 | from yolov5_onnx_demo.model.yolov5_infer import * | ||
6 | |||
7 | def path_base64(file_path): | ||
8 | f = open(file_path, 'rb') | ||
9 | file64 = base64.b64encode(f.read()) # image 64 bytes 类型 | ||
10 | file64 = file64.decode('utf-8') | ||
11 | return file64 | ||
12 | |||
13 | |||
14 | res = requests.post('http://192.168.10.11:8009/tamper_det', data=json.dumps( | ||
15 | {'img': path_base64('/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')})) | ||
16 | results = res.json() | ||
17 | img = cv2.imread( | ||
18 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg') | ||
19 | print(res) | ||
20 | plot_label(img,results['keys']) |
yolov5_onnx_demo/model/__init__.py
0 → 100644
File mode changed
No preview for this file type
No preview for this file type
yolov5_onnx_demo/model/yolov5_infer.py
0 → 100644
1 | import cv2 | ||
2 | import numpy as np | ||
3 | import tritonclient.grpc as grpcclient | ||
4 | |||
5 | |||
6 | def keep_resize_padding(image): | ||
7 | ''' | ||
8 | 注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行 | ||
9 | 导致输入的尺寸不固定,重写resize方法,添加padding到640*640 | ||
10 | ''' | ||
11 | h, w, c = image.shape | ||
12 | if h >= w: | ||
13 | pad1 = (h - w) // 2 | ||
14 | pad2 = h - w - pad1 | ||
15 | p1 = np.ones((h, pad1, 3)) * 114.0 | ||
16 | p2 = np.ones((h, pad2, 3)) * 114.0 | ||
17 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
18 | new_image = np.hstack((p1, image, p2)) | ||
19 | padding_info = [pad1, pad2, 0] | ||
20 | else: | ||
21 | pad1 = (w - h) // 2 | ||
22 | pad2 = w - h - pad1 | ||
23 | p1 = np.ones((pad1, w, 3)) * 114.0 | ||
24 | p2 = np.ones((pad2, w, 3)) * 114.0 | ||
25 | p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8) | ||
26 | new_image = np.vstack((p1, image, p2)) | ||
27 | padding_info = [pad1, pad2, 1] | ||
28 | new_image = cv2.resize(new_image, (640, 640)) | ||
29 | return new_image, padding_info | ||
30 | |||
31 | |||
32 | # remove padding | ||
33 | def extract_authentic_bboxes(image, padding_info, bboxes): | ||
34 | ''' | ||
35 | 反算坐标信息 | ||
36 | ''' | ||
37 | pad1, pad2, pad_type = padding_info | ||
38 | h, w, c = image.shape | ||
39 | bboxes = np.array(bboxes) | ||
40 | max_slide = max(h, w) | ||
41 | scale = max_slide / 640 | ||
42 | bboxes[:, :4] = bboxes[:, :4] * scale | ||
43 | if pad_type == 0: | ||
44 | bboxes[:, 0] = bboxes[:, 0] - pad1 | ||
45 | else: | ||
46 | bboxes[:, 1] = bboxes[:, 1] - pad1 | ||
47 | return bboxes.tolist() | ||
48 | |||
49 | |||
50 | # NMS | ||
51 | def py_nms_cpu( | ||
52 | prediction, | ||
53 | conf_thres=0.25, | ||
54 | iou_thres=0.45, | ||
55 | ): | ||
56 | """Non-Maximum Suppression (NMS) on inference results to reject overlapping detections | ||
57 | |||
58 | Returns: | ||
59 | list of detections, on (n,6) tensor per image [xyxy, conf, cls] | ||
60 | """ | ||
61 | xc = prediction[..., 4] > conf_thres # candidates | ||
62 | prediction = prediction[xc] | ||
63 | |||
64 | # MNS | ||
65 | x1 = prediction[..., 0] - prediction[..., 2] / 2 | ||
66 | y1 = prediction[..., 1] - prediction[..., 3] / 2 | ||
67 | x2 = prediction[..., 0] + prediction[..., 2] / 2 | ||
68 | y2 = prediction[..., 1] + prediction[..., 3] / 2 | ||
69 | |||
70 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) | ||
71 | score = prediction[..., 5] | ||
72 | order = np.argsort(score) | ||
73 | keep = [] | ||
74 | while order.size > 0: | ||
75 | i = order[0] | ||
76 | keep.append(i) | ||
77 | |||
78 | xx1 = np.maximum(x1[i], x1[order[1:]]) | ||
79 | yy1 = np.maximum(y1[i], y1[order[1:]]) | ||
80 | xx2 = np.minimum(x2[i], x2[order[1:]]) | ||
81 | yy2 = np.minimum(y2[i], y2[order[1:]]) | ||
82 | |||
83 | ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1) | ||
84 | inter = ww * hh | ||
85 | |||
86 | over = inter / (areas[i] + areas[order[1:]] - inter) | ||
87 | |||
88 | idx = np.where(over < iou_thres)[0] | ||
89 | order = order[idx + 1] | ||
90 | |||
91 | return prediction[keep] | ||
92 | |||
93 | |||
94 | def client_init(url='localhost:8001', | ||
95 | ssl=False, | ||
96 | private_key=None, | ||
97 | root_certificates=None, | ||
98 | certificate_chain=None, | ||
99 | verbose=False): | ||
100 | triton_client = grpcclient.InferenceServerClient( | ||
101 | url=url, | ||
102 | verbose=verbose, # 详细输出 默认是False | ||
103 | ssl=ssl, | ||
104 | root_certificates=root_certificates, | ||
105 | private_key=private_key, | ||
106 | certificate_chain=certificate_chain, | ||
107 | ) | ||
108 | return triton_client | ||
109 | |||
110 | |||
111 | triton_client = client_init('localhost:8001') | ||
112 | compression_algorithm = None | ||
113 | input_name = 'images' | ||
114 | output_name = 'output0' | ||
115 | model_name = 'yolov5' | ||
116 | |||
117 | |||
118 | def grpc_detect(img): | ||
119 | image, padding_info = keep_resize_padding(img) | ||
120 | image = image.transpose((2, 0, 1))[::-1] | ||
121 | image = image.astype(np.float32) | ||
122 | image = image / 255.0 | ||
123 | if len(image.shape) == 3: | ||
124 | image = image[None] | ||
125 | |||
126 | outputs, inputs = [], [] | ||
127 | |||
128 | # 动态输入 | ||
129 | input_shape = image.shape | ||
130 | inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32')) | ||
131 | outputs.append(grpcclient.InferRequestedOutput(output_name)) | ||
132 | |||
133 | inputs[0].set_data_from_numpy(image.astype(np.float32)) | ||
134 | |||
135 | pred = triton_client.infer( | ||
136 | model_name=model_name, | ||
137 | inputs=inputs, outputs=outputs, | ||
138 | compression_algorithm=compression_algorithm | ||
139 | ) | ||
140 | pred = pred.as_numpy(output_name).copy() | ||
141 | result_bboxes = py_nms_cpu(pred) | ||
142 | result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes) | ||
143 | return result_bboxes | ||
144 | |||
145 | |||
146 | def plot_label(img, result_bboxes): | ||
147 | print(result_bboxes) | ||
148 | for bbox in result_bboxes: | ||
149 | x, y, w, h, conf, cls = bbox | ||
150 | cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2) | ||
151 | cv2.imshow('im', img) | ||
152 | cv2.waitKey(0) | ||
153 | |||
154 | |||
155 | if __name__ == '__main__': | ||
156 | img = cv2.imread( | ||
157 | '/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg') | ||
158 | |||
159 | result_bboxes = grpc_detect(img) | ||
160 | plot_label(result_bboxes) |
-
Please register or sign in to post a comment