init project
0 parents
Showing
9 changed files
with
185 additions
and
0 deletions
SDTR/__pycache__/alphabets.cpython-36.pyc
0 → 100644
No preview for this file type
SDTR/__pycache__/sdtr.cpython-36.pyc
0 → 100644
No preview for this file type
SDTR/alphabets.py
0 → 100644
This diff is collapsed.
Click to expand it.
SDTR/crnn_bl2.pb
0 → 100644
This file is too large to display.
SDTR/crnn_bl2_gpu.pb
0 → 100644
This file is too large to display.
SDTR/sdtr.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | ### | ||
3 | import os, sys | ||
4 | import cv2 | ||
5 | from PIL import Image | ||
6 | import time | ||
7 | import numpy as np | ||
8 | import tensorflow as tf | ||
9 | from tensorflow.keras import backend as K | ||
10 | from SDTR import alphabets | ||
11 | |||
12 | characters = alphabets.alphabet[:] | ||
13 | height = 32 | ||
14 | batchsize = 16 | ||
15 | nclass = len(characters) + 1 | ||
16 | |||
17 | |||
18 | def init(): | ||
19 | global inputs, outs, sess | ||
20 | if tf.test.gpu_device_name(): | ||
21 | modelPath = './SDTR/crnn_bl2_gpu.pb' | ||
22 | config = tf.ConfigProto() | ||
23 | config.gpu_options.per_process_gpu_memory_fraction = 0.3 | ||
24 | config.gpu_options.allow_growth = True | ||
25 | else: | ||
26 | modelPath = './SDTR/crnn_bl2.pb' | ||
27 | config = tf.ConfigProto() | ||
28 | print(modelPath) | ||
29 | session = tf.Session(config=config) | ||
30 | graph = tf.Graph() | ||
31 | with graph.as_default(): | ||
32 | graph_def = tf.GraphDef() | ||
33 | with tf.gfile.GFile(modelPath, 'rb') as f: | ||
34 | graph_def.ParseFromString(f.read()) | ||
35 | |||
36 | inputs = tf.placeholder(tf.float32, [None, 32, None, 1], name='X') | ||
37 | outs = tf.import_graph_def( | ||
38 | graph_def, | ||
39 | input_map={'the_input:0': inputs}, | ||
40 | return_elements=['embedding2/Reshape_1:0']) | ||
41 | sess = tf.Session(graph=graph, config=config) | ||
42 | |||
43 | |||
44 | init() | ||
45 | |||
46 | |||
47 | def predict(im, boxes): | ||
48 | # global inputs, outs, sess | ||
49 | count_boxes = len(boxes) | ||
50 | boxes_max = sorted(boxes, | ||
51 | key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))), | ||
52 | reverse=True) | ||
53 | |||
54 | if len(boxes) % batchsize != 0: | ||
55 | add_box = np.expand_dims(boxes[-1], axis=0) | ||
56 | extend_num = batchsize - len(boxes) % batchsize | ||
57 | for i in range(extend_num): | ||
58 | boxes = np.concatenate((boxes, add_box), axis=0) | ||
59 | |||
60 | results = {} | ||
61 | labels = [] | ||
62 | rectime = 0.0 | ||
63 | |||
64 | if len(boxes) is not 0: | ||
65 | for i in range(int(len(boxes) / batchsize)): | ||
66 | slices = [] | ||
67 | box = boxes_max[i * batchsize] | ||
68 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
69 | width = int(32.0 * w / h) | ||
70 | # print(width) | ||
71 | if width < 24: | ||
72 | continue | ||
73 | for index, box in enumerate(boxes[i * batchsize:(i + 1) * batchsize]): | ||
74 | _box = [n for a in box for n in a] | ||
75 | if i * batchsize + index < count_boxes: | ||
76 | results[i * batchsize + index] = [np.array(_box)] | ||
77 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
78 | # print(w) | ||
79 | pts1 = np.float32(box) | ||
80 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
81 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
82 | im_crop = cv2.warpPerspective(im, M, (w, h)) | ||
83 | im_crop = resize_img(im_crop, width) | ||
84 | slices.append(im_crop) | ||
85 | slices = np.array(slices) | ||
86 | # print(slices.shape) | ||
87 | recstart = time.time() | ||
88 | preds = sess.run(outs, feed_dict={inputs: slices}) | ||
89 | # preds=model.predict(slices) | ||
90 | recend = time.time() | ||
91 | preds = preds[0] | ||
92 | # print(preds) | ||
93 | rectime += (recend - recstart) * 1000 | ||
94 | # preds=preds[:,2:,:] | ||
95 | rec_labels = decode(preds) | ||
96 | labels.extend(rec_labels) | ||
97 | for index, label in enumerate(labels[:count_boxes]): | ||
98 | results[index].append(label.replace(' ', '').replace('¥', '¥')) | ||
99 | return results, rectime | ||
100 | |||
101 | |||
102 | def resize_img(im, width): | ||
103 | ori_h, ori_w = im.shape | ||
104 | ratio1 = width * 1.0 / ori_w | ||
105 | ratio2 = height * 1.0 / ori_h | ||
106 | if ratio1 < ratio2: | ||
107 | ratio = ratio1 | ||
108 | else: | ||
109 | ratio = ratio2 | ||
110 | new_w, new_h = int(ori_w * ratio), int(ori_h * ratio) | ||
111 | im = cv2.resize(im, (new_w, new_h)) | ||
112 | delta_w = width - new_w | ||
113 | delta_h = height - new_h | ||
114 | top = delta_h // 2 | ||
115 | bottom = delta_h - top | ||
116 | left = delta_w // 2 | ||
117 | right = delta_w - left | ||
118 | img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255) | ||
119 | img = img / 255.0 | ||
120 | img = (img - 0.5) / 0.5 | ||
121 | X = img.reshape((height, width, 1)) | ||
122 | return X | ||
123 | |||
124 | |||
125 | def decode(preds): | ||
126 | labels = [] | ||
127 | charactersS = characters + u' ' | ||
128 | tops = preds.argmax(axis=2) | ||
129 | for t in tops: | ||
130 | length = len(t) | ||
131 | char_list = [] | ||
132 | for i in range(length): | ||
133 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): | ||
134 | char_list.append(charactersS[t[i] - 1]) | ||
135 | labels.append(u''.join(char_list)) | ||
136 | return labels |
demo.py
0 → 100644
1 | # -*- coding: utf-8 -*- | ||
2 | ### | ||
3 | import cv2 | ||
4 | import numpy as np | ||
5 | from SDTR import sdtr | ||
6 | |||
7 | |||
8 | if __name__ == '__main__': | ||
9 | test_img_path = './rec_test.png' | ||
10 | test_img = cv2.imread(test_img_path) | ||
11 | test_gray = cv2.cvtColor(test_img, cv2.COLOR_BGR2GRAY) | ||
12 | h, w = test_gray.shape | ||
13 | box = [np.array([[0, 0], [w, 0], [w, h], [0, h]])] | ||
14 | all_time = 0 | ||
15 | rangetimes = 1001 | ||
16 | for i in range(rangetimes): | ||
17 | results, rectime = sdtr.predict(test_gray, box) | ||
18 | print('{:.5f}ms'.format(rectime)) | ||
19 | print(results) | ||
20 | if i != 0: | ||
21 | all_time += rectime | ||
22 | print('avgtime:{:.5f}ms'.format(all_time / (rangetimes - 1))) |
readme.txt
0 → 100644
1 | SDTR_v1.0: | ||
2 | 功能:对文字切片进行识别, 返回识别结果 | ||
3 | 使用方法: | ||
4 | from SDTR import sdtr | ||
5 | sdtr.predict(im, boxes) | ||
6 | 输入参数: | ||
7 | im:opencv下的灰度图 | ||
8 | boxes:numpy矩阵,大小为N*4*2, N为box个数,每个box包含四个坐标(x,y),且坐标顺序必须为:左上,右上,右下,左下 | ||
9 | 输出: | ||
10 | 一个字典,包含key:[numpy矩阵box,string识别结果] | ||
11 | eg. {0: [array([ 0, 0, 626, 0, 626, 87, 0, 87]), '陆万壹仟叁佰圆整']} | ||
12 | 环境: | ||
13 | tensorflow,最好是1.14版本,未在其他版本上测试 | ||
14 | |||
15 | |||
16 | v1.0 2019.09.19 | ||
17 | 为CRNN基础 | ||
18 | GPU版本用了CuDNNLSTM, 相比普通LSTM能减少1/2到2/3的时间 | ||
19 | |||
20 | 性能: | ||
21 | 购车发票上平均全对率为93.25 | ||
22 | |||
23 | 对box缩放到32的高度,一张32*230的图 | ||
24 | GPU时间:15ms | ||
25 | CPU时间:210ms | ||
26 | |||
27 | 具体性能和box数量和box缩放到32高度时的宽度有关 |
rec_test.png
0 → 100644

65.3 KB
-
Please register or sign in to post a comment