update
Showing
59 changed files
with
2434 additions
and
1 deletions
ocr_engine/README.md
0 → 100644
| 1 | # turnsole | ||
| 2 | A series of convenience functions make your machine learning project easier | ||
| 3 | |||
| 4 | ## 安装方法 | ||
| 5 | |||
| 6 | ### Latest release | ||
| 7 | `pip install turnsole` | ||
| 8 | > 项目暂不开源,因此该安装方法暂时不保证能用 | ||
| 9 | |||
| 10 | ### Developer mode | ||
| 11 | |||
| 12 | `pip install -e .` | ||
| 13 | |||
| 14 | ## 快速上手 | ||
| 15 | ### PDF 操作 | ||
| 16 | #### 智能 PDF 文件转图片 | ||
| 17 | 智能的把 PDF 文件里面的插图找出来,例如没有插图就将整页 PDF 截图下来,也能智能的将碎图拼接在一起 | ||
| 18 | |||
| 19 | ##### Example: | ||
| 20 | <pre># pdf_path 表示 PDF 文件的路径,输出 images 按页码进行汇总输出 | ||
| 21 | images = turnsole.pdf_to_images(pdf_path)</pre> | ||
| 22 | |||
| 23 | ### 图像操作工具箱 | ||
| 24 | #### base64_to_bgr / bgr_to_base64 | ||
| 25 | 图像和 base64 互相转换 | ||
| 26 | |||
| 27 | ##### Example: | ||
| 28 | <pre>image = turnsole.base64_to_bgr(img64) | ||
| 29 | img64 = turnsole.bgr_to_base64(image)</pre> | ||
| 30 | |||
| 31 | ### image_crop | ||
| 32 | 根据 bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
| 33 | |||
| 34 | ##### Example: | ||
| 35 | <pre>im_slice_no_perspective = turnsole.image_crop(image, bbox) | ||
| 36 | im_slice = turnsole.image_crop(image, bbox, perspective=True)</pre> | ||
| 37 | |||
| 38 | ##### Output: | ||
| 39 | |||
| 40 | <img src="docs/images/image_crop.png?raw=true" alt="image crop example" style="max-width: 200px;"> | ||
| 41 | |||
| 42 | ### OCR 引擎模块 | ||
| 43 | OCR 引擎指的是一系列跟 OCR 相关的底层模型,我们提供了这些模型的函数式调用接口和标准 API | ||
| 44 | |||
| 45 | - [x] ADC :tada: | ||
| 46 | - [x] DBNet :tada: | ||
| 47 | - [x] CRNN :tada: | ||
| 48 | - [x] Object Detector :tada: | ||
| 49 | - [x] Signature Detector :tada: | ||
| 50 | |||
| 51 | #### 免费试用 | ||
| 52 | ```python | ||
| 53 | import requests | ||
| 54 | |||
| 55 | results = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': open(file_path, 'rb')}).json() | ||
| 56 | ocr_results = results['ocr_results'] | ||
| 57 | ``` | ||
| 58 | |||
| 59 | #### Prerequisites | ||
| 60 | 由于 OCR 引擎模块依赖于底层神经网络模型,因此需要先用 Docker 挂载底层神经网络模型 | ||
| 61 | |||
| 62 | 首先把 ./model_repository 文件夹和里面的模型放到项目根目录下再启动,如果没有相关模型找 [lvkui](lvkui@situdata.com) 要 | ||
| 63 | |||
| 64 | 使用起来非常简单,你只需要启动对应的 Docker 容器即可 | ||
| 65 | |||
| 66 | ```bash | ||
| 67 | docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
| 68 | ``` | ||
| 69 | |||
| 70 | #### ADC | ||
| 71 | 通用文件摆正算法 | ||
| 72 | |||
| 73 | ``` | ||
| 74 | from turnsole.ocr_engine import angle_detector | ||
| 75 | |||
| 76 | image_rotated, direction = angle_detector.ADC(image, fine_degree=False) | ||
| 77 | ``` | ||
| 78 | |||
| 79 | #### DBNet | ||
| 80 | 通用文字检测算法 | ||
| 81 | |||
| 82 | ``` | ||
| 83 | from turnsole.ocr_engine import text_detector | ||
| 84 | |||
| 85 | boxes = text_detector.predict(image) | ||
| 86 | ``` | ||
| 87 | |||
| 88 | #### CRNN | ||
| 89 | 通用文字识别算法 | ||
| 90 | |||
| 91 | ``` | ||
| 92 | from turnsole.ocr_engine import text_recognizer | ||
| 93 | |||
| 94 | ocr_result, ocr_time = text_recognizer.predict_batch(image, boxes) | ||
| 95 | ``` | ||
| 96 | |||
| 97 | #### Object Detector | ||
| 98 | 通用文件检测算法 | ||
| 99 | |||
| 100 | ``` | ||
| 101 | from turnsole.ocr_engine import object_detector | ||
| 102 | |||
| 103 | object_list = object_detector.process(image) | ||
| 104 | ``` | ||
| 105 | |||
| 106 | #### Signature Detector | ||
| 107 | 签字盖章二维码检测算法 | ||
| 108 | |||
| 109 | ``` | ||
| 110 | from turnsole.ocr_engine import signature_detector | ||
| 111 | |||
| 112 | signature_list = signature_detector.process(image) | ||
| 113 | ``` | ||
| 114 | |||
| 115 | #### 标准 API | ||
| 116 | ``` | ||
| 117 | python api/ocr_engine_server.py | ||
| 118 | ``` | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/api/nohup.out
0 → 100644
| 1 | [2022-10-21 14:12:17 +0800] [8546] [INFO] Goin' Fast @ http://192.168.10.11:9001 | ||
| 2 | [2022-10-21 14:12:17 +0800] [8567] [INFO] Starting worker [8567] | ||
| 3 | [2022-10-21 14:12:17 +0800] [8568] [INFO] Starting worker [8568] | ||
| 4 | [2022-10-21 14:12:17 +0800] [8569] [INFO] Starting worker [8569] | ||
| 5 | [2022-10-21 14:12:17 +0800] [8570] [INFO] Starting worker [8570] | ||
| 6 | [2022-10-21 14:12:17 +0800] [8571] [INFO] Starting worker [8571] | ||
| 7 | [2022-10-21 14:12:17 +0800] [8572] [INFO] Starting worker [8572] | ||
| 8 | [2022-10-21 14:12:17 +0800] [8573] [INFO] Starting worker [8573] | ||
| 9 | [2022-10-21 14:12:17 +0800] [8576] [INFO] Starting worker [8576] | ||
| 10 | [2022-10-21 14:12:17 +0800] [8574] [INFO] Starting worker [8574] | ||
| 11 | [2022-10-21 14:12:17 +0800] [8575] [INFO] Starting worker [8575] | ||
| 12 | [2022-10-21 14:13:51 +0800] [8575] [ERROR] Exception occurred while handling uri: 'http://192.168.10.11:9001/gen_ocr' | ||
| 13 | Traceback (most recent call last): | ||
| 14 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/sanic/app.py", line 944, in handle_request | ||
| 15 | response = await response | ||
| 16 | File "ocr_engine_server.py", line 37, in ocr_engine | ||
| 17 | boxes = text_detector.predict(image) | ||
| 18 | File "/home/situ/qfs/invoice_tamper/09_project/project/bank_bill_ocr/OCR_Engine/turnsole/ocr_engine/DBNet/text_detector.py", line 113, in predict | ||
| 19 | outputs=outputs | ||
| 20 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/tritonclient/grpc/__init__.py", line 1431, in infer | ||
| 21 | raise_error_grpc(rpc_error) | ||
| 22 | File "/home/situ/miniconda3/envs/workenv/lib/python3.6/site-packages/tritonclient/grpc/__init__.py", line 62, in raise_error_grpc | ||
| 23 | raise get_error_grpc(rpc_error) from None | ||
| 24 | tritonclient.utils.InferenceServerException: [StatusCode.UNAVAILABLE] Request for unknown model: 'dbnet_model' is not found | ||
| 25 | [2022-10-21 14:13:51 +0800] - (sanic.access)[INFO][192.168.10.11:57260]: POST http://192.168.10.11:9001/gen_ocr 500 735 |
ocr_engine/api/ocr_engine_server.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-06-05 20:49:51 | ||
| 5 | # @Last Modified : 2022-08-19 17:24:55 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | |||
| 10 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | ||
| 11 | |||
| 12 | from sanic import Sanic | ||
| 13 | from sanic.response import json | ||
| 14 | |||
| 15 | from turnsole.ocr_engine import angle_detector | ||
| 16 | from turnsole.ocr_engine import text_detector | ||
| 17 | from turnsole.ocr_engine import text_recognizer | ||
| 18 | from turnsole.ocr_engine import object_detector | ||
| 19 | from turnsole.ocr_engine import signature_detector | ||
| 20 | |||
| 21 | from turnsole import bytes_to_bgr | ||
| 22 | |||
| 23 | app = Sanic("OCR_ENGINE") | ||
| 24 | app.config.REQUEST_MAX_SIZE = 1000000000 # 请求的大小(字节)/ 1GB | ||
| 25 | app.config.REQUEST_BUFFER_QUEUE_SIZE = 1000 # 请求流缓冲区队列大小 | ||
| 26 | app.config.REQUEST_TIMEOUT = 600 # 请求到达需要多长时间(秒) | ||
| 27 | app.config.RESPONSE_TIMEOUT = 600 # 处理响应需要多长时间(秒) | ||
| 28 | |||
| 29 | |||
| 30 | @app.post('/gen_ocr') | ||
| 31 | async def ocr_engine(request): | ||
| 32 | # request.files.get() 具有 type/body/name 三个属性 | ||
| 33 | file = request.files.get('file').body | ||
| 34 | # 将 bytes 转成 bgr 图片 | ||
| 35 | image = bytes_to_bgr(file) | ||
| 36 | # 文字检测 | ||
| 37 | boxes = text_detector.predict(image) | ||
| 38 | # 文字识别 | ||
| 39 | res, _ = text_recognizer.predict_batch(image[..., ::-1], boxes) | ||
| 40 | resp = {} | ||
| 41 | resp["ocr_results"] = res | ||
| 42 | return json(resp) | ||
| 43 | |||
| 44 | |||
| 45 | @app.post('/gen_ocr_with_rotation', ) | ||
| 46 | async def ocr_engine_with_rotation(request): | ||
| 47 | # request.files.get() 具有 type/body/name 三个属性 | ||
| 48 | file = request.files.get('file').body | ||
| 49 | # 将 bytes 转成 bgr 图片 | ||
| 50 | image = bytes_to_bgr(file) | ||
| 51 | # 方向检测 | ||
| 52 | image, direction = angle_detector.ADC(image.copy(), fine_degree=False) | ||
| 53 | # 文字检测 | ||
| 54 | boxes = text_detector.predict(image) | ||
| 55 | # 文字识别 | ||
| 56 | res, _ = text_recognizer.predict_batch(image[..., ::-1], boxes) | ||
| 57 | |||
| 58 | resp = {} | ||
| 59 | resp["ocr_results"] = res | ||
| 60 | resp["direction"] = direction | ||
| 61 | return json(resp) | ||
| 62 | |||
| 63 | |||
| 64 | @app.post("/object_detect") | ||
| 65 | async def object_detect(request): | ||
| 66 | # request.files.get() 具有 type/body/name 三个属性 | ||
| 67 | file = request.files.get('file').body | ||
| 68 | # 将 bytes 转成 bgr 图片 | ||
| 69 | image = bytes_to_bgr(file) | ||
| 70 | # 通用文件检测 | ||
| 71 | object_list = object_detector.process(image) | ||
| 72 | return json(object_list) | ||
| 73 | |||
| 74 | |||
| 75 | @app.post("/signature_detect") | ||
| 76 | async def signature_detect(request): | ||
| 77 | # request.files.get() 具有 type/body/name 三个属性 | ||
| 78 | file = request.files.get('file').body | ||
| 79 | # 将 bytes 转成 bgr 图片 | ||
| 80 | image = bytes_to_bgr(file) | ||
| 81 | # 签字盖章二维码条形码检测 | ||
| 82 | signature_list = signature_detector.process(image) | ||
| 83 | return json(signature_list) | ||
| 84 | |||
| 85 | |||
| 86 | if __name__ == "__main__": | ||
| 87 | # app.run(host="0.0.0.0", port=9001) | ||
| 88 | app.run(host="192.168.10.11", port=9002, workers=10) | ||
| 89 | # uvicorn server:app --port 9001 --workers 10 |
ocr_engine/demos/images/sunflower.bmp
0 → 100644
No preview for this file type
ocr_engine/demos/images/sunflower.gif
0 → 100644
9.68 KB
ocr_engine/demos/images/sunflower.jpg
0 → 100644
405 KB
ocr_engine/demos/images/sunflower.png
0 → 100644
461 KB
ocr_engine/demos/images/sunflower.tif
0 → 100644
No preview for this file type
ocr_engine/demos/img_ocr/001.jpg
0 → 100644
1.62 MB
ocr_engine/demos/img_ocr/002.jpg
0 → 100644
97.7 KB
ocr_engine/demos/img_ocr/003.jpg
0 → 100644
112 KB
ocr_engine/demos/img_ocr/004.jpg
0 → 100644
24.4 KB
ocr_engine/demos/img_ocr/005.jpg
0 → 100644
77.5 KB
ocr_engine/demos/read_frames_fast.py
0 → 100644
| 1 | # Modified from: | ||
| 2 | # https://www.pyimagesearch.com/2017/02/06/faster-video-file-fps-with-cv2-videocapture-and-opencv/ | ||
| 3 | |||
| 4 | # Performance: | ||
| 5 | # Python 2.7: 105.78 --> 131.75 | ||
| 6 | # Python 3.7: 15.36 --> 50.13 | ||
| 7 | |||
| 8 | # USAGE | ||
| 9 | # python read_frames_fast.py --video videos/jurassic_park_intro.mp4 | ||
| 10 | |||
| 11 | # import the necessary packages | ||
| 12 | from turnsole.video import FileVideoStream | ||
| 13 | from turnsole.video import FPS | ||
| 14 | import numpy as np | ||
| 15 | import argparse | ||
| 16 | import imutils | ||
| 17 | import time | ||
| 18 | import cv2 | ||
| 19 | |||
| 20 | def filterFrame(frame): | ||
| 21 | frame = imutils.resize(frame, width=450) | ||
| 22 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | ||
| 23 | frame = np.dstack([frame, frame, frame]) | ||
| 24 | return frame | ||
| 25 | |||
| 26 | # construct the argument parse and parse the arguments | ||
| 27 | ap = argparse.ArgumentParser() | ||
| 28 | ap.add_argument("-v", "--video", required=True, | ||
| 29 | help="path to input video file") | ||
| 30 | args = vars(ap.parse_args()) | ||
| 31 | |||
| 32 | # start the file video stream thread and allow the buffer to | ||
| 33 | # start to fill | ||
| 34 | print("[INFO] starting video file thread...") | ||
| 35 | fvs = FileVideoStream(args["video"], transform=filterFrame).start() | ||
| 36 | time.sleep(1.0) | ||
| 37 | |||
| 38 | # start the FPS timer | ||
| 39 | fps = FPS().start() | ||
| 40 | |||
| 41 | # loop over frames from the video file stream | ||
| 42 | while fvs.running(): | ||
| 43 | # grab the frame from the threaded video file stream, resize | ||
| 44 | # it, and convert it to grayscale (while still retaining 3 | ||
| 45 | # channels) | ||
| 46 | frame = fvs.read() | ||
| 47 | |||
| 48 | # Relocated filtering into producer thread with transform=filterFrame | ||
| 49 | # Python 2.7: FPS 92.11 -> 131.36 | ||
| 50 | # Python 3.7: FPS 41.44 -> 50.11 | ||
| 51 | #frame = filterFrame(frame) | ||
| 52 | |||
| 53 | # display the size of the queue on the frame | ||
| 54 | cv2.putText(frame, "Queue Size: {}".format(fvs.Q.qsize()), | ||
| 55 | (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) | ||
| 56 | |||
| 57 | # show the frame and update the FPS counter | ||
| 58 | cv2.imshow("Frame", frame) | ||
| 59 | |||
| 60 | cv2.waitKey(1) | ||
| 61 | if fvs.Q.qsize() < 2: # If we are low on frames, give time to producer | ||
| 62 | time.sleep(0.001) # Ensures producer runs now, so 2 is sufficient | ||
| 63 | fps.update() | ||
| 64 | |||
| 65 | # stop the timer and display FPS information | ||
| 66 | fps.stop() | ||
| 67 | print("[INFO] elasped time: {:.2f}".format(fps.elapsed())) | ||
| 68 | print("[INFO] approx. FPS: {:.2f}".format(fps.fps())) | ||
| 69 | |||
| 70 | # do a bit of cleanup | ||
| 71 | cv2.destroyAllWindows() | ||
| 72 | fvs.stop() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_convenience.py
0 → 100644
ocr_engine/demos/test_model.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Created Date : 2021-03-05 16:51:22 | ||
| 5 | # @Last Modified : 2021-03-05 18:15:53 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | from turnsole.model import EasyDet | ||
| 9 | |||
| 10 | if __name__ == '__main__': | ||
| 11 | model = EasyDet(phi=0) | ||
| 12 | model.summary() | ||
| 13 | |||
| 14 | import time | ||
| 15 | import numpy as np | ||
| 16 | |||
| 17 | x = np.random.random_sample((1, 640, 640, 3)) | ||
| 18 | # warm up | ||
| 19 | output = model.predict(x) | ||
| 20 | |||
| 21 | print('\n[INFO] Test start') | ||
| 22 | time_start = time.time() | ||
| 23 | for i in range(1000): | ||
| 24 | output = model.predict(x) | ||
| 25 | |||
| 26 | time_end = time.time() | ||
| 27 | print('[INFO] Time used: {:.2f} ms'.format((time_end - time_start)*1000/(i+1))) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_ocr_function.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-07-22 13:10:47 | ||
| 5 | # @Last Modified : 2022-09-08 19:03:24 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | ||
| 10 | |||
| 11 | import cv2 | ||
| 12 | # from turnsole.ocr_engine import angle_detector | ||
| 13 | from turnsole.ocr_engine import object_detector | ||
| 14 | import matplotlib.pyplot as plt | ||
| 15 | |||
| 16 | |||
| 17 | if __name__ == "__main__": | ||
| 18 | |||
| 19 | base_dir = '/home/lk/MyProject/BMW/数据集/文件分类/身份证' | ||
| 20 | |||
| 21 | for (rootDir, dirNames, filenames) in os.walk(base_dir): | ||
| 22 | |||
| 23 | for filename in filenames: | ||
| 24 | |||
| 25 | if not filename.endswith('.jpg'): | ||
| 26 | continue | ||
| 27 | |||
| 28 | img_path = os.path.join(rootDir, filename) | ||
| 29 | print(img_path) | ||
| 30 | |||
| 31 | image = cv2.imread(img_path) | ||
| 32 | |||
| 33 | results = object_detector.process(image) | ||
| 34 | |||
| 35 | print(results) | ||
| 36 | |||
| 37 | for item in results: | ||
| 38 | xmin = item['location']['xmin'] | ||
| 39 | ymin = item['location']['ymin'] | ||
| 40 | xmax = item['location']['xmax'] | ||
| 41 | ymax = item['location']['ymax'] | ||
| 42 | cv2.rectangle(image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) | ||
| 43 | |||
| 44 | plt.imshow(image[...,::-1]) | ||
| 45 | plt.show() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/demos/test_pdf_tools.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-07-22 13:10:47 | ||
| 5 | # @Last Modified : 2022-08-24 15:39:55 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | |||
| 9 | import os | ||
| 10 | import cv2 | ||
| 11 | import fitz | ||
| 12 | from turnsole import pdf_to_images # pip install turnsole PyMuPDF opencv-python==4.4.0.44 | ||
| 13 | |||
| 14 | if __name__ == "__main__": | ||
| 15 | |||
| 16 | base_dir = '/PATH/TO/YOUR/WORKDIR' | ||
| 17 | |||
| 18 | for (rootDir, dirNames, filenames) in os.walk(base_dir): | ||
| 19 | |||
| 20 | for filename in filenames: | ||
| 21 | |||
| 22 | if not filename.endswith('.pdf'): | ||
| 23 | continue | ||
| 24 | |||
| 25 | pdf_path = os.path.join(rootDir, filename) | ||
| 26 | print(pdf_path) | ||
| 27 | |||
| 28 | images = pdf_to_images(pdf_path) | ||
| 29 | images = sum(images, []) | ||
| 30 | |||
| 31 | image_dir = os.path.join(rootDir, filename.replace('.pdf', '')) | ||
| 32 | if not os.path.exists(image_dir): | ||
| 33 | os.makedirs(image_dir) | ||
| 34 | |||
| 35 | for index, image in enumerate(images): | ||
| 36 | |||
| 37 | save_path = os.path.join(image_dir, filename.replace('.pdf', '')+'-'+str(index)+'.jpg') | ||
| 38 | cv2.imwrite(save_path, image) |
ocr_engine/docs/images/image_crop.png
0 → 100644
193 KB
ocr_engine/scripts/api_test.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-05-06 22:02:01 | ||
| 5 | # @Last Modified : 2022-08-03 14:59:51 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | |||
| 9 | import os | ||
| 10 | import time | ||
| 11 | import random | ||
| 12 | import requests | ||
| 13 | import numpy as np | ||
| 14 | from threading import Thread | ||
| 15 | |||
| 16 | |||
| 17 | class API_test: | ||
| 18 | def __init__(self, file_dir, test_time, num_request): | ||
| 19 | |||
| 20 | self.file_paths = [] | ||
| 21 | for fn in os.listdir(file_dir): | ||
| 22 | file_path = os.path.join(file_dir, fn) | ||
| 23 | self.file_paths.append(file_path) | ||
| 24 | |||
| 25 | self.time_start = time.time() | ||
| 26 | self.test_time = test_time * 60 # 单位:秒 | ||
| 27 | threads = [] | ||
| 28 | for i in range(num_request): | ||
| 29 | t = Thread(target=self.update, args=()) | ||
| 30 | threads.append(t) | ||
| 31 | for t in threads: | ||
| 32 | print(f'[INFO] {t} is running') | ||
| 33 | t.start() | ||
| 34 | self.results = list() | ||
| 35 | self.index = 0 | ||
| 36 | |||
| 37 | def update(self): | ||
| 38 | while True: | ||
| 39 | file_path = random.choice(self.file_paths) | ||
| 40 | |||
| 41 | # 二进制方式打开图片文件 | ||
| 42 | data = open(file_path, 'rb') | ||
| 43 | |||
| 44 | t0 = time.time() | ||
| 45 | response = requests.post(url=r'http://localhost:9001/gen_ocr_with_rotation', files={'file': data}) | ||
| 46 | |||
| 47 | # 失败请求统计 | ||
| 48 | if response.status_code != 200: | ||
| 49 | print(response) | ||
| 50 | |||
| 51 | t1 = time.time() | ||
| 52 | self.results.append((t1-t0)) | ||
| 53 | |||
| 54 | time_cost = (time.time() - self.time_start) | ||
| 55 | time_remaining = self.test_time - time_cost | ||
| 56 | |||
| 57 | self.index += 1 | ||
| 58 | |||
| 59 | if time_remaining > 0: | ||
| 60 | print(f'\r[INFO] 剩余时间 {time_remaining} 秒, 平均响应时间 {np.mean(self.results)} 秒, TPS {len(self.results)/time_cost}, 吞吐量 {self.index}', end=' ', flush=True) | ||
| 61 | else: | ||
| 62 | break | ||
| 63 | |||
| 64 | |||
| 65 | if __name__ == '__main__': | ||
| 66 | |||
| 67 | imageDir = './demos/img_ocr' # 测试数据路径 | ||
| 68 | testTime = 10 # 加压时间, 单位:分钟 | ||
| 69 | numRequest = 10 # 并发数,单位:个 | ||
| 70 | |||
| 71 | API_test(imageDir, testTime, numRequest) |
ocr_engine/setup.cfg
0 → 100644
| 1 | [metadata] | ||
| 2 | name = turnsole | ||
| 3 | version = 0.0.27 | ||
| 4 | author = Kui Lyu | ||
| 5 | author_email = 9428.al@gmail.com | ||
| 6 | description = A series of convenience functions make your machine learning project easier | ||
| 7 | long_description = file: README.md | ||
| 8 | long_description_content_type = text/markdown | ||
| 9 | url = https://github.com/Antonio-hi/turnsole | ||
| 10 | project_urls = | ||
| 11 | Bug Tracker = https://github.com/Antonio-hi/turnsole/issues | ||
| 12 | classifiers = | ||
| 13 | Programming Language :: Python :: 3 | ||
| 14 | License :: OSI Approved :: MIT License | ||
| 15 | Operating System :: OS Independent | ||
| 16 | |||
| 17 | [options] | ||
| 18 | packages = find: | ||
| 19 | python_requires = >=3.6 | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/setup.py
0 → 100644
ocr_engine/turnsole.egg-info/PKG-INFO
0 → 100644
| 1 | Metadata-Version: 2.1 | ||
| 2 | Name: turnsole | ||
| 3 | Version: 0.0.27 | ||
| 4 | Summary: A series of convenience functions make your machine learning project easier | ||
| 5 | Home-page: https://github.com/Antonio-hi/turnsole | ||
| 6 | Author: Kui Lyu | ||
| 7 | Author-email: 9428.al@gmail.com | ||
| 8 | License: UNKNOWN | ||
| 9 | Project-URL: Bug Tracker, https://github.com/Antonio-hi/turnsole/issues | ||
| 10 | Platform: UNKNOWN | ||
| 11 | Classifier: Programming Language :: Python :: 3 | ||
| 12 | Classifier: License :: OSI Approved :: MIT License | ||
| 13 | Classifier: Operating System :: OS Independent | ||
| 14 | Requires-Python: >=3.6 | ||
| 15 | Description-Content-Type: text/markdown | ||
| 16 | License-File: LICENSE | ||
| 17 | |||
| 18 | # turnsole | ||
| 19 | A series of convenience functions make your machine learning project easier | ||
| 20 | |||
| 21 | ## 安装方法 | ||
| 22 | |||
| 23 | ### Latest release | ||
| 24 | `pip install turnsole` | ||
| 25 | > 项目暂不开源,因此该安装方法暂时不保证能用 | ||
| 26 | |||
| 27 | ### Developer mode | ||
| 28 | |||
| 29 | `pip install -e .` | ||
| 30 | |||
| 31 | ## 快速上手 | ||
| 32 | ### PDF 操作 | ||
| 33 | #### 智能 PDF 文件转图片 | ||
| 34 | 智能的把 PDF 文件里面的插图找出来,例如没有插图就将整页 PDF 截图下来,也能智能的将碎图拼接在一起 | ||
| 35 | |||
| 36 | ##### Example: | ||
| 37 | <pre># pdf_path 表示 PDF 文件的路径,输出 images 按页码进行汇总输出 | ||
| 38 | images = turnsole.pdf_to_images(pdf_path)</pre> | ||
| 39 | |||
| 40 | ### 图像操作工具箱 | ||
| 41 | #### base64_to_bgr / bgr_to_base64 | ||
| 42 | 图像和 base64 互相转换 | ||
| 43 | |||
| 44 | ##### Example: | ||
| 45 | <pre>image = turnsole.base64_to_bgr(img64) | ||
| 46 | img64 = turnsole.bgr_to_base64(image)</pre> | ||
| 47 | |||
| 48 | ### image_crop | ||
| 49 | 根据 bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
| 50 | |||
| 51 | ##### Example: | ||
| 52 | <pre>im_slice_no_perspective = turnsole.image_crop(image, bbox) | ||
| 53 | im_slice = turnsole.image_crop(image, bbox, perspective=True)</pre> | ||
| 54 | |||
| 55 | ##### Output: | ||
| 56 | |||
| 57 | <img src="docs/images/image_crop.png?raw=true" alt="image crop example" style="max-width: 200px;"> | ||
| 58 | |||
| 59 | ### OCR 引擎模块 | ||
| 60 | OCR 引擎指的是一系列跟 OCR 相关的底层模型,我们提供了这些模型的函数式调用接口和标准 API | ||
| 61 | |||
| 62 | - [x] ADC :tada: | ||
| 63 | - [x] DBNet :tada: | ||
| 64 | - [x] CRNN :tada: | ||
| 65 | - [x] Object Detector :tada: | ||
| 66 | - [x] Signature Detector :tada: | ||
| 67 | |||
| 68 | #### 免费试用 | ||
| 69 | ```python | ||
| 70 | import requests | ||
| 71 | |||
| 72 | results = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': open(file_path, 'rb')}).json() | ||
| 73 | ocr_results = results['ocr_results'] | ||
| 74 | ``` | ||
| 75 | |||
| 76 | #### Prerequisites | ||
| 77 | 由于 OCR 引擎模块依赖于底层神经网络模型,因此需要先用 Docker 挂载底层神经网络模型 | ||
| 78 | |||
| 79 | 首先把 ./model_repository 文件夹和里面的模型放到项目根目录下再启动,如果没有相关模型找 [lvkui](lvkui@situdata.com) 要 | ||
| 80 | |||
| 81 | 使用起来非常简单,你只需要启动对应的 Docker 容器即可 | ||
| 82 | |||
| 83 | ```bash | ||
| 84 | docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v $PWD/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models | ||
| 85 | ``` | ||
| 86 | |||
| 87 | #### ADC | ||
| 88 | 通用文件摆正算法 | ||
| 89 | |||
| 90 | ``` | ||
| 91 | from turnsole.ocr_engine import angle_detector | ||
| 92 | |||
| 93 | image_rotated, direction = angle_detector.ADC(image, fine_degree=False) | ||
| 94 | ``` | ||
| 95 | |||
| 96 | #### DBNet | ||
| 97 | 通用文字检测算法 | ||
| 98 | |||
| 99 | ``` | ||
| 100 | from turnsole.ocr_engine import text_detector | ||
| 101 | |||
| 102 | boxes = text_detector.predict(image) | ||
| 103 | ``` | ||
| 104 | |||
| 105 | #### CRNN | ||
| 106 | 通用文字识别算法 | ||
| 107 | |||
| 108 | ``` | ||
| 109 | from turnsole.ocr_engine import text_recognizer | ||
| 110 | |||
| 111 | ocr_result, ocr_time = text_recognizer.predict_batch(image, boxes) | ||
| 112 | ``` | ||
| 113 | |||
| 114 | #### Object Detector | ||
| 115 | 通用文件检测算法 | ||
| 116 | |||
| 117 | ``` | ||
| 118 | from turnsole.ocr_engine import object_detector | ||
| 119 | |||
| 120 | object_list = object_detector.process(image) | ||
| 121 | ``` | ||
| 122 | |||
| 123 | #### Signature Detector | ||
| 124 | 签字盖章二维码检测算法 | ||
| 125 | |||
| 126 | ``` | ||
| 127 | from turnsole.ocr_engine import signature_detector | ||
| 128 | |||
| 129 | signature_list = signature_detector.process(image) | ||
| 130 | ``` | ||
| 131 | |||
| 132 | #### 标准 API | ||
| 133 | ``` | ||
| 134 | python api/ocr_engine_server.py | ||
| 135 | ``` | ||
| 136 |
ocr_engine/turnsole.egg-info/SOURCES.txt
0 → 100644
| 1 | LICENSE | ||
| 2 | README.md | ||
| 3 | setup.cfg | ||
| 4 | setup.py | ||
| 5 | turnsole/__init__.py | ||
| 6 | turnsole/convenience.py | ||
| 7 | turnsole/encodings.py | ||
| 8 | turnsole/model.py | ||
| 9 | turnsole/paths.py | ||
| 10 | turnsole/pdf_tools.py | ||
| 11 | turnsole.egg-info/PKG-INFO | ||
| 12 | turnsole.egg-info/SOURCES.txt | ||
| 13 | turnsole.egg-info/dependency_links.txt | ||
| 14 | turnsole.egg-info/top_level.txt | ||
| 15 | turnsole/face_utils/__init__.py | ||
| 16 | turnsole/face_utils/agedetector.py | ||
| 17 | turnsole/face_utils/facedetector.py | ||
| 18 | turnsole/nets/__init__.py | ||
| 19 | turnsole/nets/efficientnet.py | ||
| 20 | turnsole/ocr_engine/__init__.py | ||
| 21 | turnsole/ocr_engine/ADC/__init__.py | ||
| 22 | turnsole/ocr_engine/ADC/angle_detector.py | ||
| 23 | turnsole/ocr_engine/CRNN/__init__.py | ||
| 24 | turnsole/ocr_engine/CRNN/alphabets.py | ||
| 25 | turnsole/ocr_engine/CRNN/text_rec.py | ||
| 26 | turnsole/ocr_engine/DBNet/__init__.py | ||
| 27 | turnsole/ocr_engine/DBNet/text_detector.py | ||
| 28 | turnsole/ocr_engine/object_det/__init__.py | ||
| 29 | turnsole/ocr_engine/object_det/utils.py | ||
| 30 | turnsole/ocr_engine/signature_det/__init__.py | ||
| 31 | turnsole/ocr_engine/signature_det/utils.py | ||
| 32 | turnsole/ocr_engine/utils/__init__.py | ||
| 33 | turnsole/ocr_engine/utils/read_data.py | ||
| 34 | turnsole/video/__init__.py | ||
| 35 | turnsole/video/count_frames.py | ||
| 36 | turnsole/video/filevideostream.py | ||
| 37 | turnsole/video/fps.py | ||
| 38 | turnsole/video/pivideostream.py | ||
| 39 | turnsole/video/videostream.py | ||
| 40 | turnsole/video/webcamvideostream.py | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole.egg-info/top_level.txt
0 → 100644
| 1 | turnsole |
ocr_engine/turnsole/__init__.py
0 → 100644
| 1 | try: | ||
| 2 | from . import ocr_engine | ||
| 3 | except: | ||
| 4 | # print('[INFO] OCR engine can not import successful') | ||
| 5 | pass | ||
| 6 | from .convenience import resize | ||
| 7 | from .convenience import resize_with_pad | ||
| 8 | from .convenience import image_crop | ||
| 9 | from .encodings import bytes_to_bgr | ||
| 10 | from .encodings import base64_to_image | ||
| 11 | from .encodings import base64_encode_file | ||
| 12 | from .encodings import base64_encode_image | ||
| 13 | from .encodings import base64_decode_image | ||
| 14 | from .encodings import base64_to_bgr | ||
| 15 | from .encodings import bgr_to_base64 | ||
| 16 | from .pdf_tools import pdf_to_images | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/convenience.py
0 → 100644
| 1 | import cv2 | ||
| 2 | import numpy as np | ||
| 3 | |||
| 4 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): | ||
| 5 | # initialize the dimensions of the image to be resized and grab the image size | ||
| 6 | dim = None | ||
| 7 | (h, w) = image.shape[:2] | ||
| 8 | |||
| 9 | # if both the width and height are None, then return the original image | ||
| 10 | if width is None and height is None: | ||
| 11 | return image | ||
| 12 | |||
| 13 | # check to see if the width is None | ||
| 14 | if width is None: | ||
| 15 | # calculate the ratio of the height and construct the dimensions | ||
| 16 | r = height / float(h) | ||
| 17 | dim = (int(w * r), height) | ||
| 18 | |||
| 19 | # otherwise, the height is None | ||
| 20 | else: | ||
| 21 | # calculate the ratio of the width and construct the dimensions | ||
| 22 | r = width / float(w) | ||
| 23 | dim = (width, int(h * r)) | ||
| 24 | |||
| 25 | # resize the image | ||
| 26 | resized = cv2.resize(image, dim, interpolation=inter) | ||
| 27 | |||
| 28 | # return the resized image | ||
| 29 | return resized | ||
| 30 | |||
| 31 | def resize_with_pad(image, target_width, target_height): | ||
| 32 | """Resuzes and pads an image to a target width and height. | ||
| 33 | |||
| 34 | Resizes an image to a target width and height by keeping the aspect ratio the same | ||
| 35 | without distortion. | ||
| 36 | ratio must be less than 1.0. | ||
| 37 | width and height will pad with zeroes. | ||
| 38 | |||
| 39 | Args: | ||
| 40 | image (Array): RGB/BGR | ||
| 41 | target_width (Int): Target width. | ||
| 42 | target_height (Int): Target height. | ||
| 43 | |||
| 44 | Returns: | ||
| 45 | Array: Resized and padded image. The image paded with zeroes. | ||
| 46 | Float: Image resized ratio. The ratio must be less than 1.0. | ||
| 47 | """ | ||
| 48 | height, width, _ = image.shape | ||
| 49 | |||
| 50 | min_ratio = min(target_height/height, target_width/width) | ||
| 51 | ratio = min_ratio if min_ratio < 1.0 else 1.0 | ||
| 52 | |||
| 53 | # To shrink an image, it will generally look best with INTER_AREA interpolation. | ||
| 54 | resized = cv2.resize(image, None, fx=ratio, fy=ratio, interpolation=cv2.INTER_AREA) | ||
| 55 | h, w, _ = resized.shape | ||
| 56 | canvas = np.zeros((target_height, target_width, 3), image.dtype) | ||
| 57 | canvas[:h, :w, :] = resized | ||
| 58 | return canvas, ratio | ||
| 59 | |||
| 60 | def image_crop(image, bbox, perspective=False): | ||
| 61 | """根据 Bbox 在 image 上进行切片,如果指定 perspective 为 True 则切片方式为透视变换(可以切旋转目标) | ||
| 62 | |||
| 63 | Args: | ||
| 64 | image (array): 三通道图片,切片结果保持原图颜色通道 | ||
| 65 | bbox (array/list): 支持两点矩形框和四点旋转矩形框 | ||
| 66 | 支持以下两种格式: | ||
| 67 | 1. bbox = [xmin, ymin, xmax, ymax] | ||
| 68 | 2. bbox = [x0, y0, x1, y1, x2, y2, x3, y3] | ||
| 69 | perspective (bool, optional): 是否切出旋转目标. Defaults to False. | ||
| 70 | |||
| 71 | Returns: | ||
| 72 | array: 小切图,和原图颜色通道一致 | ||
| 73 | """ | ||
| 74 | # 按照 bbox 的正外接矩形切图 | ||
| 75 | bbox = np.array(bbox, dtype=np.int32).reshape((-1, 2)) | ||
| 76 | xmin, ymin, xmax, ymax = [min(bbox[:, 0]), | ||
| 77 | min(bbox[:, 1]), | ||
| 78 | max(bbox[:, 0]), | ||
| 79 | max(bbox[:, 1])] | ||
| 80 | xmin, ymin = max(0, xmin), max(0, ymin) | ||
| 81 | im_slice = image[ymin:ymax, xmin:xmax, :] | ||
| 82 | |||
| 83 | if perspective and bbox.shape[0] == 4: | ||
| 84 | # 获得旋转矩形的宽和高 | ||
| 85 | w, h = [int(np.linalg.norm(bbox[0] - bbox[1])), | ||
| 86 | int(np.linalg.norm(bbox[3] - bbox[0]))] | ||
| 87 | # 把 bbox 平移到正切图的对应位置上 | ||
| 88 | bbox[:, 0] -= xmin | ||
| 89 | bbox[:, 1] -= ymin | ||
| 90 | # 执行透视切图 | ||
| 91 | pts1 = np.float32(bbox) | ||
| 92 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
| 93 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
| 94 | im_slice = cv2.warpPerspective(im_slice, M, (w, h)) | ||
| 95 | |||
| 96 | return im_slice |
ocr_engine/turnsole/encodings.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Antonio-hi | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2021-08-09 19:08:49 | ||
| 5 | # @Last Modified : 2021-08-10 10:11:06 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | # import the necessary packages | ||
| 9 | import numpy as np | ||
| 10 | import base64 | ||
| 11 | import json | ||
| 12 | import sys | ||
| 13 | import cv2 | ||
| 14 | import os | ||
| 15 | |||
| 16 | def base64_encode_image(a): | ||
| 17 | # return a JSON-encoded list of the base64 encoded image, image data type, and image shape | ||
| 18 | # return json.dumps([base64_encode_array(a), str(a.dtype), a.shape]) | ||
| 19 | return json.dumps([base64_encode_array(a).decode("utf-8"), str(a.dtype), | ||
| 20 | a.shape]) | ||
| 21 | |||
| 22 | def base64_decode_image(a): | ||
| 23 | # grab the array, data type, and shape from the JSON-decoded object | ||
| 24 | (a, dtype, shape) = json.loads(a) | ||
| 25 | |||
| 26 | # set the correct data type and reshape the matrix into an image | ||
| 27 | a = base64_decode_array(a, dtype).reshape(shape) | ||
| 28 | |||
| 29 | # return the loaded image | ||
| 30 | return a | ||
| 31 | |||
| 32 | def base64_encode_array(a): | ||
| 33 | # return the base64 encoded array | ||
| 34 | return base64.b64encode(a) | ||
| 35 | |||
| 36 | def base64_decode_array(a, dtype): | ||
| 37 | # decode and return the array | ||
| 38 | return np.frombuffer(base64.b64decode(a), dtype=dtype) | ||
| 39 | |||
| 40 | def base64_encode_file(image_path): | ||
| 41 | filename = os.path.basename(image_path) | ||
| 42 | # encode image file to base64 string | ||
| 43 | with open(image_path, 'rb') as f: | ||
| 44 | buffer = f.read() | ||
| 45 | # convert bytes buffer string then encode to base64 string | ||
| 46 | img64_bytes = base64.b64encode(buffer) | ||
| 47 | img64_str = img64_bytes.decode('utf-8') # bytes to str | ||
| 48 | return json.dumps({"filename" : filename, "img64": img64_str}) | ||
| 49 | |||
| 50 | def base64_to_image(img64): | ||
| 51 | image_buffer = base64_decode_array(img64, dtype=np.uint8) | ||
| 52 | # In the case of color images, the decoded images will have the channels stored in B G R order. | ||
| 53 | image = cv2.imdecode(image_buffer, cv2.IMREAD_COLOR) | ||
| 54 | return image | ||
| 55 | |||
| 56 | def bytes_to_bgr(buffer: bytes): | ||
| 57 | """Read a byte stream as a OpenCV image | ||
| 58 | |||
| 59 | Args: | ||
| 60 | buffer (TYPE): bytes of a decoded image | ||
| 61 | """ | ||
| 62 | img_array = np.frombuffer(buffer, np.uint8) | ||
| 63 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 64 | return image | ||
| 65 | |||
| 66 | def base64_to_bgr(img64): | ||
| 67 | """把 base64 转换成图片 | ||
| 68 | 单通道的灰度图或四通道的透明图都将自动转换成三通道的 BGR 图 | ||
| 69 | |||
| 70 | Args: | ||
| 71 | img64 (TYPE): Description | ||
| 72 | |||
| 73 | Returns: | ||
| 74 | TYPE: image is a 3-D uint8 Tensor of shape [height, width, channels] where channels is BGR | ||
| 75 | """ | ||
| 76 | encoded_image = base64.b64decode(img64) | ||
| 77 | img_array = np.frombuffer(encoded_image, np.uint8) | ||
| 78 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 79 | return image | ||
| 80 | |||
| 81 | def bgr_to_base64(image): | ||
| 82 | """ 把图片转换成 base64 格式,过程中把图片以 JPEG 格式进行了压缩,通常这会导致图像质量变差 | ||
| 83 | |||
| 84 | Args: | ||
| 85 | image (TYPE): image is a 3-D uint8 or uint16 Tensor of shape [height, width, channels] where channels is BGR | ||
| 86 | |||
| 87 | Returns: | ||
| 88 | TYPE: base64 格式的图片 | ||
| 89 | """ | ||
| 90 | retval, encoded_image = cv2.imencode('.jpg', image) # Encodes an image(BGR) into a memory buffer. | ||
| 91 | img64 = base64.b64encode(encoded_image) | ||
| 92 | return img64.decode('utf-8') | ||
| 93 | |||
| 94 | |||
| 95 | if __name__ == '__main__': | ||
| 96 | |||
| 97 | image_path = '/home/lk/Repository/Project/turnsole/demos/images/sunflower.jpg' | ||
| 98 | |||
| 99 | # 1)将图片文件转换成 base64 base64编码的字符串(理论上支持任意文件) | ||
| 100 | json_str = base64_encode_file(image_path) | ||
| 101 | |||
| 102 | img64_dict = json.loads(json_str) | ||
| 103 | |||
| 104 | suffix = os.path.splitext(img64_dict['filename'])[-1].lower() | ||
| 105 | if suffix not in ['.jpg', '.jpeg', '.png', '.bmp']: | ||
| 106 | print(f'[INFO] 暂不支持格式为 {suffix} 的文件!') | ||
| 107 | |||
| 108 | # 2)将 base64 编码的字符串转成图片 | ||
| 109 | image = base64_to_image(img64_dict['img64']) | ||
| 110 | |||
| 111 | inputs = image/255. | ||
| 112 | |||
| 113 | # 3)自创的, 将 array 转 base64 编码再转回array, 中间不经历图片操作, 还能保持 array 的数据类型 | ||
| 114 | base64_encode_json_string = base64_encode_image(inputs) | ||
| 115 | |||
| 116 | inputs = base64_decode_image(base64_encode_json_string) | ||
| 117 | |||
| 118 | print(inputs) | ||
| 119 | |||
| 120 | # 3、字符串前加 b | ||
| 121 | # 例: response = b'<h1>Hello World!</h1>' # b' ' 表示这是一个 bytes 对象 | ||
| 122 | |||
| 123 | # 作用: | ||
| 124 | |||
| 125 | # b" "前缀表示:后面字符串是bytes 类型。 | ||
| 126 | |||
| 127 | # 用处: | ||
| 128 | |||
| 129 | # 网络编程中,服务器和浏览器只认bytes 类型数据。 | ||
| 130 | |||
| 131 | # 如:send 函数的参数和 recv 函数的返回值都是 bytes 类型 | ||
| 132 | |||
| 133 | # 附: | ||
| 134 | |||
| 135 | # 在 Python3 中,bytes 和 str 的互相转换方式是 | ||
| 136 | # str.encode('utf-8') | ||
| 137 | # bytes.decode('utf-8') |
ocr_engine/turnsole/face_utils/__init__.py
0 → 100644
File mode changed
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : lk | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2021-08-11 17:10:16 | ||
| 5 | # @Last Modified : 2021-08-12 16:14:53 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | import tensorflow as tf | ||
| 10 | |||
| 11 | class AgeDetector: | ||
| 12 | def __init__(self, model_path): | ||
| 13 | self.age_map = { | ||
| 14 | 0: '0-2', | ||
| 15 | 1: '4-6', | ||
| 16 | 2: '8-13', | ||
| 17 | 3: '15-20', | ||
| 18 | 4: '25-32', | ||
| 19 | 5: '38-43', | ||
| 20 | 6: '48-53', | ||
| 21 | 7: '60+' | ||
| 22 | } | ||
| 23 | |||
| 24 | self.model = tf.keras.models.load_model(filepath=model_path, | ||
| 25 | compile=False) | ||
| 26 | self.inference_model = self.build_inference_model() | ||
| 27 | |||
| 28 | def build_inference_model(self): | ||
| 29 | image = self.model.input | ||
| 30 | x = tf.keras.applications.mobilenet_v2.preprocess_input(image) | ||
| 31 | predictions = self.model(x, training=False) | ||
| 32 | inference_model = tf.keras.Model(inputs=image, outputs=predictions) | ||
| 33 | return inference_model | ||
| 34 | |||
| 35 | def predict_batch(self, images): | ||
| 36 | # 输入一个人脸图片列表,列表不应为空 | ||
| 37 | images = tf.stack([tf.image.resize(image, [96, 96]) for image in images], axis=0) | ||
| 38 | preds = self.inference_model.predict(images) | ||
| 39 | indexes = tf.argmax(preds, axis=-1) | ||
| 40 | classes = [self.age_map[index.numpy()] for index in indexes] | ||
| 41 | return classes | ||
| 42 | |||
| 43 | if __name__ == '__main__': | ||
| 44 | |||
| 45 | import cv2 | ||
| 46 | from turnsole import paths | ||
| 47 | |||
| 48 | age_det = AGE_DETECTION(model_path='./ckpt/age_detector.h5') | ||
| 49 | |||
| 50 | data_dir = '/home/lk/Project/Face_Age_Gender/data/Emotion/emotion/010003_female_yellow_22' | ||
| 51 | |||
| 52 | for image_path in paths.list_images(data_dir): | ||
| 53 | image = cv2.imread(image_path) | ||
| 54 | classes = age_det.predict_batch([image]) | ||
| 55 | |||
| 56 | print(classes) | ||
| 57 |
This diff is collapsed.
Click to expand it.
ocr_engine/turnsole/model.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Created Date : 2021-02-24 13:58:46 | ||
| 5 | # @Last Modified : 2021-03-05 18:14:17 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import tensorflow as tf | ||
| 9 | |||
| 10 | from .nets.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3 | ||
| 11 | from .nets.efficientnet import EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7 | ||
| 12 | |||
| 13 | def load_backbone(phi, input_tensor, weights='imagenet'): | ||
| 14 | if phi == 0: | ||
| 15 | model = EfficientNetB0(include_top=False, | ||
| 16 | weights=weights, | ||
| 17 | input_tensor=input_tensor) | ||
| 18 | # 从这些层提取特征 | ||
| 19 | layer_names = [ | ||
| 20 | 'block2b_add', # 1/4 | ||
| 21 | 'block3b_add', # 1/8 | ||
| 22 | 'block5c_add', # 1/16 | ||
| 23 | 'block7a_project_bn', # 1/32 | ||
| 24 | ] | ||
| 25 | elif phi == 1: | ||
| 26 | model = EfficientNetB1(include_top=False, | ||
| 27 | weights=weights, | ||
| 28 | input_tensor=input_tensor) | ||
| 29 | layer_names = [ | ||
| 30 | 'block2c_add', # 1/4 | ||
| 31 | 'block3c_add', # 1/8 | ||
| 32 | 'block5d_add', # 1/16 | ||
| 33 | 'block7b_add', # 1/32 | ||
| 34 | ] | ||
| 35 | elif phi == 2: | ||
| 36 | model = EfficientNetB2(include_top=False, | ||
| 37 | weights=weights, | ||
| 38 | input_tensor=input_tensor) | ||
| 39 | layer_names = [ | ||
| 40 | 'block2c_add', # 1/4 | ||
| 41 | 'block3c_add', # 1/8 | ||
| 42 | 'block5d_add', # 1/16 | ||
| 43 | 'block7b_add', # 1/32 | ||
| 44 | ] | ||
| 45 | elif phi == 3: | ||
| 46 | model = EfficientNetB3(include_top=False, | ||
| 47 | weights=weights, | ||
| 48 | input_tensor=input_tensor) | ||
| 49 | layer_names = [ | ||
| 50 | 'block2c_add', # 1/4 | ||
| 51 | 'block3c_add', # 1/8 | ||
| 52 | 'block5e_add', # 1/16 | ||
| 53 | 'block7b_add', # 1/32 | ||
| 54 | ] | ||
| 55 | elif phi == 4: | ||
| 56 | model = EfficientNetB4(include_top=False, | ||
| 57 | weights=weights, | ||
| 58 | input_tensor=input_tensor) | ||
| 59 | layer_names = [ | ||
| 60 | 'block2c_add', # 1/4 | ||
| 61 | 'block3d_add', # 1/8 | ||
| 62 | 'block5f_add', # 1/16 | ||
| 63 | 'block7b_add', # 1/32 | ||
| 64 | ] | ||
| 65 | elif phi == 5: | ||
| 66 | model = EfficientNetB5(include_top=False, | ||
| 67 | weights=weights, | ||
| 68 | input_tensor=input_tensor) | ||
| 69 | layer_names = [ | ||
| 70 | 'block2e_add', # 1/4 | ||
| 71 | 'block3e_add', # 1/8 | ||
| 72 | 'block5g_add', # 1/16 | ||
| 73 | 'block7c_add', # 1/32 | ||
| 74 | ] | ||
| 75 | elif phi == 6: | ||
| 76 | model = EfficientNetB6(include_top=False, | ||
| 77 | weights=weights, | ||
| 78 | input_tensor=input_tensor) | ||
| 79 | layer_names = [ | ||
| 80 | 'block2f_add', # 1/4 | ||
| 81 | 'block3f_add', # 1/8 | ||
| 82 | 'block5h_add', # 1/16 | ||
| 83 | 'block7c_add', # 1/32 | ||
| 84 | ] | ||
| 85 | elif phi == 7: | ||
| 86 | model = EfficientNetB7(include_top=False, | ||
| 87 | weights=weights, | ||
| 88 | input_tensor=input_tensor) | ||
| 89 | layer_names = [ | ||
| 90 | 'block2g_add', # 1/4 | ||
| 91 | 'block3g_add', # 1/8 | ||
| 92 | 'block5j_add', # 1/16 | ||
| 93 | 'block7d_add', # 1/32 | ||
| 94 | ] | ||
| 95 | |||
| 96 | skips = [model.get_layer(name).output for name in layer_names] | ||
| 97 | return model, skips | ||
| 98 | |||
| 99 | def EasyDet(phi=0, input_size=(None, None, 3), weights='imagenet'): | ||
| 100 | image_input = tf.keras.layers.Input(shape=input_size) | ||
| 101 | |||
| 102 | backbone, skips = load_backbone(phi=phi, input_tensor=image_input, weights=weights) | ||
| 103 | C2, C3, C4, C5 = skips | ||
| 104 | |||
| 105 | in2 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in2')(C2) | ||
| 106 | in3 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in3')(C3) | ||
| 107 | in4 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in4')(C4) | ||
| 108 | in5 = tf.keras.layers.Conv2D(256, (1, 1), padding='same', kernel_initializer='he_normal', name='in5')(C5) | ||
| 109 | |||
| 110 | # 1 / 32 * 8 = 1 / 4 | ||
| 111 | P5 = tf.keras.layers.UpSampling2D(size=(8, 8))( | ||
| 112 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(in5)) | ||
| 113 | # 1 / 16 * 4 = 1 / 4 | ||
| 114 | out4 = tf.keras.layers.Add()([in4, tf.keras.layers.UpSampling2D(size=(2, 2))(in5)]) | ||
| 115 | P4 = tf.keras.layers.UpSampling2D(size=(4, 4))( | ||
| 116 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out4)) | ||
| 117 | # 1 / 8 * 2 = 1 / 4 | ||
| 118 | out3 = tf.keras.layers.Add()([in3, tf.keras.layers.UpSampling2D(size=(2, 2))(out4)]) | ||
| 119 | P3 = tf.keras.layers.UpSampling2D(size=(2, 2))( | ||
| 120 | tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(out3)) | ||
| 121 | # 1 / 4 | ||
| 122 | P2 = tf.keras.layers.Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')( | ||
| 123 | tf.keras.layers.Add()([in2, tf.keras.layers.UpSampling2D(size=(2, 2))(out3)])) | ||
| 124 | # (b, 1/4, 1/4, 256) | ||
| 125 | fuse = tf.keras.layers.Concatenate()([P2, P3, P4, P5]) | ||
| 126 | |||
| 127 | model = tf.keras.models.Model(inputs=image_input, outputs=fuse) | ||
| 128 | return model | ||
| 129 | |||
| 130 | |||
| 131 | if __name__ == '__main__': | ||
| 132 | model = EasyDet(phi=0) | ||
| 133 | model.summary() | ||
| 134 | |||
| 135 | import time | ||
| 136 | import numpy as np | ||
| 137 | |||
| 138 | x = np.random.random_sample((1, 640, 640, 3)) | ||
| 139 | # warm up | ||
| 140 | output = model.predict(x) | ||
| 141 | |||
| 142 | print('\n[INFO] Test start') | ||
| 143 | time_start = time.time() | ||
| 144 | for i in range(1000): | ||
| 145 | output = model.predict(x) | ||
| 146 | |||
| 147 | time_end = time.time() | ||
| 148 | print('[INFO] Time used: {:.2f} ms'.format((time_end - time_start)*1000/(i+1))) |
ocr_engine/turnsole/nets/__init__.py
0 → 100644
File mode changed
ocr_engine/turnsole/nets/efficientnet.py
0 → 100644
This diff is collapsed.
Click to expand it.
| 1 | from . import angle_detector | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : lk | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Created Date : 2019-09-03 15:40:54 | ||
| 5 | # @Last Modified : 2022-07-18 16:10:36 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | import cv2 | ||
| 10 | import time | ||
| 11 | import numpy as np | ||
| 12 | # import tensorflow as tf | ||
| 13 | |||
| 14 | # import grpc | ||
| 15 | # from tensorflow_serving.apis import predict_pb2 | ||
| 16 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
| 17 | |||
| 18 | import tritonclient.grpc as grpcclient | ||
| 19 | |||
| 20 | |||
| 21 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): | ||
| 22 | ''' | ||
| 23 | Resize the input image according to the dimensions and keep aspect ratio of this image | ||
| 24 | ''' | ||
| 25 | dim = None | ||
| 26 | (h, w) = image.shape[:2] | ||
| 27 | |||
| 28 | # if both the width and height are None, then return the original image | ||
| 29 | if width is None and height is None: | ||
| 30 | return image | ||
| 31 | |||
| 32 | # check to see if the width is None | ||
| 33 | if width is None: | ||
| 34 | # calculate the ratio of the height and construct the dimensions | ||
| 35 | r = height / float(h) | ||
| 36 | dim = (int(w * r), height) | ||
| 37 | |||
| 38 | # otherwise, the height is None | ||
| 39 | else: | ||
| 40 | # calculate the ratio of the width and construct the dimensions | ||
| 41 | r = width / float(w) | ||
| 42 | dim = (width, int(h * r)) | ||
| 43 | |||
| 44 | # resize the image | ||
| 45 | resized = cv2.resize(image, dim, interpolation=inter) | ||
| 46 | |||
| 47 | return resized | ||
| 48 | |||
| 49 | def predict(image): | ||
| 50 | |||
| 51 | ROTATE = [0, 90, 180, 270] | ||
| 52 | |||
| 53 | # pre-process the image for classification | ||
| 54 | # Test 1: 直接resize到目标尺寸 | ||
| 55 | # image = cv2.resize(image, (512, 512)) | ||
| 56 | |||
| 57 | # Test 2: 按照短边resize到目标尺寸,长边按比例缩放 | ||
| 58 | short_side = 768 | ||
| 59 | if min(image.shape[:2]) > short_side: | ||
| 60 | image = resize(image, width=short_side) if image.shape[0] > image.shape[1] else resize(image, height=short_side) | ||
| 61 | |||
| 62 | # Test 3: 带padding的resize策略 | ||
| 63 | # image = resize_image_with_pad(image, 1024, 1024) | ||
| 64 | |||
| 65 | # Test 4: 直接使用原图 | ||
| 66 | # image = image | ||
| 67 | |||
| 68 | image = np.array(image, dtype="float32") | ||
| 69 | image = 2 * (image / 255.0) - 1 # Let data input to be normalized to the [-1,1] range | ||
| 70 | input_data = np.expand_dims(image, 0) | ||
| 71 | |||
| 72 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
| 73 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
| 74 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
| 75 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
| 76 | |||
| 77 | # request = predict_pb2.PredictRequest() | ||
| 78 | # request.model_spec.name = 'adc_model' | ||
| 79 | # request.model_spec.signature_name = 'serving_default' | ||
| 80 | # request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(inputs)) | ||
| 81 | |||
| 82 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
| 83 | |||
| 84 | # preds = tf.make_ndarray(result.outputs['dense']) | ||
| 85 | |||
| 86 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
| 87 | |||
| 88 | # Initialize the data | ||
| 89 | inputs = [grpcclient.InferInput('input_1', input_data.shape, "FP32")] # [InferInput 类的一个对象用于描述推理请求的输入张量。] | ||
| 90 | inputs[0].set_data_from_numpy(input_data) # 从指定的numpy数组中获取张量数据与此对象关联的输入 | ||
| 91 | outputs = [grpcclient.InferRequestedOutput("dense")] | ||
| 92 | |||
| 93 | # Inference | ||
| 94 | results = triton_client.infer( | ||
| 95 | model_name="adc_model", | ||
| 96 | inputs=inputs, | ||
| 97 | outputs=outputs | ||
| 98 | ) | ||
| 99 | # Get the output arrays from the results | ||
| 100 | preds = results.as_numpy("dense") | ||
| 101 | |||
| 102 | index = np.argmax(preds, axis=-1)[0] | ||
| 103 | |||
| 104 | return index | ||
| 105 | # return ROTATE[index] | ||
| 106 | |||
| 107 | def DegreeTrans(theta): | ||
| 108 | ''' | ||
| 109 | Convert radians to angles | ||
| 110 | ''' | ||
| 111 | res = theta / np.pi * 180 | ||
| 112 | return res | ||
| 113 | |||
| 114 | def rotateImage(src, degree): | ||
| 115 | ''' | ||
| 116 | Calculate the rotation matrix and rotate the image | ||
| 117 | param src:image after rot90 | ||
| 118 | param degree:the Hough degree | ||
| 119 | ''' | ||
| 120 | h, w = src.shape[:2] | ||
| 121 | RotateMatrix = cv2.getRotationMatrix2D((w/2.0, h/2.0), degree, 1) | ||
| 122 | # affine transformation, background color fills white | ||
| 123 | rotate = cv2.warpAffine(src, RotateMatrix, (w, h), borderValue=(255, 255, 255)) | ||
| 124 | return rotate | ||
| 125 | |||
| 126 | def CalcDegree(srcImage): | ||
| 127 | ''' | ||
| 128 | Calculating angles by Hough transform | ||
| 129 | param srcImage:image after rot90 | ||
| 130 | ''' | ||
| 131 | midImage = cv2.cvtColor(srcImage, cv2.COLOR_BGR2GRAY) | ||
| 132 | dstImage = cv2.Canny(midImage, 100, 300, 3) | ||
| 133 | lineimage = srcImage.copy() | ||
| 134 | |||
| 135 | # 通过霍夫变换检测直线 | ||
| 136 | # 第4个参数(th)就是阈值,阈值越大,检测精度越高 | ||
| 137 | th = 500 | ||
| 138 | while True: | ||
| 139 | if th > 0: | ||
| 140 | lines = cv2.HoughLines(dstImage, 1, np.pi/180, th) | ||
| 141 | else: | ||
| 142 | lines = None | ||
| 143 | break | ||
| 144 | if lines is not None: | ||
| 145 | if len(lines) > 10: | ||
| 146 | break | ||
| 147 | else: | ||
| 148 | th -= 50 | ||
| 149 | # print ('阈值是:', th) | ||
| 150 | else: | ||
| 151 | th -= 100 | ||
| 152 | # print ('阈值是:', th) | ||
| 153 | continue | ||
| 154 | |||
| 155 | sum_theta = 0 | ||
| 156 | num_theta = 0 | ||
| 157 | if lines is not None: | ||
| 158 | for i in range(len(lines)): | ||
| 159 | for rho, theta in lines[i]: | ||
| 160 | # control the angle of line between -30 to +30 | ||
| 161 | if theta > 1 and theta < 2.1: | ||
| 162 | sum_theta += theta | ||
| 163 | num_theta += 1 | ||
| 164 | # Average all angles | ||
| 165 | if num_theta == 0: | ||
| 166 | average = np.pi/2 | ||
| 167 | else: | ||
| 168 | average = sum_theta / num_theta | ||
| 169 | |||
| 170 | return DegreeTrans(average) - 90 | ||
| 171 | |||
| 172 | def ADC(image, fine_degree=False): | ||
| 173 | ''' | ||
| 174 | return param rotate: Corrected image | ||
| 175 | return param angle_degree:image offset image | ||
| 176 | ''' | ||
| 177 | |||
| 178 | # Return a wide angle index | ||
| 179 | img = np.copy(image) | ||
| 180 | angle_index = predict(img) | ||
| 181 | img_rot = np.rot90(img, -angle_index) | ||
| 182 | |||
| 183 | # if fine_degree then the image will be corrected more accurately based on character line features. | ||
| 184 | if fine_degree: | ||
| 185 | degree = CalcDegree(img_rot) | ||
| 186 | angle_degree = (angle_index * 90 - degree) % 360 | ||
| 187 | rotate = rotateImage(img_rot, degree) | ||
| 188 | return rotate, angle_degree | ||
| 189 | |||
| 190 | return img_rot, int(angle_index*90) |
This diff is collapsed.
Click to expand it.
| 1 | import cv2 | ||
| 2 | import time | ||
| 3 | import numpy as np | ||
| 4 | from .alphabets import alphabet | ||
| 5 | import tritonclient.grpc as grpcclient | ||
| 6 | |||
| 7 | |||
| 8 | def sort_poly(p): | ||
| 9 | # Find the minimum coordinate using (Xi+Yi) | ||
| 10 | min_axis = np.argmin(np.sum(p, axis=1)) | ||
| 11 | # Sort the box coordinates | ||
| 12 | p = p[[min_axis, (min_axis + 1) % 4, (min_axis + 2) % 4, (min_axis + 3) % 4]] | ||
| 13 | if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]): | ||
| 14 | return p | ||
| 15 | else: | ||
| 16 | return p[[0, 3, 2, 1]] | ||
| 17 | |||
| 18 | def client_init(url="localhost:8001", | ||
| 19 | ssl=False, private_key=None, root_certificates=None, certificate_chain=None, | ||
| 20 | verbose=False): | ||
| 21 | triton_client = grpcclient.InferenceServerClient( | ||
| 22 | url=url, | ||
| 23 | verbose=verbose, | ||
| 24 | ssl=ssl, | ||
| 25 | root_certificates=root_certificates, | ||
| 26 | private_key=private_key, | ||
| 27 | certificate_chain=certificate_chain) | ||
| 28 | return triton_client | ||
| 29 | |||
| 30 | class textRecServer: | ||
| 31 | """_summary_ | ||
| 32 | """ | ||
| 33 | def __init__(self): | ||
| 34 | super().__init__() | ||
| 35 | self.charactersS = ' ' + alphabet | ||
| 36 | self.batchsize = 8 | ||
| 37 | |||
| 38 | self.input_name = 'INPUT__0' | ||
| 39 | self.output_name = 'OUTPUT__0' | ||
| 40 | self.model_name = 'text_rec_torch' | ||
| 41 | self.np_type = np.float32 | ||
| 42 | self.quant_type = "FP32" | ||
| 43 | self.compression_algorithm = None | ||
| 44 | self.outputs = [] | ||
| 45 | self.outputs.append(grpcclient.InferRequestedOutput(self.output_name)) | ||
| 46 | |||
| 47 | def preprocess_one_image(self, image): | ||
| 48 | _, w, _ = image.shape | ||
| 49 | image = self._transform(image, w) | ||
| 50 | return image | ||
| 51 | |||
| 52 | def predict_batch(self, im, boxes): | ||
| 53 | """Summary | ||
| 54 | |||
| 55 | Args: | ||
| 56 | im (TYPE): RGB | ||
| 57 | boxes (TYPE): Description | ||
| 58 | |||
| 59 | Returns: | ||
| 60 | TYPE: Description | ||
| 61 | """ | ||
| 62 | |||
| 63 | triton_client = client_init("localhost:8001") | ||
| 64 | count_boxes = len(boxes) | ||
| 65 | boxes = sorted(boxes, | ||
| 66 | key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))), | ||
| 67 | reverse=True) | ||
| 68 | |||
| 69 | results = {} | ||
| 70 | labels = [] | ||
| 71 | rectime = 0.0 | ||
| 72 | if len(boxes) != 0: | ||
| 73 | for i in range(len(boxes) // self.batchsize + int(len(boxes) % self.batchsize != 0)): | ||
| 74 | box = boxes[min(len(boxes)-1, i * self.batchsize)] | ||
| 75 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
| 76 | width = max(32, min(int(32.0 * w / h), 960)) | ||
| 77 | if width < 32: | ||
| 78 | continue | ||
| 79 | slices = [] | ||
| 80 | for index, box in enumerate(boxes[i * self.batchsize:(i + 1) * self.batchsize]): | ||
| 81 | _box = [n for a in box for n in a] | ||
| 82 | if i * self.batchsize + index < count_boxes: | ||
| 83 | results[i * self.batchsize + index] = [list(map(int, _box))] | ||
| 84 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
| 85 | pts1 = np.float32(box) | ||
| 86 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
| 87 | |||
| 88 | # 前处理优化 | ||
| 89 | xmin, ymin, _w, _h = cv2.boundingRect(pts1) | ||
| 90 | xmax, ymax = xmin+_w, ymin+_h | ||
| 91 | xmin, ymin = max(0, xmin), max(0, ymin) | ||
| 92 | im_sclice = im[int(ymin):int(ymax), int(xmin):int(xmax), :] | ||
| 93 | pts1[:, 0] -= xmin | ||
| 94 | pts1[:, 1] -= ymin | ||
| 95 | |||
| 96 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
| 97 | im_crop = cv2.warpPerspective(im_sclice, M, (w, h)) | ||
| 98 | im_crop = self._transform(im_crop, width) | ||
| 99 | slices.append(im_crop) | ||
| 100 | start_rec = time.time() | ||
| 101 | slices = self.np_type(slices) | ||
| 102 | slices = slices.transpose(0, 3, 1, 2) | ||
| 103 | slices = slices/127.5-1. | ||
| 104 | inputs = [] | ||
| 105 | inputs.append(grpcclient.InferInput(self.input_name, list(slices.shape), self.quant_type)) | ||
| 106 | inputs[0].set_data_from_numpy(slices) | ||
| 107 | |||
| 108 | # inference | ||
| 109 | preds = triton_client.infer( | ||
| 110 | model_name=self.model_name, | ||
| 111 | inputs=inputs, | ||
| 112 | outputs=self.outputs, | ||
| 113 | compression_algorithm=self.compression_algorithm | ||
| 114 | ) | ||
| 115 | preds = preds.as_numpy(self.output_name).copy() | ||
| 116 | preds = preds.transpose(1, 0) | ||
| 117 | tmp_labels = self.decode(preds) | ||
| 118 | rectime += (time.time() - start_rec) | ||
| 119 | labels.extend(tmp_labels) | ||
| 120 | |||
| 121 | for index, label in enumerate(labels[:count_boxes]): | ||
| 122 | label = label.replace(' ', '').replace('¥', '¥') | ||
| 123 | if label == '': | ||
| 124 | del results[index] | ||
| 125 | continue | ||
| 126 | results[index].append(label) | ||
| 127 | # 重新排序 | ||
| 128 | results = list(results.values()) | ||
| 129 | results = sorted(results, key=lambda x: x[0][1], reverse=False) # 按 y0 从小到大排 | ||
| 130 | keys = [str(i) for i in range(len(results))] | ||
| 131 | results = dict(zip(keys, results)) | ||
| 132 | else: | ||
| 133 | results = dict() | ||
| 134 | rectime = -1 | ||
| 135 | return results, rectime | ||
| 136 | |||
| 137 | def decode(self, preds): | ||
| 138 | res = [] | ||
| 139 | for t in preds: | ||
| 140 | length = len(t) | ||
| 141 | char_list = [] | ||
| 142 | for i in range(length): | ||
| 143 | if t[i] != 0 and (not (i > 0 and t[i-1] == t[i])): | ||
| 144 | char_list.append(self.charactersS[t[i]]) | ||
| 145 | res.append(u''.join(char_list)) | ||
| 146 | return res | ||
| 147 | |||
| 148 | def _transform(self, im, width): | ||
| 149 | height=32 | ||
| 150 | |||
| 151 | ori_h, ori_w = im.shape[:2] | ||
| 152 | ratio1 = width * 1.0 / ori_w | ||
| 153 | ratio2 = height * 1.0 / ori_h | ||
| 154 | if ratio1 < ratio2: | ||
| 155 | ratio = ratio1 | ||
| 156 | else: | ||
| 157 | ratio = ratio2 | ||
| 158 | new_w, new_h = int(ori_w * ratio), int(ori_h * ratio) | ||
| 159 | if new_w<4: | ||
| 160 | new_w = 4 | ||
| 161 | im = cv2.resize(im, (new_w, new_h)) | ||
| 162 | img = np.ones((height, width, 3), dtype=np.uint8)*230 | ||
| 163 | img[:im.shape[0], :im.shape[1], :] = im | ||
| 164 | return img |
| 1 | from . import text_detector | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-06-01 19:00:18 | ||
| 5 | # @Last Modified : 2022-07-15 11:41:25 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | import cv2 | ||
| 10 | import time | ||
| 11 | import pyclipper | ||
| 12 | import numpy as np | ||
| 13 | # import tensorflow as tf | ||
| 14 | from shapely.geometry import Polygon | ||
| 15 | |||
| 16 | # import grpc | ||
| 17 | # from tensorflow_serving.apis import predict_pb2 | ||
| 18 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
| 19 | |||
| 20 | import tritonclient.grpc as grpcclient | ||
| 21 | |||
| 22 | |||
| 23 | def resize_with_padding(src, limit_max=1024): | ||
| 24 | '''限制长边不大于 limit_max 短边等比例缩放,以 0 填充''' | ||
| 25 | img = src.copy() | ||
| 26 | |||
| 27 | h, w, _ = img.shape | ||
| 28 | max_side = max(h, w) | ||
| 29 | ratio = limit_max / max_side if max_side > limit_max else 1 | ||
| 30 | h, w = int(h * ratio), int(w * ratio) | ||
| 31 | proc = cv2.resize(img, (w, h)) | ||
| 32 | |||
| 33 | canvas = np.zeros((limit_max, limit_max, 3), dtype=np.float32) | ||
| 34 | canvas[0:h, 0:w, :] = proc | ||
| 35 | return canvas, ratio | ||
| 36 | |||
| 37 | def rectangle_boxes_zoom(boxes, offset=1): | ||
| 38 | '''Scale the rectangle boxes via offset | ||
| 39 | Input: | ||
| 40 | boxes: with shape (-1, 4, 2) | ||
| 41 | offset: how many pix do you wanna zoom, we recommend less than 5 | ||
| 42 | Output: | ||
| 43 | boxes: zoomed | ||
| 44 | ''' | ||
| 45 | boxes = np.array(boxes) | ||
| 46 | boxes += [[[-offset,-offset], [offset,-offset], [offset,offset], [-offset,offset]]] | ||
| 47 | return boxes | ||
| 48 | |||
| 49 | def polygons_from_probmap(preds, ratio): | ||
| 50 | # 二值化 | ||
| 51 | prob_map_pred = np.array(preds, dtype=np.uint8)[0,:,:,0] | ||
| 52 | # 输入:二值图、轮廓检索(层次)模式、轮廓渐进方法 | ||
| 53 | # 输出:轮廓、层级关系 | ||
| 54 | contours, hierarchy = cv2.findContours(prob_map_pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | ||
| 55 | |||
| 56 | boxes = [] | ||
| 57 | for contour in contours: | ||
| 58 | if len(contour) < 4: | ||
| 59 | continue | ||
| 60 | |||
| 61 | # Vatti clipping | ||
| 62 | polygon = Polygon(np.array(contour).reshape((-1, 2))).buffer(0) | ||
| 63 | polygon = polygon.convex_hull if polygon.type == 'MultiPolygon' else polygon # Note: 这里不是 bug 是我们故意而为之 | ||
| 64 | |||
| 65 | if polygon.area < 10: | ||
| 66 | continue | ||
| 67 | |||
| 68 | distance = polygon.area * 1.5 / polygon.length | ||
| 69 | offset = pyclipper.PyclipperOffset() | ||
| 70 | offset.AddPath(list(polygon.exterior.coords), pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) | ||
| 71 | expanded = np.array(offset.Execute(distance)[0]) # Note: 这里不是 bug 是我们故意而为之 | ||
| 72 | |||
| 73 | # Convert polygon to rectangle | ||
| 74 | rect = cv2.minAreaRect(expanded) | ||
| 75 | box = cv2.boxPoints(rect) | ||
| 76 | # make clock-wise order | ||
| 77 | box = np.roll(box, 4-box.sum(axis=1).argmin(), 0) | ||
| 78 | box = np.array(box/ratio, dtype=np.int32) | ||
| 79 | boxes.append(box) | ||
| 80 | |||
| 81 | return boxes | ||
| 82 | |||
| 83 | def predict(image): | ||
| 84 | |||
| 85 | image_resized, ratio = resize_with_padding(image, limit_max=1280) | ||
| 86 | input_data = np.expand_dims(image_resized/255., axis=0) | ||
| 87 | |||
| 88 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
| 89 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
| 90 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
| 91 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
| 92 | |||
| 93 | # request = predict_pb2.PredictRequest() | ||
| 94 | # request.model_spec.name = 'dbnet_model' | ||
| 95 | # request.model_spec.signature_name = 'serving_default' | ||
| 96 | # request.inputs['input_1'].CopyFrom(tf.make_tensor_proto(inputs)) | ||
| 97 | |||
| 98 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
| 99 | |||
| 100 | # preds = tf.make_ndarray(result.outputs['tf.math.greater']) | ||
| 101 | |||
| 102 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
| 103 | |||
| 104 | # Initialize the data | ||
| 105 | inputs = [grpcclient.InferInput('input_1', input_data.shape, "FP32")] | ||
| 106 | inputs[0].set_data_from_numpy(input_data) | ||
| 107 | outputs = [grpcclient.InferRequestedOutput("tf.math.greater")] | ||
| 108 | |||
| 109 | # Inference | ||
| 110 | results = triton_client.infer( | ||
| 111 | model_name="dbnet_model", | ||
| 112 | inputs=inputs, | ||
| 113 | outputs=outputs | ||
| 114 | ) | ||
| 115 | # Get the output arrays from the results | ||
| 116 | preds = results.as_numpy("tf.math.greater") | ||
| 117 | |||
| 118 | boxes = polygons_from_probmap(preds, ratio) | ||
| 119 | #boxes = rectangle_boxes_zoom(boxes, offset=0) | ||
| 120 | |||
| 121 | return boxes |
ocr_engine/turnsole/ocr_engine/__init__.py
0 → 100644
| 1 | # import grpc | ||
| 2 | import turnsole | ||
| 3 | import numpy as np | ||
| 4 | # import tensorflow as tf | ||
| 5 | # from tensorflow_serving.apis import predict_pb2, prediction_service_pb2_grpc | ||
| 6 | |||
| 7 | import tritonclient.grpc as grpcclient | ||
| 8 | |||
| 9 | |||
| 10 | class ObjectDetection(): | ||
| 11 | |||
| 12 | """通用文件检测算法 | ||
| 13 | 输入图片输出检测结果 | ||
| 14 | |||
| 15 | API 文档请参阅: | ||
| 16 | """ | ||
| 17 | |||
| 18 | def __init__(self, confidence_threshold=0.5): | ||
| 19 | """初始化检测对象 | ||
| 20 | |||
| 21 | Args: | ||
| 22 | confidence_threshold (float, optional): 目标检测模型的分类置信度 | ||
| 23 | """ | ||
| 24 | |||
| 25 | self.lable2index = { | ||
| 26 | 'id_card_info': 0, | ||
| 27 | 'id_card_guohui': 1, | ||
| 28 | 'lssfz_front': 2, | ||
| 29 | 'lssfz_back': 3, | ||
| 30 | 'jzz_front': 4, | ||
| 31 | 'jzz_back': 5, | ||
| 32 | 'txz_front': 6, | ||
| 33 | 'txz_back': 7, | ||
| 34 | 'bank_card': 8, | ||
| 35 | 'vehicle_license_front': 9, | ||
| 36 | 'vehicle_license_back': 10, | ||
| 37 | 'driving_license_front': 11, | ||
| 38 | 'driving_license_back': 12, | ||
| 39 | 'vrc_page_12': 13, | ||
| 40 | 'vrc_page_34': 14, | ||
| 41 | } | ||
| 42 | self.index2lable = list(self.lable2index.keys()) | ||
| 43 | |||
| 44 | # def resize_and_pad_to_384(self, image, jitter=True): | ||
| 45 | # """长边在 256-384 之间随机取一个数,四边 pad 到 384 | ||
| 46 | |||
| 47 | # Args: | ||
| 48 | # image (TYPE): An image represented as a numpy ndarray. | ||
| 49 | # """ | ||
| 50 | # image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) | ||
| 51 | # max_side = tf.random.uniform( | ||
| 52 | # (), 256, 384, dtype=tf.float32) if jitter else 384. | ||
| 53 | # ratio = max_side / tf.reduce_max(image_shape) | ||
| 54 | # image_shape = tf.cast(ratio * image_shape, dtype=tf.int32) | ||
| 55 | # image = tf.image.resize(image, image_shape) | ||
| 56 | # image = tf.image.pad_to_bounding_box(image, 0, 0, 384, 384) | ||
| 57 | # return image, ratio | ||
| 58 | |||
| 59 | def process(self, image): | ||
| 60 | """Processes an image and returns a list of the detected object location and classes data. | ||
| 61 | |||
| 62 | Args: | ||
| 63 | image (TYPE): An image represented as a numpy ndarray. | ||
| 64 | """ | ||
| 65 | h, w, _ = image.shape | ||
| 66 | # image, ratio = self.resize_and_pad_to_384(image, jitter=False) | ||
| 67 | image, ratio = turnsole.resize_with_pad(image, target_height=384, target_width=384) | ||
| 68 | input_data = np.expand_dims(image/255., axis=0) | ||
| 69 | |||
| 70 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
| 71 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
| 72 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
| 73 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
| 74 | |||
| 75 | # request = predict_pb2.PredictRequest() | ||
| 76 | # request.model_spec.name = 'object_detection' | ||
| 77 | # request.model_spec.signature_name = 'serving_default' | ||
| 78 | # request.inputs['image'].CopyFrom(tf.make_tensor_proto(inputs, dtype='float32')) | ||
| 79 | # # 100 secs timeout | ||
| 80 | # result = stub.Predict(request, 100.0) | ||
| 81 | |||
| 82 | # # saved_model_cli show --dir saved_model/ --all # 查看 saved model 的输入输出 | ||
| 83 | # boxes = tf.make_ndarray(result.outputs['decode_predictions']) | ||
| 84 | # scores = tf.make_ndarray(result.outputs['decode_predictions_1']) | ||
| 85 | # classes = tf.make_ndarray(result.outputs['decode_predictions_2']) | ||
| 86 | # valid_detections = tf.make_ndarray( | ||
| 87 | # result.outputs['decode_predictions_3']) | ||
| 88 | |||
| 89 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
| 90 | |||
| 91 | # Initialize the data | ||
| 92 | inputs = [grpcclient.InferInput('image', input_data.shape, "FP32")] | ||
| 93 | inputs[0].set_data_from_numpy(input_data.astype('float32')) | ||
| 94 | outputs = [ | ||
| 95 | grpcclient.InferRequestedOutput("decode_predictions"), | ||
| 96 | grpcclient.InferRequestedOutput("decode_predictions_1"), | ||
| 97 | grpcclient.InferRequestedOutput("decode_predictions_2"), | ||
| 98 | grpcclient.InferRequestedOutput("decode_predictions_3") | ||
| 99 | ] | ||
| 100 | |||
| 101 | # Inference | ||
| 102 | results = triton_client.infer( | ||
| 103 | model_name="object_detection", | ||
| 104 | inputs=inputs, | ||
| 105 | outputs=outputs | ||
| 106 | ) | ||
| 107 | # Get the output arrays from the results | ||
| 108 | boxes = results.as_numpy("decode_predictions") | ||
| 109 | scores = results.as_numpy("decode_predictions_1") | ||
| 110 | classes = results.as_numpy("decode_predictions_2") | ||
| 111 | valid_detections = results.as_numpy("decode_predictions_3") | ||
| 112 | |||
| 113 | boxes = boxes[0][:valid_detections[0]] | ||
| 114 | scores = scores[0][:valid_detections[0]] | ||
| 115 | classes = classes[0][:valid_detections[0]] | ||
| 116 | |||
| 117 | object_list = [] | ||
| 118 | for box, score, class_index in zip(boxes, scores, classes): | ||
| 119 | xmin, ymin, xmax, ymax = box / ratio | ||
| 120 | xmin = max(0, int(xmin)) | ||
| 121 | ymin = max(0, int(ymin)) | ||
| 122 | xmax = min(w, int(xmax)) | ||
| 123 | ymax = min(h, int(ymax)) | ||
| 124 | class_label = self.index2lable[int(class_index)] | ||
| 125 | item = { | ||
| 126 | "label": class_label, | ||
| 127 | "confidence": float(score), | ||
| 128 | "location": { | ||
| 129 | "xmin": xmin, | ||
| 130 | "ymin": ymin, | ||
| 131 | "xmax": xmax, | ||
| 132 | "ymax": ymax | ||
| 133 | } | ||
| 134 | } | ||
| 135 | object_list.append(item) | ||
| 136 | |||
| 137 | return object_list | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : lk | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-06-28 14:38:57 | ||
| 5 | # @Last Modified : 2022-09-06 14:37:47 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | from .utils import SignatureDetection | ||
| 9 | |||
| 10 | signature_detector = SignatureDetection() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : lk | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-02-08 14:10:00 | ||
| 5 | # @Last Modified : 2022-09-06 14:45:10 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import turnsole | ||
| 9 | import numpy as np | ||
| 10 | # import tensorflow as tf | ||
| 11 | |||
| 12 | # import grpc | ||
| 13 | # from tensorflow_serving.apis import predict_pb2 | ||
| 14 | # from tensorflow_serving.apis import prediction_service_pb2_grpc | ||
| 15 | |||
| 16 | import tritonclient.grpc as grpcclient | ||
| 17 | |||
| 18 | |||
| 19 | # def resize_and_pad_to_1024(image, jitter=True): | ||
| 20 | # # 长边在 512-1024 之间随机取一个数,四边 pad 到 1024 | ||
| 21 | # image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) | ||
| 22 | # max_side = tf.random.uniform((), 512, 1024, dtype=tf.float32) if jitter else 1024. | ||
| 23 | # ratio = max_side / tf.reduce_max(image_shape) | ||
| 24 | # image_shape = tf.cast(ratio * image_shape, dtype=tf.int32) | ||
| 25 | # image = tf.image.resize(image, image_shape) | ||
| 26 | # image = tf.image.pad_to_bounding_box(image, 0, 0, 1024, 1024) | ||
| 27 | # return image, ratio | ||
| 28 | |||
| 29 | class SignatureDetection(): | ||
| 30 | |||
| 31 | """签字盖章检测算法 | ||
| 32 | 输入图片输出检测结果 | ||
| 33 | |||
| 34 | API 文档请参阅: | ||
| 35 | """ | ||
| 36 | |||
| 37 | def __init__(self, confidence_threshold=0.5): | ||
| 38 | """初始化检测对象 | ||
| 39 | |||
| 40 | Args: | ||
| 41 | confidence_threshold (float, optional): 目标检测模型的分类置信度 | ||
| 42 | """ | ||
| 43 | |||
| 44 | self.lable2index = { | ||
| 45 | 'circle': 0, | ||
| 46 | 'ellipse': 1, | ||
| 47 | 'rectangle': 2, | ||
| 48 | 'signature': 3, | ||
| 49 | 'qr_code': 4, | ||
| 50 | 'bar_code': 5 | ||
| 51 | } | ||
| 52 | self.index2lable = { | ||
| 53 | 0: 'circle', | ||
| 54 | 1: 'ellipse', | ||
| 55 | 2: 'rectangle', | ||
| 56 | 3: 'signature', | ||
| 57 | 4: 'qr_code', | ||
| 58 | 5: 'bar_code' | ||
| 59 | } | ||
| 60 | |||
| 61 | |||
| 62 | def process(self, image): | ||
| 63 | """Processes an image and returns a list of the detected signature location and classes data. | ||
| 64 | |||
| 65 | Args: | ||
| 66 | image (TYPE): An image represented as a numpy ndarray. | ||
| 67 | """ | ||
| 68 | h, w, _ = image.shape | ||
| 69 | |||
| 70 | # image, ratio = resize_and_pad_to_1024(image, jitter=False) | ||
| 71 | image, ratio = turnsole.resize_with_pad(image, target_height=1024, target_width=1024) | ||
| 72 | input_data = np.expand_dims(np.float32(image/255.), axis=0) | ||
| 73 | |||
| 74 | # options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), | ||
| 75 | # ('grpc.max_receive_message_length', 1000 * 1024 * 1024)] | ||
| 76 | # channel = grpc.insecure_channel('localhost:8500', options=options) | ||
| 77 | # stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) | ||
| 78 | |||
| 79 | # request = predict_pb2.PredictRequest() | ||
| 80 | # request.model_spec.name = 'signature_model' | ||
| 81 | # request.model_spec.signature_name = 'serving_default' | ||
| 82 | # request.inputs['image'].CopyFrom(tf.make_tensor_proto(inputs, dtype='float32')) | ||
| 83 | # result = stub.Predict(request, 100.0) # 100 secs timeout | ||
| 84 | |||
| 85 | # # saved_model_cli show --dir saved_model/ --all # 查看 saved model 的输入输出 | ||
| 86 | # boxes = tf.make_ndarray(result.outputs['decode_predictions']) | ||
| 87 | # scores = tf.make_ndarray(result.outputs['decode_predictions_1']) | ||
| 88 | # classes = tf.make_ndarray(result.outputs['decode_predictions_2']) | ||
| 89 | # valid_detections = tf.make_ndarray(result.outputs['decode_predictions_3']) | ||
| 90 | |||
| 91 | triton_client = grpcclient.InferenceServerClient("localhost:8001") | ||
| 92 | |||
| 93 | # Initialize the data | ||
| 94 | inputs = [grpcclient.InferInput('image', input_data.shape, "FP32")] | ||
| 95 | inputs[0].set_data_from_numpy(input_data) | ||
| 96 | outputs = [ | ||
| 97 | grpcclient.InferRequestedOutput("decode_predictions"), | ||
| 98 | grpcclient.InferRequestedOutput("decode_predictions_1"), | ||
| 99 | grpcclient.InferRequestedOutput("decode_predictions_2"), | ||
| 100 | grpcclient.InferRequestedOutput("decode_predictions_3") | ||
| 101 | ] | ||
| 102 | |||
| 103 | # Inference | ||
| 104 | results = triton_client.infer( | ||
| 105 | model_name="signature_model", | ||
| 106 | inputs=inputs, | ||
| 107 | outputs=outputs | ||
| 108 | ) | ||
| 109 | # Get the output arrays from the results | ||
| 110 | boxes = results.as_numpy("decode_predictions") | ||
| 111 | scores = results.as_numpy("decode_predictions_1") | ||
| 112 | classes = results.as_numpy("decode_predictions_2") | ||
| 113 | valid_detections = results.as_numpy("decode_predictions_3") | ||
| 114 | |||
| 115 | boxes = boxes[0][:valid_detections[0]] | ||
| 116 | scores = scores[0][:valid_detections[0]] | ||
| 117 | classes = classes[0][:valid_detections[0]] | ||
| 118 | |||
| 119 | signature_list = [] | ||
| 120 | for box, score, class_index in zip(boxes, scores, classes): | ||
| 121 | xmin, ymin, xmax, ymax = box / ratio | ||
| 122 | class_label = self.index2lable[class_index] | ||
| 123 | item = { | ||
| 124 | "label": class_label, | ||
| 125 | "confidence": float(score), | ||
| 126 | "location": { | ||
| 127 | "xmin": max(0, int(xmin)), | ||
| 128 | "ymin": max(0, int(ymin)), | ||
| 129 | "xmax": min(w, int(xmax)), | ||
| 130 | "ymax": min(h, int(ymax)) | ||
| 131 | } | ||
| 132 | } | ||
| 133 | signature_list.append(item) | ||
| 134 | |||
| 135 | return signature_list |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-06-16 11:01:36 | ||
| 5 | # @Last Modified : 2022-07-15 10:57:06 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | from .read_data import base64_to_bgr | ||
| 9 | from .read_data import bytes_to_bgr | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Create Date : 2022-06-16 10:59:50 | ||
| 5 | # @Last Modified : 2022-08-03 14:59:15 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import cv2 | ||
| 9 | import base64 | ||
| 10 | import numpy as np | ||
| 11 | import tensorflow as tf | ||
| 12 | |||
| 13 | |||
| 14 | def base64_to_bgr(img64): | ||
| 15 | """把 base64 转换成图片 | ||
| 16 | 单通道的灰度图或四通道的透明图都将自动转换成三通道的 BGR 图 | ||
| 17 | |||
| 18 | Args: | ||
| 19 | img64 (TYPE): Description | ||
| 20 | |||
| 21 | Returns: | ||
| 22 | TYPE: image is a 3-D uint8 Tensor of shape [height, width, channels] where channels is BGR | ||
| 23 | """ | ||
| 24 | encoded_image = base64.b64decode(img64) | ||
| 25 | img_array = np.frombuffer(encoded_image, np.uint8) | ||
| 26 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 27 | return image | ||
| 28 | |||
| 29 | def bytes_to_bgr(buffer: bytes): | ||
| 30 | """Read a byte stream as a OpenCV image | ||
| 31 | |||
| 32 | Args: | ||
| 33 | buffer (TYPE): bytes of a decoded image | ||
| 34 | """ | ||
| 35 | img_array = np.frombuffer(buffer, np.uint8) | ||
| 36 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 37 | |||
| 38 | # image = tf.io.decode_image(buffer, channels=3) | ||
| 39 | # image = np.array(image)[...,::-1] | ||
| 40 | return image | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/paths.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | # @Author : Lyu Kui | ||
| 3 | # @Email : 9428.al@gmail.com | ||
| 4 | # @Created Date : 2021-03-04 17:50:09 | ||
| 5 | # @Last Modified : 2021-03-10 14:03:02 | ||
| 6 | # @Description : | ||
| 7 | |||
| 8 | import os | ||
| 9 | |||
| 10 | image_types = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff") | ||
| 11 | |||
| 12 | |||
| 13 | def list_images(basePath, contains=None): | ||
| 14 | # return the set of files that are valid | ||
| 15 | return list_files(basePath, validExts=image_types, contains=contains) | ||
| 16 | |||
| 17 | def list_files(basePath, validExts=None, contains=None): | ||
| 18 | # loop over the directory structure | ||
| 19 | for (rootDir, dirNames, filenames) in os.walk(basePath): | ||
| 20 | # loop over the filenames in the current directory | ||
| 21 | for filename in filenames: | ||
| 22 | # if the contains string is not none and the filename does not contain | ||
| 23 | # the supplied string, then ignore the file | ||
| 24 | if contains is not None and filename.find(contains) == -1: | ||
| 25 | continue | ||
| 26 | |||
| 27 | # determine the file extension of the current file | ||
| 28 | ext = filename[filename.rfind("."):].lower() | ||
| 29 | |||
| 30 | # check to see if the file is an image and should be processed | ||
| 31 | if validExts is None or ext.endswith(validExts): | ||
| 32 | # construct the path to the image and yield it | ||
| 33 | imagePath = os.path.join(rootDir, filename) | ||
| 34 | yield imagePath | ||
| 35 | |||
| 36 | def get_filename(filePath): | ||
| 37 | basename = os.path.basename(filePath) | ||
| 38 | fname, fextension = os.path.splitext(basename) | ||
| 39 | return fname | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/pdf_tools.py
0 → 100644
| 1 | import cv2 | ||
| 2 | import fitz | ||
| 3 | import numpy as np | ||
| 4 | |||
| 5 | def pdf_to_images(pdf_path: str): | ||
| 6 | """PDF 转 OpenCV Image | ||
| 7 | |||
| 8 | Args: | ||
| 9 | pdf_path (str): Description | ||
| 10 | |||
| 11 | Returns: | ||
| 12 | TYPE: Description | ||
| 13 | """ | ||
| 14 | images = [] | ||
| 15 | doc = fitz.open(pdf_path) | ||
| 16 | # producer = doc.metadata.get('producer') | ||
| 17 | |||
| 18 | for pno in range(doc.page_count): | ||
| 19 | page = doc.load_page(pno) | ||
| 20 | |||
| 21 | all_texts = page.get_text().replace('\n', '').strip() | ||
| 22 | # 根据经验过滤掉特殊情况 | ||
| 23 | all_texts = all_texts.strip('Click to buy NOW!PDF-XChangewww.docu-track.comClick to buy NOW!PDF-XChangewww.docu-track.com') | ||
| 24 | blocks = page.get_text("dict")["blocks"] | ||
| 25 | imgblocks = [b for b in blocks if b["type"] == 1] | ||
| 26 | |||
| 27 | page_images = [] | ||
| 28 | # 如果一个字都没有, | ||
| 29 | if len(all_texts) == 0 and len(imgblocks) != 0: | ||
| 30 | # # 这些 producer 包含碎图,如果真的是碎图我们把碎图拼接一下 | ||
| 31 | # if producer in ['Microsoft: Print To PDF', | ||
| 32 | # 'GPL Ghostscript 8.71', | ||
| 33 | # 'doPDF Ver 7.3 Build 398 (Windows 7 Business Edition (SP 1) - Version: 6.1.7601 (x64))', | ||
| 34 | # '福昕阅读器PDF打印机 版本 11.0.114.4386']: | ||
| 35 | patches = [] | ||
| 36 | for imgblock in imgblocks: | ||
| 37 | contents = imgblock["image"] | ||
| 38 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
| 39 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 40 | patches.append(image) | ||
| 41 | try: | ||
| 42 | try: | ||
| 43 | image = np.concatenate(patches, axis=0) | ||
| 44 | page_images.append(image) | ||
| 45 | except: | ||
| 46 | image = np.concatenate(patches, axis=1) | ||
| 47 | page_images.append(image) | ||
| 48 | except: | ||
| 49 | # 当两张拼不到一块的时候我们可以认为他是两张图,如果超过两张那就不一定了 | ||
| 50 | if len(patches) == 2: | ||
| 51 | page_images = patches | ||
| 52 | else: | ||
| 53 | pix = page.get_pixmap(dpi=350) | ||
| 54 | contents = pix.tobytes(output="png") | ||
| 55 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
| 56 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 57 | page_images.append(image) | ||
| 58 | # else: | ||
| 59 | # for imgblock in imgblocks: | ||
| 60 | # contents = imgblock["image"] | ||
| 61 | # img_array = np.frombuffer(contents, dtype=np.uint8) | ||
| 62 | # image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 63 | # page_images.append(image) | ||
| 64 | else: | ||
| 65 | pix = page.get_pixmap(dpi=350) | ||
| 66 | contents = pix.tobytes(output="png") | ||
| 67 | img_array = np.frombuffer(contents, dtype=np.uint8) | ||
| 68 | image = cv2.imdecode(img_array, cv2.IMREAD_COLOR) | ||
| 69 | page_images.append(image) | ||
| 70 | images.append(page_images) | ||
| 71 | return images | ||
| 72 |
ocr_engine/turnsole/video/__init__.py
0 → 100644
ocr_engine/turnsole/video/count_frames.py
0 → 100644
| 1 | # import the necessary packages | ||
| 2 | # from ..convenience import is_cv3 | ||
| 3 | import cv2 | ||
| 4 | |||
| 5 | def count_frames(path, override=False): | ||
| 6 | # grab a pointer to the video file and initialize the total | ||
| 7 | # number of frames read | ||
| 8 | video = cv2.VideoCapture(path) | ||
| 9 | total = 0 | ||
| 10 | |||
| 11 | # if the override flag is passed in, revert to the manual | ||
| 12 | # method of counting frames | ||
| 13 | if override: | ||
| 14 | total = count_frames_manual(video) | ||
| 15 | |||
| 16 | # otherwise, let's try the fast way first | ||
| 17 | else: | ||
| 18 | # lets try to determine the number of frames in a video | ||
| 19 | # via video properties; this method can be very buggy | ||
| 20 | # and might throw an error based on your OpenCV version | ||
| 21 | # or may fail entirely based on your which video codecs | ||
| 22 | # you have installed | ||
| 23 | try: | ||
| 24 | # # check if we are using OpenCV 3 | ||
| 25 | # if is_cv3(): | ||
| 26 | # total = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
| 27 | |||
| 28 | # # otherwise, we are using OpenCV 2.4 | ||
| 29 | # else: | ||
| 30 | # total = int(video.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) | ||
| 31 | |||
| 32 | total = int(video.get(cv2.cv.CV_CAP_PROP_FRAME_COUNT)) | ||
| 33 | |||
| 34 | # uh-oh, we got an error -- revert to counting manually | ||
| 35 | except: | ||
| 36 | total = count_frames_manual(video) | ||
| 37 | |||
| 38 | # release the video file pointer | ||
| 39 | video.release() | ||
| 40 | |||
| 41 | # return the total number of frames in the video | ||
| 42 | return total | ||
| 43 | |||
| 44 | def count_frames_manual(video): | ||
| 45 | # initialize the total number of frames read | ||
| 46 | total = 0 | ||
| 47 | |||
| 48 | # loop over the frames of the video | ||
| 49 | while True: | ||
| 50 | # grab the current frame | ||
| 51 | (grabbed, frame) = video.read() | ||
| 52 | |||
| 53 | # check to see if we have reached the end of the | ||
| 54 | # video | ||
| 55 | if not grabbed: | ||
| 56 | break | ||
| 57 | |||
| 58 | # increment the total number of frames read | ||
| 59 | total += 1 | ||
| 60 | |||
| 61 | # return the total number of frames in the video file | ||
| 62 | return total | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/video/filevideostream.py
0 → 100644
| 1 | # import the necessary packages | ||
| 2 | from threading import Thread | ||
| 3 | import sys | ||
| 4 | import cv2 | ||
| 5 | import time | ||
| 6 | |||
| 7 | # import the Queue class from Python 3 | ||
| 8 | if sys.version_info >= (3, 0): | ||
| 9 | from queue import Queue | ||
| 10 | |||
| 11 | # otherwise, import the Queue class for Python 2.7 | ||
| 12 | else: | ||
| 13 | from Queue import Queue | ||
| 14 | |||
| 15 | |||
| 16 | class FileVideoStream: | ||
| 17 | def __init__(self, path, transform=None, queue_size=128): | ||
| 18 | # initialize the file video stream along with the boolean | ||
| 19 | # used to indicate if the thread should be stopped or not | ||
| 20 | self.stream = cv2.VideoCapture(path) | ||
| 21 | self.stopped = False | ||
| 22 | self.transform = transform | ||
| 23 | |||
| 24 | # initialize the queue used to store frames read from | ||
| 25 | # the video file | ||
| 26 | self.Q = Queue(maxsize=queue_size) | ||
| 27 | # intialize thread | ||
| 28 | self.thread = Thread(target=self.update, args=()) | ||
| 29 | self.thread.daemon = True | ||
| 30 | |||
| 31 | def start(self): | ||
| 32 | # start a thread to read frames from the file video stream | ||
| 33 | self.thread.start() | ||
| 34 | return self | ||
| 35 | |||
| 36 | def update(self): | ||
| 37 | # keep looping infinitely | ||
| 38 | while True: | ||
| 39 | # if the thread indicator variable is set, stop the | ||
| 40 | # thread | ||
| 41 | if self.stopped: | ||
| 42 | break | ||
| 43 | |||
| 44 | # otherwise, ensure the queue has room in it | ||
| 45 | if not self.Q.full(): | ||
| 46 | # read the next frame from the file | ||
| 47 | (grabbed, frame) = self.stream.read() | ||
| 48 | |||
| 49 | # if the `grabbed` boolean is `False`, then we have | ||
| 50 | # reached the end of the video file | ||
| 51 | if not grabbed: | ||
| 52 | self.stopped = True | ||
| 53 | break | ||
| 54 | |||
| 55 | # if there are transforms to be done, might as well | ||
| 56 | # do them on producer thread before handing back to | ||
| 57 | # consumer thread. ie. Usually the producer is so far | ||
| 58 | # ahead of consumer that we have time to spare. | ||
| 59 | # | ||
| 60 | # Python is not parallel but the transform operations | ||
| 61 | # are usually OpenCV native so release the GIL. | ||
| 62 | # | ||
| 63 | # Really just trying to avoid spinning up additional | ||
| 64 | # native threads and overheads of additional | ||
| 65 | # producer/consumer queues since this one was generally | ||
| 66 | # idle grabbing frames. | ||
| 67 | if self.transform: | ||
| 68 | frame = self.transform(frame) | ||
| 69 | |||
| 70 | # add the frame to the queue | ||
| 71 | self.Q.put(frame) | ||
| 72 | else: | ||
| 73 | time.sleep(0.1) # Rest for 10ms, we have a full queue | ||
| 74 | |||
| 75 | self.stream.release() | ||
| 76 | |||
| 77 | def read(self): | ||
| 78 | # return next frame in the queue | ||
| 79 | return self.Q.get() | ||
| 80 | |||
| 81 | # Insufficient to have consumer use while(more()) which does | ||
| 82 | # not take into account if the producer has reached end of | ||
| 83 | # file stream. | ||
| 84 | def running(self): | ||
| 85 | return self.more() or not self.stopped | ||
| 86 | |||
| 87 | def more(self): | ||
| 88 | # return True if there are still frames in the queue. If stream is not stopped, try to wait a moment | ||
| 89 | tries = 0 | ||
| 90 | while self.Q.qsize() == 0 and not self.stopped and tries < 5: | ||
| 91 | time.sleep(0.1) | ||
| 92 | tries += 1 | ||
| 93 | |||
| 94 | return self.Q.qsize() > 0 | ||
| 95 | |||
| 96 | def stop(self): | ||
| 97 | # indicate that the thread should be stopped | ||
| 98 | self.stopped = True | ||
| 99 | # wait until stream resources are released (producer thread might be still grabbing frame) | ||
| 100 | self.thread.join() |
ocr_engine/turnsole/video/fps.py
0 → 100644
| 1 | # import the necessary packages | ||
| 2 | import datetime | ||
| 3 | |||
| 4 | class FPS: | ||
| 5 | def __init__(self): | ||
| 6 | # store the start time, end time, and total number of frames | ||
| 7 | # that were examined between the start and end intervals | ||
| 8 | self._start = None | ||
| 9 | self._end = None | ||
| 10 | self._numFrames = 0 | ||
| 11 | |||
| 12 | def start(self): | ||
| 13 | # start the timer | ||
| 14 | self._start = datetime.datetime.now() | ||
| 15 | return self | ||
| 16 | |||
| 17 | def stop(self): | ||
| 18 | # stop the timer | ||
| 19 | self._end = datetime.datetime.now() | ||
| 20 | |||
| 21 | def update(self): | ||
| 22 | # increment the total number of frames examined during the | ||
| 23 | # start and end intervals | ||
| 24 | self._numFrames += 1 | ||
| 25 | |||
| 26 | def elapsed(self): | ||
| 27 | # return the total number of seconds between the start and | ||
| 28 | # end interval | ||
| 29 | return (self._end - self._start).total_seconds() | ||
| 30 | |||
| 31 | def fps(self): | ||
| 32 | # compute the (approximate) frames per second | ||
| 33 | return self._numFrames / self.elapsed() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ocr_engine/turnsole/video/pivideostream.py
0 → 100644
| 1 | # import the necessary packages | ||
| 2 | from picamera.array import PiRGBArray | ||
| 3 | from picamera import PiCamera | ||
| 4 | from threading import Thread | ||
| 5 | import cv2 | ||
| 6 | |||
| 7 | class PiVideoStream: | ||
| 8 | def __init__(self, resolution=(320, 240), framerate=32, **kwargs): | ||
| 9 | # initialize the camera | ||
| 10 | self.camera = PiCamera() | ||
| 11 | |||
| 12 | # set camera parameters | ||
| 13 | self.camera.resolution = resolution | ||
| 14 | self.camera.framerate = framerate | ||
| 15 | |||
| 16 | # set optional camera parameters (refer to PiCamera docs) | ||
| 17 | for (arg, value) in kwargs.items(): | ||
| 18 | setattr(self.camera, arg, value) | ||
| 19 | |||
| 20 | # initialize the stream | ||
| 21 | self.rawCapture = PiRGBArray(self.camera, size=resolution) | ||
| 22 | self.stream = self.camera.capture_continuous(self.rawCapture, | ||
| 23 | format="bgr", use_video_port=True) | ||
| 24 | |||
| 25 | # initialize the frame and the variable used to indicate | ||
| 26 | # if the thread should be stopped | ||
| 27 | self.frame = None | ||
| 28 | self.stopped = False | ||
| 29 | |||
| 30 | def start(self): | ||
| 31 | # start the thread to read frames from the video stream | ||
| 32 | t = Thread(target=self.update, args=()) | ||
| 33 | t.daemon = True | ||
| 34 | t.start() | ||
| 35 | return self | ||
| 36 | |||
| 37 | def update(self): | ||
| 38 | # keep looping infinitely until the thread is stopped | ||
| 39 | for f in self.stream: | ||
| 40 | # grab the frame from the stream and clear the stream in | ||
| 41 | # preparation for the next frame | ||
| 42 | self.frame = f.array | ||
| 43 | self.rawCapture.truncate(0) | ||
| 44 | |||
| 45 | # if the thread indicator variable is set, stop the thread | ||
| 46 | # and resource camera resources | ||
| 47 | if self.stopped: | ||
| 48 | self.stream.close() | ||
| 49 | self.rawCapture.close() | ||
| 50 | self.camera.close() | ||
| 51 | return | ||
| 52 | |||
| 53 | def read(self): | ||
| 54 | # return the frame most recently read | ||
| 55 | return self.frame | ||
| 56 | |||
| 57 | def stop(self): | ||
| 58 | # indicate that the thread should be stopped | ||
| 59 | self.stopped = True |
ocr_engine/turnsole/video/videostream.py
0 → 100644
| 1 | # import the necessary packages | ||
| 2 | from .webcamvideostream import WebcamVideoStream | ||
| 3 | |||
| 4 | class VideoStream: | ||
| 5 | def __init__(self, src=0, usePiCamera=False, resolution=(320, 240), | ||
| 6 | framerate=32, **kwargs): | ||
| 7 | # check to see if the picamera module should be used | ||
| 8 | if usePiCamera: | ||
| 9 | # only import the picamera packages unless we are | ||
| 10 | # explicity told to do so -- this helps remove the | ||
| 11 | # requirement of `picamera[array]` from desktops or | ||
| 12 | # laptops that still want to use the `imutils` package | ||
| 13 | from .pivideostream import PiVideoStream | ||
| 14 | |||
| 15 | # initialize the picamera stream and allow the camera | ||
| 16 | # sensor to warmup | ||
| 17 | self.stream = PiVideoStream(resolution=resolution, | ||
| 18 | framerate=framerate, **kwargs) | ||
| 19 | |||
| 20 | # otherwise, we are using OpenCV so initialize the webcam | ||
| 21 | # stream | ||
| 22 | else: | ||
| 23 | self.stream = WebcamVideoStream(src=src) | ||
| 24 | |||
| 25 | def start(self): | ||
| 26 | # start the threaded video stream | ||
| 27 | return self.stream.start() | ||
| 28 | |||
| 29 | def update(self): | ||
| 30 | # grab the next frame from the stream | ||
| 31 | self.stream.update() | ||
| 32 | |||
| 33 | def read(self): | ||
| 34 | # return the current frame | ||
| 35 | return self.stream.read() | ||
| 36 | |||
| 37 | def stop(self): | ||
| 38 | # stop the thread and release any resources | ||
| 39 | self.stream.stop() |
| 1 | # import the necessary packages | ||
| 2 | from threading import Thread | ||
| 3 | import cv2 | ||
| 4 | |||
| 5 | class WebcamVideoStream: | ||
| 6 | def __init__(self, src=0, name="WebcamVideoStream"): | ||
| 7 | # initialize the video camera stream and read the first frame | ||
| 8 | # from the stream | ||
| 9 | self.stream = cv2.VideoCapture(src) | ||
| 10 | (self.grabbed, self.frame) = self.stream.read() | ||
| 11 | |||
| 12 | # initialize the thread name | ||
| 13 | self.name = name | ||
| 14 | |||
| 15 | # initialize the variable used to indicate if the thread should | ||
| 16 | # be stopped | ||
| 17 | self.stopped = False | ||
| 18 | |||
| 19 | def start(self): | ||
| 20 | # start the thread to read frames from the video stream | ||
| 21 | t = Thread(target=self.update, name=self.name, args=()) | ||
| 22 | t.daemon = True | ||
| 23 | t.start() | ||
| 24 | return self | ||
| 25 | |||
| 26 | def update(self): | ||
| 27 | # keep looping infinitely until the thread is stopped | ||
| 28 | while True: | ||
| 29 | # if the thread indicator variable is set, stop the thread | ||
| 30 | if self.stopped: | ||
| 31 | return | ||
| 32 | |||
| 33 | # otherwise, read the next frame from the stream | ||
| 34 | (self.grabbed, self.frame) = self.stream.read() | ||
| 35 | |||
| 36 | def read(self): | ||
| 37 | # return the frame most recently read | ||
| 38 | return self.frame | ||
| 39 | |||
| 40 | def stop(self): | ||
| 41 | # indicate that the thread should be stopped | ||
| 42 | self.stopped = True |
-
Please register or sign in to post a comment