update
Showing
59 changed files
with
2434 additions
and
1 deletions
ocr_engine/README.md
0 → 100644
1 | # turnsole | ||
2 | A series of convenience functions make your machine learning project easier | ||
3 | |||
4 | ## 安装方法 | ||
5 | |||
6 | ### Latest release | ||
7 | `pip install turnsole` | ||
8 | > 项目暂不开源,因此该安装方法暂时不保证能用 | ||
9 | |||
10 | ### Developer mode | ||
11 | |||
12 | `pip install -e .` | ||
13 | |||
14 | ## 快速上手 | ||
15 | ### PDF 操作 | ||
16 | #### 智能 PDF 文件转图片 | ||
17 | 智能的把 PDF 文件里面的插图找出来,例如没有插图就将整页 PDF 截图下来,也能智能的将碎图拼接在一起 | ||
18 | |||
19 | ##### Example: | ||
20 | <pre># pdf_path 表示 PDF 文件的路径,输出 images 按页码进行汇总输出 | ||
21 | images = turnsole.pdf_to_images(pdf_path)</pre> | ||
22 | |||
23 | ### 图像操作工具箱 | ||
24 | #### base64_to_bgr / bgr_to_base64 | ||
25 | 图像和 base64 互相转换 | ||
26 | |||
27 | ##### Example: | ||
28 | <pre>image = turnsole.base64_to_bgr(img64) | ||
29 | img64 = turnsole.bgr_to_base64(image)</pre> | ||
30 | |||
31 | ### image_crop | ||
32 | 根据 bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
33 | |||
34 | ##### Example: | ||
35 | <pre>im_slice_no_perspective = turnsole.image_crop(image, bbox) | ||
36 | im_slice = turnsole.image_crop(image, bbox, perspective=True)</pre> | ||
37 | |||
38 | ##### Output: | ||
39 | |||
40 | <img src="docs/images/image_crop.png?raw=true" alt="image crop example" style="max-width: 200px;"> | ||
41 | |||
42 | ### OCR 引擎模块 | ||
43 | OCR 引擎指的是一系列跟 OCR 相关的底层模型,我们提供了这些模型的函数式调用接口和标准 API | ||
44 | |||
45 | - [x] ADC :tada: | ||
46 | - [x] DBNet :tada: | ||
47 | - [x] CRNN :tada: | ||
48 | - [x] Object Detector :tada: | ||
49 | - [x] Signature Detector :tada: | ||
50 | |||
51 | #### 免费试用 | ||
52 | ```python | ||
53 | import requests | ||
54 | |||
55 | results = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': open(file_path, 'rb')}).json() | ||
56 | ocr_results = results['ocr_results'] | ||
57 | ``` | ||
58 | |||
59 | #### Prerequisites | ||
60 | 由于 OCR 引擎模块依赖于底层神经网络模型,因此需要先用 Docker 挂载底层神经网络模型 | ||
61 | |||
62 | 首先把 ./model_repository 文件夹和里面的模型放到项目根目录下再启动,如果没有相关模型找 [lvkui](lvkui@situdata.com) 要 | ||
63 | |||
64 | 使用起来非常简单,你只需要启动对应的 Docker 容器即可 | ||
65 | |||
66 | ```bash | ||
67 | docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
68 | ``` | ||
69 | |||
70 | #### ADC | ||
71 | 通用文件摆正算法 | ||
72 | |||
73 | ``` | ||
74 | from turnsole.ocr_engine import angle_detector | ||
75 | |||
76 | image_rotated, direction = angle_detector.ADC(image, fine_degree=False) | ||
77 | ``` | ||
78 | |||
79 | #### DBNet | ||
80 | 通用文字检测算法 | ||
81 | |||
82 | ``` | ||
83 | from turnsole.ocr_engine import text_detector | ||
84 | |||
85 | boxes = text_detector.predict(image) | ||
86 | ``` | ||
87 | |||
88 | #### CRNN | ||
89 | 通用文字识别算法 | ||
90 | |||
91 | ``` | ||
92 | from turnsole.ocr_engine import text_recognizer | ||
93 | |||
94 | ocr_result, ocr_time = text_recognizer.predict_batch(image, boxes) | ||
95 | ``` | ||
96 | |||
97 | #### Object Detector | ||
98 | 通用文件检测算法 | ||
99 | |||
100 | ``` | ||
101 | from turnsole.ocr_engine import object_detector | ||
102 | |||
103 | object_list = object_detector.process(image) | ||
104 | ``` | ||
105 | |||
106 | #### Signature Detector | ||
107 | 签字盖章二维码检测算法 | ||
108 | |||
109 | ``` | ||
110 | from turnsole.ocr_engine import signature_detector | ||
111 | |||
112 | signature_list = signature_detector.process(image) | ||
113 | ``` | ||
114 | |||
115 | #### 标准 API | ||
116 | ``` | ||
117 | python api/ocr_engine_server.py | ||
118 | ``` | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/api/nohup.out
0 → 100644
1 | [2022-10-21 14:12:17 +0800] [8546] [INFO] Goin' Fast @ http://192.168.10.11:9001 | ||
2 | [2022-10-21 14:12:17 +0800] [8567] [INFO] Starting worker [8567] | ||
3 | [2022-10-21 14:12:17 +0800] [8568] [INFO] Starting worker [8568] | ||
4 | [2022-10-21 14:12:17 +0800] [8569] [INFO] Starting worker [8569] | ||
5 | [2022-10-21 14:12:17 +0800] [8570] [INFO] Starting worker [8570] | ||
6 | [2022-10-21 14:12:17 +0800] [8571] [INFO] Starting worker [8571] | ||
7 | [2022-10-21 14:12:17 +0800] [8572] [INFO] Starting worker [8572] | ||
8 | [2022-10-21 14:12:17 +0800] [8573] [INFO] Starting worker [8573] | ||
9 | [2022-10-21 14:12:17 +0800] [8576] [INFO] Starting worker [8576] | ||
10 | [2022-10-21 14:12:17 +0800] [8574] [INFO] Starting worker [8574] | ||
11 | [2022-10-21 14:12:17 +0800] [8575] [INFO] Starting worker [8575] | ||
12 | [2022-10-21 14:13:51 +0800] [8575] [ERROR] Exception occurred while handling uri: 'http://192.168.10.11:9001/gen_ocr' | ||
13 | Traceback (most recent call last): | ||
14 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/sanic/app.py", line 944, in handle_request | ||
15 | response = await response | ||
16 | File "ocr_engine_server.py", line 37, in ocr_engine | ||
17 | boxes = text_detector.predict(image) | ||
18 | File "/home/situ/qfs/invoice_tamper/09_project/project/bank_bill_ocr/OCR_Engine/turnsole/ocr_engine/DBNet/text_detector.py", line 113, in predict | ||
19 | outputs=outputs | ||
20 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/tritonclient/grpc/__init__.py", line 1431, in infer | ||
21 | raise_error_grpc(rpc_error) | ||
22 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/tritonclient/grpc/__init__.py", line 62, in raise_error_grpc | ||
23 | raise get_error_grpc(rpc_error) from None | ||
24 | tritonclient.utils.InferenceServerException: [StatusCode.UNAVAILABLE] Request for unknown model: 'dbnet_model' is not found | ||
25 | [2022-10-21 14:13:51 +0800] - (sanic.access)[INFO][192.168.10.11:57260]: POST http://192.168.10.11:9001/gen_ocr 500 735 |
ocr_engine/api/ocr_engine_server.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-06-05 20:49:51 | ||
5 | # @Last Modified : 2022-08-19 17:24:55 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | |||
10 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | ||
11 | |||
12 | from sanic import Sanic | ||
13 | from sanic.response import json | ||
14 | |||
15 | from turnsole.ocr_engine import angle_detector | ||
16 | from turnsole.ocr_engine import text_detector | ||
17 | from turnsole.ocr_engine import text_recognizer | ||
18 | from turnsole.ocr_engine import object_detector | ||
19 | from turnsole.ocr_engine import signature_detector | ||
20 | |||
21 | from turnsole import bytes_to_bgr | ||
22 | |||
23 | app = Sanic("OCR_ENGINE") | ||
24 | app.config.REQUEST_MAX_SIZE = 1000000000 # 请求的大小(字节)/ 1GB | ||
25 | app.config.REQUEST_BUFFER_QUEUE_SIZE = 1000 # 请求流缓冲区队列大小 | ||
26 | app.config.REQUEST_TIMEOUT = 600 # 请求到达需要多长时间(秒) | ||
27 | app.config.RESPONSE_TIMEOUT = 600 # 处理响应需要多长时间(秒) | ||
28 | |||
29 | |||
30 | @app.post('/gen_ocr') | ||
31 | async def ocr_engine(request): | ||
32 | # request.files.get() 具有 type/body/name 三个属性 | ||
33 | file = request.files.get('file').body | ||
34 | # 将 bytes 转成 bgr 图片 | ||
35 | image = bytes_to_bgr(file) | ||
36 | # 文字检测 | ||
37 | boxes = text_detector.predict(image) | ||
38 | # 文字识别 | ||
39 | res, _ = text_recognizer.predict_batch(image[..., ::-1], boxes) | ||
40 | resp = {} | ||
41 | resp["ocr_results"] = res | ||
42 | return json(resp) | ||
43 | |||
44 | |||
45 | @app.post('/gen_ocr_with_rotation', ) | ||
46 | async def ocr_engine_with_rotation(request): | ||
47 | # request.files.get() 具有 type/body/name 三个属性 | ||
48 | file = request.files.get('file').body | ||
49 | # 将 bytes 转成 bgr 图片 | ||
50 | image = bytes_to_bgr(file) | ||
51 | # 方向检测 | ||
52 | image, direction = angle_detector.ADC(image.copy(), fine_degree=False) | ||
53 | # 文字检测 | ||
54 | boxes = text_detector.predict(image) | ||
55 | # 文字识别 | ||
56 | res, _ = text_recognizer.predict_batch(image[..., ::-1], boxes) | ||
57 | |||
58 | resp = {} | ||
59 | resp["ocr_results"] = res | ||
60 | resp["direction"] = direction | ||
61 | return json(resp) | ||
62 | |||
63 | |||
64 | @app.post("/object_detect") | ||
65 | async def object_detect(request): | ||
66 | # request.files.get() 具有 type/body/name 三个属性 | ||
67 | file = request.files.get('file').body | ||
68 | # 将 bytes 转成 bgr 图片 | ||
69 | image = bytes_to_bgr(file) | ||
70 | # 通用文件检测 | ||
71 | object_list = object_detector.process(image) | ||
72 | return json(object_list) | ||
73 | |||
74 | |||
75 | @app.post("/signature_detect") | ||
76 | async def signature_detect(request): | ||
77 | # request.files.get() 具有 type/body/name 三个属性 | ||
78 | file = request.files.get('file').body | ||
79 | # 将 bytes 转成 bgr 图片 | ||
80 | image = bytes_to_bgr(file) | ||
81 | # 签字盖章二维码条形码检测 | ||
82 | signature_list = signature_detector.process(image) | ||
83 | return json(signature_list) | ||
84 | |||
85 | |||
86 | if __name__ == "__main__": | ||
87 | # app.run(host="0.0.0.0", port=9001) | ||
88 | app.run(host="192.168.10.11", port=9002, workers=10) | ||
89 | # uvicorn server:app --port 9001 --workers 10 |
ocr_engine/demos/images/sunflower.bmp
0 → 100644
No preview for this file type
ocr_engine/demos/images/sunflower.gif
0 → 100644

9.68 KB
ocr_engine/demos/images/sunflower.jpg
0 → 100644

405 KB
ocr_engine/demos/images/sunflower.png
0 → 100644

461 KB
ocr_engine/demos/images/sunflower.tif
0 → 100644
No preview for this file type
ocr_engine/demos/img_ocr/001.jpg
0 → 100644

1.62 MB
ocr_engine/demos/img_ocr/002.jpg
0 → 100644

97.7 KB
ocr_engine/demos/img_ocr/003.jpg
0 → 100644

112 KB
ocr_engine/demos/img_ocr/004.jpg
0 → 100644

24.4 KB
ocr_engine/demos/img_ocr/005.jpg
0 → 100644

77.5 KB
ocr_engine/demos/read_frames_fast.py
0 → 100644
1 | # Modified from: | ||
2 | # https://www.pyimagesearch.com/2017/02/06/faster-video-file-fps-with-cv2-videocapture-and-opencv/ | ||
3 | |||
4 | # Performance: | ||
5 | # Python 2.7: 105.78 --> 131.75 | ||
6 | # Python 3.7: 15.36 --> 50.13 | ||
7 | |||
8 | # USAGE | ||
9 | # python read_frames_fast.py --video videos/jurassic_park_intro.mp4 | ||
10 | |||
11 | # import the necessary packages | ||
12 | from turnsole.video import FileVideoStream | ||
13 | from turnsole.video import FPS | ||
14 | import numpy as np | ||
15 | import argparse | ||
16 | import imutils | ||
17 | import time | ||
18 | import cv2 | ||
19 | |||
20 | def filterFrame(frame): | ||
21 | frame = imutils.resize(frame, width=450) | ||
22 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | ||
23 | frame = np.dstack([frame, frame, frame]) | ||
24 | return frame | ||
25 | |||
26 | # construct the argument parse and parse the arguments | ||
27 | ap = argparse.ArgumentParser() | ||
28 | ap.add_argument("-v", "--video", required=True, | ||
29 | help="path to input video file") | ||
30 | args = vars(ap.parse_args()) | ||
31 | |||
32 | # start the file video stream thread and allow the buffer to | ||
33 | # start to fill | ||
34 | print("[INFO] starting video file thread...") | ||
35 | fvs = FileVideoStream(args["video"], transform=filterFrame).start() | ||
36 | time.sleep(1.0) | ||
37 | |||
38 | # start the FPS timer | ||
39 | fps = FPS().start() | ||
40 | |||
41 | # loop over frames from the video file stream | ||
42 | while fvs.running(): | ||
43 | # grab the frame from the threaded video file stream, resize | ||
44 | # it, and convert it to grayscale (while still retaining 3 | ||
45 | # channels) | ||
46 | frame = fvs.read() | ||
47 | |||
48 | # Relocated filtering into producer thread with transform=filterFrame | ||
49 | # Python 2.7: FPS 92.11 -> 131.36 | ||
50 | # Python 3.7: FPS 41.44 -> 50.11 | ||
51 | #frame = filterFrame(frame) | ||
52 | |||
53 | # display the size of the queue on the frame | ||
54 | cv2.putText(frame, "Queue Size: {}".format(fvs.Q.qsize()), | ||
55 | (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | ||
56 | |||
57 | # show the frame and update the FPS counter | ||
58 | cv2.imshow("Frame", frame) | ||
59 | |||
60 | cv2.waitKey(1) | ||
61 | if fvs.Q.qsize() < 2: # If we are low on frames, give time to producer | ||
62 | time.sleep(0.001) # Ensures producer runs now, so 2 is sufficient | ||
63 | fps.update() | ||
64 | |||
65 | # stop the timer and display FPS information | ||
66 | fps.stop() | ||
67 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) | ||
68 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) | ||
69 | |||
70 | # do a bit of cleanup | ||
71 | cv2.destroyAllWindows() | ||
72 | fvs.stop() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_convenience.py
0 → 100644
ocr_engine/demos/test_model.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Created Date : 2021-03-05 16:51:22 | ||
5 | # @Last Modified : 2021-03-05 18:15:53 | ||
6 | # @Description : | ||
7 | |||
8 | from turnsole.model import EasyDet | ||
9 | |||
10 | if __name__ == '__main__': | ||
11 | model = EasyDet(phi=0) | ||
12 | model.summary() | ||
13 | |||
14 | import time | ||
15 | import numpy as np | ||
16 | |||
17 | x = np.random.random_sample((1, 640, 640, 3)) | ||
18 | # warm up | ||
19 | output = model.predict(x) | ||
20 | |||
21 | print('\n[INFO] Test start') | ||
22 | time_start = time.time() | ||
23 | for i in range(1000): | ||
24 | output = model.predict(x) | ||
25 | |||
26 | time_end = time.time() | ||
27 | print('[INFO] Time used: {:.2f} ms'.format((time_end - time_start)*1000/(i+1))) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_ocr_function.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-07-22 13:10:47 | ||
5 | # @Last Modified : 2022-09-08 19:03:24 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | ||
10 | |||
11 | import cv2 | ||
12 | # from turnsole.ocr_engine import angle_detector | ||
13 | from turnsole.ocr_engine import object_detector | ||
14 | import matplotlib.pyplot as plt | ||
15 | |||
16 | |||
17 | if __name__ == "__main__": | ||
18 | |||
19 | base_dir = '/home/lk/MyProject/BMW/数据集/文件分类/身份证' | ||
20 | |||
21 | for (rootDir, dirNames, filenames) in os.walk(base_dir): | ||
22 | |||
23 | for filename in filenames: | ||
24 | |||
25 | if not filename.endswith('.jpg'): | ||
26 | continue | ||
27 | |||
28 | img_path = os.path.join(rootDir, filename) | ||
29 | print(img_path) | ||
30 | |||
31 | image = cv2.imread(img_path) | ||
32 | |||
33 | results = object_detector.process(image) | ||
34 | |||
35 | print(results) | ||
36 | |||
37 | for item in results: | ||
38 | xmin = item['location']['xmin'] | ||
39 | ymin = item['location']['ymin'] | ||
40 | xmax = item['location']['xmax'] | ||
41 | ymax = item['location']['ymax'] | ||
42 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) | ||
43 | |||
44 | plt.imshow(image[...,::-1]) | ||
45 | plt.show() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_pdf_tools.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-07-22 13:10:47 | ||
5 | # @Last Modified : 2022-08-24 15:39:55 | ||
6 | # @Description : | ||
7 | |||
8 | |||
9 | import os | ||
10 | import cv2 | ||
11 | import fitz | ||
12 | from turnsole import pdf_to_images # pip install turnsole PyMuPDF opencv-python==4.4.0.44 | ||
13 | |||
14 | if __name__ == "__main__": | ||
15 | |||
16 | base_dir = '/PATH/TO/YOUR/WORKDIR' | ||
17 | |||
18 | for (rootDir, dirNames, filenames) in os.walk(base_dir): | ||
19 | |||
20 | for filename in filenames: | ||
21 | |||
22 | if not filename.endswith('.pdf'): | ||
23 | continue | ||
24 | |||
25 | pdf_path = os.path.join(rootDir, filename) | ||
26 | print(pdf_path) | ||
27 | |||
28 | images = pdf_to_images(pdf_path) | ||
29 | images = sum(images, []) | ||
30 | |||
31 | image_dir = os.path.join(rootDir, filename.replace('.pdf', '')) | ||
32 | if not os.path.exists(image_dir): | ||
33 | os.makedirs(image_dir) | ||
34 | |||
35 | for index, image in enumerate(images): | ||
36 | |||
37 | save_path = os.path.join(image_dir, filename.replace('.pdf', '')+'-'+str(index)+'.jpg') | ||
38 | cv2.imwrite(save_path, image) |
ocr_engine/docs/images/image_crop.png
0 → 100644

193 KB
ocr_engine/scripts/api_test.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-05-06 22:02:01 | ||
5 | # @Last Modified : 2022-08-03 14:59:51 | ||
6 | # @Description : | ||
7 | |||
8 | |||
9 | import os | ||
10 | import time | ||
11 | import random | ||
12 | import requests | ||
13 | import numpy as np | ||
14 | from threading import Thread | ||
15 | |||
16 | |||
17 | class API_test: | ||
18 | def __init__(self, file_dir, test_time, num_request): | ||
19 | |||
20 | self.file_paths = [] | ||
21 | for fn in os.listdir(file_dir): | ||
22 | file_path = os.path.join(file_dir, fn) | ||
23 | self.file_paths.append(file_path) | ||
24 | |||
25 | self.time_start = time.time() | ||
26 | self.test_time = test_time * 60 # 单位:秒 | ||
27 | threads = [] | ||
28 | for i in range(num_request): | ||
29 | t = Thread(target=self.update, args=()) | ||
30 | threads.append(t) | ||
31 | for t in threads: | ||
32 | print(f'[INFO] {t} is running') | ||
33 | t.start() | ||
34 | self.results = list() | ||
35 | self.index = 0 | ||
36 | |||
37 | def update(self): | ||
38 | while True: | ||
39 | file_path = random.choice(self.file_paths) | ||
40 | |||
41 | # 二进制方式打开图片文件 | ||
42 | data = open(file_path, 'rb') | ||
43 | |||
44 | t0 = time.time() | ||
45 | response = requests.post(url=r'http://localhost:9001/gen_ocr_with_rotation', files={'file': data}) | ||
46 | |||
47 | # 失败请求统计 | ||
48 | if response.status_code != 200: | ||
49 | print(response) | ||
50 | |||
51 | t1 = time.time() | ||
52 | self.results.append((t1-t0)) | ||
53 | |||
54 | time_cost = (time.time() - self.time_start) | ||
55 | time_remaining = self.test_time - time_cost | ||
56 | |||
57 | self.index += 1 | ||
58 | |||
59 | if time_remaining > 0: | ||
60 | print(f'\r[INFO] 剩余时间 {time_remaining} 秒, 平均响应时间 {np.mean(self.results)} 秒, TPS {len(self.results)/time_cost}, 吞吐量 {self.index}', end=' ', flush=True) | ||
61 | else: | ||
62 | break | ||
63 | |||
64 | |||
65 | if __name__ == '__main__': | ||
66 | |||
67 | imageDir = './demos/img_ocr' # 测试数据路径 | ||
68 | testTime = 10 # 加压时间, 单位:分钟 | ||
69 | numRequest = 10 # 并发数,单位:个 | ||
70 | |||
71 | API_test(imageDir, testTime, numRequest) |
ocr_engine/setup.cfg
0 → 100644
1 | [metadata] | ||
2 | name = turnsole | ||
3 | version = 0.0.27 | ||
4 | author = Kui Lyu | ||
5 | author_email = 9428.al@gmail.com | ||
6 | description = A series of convenience functions make your machine learning project easier | ||
7 | long_description = file: README.md | ||
8 | long_description_content_type = text/markdown | ||
9 | url = https://github.com/Antonio-hi/turnsole | ||
10 | project_urls = | ||
11 | Bug Tracker = https://github.com/Antonio-hi/turnsole/issues | ||
12 | classifiers = | ||
13 | Programming Language :: Python :: 3 | ||
14 | License :: OSI Approved :: MIT License | ||
15 | Operating System :: OS Independent | ||
16 | |||
17 | [options] | ||
18 | packages = find: | ||
19 | python_requires = >=3.6 | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/setup.py
0 → 100644
ocr_engine/turnsole.egg-info/PKG-INFO
0 → 100644
1 | Metadata-Version: 2.1 | ||
2 | Name: turnsole | ||
3 | Version: 0.0.27 | ||
4 | Summary: A series of convenience functions make your machine learning project easier | ||
5 | Home-page: https://github.com/Antonio-hi/turnsole | ||
6 | Author: Kui Lyu | ||
7 | Author-email: 9428.al@gmail.com | ||
8 | License: UNKNOWN | ||
9 | Project-URL: Bug Tracker, https://github.com/Antonio-hi/turnsole/issues | ||
10 | Platform: UNKNOWN | ||
11 | Classifier: Programming Language :: Python :: 3 | ||
12 | Classifier: License :: OSI Approved :: MIT License | ||
13 | Classifier: Operating System :: OS Independent | ||
14 | Requires-Python: >=3.6 | ||
15 | Description-Content-Type: text/markdown | ||
16 | License-File: LICENSE | ||
17 | |||
18 | # turnsole | ||
19 | A series of convenience functions make your machine learning project easier | ||
20 | |||
21 | ## 安装方法 | ||
22 | |||
23 | ### Latest release | ||
24 | `pip install turnsole` | ||
25 | > 项目暂不开源,因此该安装方法暂时不保证能用 | ||
26 | |||
27 | ### Developer mode | ||
28 | |||
29 | `pip install -e .` | ||
30 | |||
31 | ## 快速上手 | ||
32 | ### PDF 操作 | ||
33 | #### 智能 PDF 文件转图片 | ||
34 | 智能的把 PDF 文件里面的插图找出来,例如没有插图就将整页 PDF 截图下来,也能智能的将碎图拼接在一起 | ||
35 | |||
36 | ##### Example: | ||
37 | <pre># pdf_path 表示 PDF 文件的路径,输出 images 按页码进行汇总输出 | ||
38 | images = turnsole.pdf_to_images(pdf_path)</pre> | ||
39 | |||
40 | ### 图像操作工具箱 | ||
41 | #### base64_to_bgr / bgr_to_base64 | ||
42 | 图像和 base64 互相转换 | ||
43 | |||
44 | ##### Example: | ||
45 | <pre>image = turnsole.base64_to_bgr(img64) | ||
46 | img64 = turnsole.bgr_to_base64(image)</pre> | ||
47 | |||
48 | ### image_crop | ||
49 | 根据 bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
50 | |||
51 | ##### Example: | ||
52 | <pre>im_slice_no_perspective = turnsole.image_crop(image, bbox) | ||
53 | im_slice = turnsole.image_crop(image, bbox, perspective=True)</pre> | ||
54 | |||
55 | ##### Output: | ||
56 | |||
57 | <img src="docs/images/image_crop.png?raw=true" alt="image crop example" style="max-width: 200px;"> | ||
58 | |||
59 | ### OCR 引擎模块 | ||
60 | OCR 引擎指的是一系列跟 OCR 相关的底层模型,我们提供了这些模型的函数式调用接口和标准 API | ||
61 | |||
62 | - [x] ADC :tada: | ||
63 | - [x] DBNet :tada: | ||
64 | - [x] CRNN :tada: | ||
65 | - [x] Object Detector :tada: | ||
66 | - [x] Signature Detector :tada: | ||
67 | |||
68 | #### 免费试用 | ||
69 | ```python | ||
70 | import requests | ||
71 | |||
72 | results = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': open(file_path, 'rb')}).json() | ||
73 | ocr_results = results['ocr_results'] | ||
74 | ``` | ||
75 | |||
76 | #### Prerequisites | ||
77 | 由于 OCR 引擎模块依赖于底层神经网络模型,因此需要先用 Docker 挂载底层神经网络模型 | ||
78 | |||
79 | 首先把 ./model_repository 文件夹和里面的模型放到项目根目录下再启动,如果没有相关模型找 [lvkui](lvkui@situdata.com) 要 | ||
80 | |||
81 | 使用起来非常简单,你只需要启动对应的 Docker 容器即可 | ||
82 | |||
83 | ```bash | ||
84 | docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
85 | ``` | ||
86 | |||
87 | #### ADC | ||
88 | 通用文件摆正算法 | ||
89 | |||
90 | ``` | ||
91 | from turnsole.ocr_engine import angle_detector | ||
92 | |||
93 | image_rotated, direction = angle_detector.ADC(image, fine_degree=False) | ||
94 | ``` | ||
95 | |||
96 | #### DBNet | ||
97 | 通用文字检测算法 | ||
98 | |||
99 | ``` | ||
100 | from turnsole.ocr_engine import text_detector | ||
101 | |||
102 | boxes = text_detector.predict(image) | ||
103 | ``` | ||
104 | |||
105 | #### CRNN | ||
106 | 通用文字识别算法 | ||
107 | |||
108 | ``` | ||
109 | from turnsole.ocr_engine import text_recognizer | ||
110 | |||
111 | ocr_result, ocr_time = text_recognizer.predict_batch(image, boxes) | ||
112 | ``` | ||
113 | |||
114 | #### Object Detector | ||
115 | 通用文件检测算法 | ||
116 | |||
117 | ``` | ||
118 | from turnsole.ocr_engine import object_detector | ||
119 | |||
120 | object_list = object_detector.process(image) | ||
121 | ``` | ||
122 | |||
123 | #### Signature Detector | ||
124 | 签字盖章二维码检测算法 | ||
125 | |||
126 | ``` | ||
127 | from turnsole.ocr_engine import signature_detector | ||
128 | |||
129 | signature_list = signature_detector.process(image) | ||
130 | ``` | ||
131 | |||
132 | #### 标准 API | ||
133 | ``` | ||
134 | python api/ocr_engine_server.py | ||
135 | ``` | ||
136 |
ocr_engine/turnsole.egg-info/SOURCES.txt
0 → 100644
1 | LICENSE | ||
2 | README.md | ||
3 | setup.cfg | ||
4 | setup.py | ||
5 | turnsole/__init__.py | ||
6 | turnsole/convenience.py | ||
7 | turnsole/encodings.py | ||
8 | turnsole/model.py | ||
9 | turnsole/paths.py | ||
10 | turnsole/pdf_tools.py | ||
11 | turnsole.egg-info/PKG-INFO | ||
12 | turnsole.egg-info/SOURCES.txt | ||
13 | turnsole.egg-info/dependency_links.txt | ||
14 | turnsole.egg-info/top_level.txt | ||
15 | turnsole/face_utils/__init__.py | ||
16 | turnsole/face_utils/agedetector.py | ||
17 | turnsole/face_utils/facedetector.py | ||
18 | turnsole/nets/__init__.py | ||
19 | turnsole/nets/efficientnet.py | ||
20 | turnsole/ocr_engine/__init__.py | ||
21 | turnsole/ocr_engine/ADC/__init__.py | ||
22 | turnsole/ocr_engine/ADC/angle_detector.py | ||
23 | turnsole/ocr_engine/CRNN/__init__.py | ||
24 | turnsole/ocr_engine/CRNN/alphabets.py | ||
25 | turnsole/ocr_engine/CRNN/text_rec.py | ||
26 | turnsole/ocr_engine/DBNet/__init__.py | ||
27 | turnsole/ocr_engine/DBNet/text_detector.py | ||
28 | turnsole/ocr_engine/object_det/__init__.py | ||
29 | turnsole/ocr_engine/object_det/utils.py | ||
30 | turnsole/ocr_engine/signature_det/__init__.py | ||
31 | turnsole/ocr_engine/signature_det/utils.py | ||
32 | turnsole/ocr_engine/utils/__init__.py | ||
33 | turnsole/ocr_engine/utils/read_data.py | ||
34 | turnsole/video/__init__.py | ||
35 | turnsole/video/count_frames.py | ||
36 | turnsole/video/filevideostream.py | ||
37 | turnsole/video/fps.py | ||
38 | turnsole/video/pivideostream.py | ||
39 | turnsole/video/videostream.py | ||
40 | turnsole/video/webcamvideostream.py | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole.egg-info/top_level.txt
0 → 100644
1 | turnsole |
ocr_engine/turnsole/__init__.py
0 → 100644
1 | try: | ||
2 | from . import ocr_engine | ||
3 | except: | ||
4 | # print('[INFO] OCR engine can not import successful') | ||
5 | pass | ||
6 | from .convenience import resize | ||
7 | from .convenience import resize_with_pad | ||
8 | from .convenience import image_crop | ||
9 | from .encodings import bytes_to_bgr | ||
10 | from .encodings import base64_to_image | ||
11 | from .encodings import base64_encode_file | ||
12 | from .encodings import base64_encode_image | ||
13 | from .encodings import base64_decode_image | ||
14 | from .encodings import base64_to_bgr | ||
15 | from .encodings import bgr_to_base64 | ||
16 | from .pdf_tools import pdf_to_images | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/convenience.py
0 → 100644
1 | import cv2 | ||
2 | import numpy as np | ||
3 | |||
4 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): | ||
5 | # initialize the dimensions of the image to be resized and grab the image size | ||
6 | dim = None | ||
7 | (h, w) = image.shape[:2] | ||
8 | |||
9 | # if both the width and height are None, then return the original image | ||
10 | if width is None and height is None: | ||
11 | return image | ||
12 | |||
13 | # check to see if the width is None | ||
14 | if width is None: | ||
15 | # calculate the ratio of the height and construct the dimensions | ||
16 | r = height / float(h) | ||
17 | dim = (int(w * r), height) | ||
18 | |||
19 | # otherwise, the height is None | ||
20 | else: | ||
21 | # calculate the ratio of the width and construct the dimensions | ||
22 | r = width / float(w) | ||
23 | dim = (width, int(h * r)) | ||
24 | |||
25 | # resize the image | ||
26 | resized = cv2.resize(image, dim, interpolation=inter) | ||
27 | |||
28 | # return the resized image | ||
29 | return resized | ||
30 | |||
31 | def resize_with_pad(image, target_width, target_height): | ||
32 | """Resuzes and pads an image to a target width and height. | ||
33 | |||
34 | Resizes an image to a target width and height by keeping the aspect ratio the same | ||
35 | without distortion. | ||
36 | ratio must be less than 1.0. | ||
37 | width and height will pad with zeroes. | ||
38 | |||
39 | Args: | ||
40 | image (Array): RGB/BGR | ||
41 | target_width (Int): Target width. | ||
42 | target_height (Int): Target height. | ||
43 | |||
44 | Returns: | ||
45 | Array: Resized and padded image. The image paded with zeroes. | ||
46 | Float: Image resized ratio. The ratio must be less than 1.0. | ||
47 | """ | ||
48 | height, width, _ = image.shape | ||
49 | |||
50 | min_ratio = min(target_height/height, target_width/width) | ||
51 | ratio = min_ratio if min_ratio < 1.0 else 1.0 | ||
52 | |||
53 | # To shrink an image, it will generally look best with INTER_AREA interpolation. | ||
54 | resized = cv2.resize(image, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_AREA) | ||
55 | h, w, _ = resized.shape | ||
56 | canvas = np.zeros((target_height, target_width, 3), image.dtype) | ||
57 | canvas[:h, :w, :] = resized | ||
58 | return canvas, ratio | ||
59 | |||
60 | def image_crop(image, bbox, perspective=False): | ||
61 | """根据 Bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
62 | |||
63 | Args: | ||
64 | image (array): 三通道图片,切片结果保持原图颜色通道 | ||
65 | bbox (array/list): 支持两点矩形框和四点旋转矩形框 | ||
66 | 支持以下两种格式: | ||
67 | 1. bbox = [xmin, ymin, xmax, ymax] | ||
68 | 2. bbox = [x0, y0, x1, y1, x2, y2, x3, y3] | ||
69 | perspective (bool, optional): 是否切出旋转目标. Defaults to False. | ||
70 | |||
71 | Returns: | ||
72 | array: 小切图,和原图颜色通道一致 | ||
73 | """ | ||
74 | # 按照 bbox 的正外接矩形切图 | ||
75 | bbox = np.array(bbox, dtype=np.int32).reshape((-1, 2)) | ||
76 | xmin, ymin, xmax, ymax = [min(bbox[:, 0]), | ||
77 | min(bbox[:, 1]), | ||
78 | max(bbox[:, 0]), | ||
79 | max(bbox[:, 1])] | ||
80 | xmin, ymin = max(0, xmin), max(0, ymin) | ||
81 | im_slice = image[ymin:ymax, xmin:xmax, :] | ||
82 | |||
83 | if perspective and bbox.shape[0] == 4: | ||
84 | # 获得旋转矩形的宽和高 | ||
85 | w, h = [int(np.linalg.norm(bbox[0] - bbox[1])), | ||
86 | int(np.linalg.norm(bbox[3] - bbox[0]))] | ||
87 | # 把 bbox 平移到正切图的对应位置上 | ||
88 | bbox[:, 0] -= xmin | ||
89 | bbox[:, 1] -= ymin | ||
90 | # 执行透视切图 | ||
91 | pts1 = np.float32(bbox) | ||
92 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
93 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
94 | im_slice = cv2.warpPerspective(im_slice, M, (w, h)) | ||
95 | |||
96 | return im_slice |
ocr_engine/turnsole/encodings.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Antonio-hi | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2021-08-09 19:08:49 | ||
5 | # @Last Modified : 2021-08-10 10:11:06 | ||
6 | # @Description : | ||
7 | |||
8 | # import the necessary packages | ||
9 | import numpy as np | ||
10 | import base64 | ||
11 | import json | ||
12 | import sys | ||
13 | import cv2 | ||
14 | import os | ||
15 | |||
16 | def base64_encode_image(a): | ||
17 | # return a JSON-encoded list of the base64 encoded image, image data type, and image shape | ||
18 | # return json.dumps([base64_encode_array(a), str(a.dtype), a.shape]) | ||
19 | return json.dumps([base64_encode_array(a).decode("utf-8"), str(a.dtype), | ||
20 | a.shape]) | ||
21 | |||
22 | def base64_decode_image(a): | ||
23 | # grab the array, data type, and shape from the JSON-decoded object | ||
24 | (a, dtype, shape) = json.loads(a) | ||
25 | |||
26 | # set the correct data type and reshape the matrix into an image | ||
27 | a = base64_decode_array(a, dtype).reshape(shape) | ||
28 | |||
29 | # return the loaded image | ||
30 | return a | ||
31 | |||
32 | def base64_encode_array(a): | ||
33 | # return the base64 encoded array | ||
34 | return base64.b64encode(a) | ||
35 | |||
36 | def base64_decode_array(a, dtype): | ||
37 | # decode and return the array | ||
38 | return np.frombuffer(base64.b64decode(a), dtype=dtype) | ||
39 | |||
40 | def base64_encode_file(image_path): | ||
41 | filename = os.path.basename(image_path) | ||
42 | # encode image file to base64 string | ||
43 | with open(image_path, 'rb') as f: | ||
44 | buffer = f.read() | ||
45 | # convert bytes buffer string then encode to base64 string | ||
46 | img64_bytes = base64.b64encode(buffer) | ||
47 | img64_str = img64_bytes.decode('utf-8') # bytes to str | ||
48 | return json.dumps({"filename" : filename, "img64": img64_str}) | ||
49 | |||
50 | def base64_to_image(img64): | ||
51 | image_buffer = base64_decode_array(img64, dtype=np.uint8) | ||
52 | # In the case of color images, the decoded images will have the channels stored in B G R order. | ||
53 | image = cv2.imdecode(image_buffer, cv2.IMREAD_COLOR) | ||
54 | return image | ||
55 | |||
56 | def bytes_to_bgr(buffer: bytes): | ||
57 | """Read a byte stream as a OpenCV image | ||
58 | |||
59 | Args: | ||
60 | buffer (TYPE): bytes of a decoded image | ||
61 | """ | ||
62 | img_array = np.frombuffer(buffer, np.uint8) | ||
63 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
64 | return image | ||
65 | |||
66 | def base64_to_bgr(img64): | ||
67 | """把 base64 转换成图片 | ||
68 | 单通道的灰度图或四通道的透明图都将自动转换成三通道的 BGR 图 | ||
69 | |||
70 | Args: | ||
71 | img64 (TYPE): Description | ||
72 | |||
73 | Returns: | ||
74 | TYPE: image is a 3-D uint8 Tensor of shape [height, width, channels] where channels is BGR | ||
75 | """ | ||
76 | encoded_image = base64.b64decode(img64) | ||
77 | img_array = np.frombuffer(encoded_image, np.uint8) | ||
78 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
79 | return image | ||
80 | |||
81 | def bgr_to_base64(image): | ||
82 | """ 把图片转换成 base64 格式,过程中把图片以 JPEG 格式进行了压缩,通常这会导致图像质量变差 | ||
83 | |||
84 | Args: | ||
85 | image (TYPE): image is a 3-D uint8 or uint16 Tensor of shape [height, width, channels] where channels is BGR | ||
86 | |||
87 | Returns: | ||
88 | TYPE: base64 格式的图片 | ||
89 | """ | ||
90 | retval, encoded_image = cv2.imencode('.jpg', image) # Encodes an image(BGR) into a memory buffer. | ||
91 | img64 = base64.b64encode(encoded_image) | ||
92 | return img64.decode('utf-8') | ||
93 | |||
94 | |||
95 | if __name__ == '__main__': | ||
96 | |||
97 | image_path = '/home/lk/Repository/Project/turnsole/demos/images/sunflower.jpg' | ||
98 | |||
99 | # 1)将图片文件转换成 base64 base64编码的字符串(理论上支持任意文件) | ||
100 | json_str = base64_encode_file(image_path) | ||
101 | |||
102 | img64_dict = json.loads(json_str) | ||
103 | |||
104 | suffix = os.path.splitext(img64_dict['filename'])[-1].lower() | ||
105 | if suffix not in ['.jpg', '.jpeg', '.png', '.bmp']: | ||
106 | print(f'[INFO] 暂不支持格式为 {suffix} 的文件!') | ||
107 | |||
108 | # 2)将 base64 编码的字符串转成图片 | ||
109 | image = base64_to_image(img64_dict['img64']) | ||
110 | |||
111 | inputs = image/255. | ||
112 | |||
113 | # 3)自创的, 将 array 转 base64 编码再转回array, 中间不经历图片操作, 还能保持 array 的数据类型 | ||
114 | base64_encode_json_string = base64_encode_image(inputs) | ||
115 | |||
116 | inputs = base64_decode_image(base64_encode_json_string) | ||
117 | |||
118 | print(inputs) | ||
119 | |||
120 | # 3、字符串前加 b | ||
121 | # 例: response = b'<h1>Hello World!</h1>' # b' ' 表示这是一个 bytes 对象 | ||
122 | |||
123 | # 作用: | ||
124 | |||
125 | # b" "前缀表示:后面字符串是bytes 类型。 | ||
126 | |||
127 | # 用处: | ||
128 | |||
129 | # 网络编程中,服务器和浏览器只认bytes 类型数据。 | ||
130 | |||
131 | # 如:send 函数的参数和 recv 函数的返回值都是 bytes 类型 | ||
132 | |||
133 | # 附: | ||
134 | |||
135 | # 在 Python3 中,bytes 和 str 的互相转换方式是 | ||
136 | # str.encode('utf-8') | ||
137 | # bytes.decode('utf-8') |
ocr_engine/turnsole/face_utils/__init__.py
0 → 100644
File mode changed
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : lk | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2021-08-11 17:10:16 | ||
5 | # @Last Modified : 2021-08-12 16:14:53 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | import tensorflow as tf | ||
10 | |||
11 | class AgeDetector: | ||
12 | def __init__(self, model_path): | ||
13 | self.age_map = { | ||
14 | 0: '0-2', | ||
15 | 1: '4-6', | ||
16 | 2: '8-13', | ||
17 | 3: '15-20', | ||
18 | 4: '25-32', | ||
19 | 5: '38-43', | ||
20 | 6: '48-53', | ||
21 | 7: '60+' | ||
22 | } | ||
23 | |||
24 | self.model = tf.keras.models.load_model(filepath=model_path, | ||
25 | compile=False) | ||
26 | self.inference_model = self.build_inference_model() | ||
27 | |||
28 | def build_inference_model(self): | ||
29 | image = self.model.input | ||
30 | x = tf.keras.applications.mobilenet_v2.preprocess_input(image) | ||
31 | predictions = self.model(x, training=False) | ||
32 | inference_model = tf.keras.Model(inputs=image, outputs=predictions) | ||
33 | return inference_model | ||
34 | |||
35 | def predict_batch(self, images): | ||
36 | # 输入一个人脸图片列表,列表不应为空 | ||
37 | images = tf.stack([tf.image.resize(image, [96, 96]) for image in images], axis=0) | ||
38 | preds = self.inference_model.predict(images) | ||
39 | indexes = tf.argmax(preds, axis=-1) | ||
40 | classes = [self.age_map[index.numpy()] for index in indexes] | ||
41 | return classes | ||
42 | |||
43 | if __name__ == '__main__': | ||
44 | |||
45 | import cv2 | ||
46 | from turnsole import paths | ||
47 | |||
48 | age_det = AGE_DETECTION(model_path='./ckpt/age_detector.h5') | ||
49 | |||
50 | data_dir = '/home/lk/Project/Face_Age_Gender/data/Emotion/emotion/010003_female_yellow_22' | ||
51 | |||
52 | for image_path in paths.list_images(data_dir): | ||
53 | image = cv2.imread(image_path) | ||
54 | classes = age_det.predict_batch([image]) | ||
55 | |||
56 | print(classes) | ||
57 |
This diff is collapsed.
Click to expand it.
ocr_engine/turnsole/model.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Created Date : 2021-02-24 13:58:46 | ||
5 | # @Last Modified : 2021-03-05 18:14:17 | ||
6 | # @Description : | ||
7 | |||
8 | import tensorflow as tf | ||
9 | |||
10 | from .nets.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3 | ||
11 | from .nets.efficientnet import EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 | ||
12 | |||
13 | def load_backbone(phi, input_tensor, weights='imagenet'): | ||
14 | if phi == 0: | ||
15 | model = EfficientNetB0(include_top=False, | ||
16 | weights=weights, | ||
17 | input_tensor=input_tensor) | ||
18 | # 从这些层提取特征 | ||
19 | layer_names = [ | ||
20 | 'block2b_add', # 1/4 | ||
21 | 'block3b_add', # 1/8 | ||
22 | 'block5c_add', # 1/16 | ||
23 | 'block7a_project_bn', # 1/32 | ||
24 | ] | ||
25 | elif phi == 1: | ||
26 | model = EfficientNetB1(include_top=False, | ||
27 | weights=weights, | ||
28 | input_tensor=input_tensor) | ||
29 | layer_names = [ | ||
30 | 'block2c_add', # 1/4 | ||
31 | 'block3c_add', # 1/8 | ||
32 | 'block5d_add', # 1/16 | ||
33 | 'block7b_add', # 1/32 | ||
34 | ] | ||
35 | elif phi == 2: | ||
36 | model = EfficientNetB2(include_top=False, | ||
37 | weights=weights, | ||
38 | input_tensor=input_tensor) | ||
39 | layer_names = [ | ||
40 | 'block2c_add', # 1/4 | ||
41 | 'block3c_add', # 1/8 | ||
42 | 'block5d_add', # 1/16 | ||
43 | 'block7b_add', # 1/32 | ||
44 | ] | ||
45 | elif phi == 3: | ||
46 | model = EfficientNetB3(include_top=False, | ||
47 | weights=weights, | ||
48 | input_tensor=input_tensor) | ||
49 | layer_names = [ | ||
50 | 'block2c_add', # 1/4 | ||
51 | 'block3c_add', # 1/8 | ||
52 | 'block5e_add', # 1/16 | ||
53 | 'block7b_add', # 1/32 | ||
54 | ] | ||
55 | elif phi == 4: | ||
56 | model = EfficientNetB4(include_top=False, | ||
57 | weights=weights, | ||
58 | input_tensor=input_tensor) | ||
59 | layer_names = [ | ||
60 | 'block2c_add', # 1/4 | ||
61 | 'block3d_add', # 1/8 | ||
62 | 'block5f_add', # 1/16 | ||
63 | 'block7b_add', # 1/32 | ||
64 | ] | ||
65 | elif phi == 5: | ||
66 | model = EfficientNetB5(include_top=False, | ||
67 | weights=weights, | ||
68 | input_tensor=input_tensor) | ||
69 | layer_names = [ | ||
70 | 'block2e_add', # 1/4 | ||
71 | 'block3e_add', # 1/8 | ||
72 | 'block5g_add', # 1/16 | ||
73 | 'block7c_add', # 1/32 | ||
74 | ] | ||
75 | elif phi == 6: | ||
76 | model = EfficientNetB6(include_top=False, | ||
77 | weights=weights, | ||
78 | input_tensor=input_tensor) | ||
79 | layer_names = [ | ||
80 | 'block2f_add', # 1/4 | ||
81 | 'block3f_add', # 1/8 | ||
82 | 'block5h_add', # 1/16 | ||
83 | 'block7c_add', # 1/32 | ||
84 | ] | ||
85 | elif phi == 7: | ||
86 | model = EfficientNetB7(include_top=False, | ||
87 | weights=weights, | ||
88 | input_tensor=input_tensor) | ||
89 | layer_names = [ | ||
90 | 'block2g_add', # 1/4 | ||
91 | 'block3g_add', # 1/8 | ||
92 | 'block5j_add', # 1/16 | ||
93 | 'block7d_add', # 1/32 | ||
94 | ] | ||
95 | |||
96 | skips = [model.get_layer(name).output for name in layer_names] | ||
97 | return model, skips | ||
98 | |||
99 | def EasyDet(phi=0, input_size=(None, None, 3), weights='imagenet'): | ||
100 | image_input = tf.keras.layers.Input(shape=input_size) | ||
101 | |||
102 | backbone, skips = load_backbone(phi=phi, input_tensor=image_input, weights=weights) | ||
103 | C2, C3, C4, C5 = skips | ||
104 | |||
105 | in2 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in2')(C2) | ||
106 | in3 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in3')(C3) | ||
107 | in4 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in4')(C4) | ||
108 | in5 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in5')(C5) | ||
109 | |||
110 | # 1 / 32 * 8 = 1 / 4 | ||
111 | P5 = tf.keras.layers.UpSampling2D(size=(8, 8))( | ||
112 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(in5)) | ||
113 | # 1 / 16 * 4 = 1 / 4 | ||
114 | out4 = tf.keras.layers.Add()([in4, tf.keras.layers.UpSampling2D(size=(2, 2))(in5)]) | ||
115 | P4 = tf.keras.layers.UpSampling2D(size=(4, 4))( | ||
116 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out4)) | ||
117 | # 1 / 8 * 2 = 1 / 4 | ||
118 | out3 = tf.keras.layers.Add()([in3, tf.keras.layers.UpSampling2D(size=(2, 2))(out4)]) | ||
119 | P3 = tf.keras.layers.UpSampling2D(size=(2, 2))( | ||
120 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out3)) | ||
121 | # 1 / 4 | ||
122 | P2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')( | ||
123 | tf.keras.layers.Add()([in2, tf.keras.layers.UpSampling2D(size=(2, 2))(out3)])) | ||
124 | # (b, 1/4, 1/4, 256) | ||
125 | fuse = tf.keras.layers.Concatenate()([P2, P3, P4, P5]) | ||
126 | |||
127 | model = tf.keras.models.Model(inputs=image_input, outputs=fuse) | ||
128 | return model | ||
129 | |||
130 | |||
131 | if __name__ == '__main__': | ||
132 | model = EasyDet(phi=0) | ||
133 | model.summary() | ||
134 | |||
135 | import time | ||
136 | import numpy as np | ||
137 | |||
138 | x = np.random.random_sample((1, 640, 640, 3)) | ||
139 | # warm up | ||
140 | output = model.predict(x) | ||
141 | |||
142 | print('\n[INFO] Test start') | ||
143 | time_start = time.time() | ||
144 | for i in range(1000): | ||
145 | output = model.predict(x) | ||
146 | |||
147 | time_end = time.time() | ||
148 | print('[INFO] Time used: {:.2f} ms'.format((time_end - time_start)*1000/(i+1))) |
ocr_engine/turnsole/nets/__init__.py
0 → 100644
File mode changed
ocr_engine/turnsole/nets/efficientnet.py
0 → 100644
This diff is collapsed.
Click to expand it.
1 | from . import angle_detector | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : lk | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Created Date : 2019-09-03 15:40:54 | ||
5 | # @Last Modified : 2022-07-18 16:10:36 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | import cv2 | ||
10 | import time | ||
11 | import numpy as np | ||
12 | # import tensorflow as tf | ||
13 | |||
14 | # import grpc | ||
15 | # from tensorflow_serving.apis import predict_pb2 | ||
16 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
17 | |||
18 | import tritonclient.grpc as grpcclient | ||
19 | |||
20 | |||
21 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): | ||
22 | ''' | ||
23 | Resize the input image according to the dimensions and keep aspect ratio of this image | ||
24 | ''' | ||
25 | dim = None | ||
26 | (h, w) = image.shape[:2] | ||
27 | |||
28 | # if both the width and height are None, then return the original image | ||
29 | if width is None and height is None: | ||
30 | return image | ||
31 | |||
32 | # check to see if the width is None | ||
33 | if width is None: | ||
34 | # calculate the ratio of the height and construct the dimensions | ||
35 | r = height / float(h) | ||
36 | dim = (int(w * r), height) | ||
37 | |||
38 | # otherwise, the height is None | ||
39 | else: | ||
40 | # calculate the ratio of the width and construct the dimensions | ||
41 | r = width / float(w) | ||
42 | dim = (width, int(h * r)) | ||
43 | |||
44 | # resize the image | ||
45 | resized = cv2.resize(image, dim, interpolation=inter) | ||
46 | |||
47 | return resized | ||
48 | |||
49 | def predict(image): | ||
50 | |||
51 | ROTATE = [0, 90, 180, 270] | ||
52 | |||
53 | # pre-process the image for classification | ||
54 | # Test 1: 直接resize到目标尺寸 | ||
55 | # image = cv2.resize(image, (512, 512)) | ||
56 | |||
57 | # Test 2: 按照短边resize到目标尺寸,长边按比例缩放 | ||
58 | short_side = 768 | ||
59 | if min(image.shape[:2]) > short_side: | ||
60 | image = resize(image, width=short_side) if image.shape[0] > image.shape[1] else resize(image, height=short_side) | ||
61 | |||
62 | # Test 3: 带padding的resize策略 | ||
63 | # image = resize_image_with_pad(image, 1024, 1024) | ||
64 | |||
65 | # Test 4: 直接使用原图 | ||
66 | # image = image | ||
67 | |||
68 | image = np.array(image, dtype="float32") | ||
69 | image = 2 * (image / 255.0) - 1 # Let data input to be normalized to the [-1,1] range | ||
70 | input_data = np.expand_dims(image, 0) | ||
71 | |||
72 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
73 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
74 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
75 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
76 | |||
77 | # request = predict_pb2.PredictRequest() | ||
78 | # request.model_spec.name = 'adc_model' | ||
79 | # request.model_spec.signature_name = 'serving_default' | ||
80 | # request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(inputs)) | ||
81 | |||
82 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
83 | |||
84 | # preds = tf.make_ndarray(result.outputs['dense']) | ||
85 | |||
86 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
87 | |||
88 | # Initialize the data | ||
89 | inputs = [grpcclient.InferInput('input_1', input_data.shape, "FP32")] # [InferInput 类的一个对象用于描述推理请求的输入张量。] | ||
90 | inputs[0].set_data_from_numpy(input_data) # 从指定的numpy数组中获取张量数据与此对象关联的输入 | ||
91 | outputs = [grpcclient.InferRequestedOutput("dense")] | ||
92 | |||
93 | # Inference | ||
94 | results = triton_client.infer( | ||
95 | model_name="adc_model", | ||
96 | inputs=inputs, | ||
97 | outputs=outputs | ||
98 | ) | ||
99 | # Get the output arrays from the results | ||
100 | preds = results.as_numpy("dense") | ||
101 | |||
102 | index = np.argmax(preds, axis=-1)[0] | ||
103 | |||
104 | return index | ||
105 | # return ROTATE[index] | ||
106 | |||
107 | def DegreeTrans(theta): | ||
108 | ''' | ||
109 | Convert radians to angles | ||
110 | ''' | ||
111 | res = theta / np.pi * 180 | ||
112 | return res | ||
113 | |||
114 | def rotateImage(src, degree): | ||
115 | ''' | ||
116 | Calculate the rotation matrix and rotate the image | ||
117 | param src:image after rot90 | ||
118 | param degree:the Hough degree | ||
119 | ''' | ||
120 | h, w = src.shape[:2] | ||
121 | RotateMatrix = cv2.getRotationMatrix2D((w/2.0, h/2.0), degree, 1) | ||
122 | # affine transformation, background color fills white | ||
123 | rotate = cv2.warpAffine(src, RotateMatrix, (w, h), borderValue=(255, 255, 255)) | ||
124 | return rotate | ||
125 | |||
126 | def CalcDegree(srcImage): | ||
127 | ''' | ||
128 | Calculating angles by Hough transform | ||
129 | param srcImage:image after rot90 | ||
130 | ''' | ||
131 | midImage = cv2.cvtColor(srcImage, cv2.COLOR_BGR2GRAY) | ||
132 | dstImage = cv2.Canny(midImage, 100, 300, 3) | ||
133 | lineimage = srcImage.copy() | ||
134 | |||
135 | # 通过霍夫变换检测直线 | ||
136 | # 第4个参数(th)就是阈值,阈值越大,检测精度越高 | ||
137 | th = 500 | ||
138 | while True: | ||
139 | if th > 0: | ||
140 | lines = cv2.HoughLines(dstImage, 1, np.pi/180, th) | ||
141 | else: | ||
142 | lines = None | ||
143 | break | ||
144 | if lines is not None: | ||
145 | if len(lines) > 10: | ||
146 | break | ||
147 | else: | ||
148 | th -= 50 | ||
149 | # print ('阈值是:', th) | ||
150 | else: | ||
151 | th -= 100 | ||
152 | # print ('阈值是:', th) | ||
153 | continue | ||
154 | |||
155 | sum_theta = 0 | ||
156 | num_theta = 0 | ||
157 | if lines is not None: | ||
158 | for i in range(len(lines)): | ||
159 | for rho, theta in lines[i]: | ||
160 | # control the angle of line between -30 to +30 | ||
161 | if theta > 1 and theta < 2.1: | ||
162 | sum_theta += theta | ||
163 | num_theta += 1 | ||
164 | # Average all angles | ||
165 | if num_theta == 0: | ||
166 | average = np.pi/2 | ||
167 | else: | ||
168 | average = sum_theta / num_theta | ||
169 | |||
170 | return DegreeTrans(average) - 90 | ||
171 | |||
172 | def ADC(image, fine_degree=False): | ||
173 | ''' | ||
174 | return param rotate: Corrected image | ||
175 | return param angle_degree:image offset image | ||
176 | ''' | ||
177 | |||
178 | # Return a wide angle index | ||
179 | img = np.copy(image) | ||
180 | angle_index = predict(img) | ||
181 | img_rot = np.rot90(img, -angle_index) | ||
182 | |||
183 | # if fine_degree then the image will be corrected more accurately based on character line features. | ||
184 | if fine_degree: | ||
185 | degree = CalcDegree(img_rot) | ||
186 | angle_degree = (angle_index * 90 - degree) % 360 | ||
187 | rotate = rotateImage(img_rot, degree) | ||
188 | return rotate, angle_degree | ||
189 | |||
190 | return img_rot, int(angle_index*90) |
This diff is collapsed.
Click to expand it.
1 | import cv2 | ||
2 | import time | ||
3 | import numpy as np | ||
4 | from .alphabets import alphabet | ||
5 | import tritonclient.grpc as grpcclient | ||
6 | |||
7 | |||
8 | def sort_poly(p): | ||
9 | # Find the minimum coordinate using (Xi+Yi) | ||
10 | min_axis = np.argmin(np.sum(p, axis=1)) | ||
11 | # Sort the box coordinates | ||
12 | p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] | ||
13 | if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): | ||
14 | return p | ||
15 | else: | ||
16 | return p[[0, 3, 2, 1]] | ||
17 | |||
18 | def client_init(url="localhost:8001", | ||
19 | ssl=False, private_key=None, root_certificates=None, certificate_chain=None, | ||
20 | verbose=False): | ||
21 | triton_client = grpcclient.InferenceServerClient( | ||
22 | url=url, | ||
23 | verbose=verbose, | ||
24 | ssl=ssl, | ||
25 | root_certificates=root_certificates, | ||
26 | private_key=private_key, | ||
27 | certificate_chain=certificate_chain) | ||
28 | return triton_client | ||
29 | |||
30 | class textRecServer: | ||
31 | """_summary_ | ||
32 | """ | ||
33 | def __init__(self): | ||
34 | super().__init__() | ||
35 | self.charactersS = ' ' + alphabet | ||
36 | self.batchsize = 8 | ||
37 | |||
38 | self.input_name = 'INPUT__0' | ||
39 | self.output_name = 'OUTPUT__0' | ||
40 | self.model_name = 'text_rec_torch' | ||
41 | self.np_type = np.float32 | ||
42 | self.quant_type = "FP32" | ||
43 | self.compression_algorithm = None | ||
44 | self.outputs = [] | ||
45 | self.outputs.append(grpcclient.InferRequestedOutput(self.output_name)) | ||
46 | |||
47 | def preprocess_one_image(self, image): | ||
48 | _, w, _ = image.shape | ||
49 | image = self._transform(image, w) | ||
50 | return image | ||
51 | |||
52 | def predict_batch(self, im, boxes): | ||
53 | """Summary | ||
54 | |||
55 | Args: | ||
56 | im (TYPE): RGB | ||
57 | boxes (TYPE): Description | ||
58 | |||
59 | Returns: | ||
60 | TYPE: Description | ||
61 | """ | ||
62 | |||
63 | triton_client = client_init("localhost:8001") | ||
64 | count_boxes = len(boxes) | ||
65 | boxes = sorted(boxes, | ||
66 | key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))), | ||
67 | reverse=True) | ||
68 | |||
69 | results = {} | ||
70 | labels = [] | ||
71 | rectime = 0.0 | ||
72 | if len(boxes) != 0: | ||
73 | for i in range(len(boxes) // self.batchsize + int(len(boxes) % self.batchsize != 0)): | ||
74 | box = boxes[min(len(boxes)-1, i * self.batchsize)] | ||
75 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
76 | width = max(32, min(int(32.0 * w / h), 960)) | ||
77 | if width < 32: | ||
78 | continue | ||
79 | slices = [] | ||
80 | for index, box in enumerate(boxes[i * self.batchsize:(i + 1) * self.batchsize]): | ||
81 | _box = [n for a in box for n in a] | ||
82 | if i * self.batchsize + index < count_boxes: | ||
83 | results[i * self.batchsize + index] = [list(map(int, _box))] | ||
84 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
85 | pts1 = np.float32(box) | ||
86 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
87 | |||
88 | # 前处理优化 | ||
89 | xmin, ymin, _w, _h = cv2.boundingRect(pts1) | ||
90 | xmax, ymax = xmin+_w, ymin+_h | ||
91 | xmin, ymin = max(0, xmin), max(0, ymin) | ||
92 | im_sclice = im[int(ymin):int(ymax), int(xmin):int(xmax), :] | ||
93 | pts1[:, 0] -= xmin | ||
94 | pts1[:, 1] -= ymin | ||
95 | |||
96 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
97 | im_crop = cv2.warpPerspective(im_sclice, M, (w, h)) | ||
98 | im_crop = self._transform(im_crop, width) | ||
99 | slices.append(im_crop) | ||
100 | start_rec = time.time() | ||
101 | slices = self.np_type(slices) | ||
102 | slices = slices.transpose(0, 3, 1, 2) | ||
103 | slices = slices/127.5-1. | ||
104 | inputs = [] | ||
105 | inputs.append(grpcclient.InferInput(self.input_name, list(slices.shape), self.quant_type)) | ||
106 | inputs[0].set_data_from_numpy(slices) | ||
107 | |||
108 | # inference | ||
109 | preds = triton_client.infer( | ||
110 | model_name=self.model_name, | ||
111 | inputs=inputs, | ||
112 | outputs=self.outputs, | ||
113 | compression_algorithm=self.compression_algorithm | ||
114 | ) | ||
115 | preds = preds.as_numpy(self.output_name).copy() | ||
116 | preds = preds.transpose(1, 0) | ||
117 | tmp_labels = self.decode(preds) | ||
118 | rectime += (time.time() - start_rec) | ||
119 | labels.extend(tmp_labels) | ||
120 | |||
121 | for index, label in enumerate(labels[:count_boxes]): | ||
122 | label = label.replace(' ', '').replace('¥', '¥') | ||
123 | if label == '': | ||
124 | del results[index] | ||
125 | continue | ||
126 | results[index].append(label) | ||
127 | # 重新排序 | ||
128 | results = list(results.values()) | ||
129 | results = sorted(results, key=lambda x: x[0][1], reverse=False) # 按 y0 从小到大排 | ||
130 | keys = [str(i) for i in range(len(results))] | ||
131 | results = dict(zip(keys, results)) | ||
132 | else: | ||
133 | results = dict() | ||
134 | rectime = -1 | ||
135 | return results, rectime | ||
136 | |||
137 | def decode(self, preds): | ||
138 | res = [] | ||
139 | for t in preds: | ||
140 | length = len(t) | ||
141 | char_list = [] | ||
142 | for i in range(length): | ||
143 | if t[i] != 0 and (not (i > 0 and t[i-1] == t[i])): | ||
144 | char_list.append(self.charactersS[t[i]]) | ||
145 | res.append(u''.join(char_list)) | ||
146 | return res | ||
147 | |||
148 | def _transform(self, im, width): | ||
149 | height=32 | ||
150 | |||
151 | ori_h, ori_w = im.shape[:2] | ||
152 | ratio1 = width * 1.0 / ori_w | ||
153 | ratio2 = height * 1.0 / ori_h | ||
154 | if ratio1 < ratio2: | ||
155 | ratio = ratio1 | ||
156 | else: | ||
157 | ratio = ratio2 | ||
158 | new_w, new_h = int(ori_w * ratio), int(ori_h * ratio) | ||
159 | if new_w<4: | ||
160 | new_w = 4 | ||
161 | im = cv2.resize(im, (new_w, new_h)) | ||
162 | img = np.ones((height, width, 3), dtype=np.uint8)*230 | ||
163 | img[:im.shape[0], :im.shape[1], :] = im | ||
164 | return img |
1 | from . import text_detector | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-06-01 19:00:18 | ||
5 | # @Last Modified : 2022-07-15 11:41:25 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | import cv2 | ||
10 | import time | ||
11 | import pyclipper | ||
12 | import numpy as np | ||
13 | # import tensorflow as tf | ||
14 | from shapely.geometry import Polygon | ||
15 | |||
16 | # import grpc | ||
17 | # from tensorflow_serving.apis import predict_pb2 | ||
18 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
19 | |||
20 | import tritonclient.grpc as grpcclient | ||
21 | |||
22 | |||
23 | def resize_with_padding(src, limit_max=1024): | ||
24 | '''限制长边不大于 limit_max 短边等比例缩放,以 0 填充''' | ||
25 | img = src.copy() | ||
26 | |||
27 | h, w, _ = img.shape | ||
28 | max_side = max(h, w) | ||
29 | ratio = limit_max / max_side if max_side > limit_max else 1 | ||
30 | h, w = int(h * ratio), int(w * ratio) | ||
31 | proc = cv2.resize(img, (w, h)) | ||
32 | |||
33 | canvas = np.zeros((limit_max, limit_max, 3), dtype=np.float32) | ||
34 | canvas[0:h, 0:w, :] = proc | ||
35 | return canvas, ratio | ||
36 | |||
37 | def rectangle_boxes_zoom(boxes, offset=1): | ||
38 | '''Scale the rectangle boxes via offset | ||
39 | Input: | ||
40 | boxes: with shape (-1, 4, 2) | ||
41 | offset: how many pix do you wanna zoom, we recommend less than 5 | ||
42 | Output: | ||
43 | boxes: zoomed | ||
44 | ''' | ||
45 | boxes = np.array(boxes) | ||
46 | boxes += [[[-offset,-offset], [offset,-offset], [offset,offset], [-offset,offset]]] | ||
47 | return boxes | ||
48 | |||
49 | def polygons_from_probmap(preds, ratio): | ||
50 | # 二值化 | ||
51 | prob_map_pred = np.array(preds, dtype=np.uint8)[0,:,:,0] | ||
52 | # 输入:二值图、轮廓检索(层次)模式、轮廓渐进方法 | ||
53 | # 输出:轮廓、层级关系 | ||
54 | contours, hierarchy = cv2.findContours(prob_map_pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | ||
55 | |||
56 | boxes = [] | ||
57 | for contour in contours: | ||
58 | if len(contour) < 4: | ||
59 | continue | ||
60 | |||
61 | # Vatti clipping | ||
62 | polygon = Polygon(np.array(contour).reshape((-1, 2))).buffer(0) | ||
63 | polygon = polygon.convex_hull if polygon.type == 'MultiPolygon' else polygon # Note: 这里不是 bug 是我们故意而为之 | ||
64 | |||
65 | if polygon.area < 10: | ||
66 | continue | ||
67 | |||
68 | distance = polygon.area * 1.5 / polygon.length | ||
69 | offset = pyclipper.PyclipperOffset() | ||
70 | offset.AddPath(list(polygon.exterior.coords), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | ||
71 | expanded = np.array(offset.Execute(distance)[0]) # Note: 这里不是 bug 是我们故意而为之 | ||
72 | |||
73 | # Convert polygon to rectangle | ||
74 | rect = cv2.minAreaRect(expanded) | ||
75 | box = cv2.boxPoints(rect) | ||
76 | # make clock-wise order | ||
77 | box = np.roll(box, 4-box.sum(axis=1).argmin(), 0) | ||
78 | box = np.array(box/ratio, dtype=np.int32) | ||
79 | boxes.append(box) | ||
80 | |||
81 | return boxes | ||
82 | |||
83 | def predict(image): | ||
84 | |||
85 | image_resized, ratio = resize_with_padding(image, limit_max=1280) | ||
86 | input_data = np.expand_dims(image_resized/255., axis=0) | ||
87 | |||
88 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
89 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
90 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
91 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
92 | |||
93 | # request = predict_pb2.PredictRequest() | ||
94 | # request.model_spec.name = 'dbnet_model' | ||
95 | # request.model_spec.signature_name = 'serving_default' | ||
96 | # request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(inputs)) | ||
97 | |||
98 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
99 | |||
100 | # preds = tf.make_ndarray(result.outputs['tf.math.greater']) | ||
101 | |||
102 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
103 | |||
104 | # Initialize the data | ||
105 | inputs = [grpcclient.InferInput('input_1', input_data.shape, "FP32")] | ||
106 | inputs[0].set_data_from_numpy(input_data) | ||
107 | outputs = [grpcclient.InferRequestedOutput("tf.math.greater")] | ||
108 | |||
109 | # Inference | ||
110 | results = triton_client.infer( | ||
111 | model_name="dbnet_model", | ||
112 | inputs=inputs, | ||
113 | outputs=outputs | ||
114 | ) | ||
115 | # Get the output arrays from the results | ||
116 | preds = results.as_numpy("tf.math.greater") | ||
117 | |||
118 | boxes = polygons_from_probmap(preds, ratio) | ||
119 | #boxes = rectangle_boxes_zoom(boxes, offset=0) | ||
120 | |||
121 | return boxes |
ocr_engine/turnsole/ocr_engine/__init__.py
0 → 100644
1 | # import grpc | ||
2 | import turnsole | ||
3 | import numpy as np | ||
4 | # import tensorflow as tf | ||
5 | # from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc | ||
6 | |||
7 | import tritonclient.grpc as grpcclient | ||
8 | |||
9 | |||
10 | class ObjectDetection(): | ||
11 | |||
12 | """通用文件检测算法 | ||
13 | 输入图片输出检测结果 | ||
14 | |||
15 | API 文档请参阅: | ||
16 | """ | ||
17 | |||
18 | def __init__(self, confidence_threshold=0.5): | ||
19 | """初始化检测对象 | ||
20 | |||
21 | Args: | ||
22 | confidence_threshold (float, optional): 目标检测模型的分类置信度 | ||
23 | """ | ||
24 | |||
25 | self.lable2index = { | ||
26 | 'id_card_info': 0, | ||
27 | 'id_card_guohui': 1, | ||
28 | 'lssfz_front': 2, | ||
29 | 'lssfz_back': 3, | ||
30 | 'jzz_front': 4, | ||
31 | 'jzz_back': 5, | ||
32 | 'txz_front': 6, | ||
33 | 'txz_back': 7, | ||
34 | 'bank_card': 8, | ||
35 | 'vehicle_license_front': 9, | ||
36 | 'vehicle_license_back': 10, | ||
37 | 'driving_license_front': 11, | ||
38 | 'driving_license_back': 12, | ||
39 | 'vrc_page_12': 13, | ||
40 | 'vrc_page_34': 14, | ||
41 | } | ||
42 | self.index2lable = list(self.lable2index.keys()) | ||
43 | |||
44 | # def resize_and_pad_to_384(self, image, jitter=True): | ||
45 | # """长边在 256-384 之间随机取一个数,四边 pad 到 384 | ||
46 | |||
47 | # Args: | ||
48 | # image (TYPE): An image represented as a numpy ndarray. | ||
49 | # """ | ||
50 | # image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) | ||
51 | # max_side = tf.random.uniform( | ||
52 | # (), 256, 384, dtype=tf.float32) if jitter else 384. | ||
53 | # ratio = max_side / tf.reduce_max(image_shape) | ||
54 | # image_shape = tf.cast(ratio * image_shape, dtype=tf.int32) | ||
55 | # image = tf.image.resize(image, image_shape) | ||
56 | # image = tf.image.pad_to_bounding_box(image, 0, 0, 384, 384) | ||
57 | # return image, ratio | ||
58 | |||
59 | def process(self, image): | ||
60 | """Processes an image and returns a list of the detected object location and classes data. | ||
61 | |||
62 | Args: | ||
63 | image (TYPE): An image represented as a numpy ndarray. | ||
64 | """ | ||
65 | h, w, _ = image.shape | ||
66 | # image, ratio = self.resize_and_pad_to_384(image, jitter=False) | ||
67 | image, ratio = turnsole.resize_with_pad(image, target_height=384, target_width=384) | ||
68 | input_data = np.expand_dims(image/255., axis=0) | ||
69 | |||
70 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
71 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
72 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
73 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
74 | |||
75 | # request = predict_pb2.PredictRequest() | ||
76 | # request.model_spec.name = 'object_detection' | ||
77 | # request.model_spec.signature_name = 'serving_default' | ||
78 | # request.inputs['image'].CopyFrom(tf.make_tensor_proto(inputs, dtype='float32')) | ||
79 | # # 100 secs timeout | ||
80 | # result = stub.Predict(request, 100.0) | ||
81 | |||
82 | # # saved_model_cli show --dir saved_model/ --all # 查看 saved model 的输入输出 | ||
83 | # boxes = tf.make_ndarray(result.outputs['decode_predictions']) | ||
84 | # scores = tf.make_ndarray(result.outputs['decode_predictions_1']) | ||
85 | # classes = tf.make_ndarray(result.outputs['decode_predictions_2']) | ||
86 | # valid_detections = tf.make_ndarray( | ||
87 | # result.outputs['decode_predictions_3']) | ||
88 | |||
89 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
90 | |||
91 | # Initialize the data | ||
92 | inputs = [grpcclient.InferInput('image', input_data.shape, "FP32")] | ||
93 | inputs[0].set_data_from_numpy(input_data.astype('float32')) | ||
94 | outputs = [ | ||
95 | grpcclient.InferRequestedOutput("decode_predictions"), | ||
96 | grpcclient.InferRequestedOutput("decode_predictions_1"), | ||
97 | grpcclient.InferRequestedOutput("decode_predictions_2"), | ||
98 | grpcclient.InferRequestedOutput("decode_predictions_3") | ||
99 | ] | ||
100 | |||
101 | # Inference | ||
102 | results = triton_client.infer( | ||
103 | model_name="object_detection", | ||
104 | inputs=inputs, | ||
105 | outputs=outputs | ||
106 | ) | ||
107 | # Get the output arrays from the results | ||
108 | boxes = results.as_numpy("decode_predictions") | ||
109 | scores = results.as_numpy("decode_predictions_1") | ||
110 | classes = results.as_numpy("decode_predictions_2") | ||
111 | valid_detections = results.as_numpy("decode_predictions_3") | ||
112 | |||
113 | boxes = boxes[0][:valid_detections[0]] | ||
114 | scores = scores[0][:valid_detections[0]] | ||
115 | classes = classes[0][:valid_detections[0]] | ||
116 | |||
117 | object_list = [] | ||
118 | for box, score, class_index in zip(boxes, scores, classes): | ||
119 | xmin, ymin, xmax, ymax = box / ratio | ||
120 | xmin = max(0, int(xmin)) | ||
121 | ymin = max(0, int(ymin)) | ||
122 | xmax = min(w, int(xmax)) | ||
123 | ymax = min(h, int(ymax)) | ||
124 | class_label = self.index2lable[int(class_index)] | ||
125 | item = { | ||
126 | "label": class_label, | ||
127 | "confidence": float(score), | ||
128 | "location": { | ||
129 | "xmin": xmin, | ||
130 | "ymin": ymin, | ||
131 | "xmax": xmax, | ||
132 | "ymax": ymax | ||
133 | } | ||
134 | } | ||
135 | object_list.append(item) | ||
136 | |||
137 | return object_list | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : lk | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-06-28 14:38:57 | ||
5 | # @Last Modified : 2022-09-06 14:37:47 | ||
6 | # @Description : | ||
7 | |||
8 | from .utils import SignatureDetection | ||
9 | |||
10 | signature_detector = SignatureDetection() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : lk | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-02-08 14:10:00 | ||
5 | # @Last Modified : 2022-09-06 14:45:10 | ||
6 | # @Description : | ||
7 | |||
8 | import turnsole | ||
9 | import numpy as np | ||
10 | # import tensorflow as tf | ||
11 | |||
12 | # import grpc | ||
13 | # from tensorflow_serving.apis import predict_pb2 | ||
14 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
15 | |||
16 | import tritonclient.grpc as grpcclient | ||
17 | |||
18 | |||
19 | # def resize_and_pad_to_1024(image, jitter=True): | ||
20 | # # 长边在 512-1024 之间随机取一个数,四边 pad 到 1024 | ||
21 | # image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) | ||
22 | # max_side = tf.random.uniform((), 512, 1024, dtype=tf.float32) if jitter else 1024. | ||
23 | # ratio = max_side / tf.reduce_max(image_shape) | ||
24 | # image_shape = tf.cast(ratio * image_shape, dtype=tf.int32) | ||
25 | # image = tf.image.resize(image, image_shape) | ||
26 | # image = tf.image.pad_to_bounding_box(image, 0, 0, 1024, 1024) | ||
27 | # return image, ratio | ||
28 | |||
29 | class SignatureDetection(): | ||
30 | |||
31 | """签字盖章检测算法 | ||
32 | 输入图片输出检测结果 | ||
33 | |||
34 | API 文档请参阅: | ||
35 | """ | ||
36 | |||
37 | def __init__(self, confidence_threshold=0.5): | ||
38 | """初始化检测对象 | ||
39 | |||
40 | Args: | ||
41 | confidence_threshold (float, optional): 目标检测模型的分类置信度 | ||
42 | """ | ||
43 | |||
44 | self.lable2index = { | ||
45 | 'circle': 0, | ||
46 | 'ellipse': 1, | ||
47 | 'rectangle': 2, | ||
48 | 'signature': 3, | ||
49 | 'qr_code': 4, | ||
50 | 'bar_code': 5 | ||
51 | } | ||
52 | self.index2lable = { | ||
53 | 0: 'circle', | ||
54 | 1: 'ellipse', | ||
55 | 2: 'rectangle', | ||
56 | 3: 'signature', | ||
57 | 4: 'qr_code', | ||
58 | 5: 'bar_code' | ||
59 | } | ||
60 | |||
61 | |||
62 | def process(self, image): | ||
63 | """Processes an image and returns a list of the detected signature location and classes data. | ||
64 | |||
65 | Args: | ||
66 | image (TYPE): An image represented as a numpy ndarray. | ||
67 | """ | ||
68 | h, w, _ = image.shape | ||
69 | |||
70 | # image, ratio = resize_and_pad_to_1024(image, jitter=False) | ||
71 | image, ratio = turnsole.resize_with_pad(image, target_height=1024, target_width=1024) | ||
72 | input_data = np.expand_dims(np.float32(image/255.), axis=0) | ||
73 | |||
74 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
75 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
76 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
77 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
78 | |||
79 | # request = predict_pb2.PredictRequest() | ||
80 | # request.model_spec.name = 'signature_model' | ||
81 | # request.model_spec.signature_name = 'serving_default' | ||
82 | # request.inputs['image'].CopyFrom(tf.make_tensor_proto(inputs, dtype='float32')) | ||
83 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
84 | |||
85 | # # saved_model_cli show --dir saved_model/ --all # 查看 saved model 的输入输出 | ||
86 | # boxes = tf.make_ndarray(result.outputs['decode_predictions']) | ||
87 | # scores = tf.make_ndarray(result.outputs['decode_predictions_1']) | ||
88 | # classes = tf.make_ndarray(result.outputs['decode_predictions_2']) | ||
89 | # valid_detections = tf.make_ndarray(result.outputs['decode_predictions_3']) | ||
90 | |||
91 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
92 | |||
93 | # Initialize the data | ||
94 | inputs = [grpcclient.InferInput('image', input_data.shape, "FP32")] | ||
95 | inputs[0].set_data_from_numpy(input_data) | ||
96 | outputs = [ | ||
97 | grpcclient.InferRequestedOutput("decode_predictions"), | ||
98 | grpcclient.InferRequestedOutput("decode_predictions_1"), | ||
99 | grpcclient.InferRequestedOutput("decode_predictions_2"), | ||
100 | grpcclient.InferRequestedOutput("decode_predictions_3") | ||
101 | ] | ||
102 | |||
103 | # Inference | ||
104 | results = triton_client.infer( | ||
105 | model_name="signature_model", | ||
106 | inputs=inputs, | ||
107 | outputs=outputs | ||
108 | ) | ||
109 | # Get the output arrays from the results | ||
110 | boxes = results.as_numpy("decode_predictions") | ||
111 | scores = results.as_numpy("decode_predictions_1") | ||
112 | classes = results.as_numpy("decode_predictions_2") | ||
113 | valid_detections = results.as_numpy("decode_predictions_3") | ||
114 | |||
115 | boxes = boxes[0][:valid_detections[0]] | ||
116 | scores = scores[0][:valid_detections[0]] | ||
117 | classes = classes[0][:valid_detections[0]] | ||
118 | |||
119 | signature_list = [] | ||
120 | for box, score, class_index in zip(boxes, scores, classes): | ||
121 | xmin, ymin, xmax, ymax = box / ratio | ||
122 | class_label = self.index2lable[class_index] | ||
123 | item = { | ||
124 | "label": class_label, | ||
125 | "confidence": float(score), | ||
126 | "location": { | ||
127 | "xmin": max(0, int(xmin)), | ||
128 | "ymin": max(0, int(ymin)), | ||
129 | "xmax": min(w, int(xmax)), | ||
130 | "ymax": min(h, int(ymax)) | ||
131 | } | ||
132 | } | ||
133 | signature_list.append(item) | ||
134 | |||
135 | return signature_list |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-06-16 11:01:36 | ||
5 | # @Last Modified : 2022-07-15 10:57:06 | ||
6 | # @Description : | ||
7 | |||
8 | from .read_data import base64_to_bgr | ||
9 | from .read_data import bytes_to_bgr | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Create Date : 2022-06-16 10:59:50 | ||
5 | # @Last Modified : 2022-08-03 14:59:15 | ||
6 | # @Description : | ||
7 | |||
8 | import cv2 | ||
9 | import base64 | ||
10 | import numpy as np | ||
11 | import tensorflow as tf | ||
12 | |||
13 | |||
14 | def base64_to_bgr(img64): | ||
15 | """把 base64 转换成图片 | ||
16 | 单通道的灰度图或四通道的透明图都将自动转换成三通道的 BGR 图 | ||
17 | |||
18 | Args: | ||
19 | img64 (TYPE): Description | ||
20 | |||
21 | Returns: | ||
22 | TYPE: image is a 3-D uint8 Tensor of shape [height, width, channels] where channels is BGR | ||
23 | """ | ||
24 | encoded_image = base64.b64decode(img64) | ||
25 | img_array = np.frombuffer(encoded_image, np.uint8) | ||
26 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
27 | return image | ||
28 | |||
29 | def bytes_to_bgr(buffer: bytes): | ||
30 | """Read a byte stream as a OpenCV image | ||
31 | |||
32 | Args: | ||
33 | buffer (TYPE): bytes of a decoded image | ||
34 | """ | ||
35 | img_array = np.frombuffer(buffer, np.uint8) | ||
36 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
37 | |||
38 | # image = tf.io.decode_image(buffer, channels=3) | ||
39 | # image = np.array(image)[...,::-1] | ||
40 | return image | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/paths.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | # @Author : Lyu Kui | ||
3 | # @Email : 9428.al@gmail.com | ||
4 | # @Created Date : 2021-03-04 17:50:09 | ||
5 | # @Last Modified : 2021-03-10 14:03:02 | ||
6 | # @Description : | ||
7 | |||
8 | import os | ||
9 | |||
10 | image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff") | ||
11 | |||
12 | |||
13 | def list_images(basePath, contains=None): | ||
14 | # return the set of files that are valid | ||
15 | return list_files(basePath, validExts=image_types, contains=contains) | ||
16 | |||
17 | def list_files(basePath, validExts=None, contains=None): | ||
18 | # loop over the directory structure | ||
19 | for (rootDir, dirNames, filenames) in os.walk(basePath): | ||
20 | # loop over the filenames in the current directory | ||
21 | for filename in filenames: | ||
22 | # if the contains string is not none and the filename does not contain | ||
23 | # the supplied string, then ignore the file | ||
24 | if contains is not None and filename.find(contains) == -1: | ||
25 | continue | ||
26 | |||
27 | # determine the file extension of the current file | ||
28 | ext = filename[filename.rfind("."):].lower() | ||
29 | |||
30 | # check to see if the file is an image and should be processed | ||
31 | if validExts is None or ext.endswith(validExts): | ||
32 | # construct the path to the image and yield it | ||
33 | imagePath = os.path.join(rootDir, filename) | ||
34 | yield imagePath | ||
35 | |||
36 | def get_filename(filePath): | ||
37 | basename = os.path.basename(filePath) | ||
38 | fname, fextension = os.path.splitext(basename) | ||
39 | return fname | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/pdf_tools.py
0 → 100644
1 | import cv2 | ||
2 | import fitz | ||
3 | import numpy as np | ||
4 | |||
5 | def pdf_to_images(pdf_path: str): | ||
6 | """PDF 转 OpenCV Image | ||
7 | |||
8 | Args: | ||
9 | pdf_path (str): Description | ||
10 | |||
11 | Returns: | ||
12 | TYPE: Description | ||
13 | """ | ||
14 | images = [] | ||
15 | doc = fitz.open(pdf_path) | ||
16 | # producer = doc.metadata.get('producer') | ||
17 | |||
18 | for pno in range(doc.page_count): | ||
19 | page = doc.load_page(pno) | ||
20 | |||
21 | all_texts = page.get_text().replace('\n', '').strip() | ||
22 | # 根据经验过滤掉特殊情况 | ||
23 | all_texts = all_texts.strip('Click to buy NOW!PDF-XChangewww.docu-track.comClick to buy NOW!PDF-XChangewww.docu-track.com') | ||
24 | blocks = page.get_text("dict")["blocks"] | ||
25 | imgblocks = [b for b in blocks if b["type"] == 1] | ||
26 | |||
27 | page_images = [] | ||
28 | # 如果一个字都没有, | ||
29 | if len(all_texts) == 0 and len(imgblocks) != 0: | ||
30 | # # 这些 producer 包含碎图,如果真的是碎图我们把碎图拼接一下 | ||
31 | # if producer in ['Microsoft: Print To PDF', | ||
32 | # 'GPL Ghostscript 8.71', | ||
33 | # 'doPDF Ver 7.3 Build 398 (Windows 7 Business Edition (SP 1) - Version: 6.1.7601 (x64))', | ||
34 | # '福昕阅读器PDF打印机 版本 11.0.114.4386']: | ||
35 | patches = [] | ||
36 | for imgblock in imgblocks: | ||
37 | contents = imgblock["image"] | ||
38 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
39 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
40 | patches.append(image) | ||
41 | try: | ||
42 | try: | ||
43 | image = np.concatenate(patches, axis=0) | ||
44 | page_images.append(image) | ||
45 | except: | ||
46 | image = np.concatenate(patches, axis=1) | ||
47 | page_images.append(image) | ||
48 | except: | ||
49 | # 当两张拼不到一块的时候我们可以认为他是两张图,如果超过两张那就不一定了 | ||
50 | if len(patches) == 2: | ||
51 | page_images = patches | ||
52 | else: | ||
53 | pix = page.get_pixmap(dpi=350) | ||
54 | contents = pix.tobytes(output="png") | ||
55 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
56 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
57 | page_images.append(image) | ||
58 | # else: | ||
59 | # for imgblock in imgblocks: | ||
60 | # contents = imgblock["image"] | ||
61 | # img_array = np.frombuffer(contents, dtype=np.uint8) | ||
62 | # image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
63 | # page_images.append(image) | ||
64 | else: | ||
65 | pix = page.get_pixmap(dpi=350) | ||
66 | contents = pix.tobytes(output="png") | ||
67 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
68 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
69 | page_images.append(image) | ||
70 | images.append(page_images) | ||
71 | return images | ||
72 |
ocr_engine/turnsole/video/__init__.py
0 → 100644
ocr_engine/turnsole/video/count_frames.py
0 → 100644
1 | # import the necessary packages | ||
2 | # from ..convenience import is_cv3 | ||
3 | import cv2 | ||
4 | |||
5 | def count_frames(path, override=False): | ||
6 | # grab a pointer to the video file and initialize the total | ||
7 | # number of frames read | ||
8 | video = cv2.VideoCapture(path) | ||
9 | total = 0 | ||
10 | |||
11 | # if the override flag is passed in, revert to the manual | ||
12 | # method of counting frames | ||
13 | if override: | ||
14 | total = count_frames_manual(video) | ||
15 | |||
16 | # otherwise, let's try the fast way first | ||
17 | else: | ||
18 | # lets try to determine the number of frames in a video | ||
19 | # via video properties; this method can be very buggy | ||
20 | # and might throw an error based on your OpenCV version | ||
21 | # or may fail entirely based on your which video codecs | ||
22 | # you have installed | ||
23 | try: | ||
24 | # # check if we are using OpenCV 3 | ||
25 | # if is_cv3(): | ||
26 | # total = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
27 | |||
28 | # # otherwise, we are using OpenCV 2.4 | ||
29 | # else: | ||
30 | # total = int(video.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) | ||
31 | |||
32 | total = int(video.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) | ||
33 | |||
34 | # uh-oh, we got an error -- revert to counting manually | ||
35 | except: | ||
36 | total = count_frames_manual(video) | ||
37 | |||
38 | # release the video file pointer | ||
39 | video.release() | ||
40 | |||
41 | # return the total number of frames in the video | ||
42 | return total | ||
43 | |||
44 | def count_frames_manual(video): | ||
45 | # initialize the total number of frames read | ||
46 | total = 0 | ||
47 | |||
48 | # loop over the frames of the video | ||
49 | while True: | ||
50 | # grab the current frame | ||
51 | (grabbed, frame) = video.read() | ||
52 | |||
53 | # check to see if we have reached the end of the | ||
54 | # video | ||
55 | if not grabbed: | ||
56 | break | ||
57 | |||
58 | # increment the total number of frames read | ||
59 | total += 1 | ||
60 | |||
61 | # return the total number of frames in the video file | ||
62 | return total | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/video/filevideostream.py
0 → 100644
1 | # import the necessary packages | ||
2 | from threading import Thread | ||
3 | import sys | ||
4 | import cv2 | ||
5 | import time | ||
6 | |||
7 | # import the Queue class from Python 3 | ||
8 | if sys.version_info >= (3, 0): | ||
9 | from queue import Queue | ||
10 | |||
11 | # otherwise, import the Queue class for Python 2.7 | ||
12 | else: | ||
13 | from Queue import Queue | ||
14 | |||
15 | |||
16 | class FileVideoStream: | ||
17 | def __init__(self, path, transform=None, queue_size=128): | ||
18 | # initialize the file video stream along with the boolean | ||
19 | # used to indicate if the thread should be stopped or not | ||
20 | self.stream = cv2.VideoCapture(path) | ||
21 | self.stopped = False | ||
22 | self.transform = transform | ||
23 | |||
24 | # initialize the queue used to store frames read from | ||
25 | # the video file | ||
26 | self.Q = Queue(maxsize=queue_size) | ||
27 | # intialize thread | ||
28 | self.thread = Thread(target=self.update, args=()) | ||
29 | self.thread.daemon = True | ||
30 | |||
31 | def start(self): | ||
32 | # start a thread to read frames from the file video stream | ||
33 | self.thread.start() | ||
34 | return self | ||
35 | |||
36 | def update(self): | ||
37 | # keep looping infinitely | ||
38 | while True: | ||
39 | # if the thread indicator variable is set, stop the | ||
40 | # thread | ||
41 | if self.stopped: | ||
42 | break | ||
43 | |||
44 | # otherwise, ensure the queue has room in it | ||
45 | if not self.Q.full(): | ||
46 | # read the next frame from the file | ||
47 | (grabbed, frame) = self.stream.read() | ||
48 | |||
49 | # if the `grabbed` boolean is `False`, then we have | ||
50 | # reached the end of the video file | ||
51 | if not grabbed: | ||
52 | self.stopped = True | ||
53 | break | ||
54 | |||
55 | # if there are transforms to be done, might as well | ||
56 | # do them on producer thread before handing back to | ||
57 | # consumer thread. ie. Usually the producer is so far | ||
58 | # ahead of consumer that we have time to spare. | ||
59 | # | ||
60 | # Python is not parallel but the transform operations | ||
61 | # are usually OpenCV native so release the GIL. | ||
62 | # | ||
63 | # Really just trying to avoid spinning up additional | ||
64 | # native threads and overheads of additional | ||
65 | # producer/consumer queues since this one was generally | ||
66 | # idle grabbing frames. | ||
67 | if self.transform: | ||
68 | frame = self.transform(frame) | ||
69 | |||
70 | # add the frame to the queue | ||
71 | self.Q.put(frame) | ||
72 | else: | ||
73 | time.sleep(0.1) # Rest for 10ms, we have a full queue | ||
74 | |||
75 | self.stream.release() | ||
76 | |||
77 | def read(self): | ||
78 | # return next frame in the queue | ||
79 | return self.Q.get() | ||
80 | |||
81 | # Insufficient to have consumer use while(more()) which does | ||
82 | # not take into account if the producer has reached end of | ||
83 | # file stream. | ||
84 | def running(self): | ||
85 | return self.more() or not self.stopped | ||
86 | |||
87 | def more(self): | ||
88 | # return True if there are still frames in the queue. If stream is not stopped, try to wait a moment | ||
89 | tries = 0 | ||
90 | while self.Q.qsize() == 0 and not self.stopped and tries < 5: | ||
91 | time.sleep(0.1) | ||
92 | tries += 1 | ||
93 | |||
94 | return self.Q.qsize() > 0 | ||
95 | |||
96 | def stop(self): | ||
97 | # indicate that the thread should be stopped | ||
98 | self.stopped = True | ||
99 | # wait until stream resources are released (producer thread might be still grabbing frame) | ||
100 | self.thread.join() |
ocr_engine/turnsole/video/fps.py
0 → 100644
1 | # import the necessary packages | ||
2 | import datetime | ||
3 | |||
4 | class FPS: | ||
5 | def __init__(self): | ||
6 | # store the start time, end time, and total number of frames | ||
7 | # that were examined between the start and end intervals | ||
8 | self._start = None | ||
9 | self._end = None | ||
10 | self._numFrames = 0 | ||
11 | |||
12 | def start(self): | ||
13 | # start the timer | ||
14 | self._start = datetime.datetime.now() | ||
15 | return self | ||
16 | |||
17 | def stop(self): | ||
18 | # stop the timer | ||
19 | self._end = datetime.datetime.now() | ||
20 | |||
21 | def update(self): | ||
22 | # increment the total number of frames examined during the | ||
23 | # start and end intervals | ||
24 | self._numFrames += 1 | ||
25 | |||
26 | def elapsed(self): | ||
27 | # return the total number of seconds between the start and | ||
28 | # end interval | ||
29 | return (self._end - self._start).total_seconds() | ||
30 | |||
31 | def fps(self): | ||
32 | # compute the (approximate) frames per second | ||
33 | return self._numFrames / self.elapsed() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/video/pivideostream.py
0 → 100644
1 | # import the necessary packages | ||
2 | from picamera.array import PiRGBArray | ||
3 | from picamera import PiCamera | ||
4 | from threading import Thread | ||
5 | import cv2 | ||
6 | |||
7 | class PiVideoStream: | ||
8 | def __init__(self, resolution=(320, 240), framerate=32, **kwargs): | ||
9 | # initialize the camera | ||
10 | self.camera = PiCamera() | ||
11 | |||
12 | # set camera parameters | ||
13 | self.camera.resolution = resolution | ||
14 | self.camera.framerate = framerate | ||
15 | |||
16 | # set optional camera parameters (refer to PiCamera docs) | ||
17 | for (arg, value) in kwargs.items(): | ||
18 | setattr(self.camera, arg, value) | ||
19 | |||
20 | # initialize the stream | ||
21 | self.rawCapture = PiRGBArray(self.camera, size=resolution) | ||
22 | self.stream = self.camera.capture_continuous(self.rawCapture, | ||
23 | format="bgr", use_video_port=True) | ||
24 | |||
25 | # initialize the frame and the variable used to indicate | ||
26 | # if the thread should be stopped | ||
27 | self.frame = None | ||
28 | self.stopped = False | ||
29 | |||
30 | def start(self): | ||
31 | # start the thread to read frames from the video stream | ||
32 | t = Thread(target=self.update, args=()) | ||
33 | t.daemon = True | ||
34 | t.start() | ||
35 | return self | ||
36 | |||
37 | def update(self): | ||
38 | # keep looping infinitely until the thread is stopped | ||
39 | for f in self.stream: | ||
40 | # grab the frame from the stream and clear the stream in | ||
41 | # preparation for the next frame | ||
42 | self.frame = f.array | ||
43 | self.rawCapture.truncate(0) | ||
44 | |||
45 | # if the thread indicator variable is set, stop the thread | ||
46 | # and resource camera resources | ||
47 | if self.stopped: | ||
48 | self.stream.close() | ||
49 | self.rawCapture.close() | ||
50 | self.camera.close() | ||
51 | return | ||
52 | |||
53 | def read(self): | ||
54 | # return the frame most recently read | ||
55 | return self.frame | ||
56 | |||
57 | def stop(self): | ||
58 | # indicate that the thread should be stopped | ||
59 | self.stopped = True |
ocr_engine/turnsole/video/videostream.py
0 → 100644
1 | # import the necessary packages | ||
2 | from .webcamvideostream import WebcamVideoStream | ||
3 | |||
4 | class VideoStream: | ||
5 | def __init__(self, src=0, usePiCamera=False, resolution=(320, 240), | ||
6 | framerate=32, **kwargs): | ||
7 | # check to see if the picamera module should be used | ||
8 | if usePiCamera: | ||
9 | # only import the picamera packages unless we are | ||
10 | # explicity told to do so -- this helps remove the | ||
11 | # requirement of `picamera[array]` from desktops or | ||
12 | # laptops that still want to use the `imutils` package | ||
13 | from .pivideostream import PiVideoStream | ||
14 | |||
15 | # initialize the picamera stream and allow the camera | ||
16 | # sensor to warmup | ||
17 | self.stream = PiVideoStream(resolution=resolution, | ||
18 | framerate=framerate, **kwargs) | ||
19 | |||
20 | # otherwise, we are using OpenCV so initialize the webcam | ||
21 | # stream | ||
22 | else: | ||
23 | self.stream = WebcamVideoStream(src=src) | ||
24 | |||
25 | def start(self): | ||
26 | # start the threaded video stream | ||
27 | return self.stream.start() | ||
28 | |||
29 | def update(self): | ||
30 | # grab the next frame from the stream | ||
31 | self.stream.update() | ||
32 | |||
33 | def read(self): | ||
34 | # return the current frame | ||
35 | return self.stream.read() | ||
36 | |||
37 | def stop(self): | ||
38 | # stop the thread and release any resources | ||
39 | self.stream.stop() |
1 | # import the necessary packages | ||
2 | from threading import Thread | ||
3 | import cv2 | ||
4 | |||
5 | class WebcamVideoStream: | ||
6 | def __init__(self, src=0, name="WebcamVideoStream"): | ||
7 | # initialize the video camera stream and read the first frame | ||
8 | # from the stream | ||
9 | self.stream = cv2.VideoCapture(src) | ||
10 | (self.grabbed, self.frame) = self.stream.read() | ||
11 | |||
12 | # initialize the thread name | ||
13 | self.name = name | ||
14 | |||
15 | # initialize the variable used to indicate if the thread should | ||
16 | # be stopped | ||
17 | self.stopped = False | ||
18 | |||
19 | def start(self): | ||
20 | # start the thread to read frames from the video stream | ||
21 | t = Thread(target=self.update, name=self.name, args=()) | ||
22 | t.daemon = True | ||
23 | t.start() | ||
24 | return self | ||
25 | |||
26 | def update(self): | ||
27 | # keep looping infinitely until the thread is stopped | ||
28 | while True: | ||
29 | # if the thread indicator variable is set, stop the thread | ||
30 | if self.stopped: | ||
31 | return | ||
32 | |||
33 | # otherwise, read the next frame from the stream | ||
34 | (self.grabbed, self.frame) = self.stream.read() | ||
35 | |||
36 | def read(self): | ||
37 | # return the frame most recently read | ||
38 | return self.frame | ||
39 | |||
40 | def stop(self): | ||
41 | # indicate that the thread should be stopped | ||
42 | self.stopped = True |
-
Please register or sign in to post a comment