8bb3575a by mengliyu

init project

0 parents
No preview for this file type
No preview for this file type
This file is too large to display.
This file is too large to display.
1 # -*- coding: utf-8 -*-
2 ###
3 import os, sys
4 import cv2
5 from PIL import Image
6 import time
7 import numpy as np
8 import tensorflow as tf
9 from tensorflow.keras import backend as K
10 from SDTR import alphabets
11
12 characters = alphabets.alphabet[:]
13 height = 32
14 batchsize = 16
15 nclass = len(characters) + 1
16
17
18 def init():
19 global inputs, outs, sess
20 if tf.test.gpu_device_name():
21 modelPath = './SDTR/crnn_bl2_gpu.pb'
22 config = tf.ConfigProto()
23 config.gpu_options.per_process_gpu_memory_fraction = 0.3
24 config.gpu_options.allow_growth = True
25 else:
26 modelPath = './SDTR/crnn_bl2.pb'
27 config = tf.ConfigProto()
28 print(modelPath)
29 session = tf.Session(config=config)
30 graph = tf.Graph()
31 with graph.as_default():
32 graph_def = tf.GraphDef()
33 with tf.gfile.GFile(modelPath, 'rb') as f:
34 graph_def.ParseFromString(f.read())
35
36 inputs = tf.placeholder(tf.float32, [None, 32, None, 1], name='X')
37 outs = tf.import_graph_def(
38 graph_def,
39 input_map={'the_input:0': inputs},
40 return_elements=['embedding2/Reshape_1:0'])
41 sess = tf.Session(graph=graph, config=config)
42
43
44 init()
45
46
47 def predict(im, boxes):
48 # global inputs, outs, sess
49 count_boxes = len(boxes)
50 boxes_max = sorted(boxes,
51 key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))),
52 reverse=True)
53
54 if len(boxes) % batchsize != 0:
55 add_box = np.expand_dims(boxes[-1], axis=0)
56 extend_num = batchsize - len(boxes) % batchsize
57 for i in range(extend_num):
58 boxes = np.concatenate((boxes, add_box), axis=0)
59
60 results = {}
61 labels = []
62 rectime = 0.0
63
64 if len(boxes) is not 0:
65 for i in range(int(len(boxes) / batchsize)):
66 slices = []
67 box = boxes_max[i * batchsize]
68 w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
69 width = int(32.0 * w / h)
70 # print(width)
71 if width < 24:
72 continue
73 for index, box in enumerate(boxes[i * batchsize:(i + 1) * batchsize]):
74 _box = [n for a in box for n in a]
75 if i * batchsize + index < count_boxes:
76 results[i * batchsize + index] = [np.array(_box)]
77 w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
78 # print(w)
79 pts1 = np.float32(box)
80 pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
81 M = cv2.getPerspectiveTransform(pts1, pts2)
82 im_crop = cv2.warpPerspective(im, M, (w, h))
83 im_crop = resize_img(im_crop, width)
84 slices.append(im_crop)
85 slices = np.array(slices)
86 # print(slices.shape)
87 recstart = time.time()
88 preds = sess.run(outs, feed_dict={inputs: slices})
89 # preds=model.predict(slices)
90 recend = time.time()
91 preds = preds[0]
92 # print(preds)
93 rectime += (recend - recstart) * 1000
94 # preds=preds[:,2:,:]
95 rec_labels = decode(preds)
96 labels.extend(rec_labels)
97 for index, label in enumerate(labels[:count_boxes]):
98 results[index].append(label.replace(' ', '').replace('¥', '¥'))
99 return results, rectime
100
101
102 def resize_img(im, width):
103 ori_h, ori_w = im.shape
104 ratio1 = width * 1.0 / ori_w
105 ratio2 = height * 1.0 / ori_h
106 if ratio1 < ratio2:
107 ratio = ratio1
108 else:
109 ratio = ratio2
110 new_w, new_h = int(ori_w * ratio), int(ori_h * ratio)
111 im = cv2.resize(im, (new_w, new_h))
112 delta_w = width - new_w
113 delta_h = height - new_h
114 top = delta_h // 2
115 bottom = delta_h - top
116 left = delta_w // 2
117 right = delta_w - left
118 img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255)
119 img = img / 255.0
120 img = (img - 0.5) / 0.5
121 X = img.reshape((height, width, 1))
122 return X
123
124
125 def decode(preds):
126 labels = []
127 charactersS = characters + u' '
128 tops = preds.argmax(axis=2)
129 for t in tops:
130 length = len(t)
131 char_list = []
132 for i in range(length):
133 if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
134 char_list.append(charactersS[t[i] - 1])
135 labels.append(u''.join(char_list))
136 return labels
1 # -*- coding: utf-8 -*-
2 ###
3 import cv2
4 import numpy as np
5 from SDTR import sdtr
6
7
8 if __name__ == '__main__':
9 test_img_path = './rec_test.png'
10 test_img = cv2.imread(test_img_path)
11 test_gray = cv2.cvtColor(test_img, cv2.COLOR_BGR2GRAY)
12 h, w = test_gray.shape
13 box = [np.array([[0, 0], [w, 0], [w, h], [0, h]])]
14 all_time = 0
15 rangetimes = 1001
16 for i in range(rangetimes):
17 results, rectime = sdtr.predict(test_gray, box)
18 print('{:.5f}ms'.format(rectime))
19 print(results)
20 if i != 0:
21 all_time += rectime
22 print('avgtime:{:.5f}ms'.format(all_time / (rangetimes - 1)))
1 SDTR_v1.0:
2 功能:对文字切片进行识别, 返回识别结果
3 使用方法:
4 from SDTR import sdtr
5 sdtr.predict(im, boxes)
6 输入参数:
7 im:opencv下的灰度图
8 boxes:numpy矩阵,大小为N*4*2, N为box个数,每个box包含四个坐标(x,y),且坐标顺序必须为:左上,右上,右下,左下
9 输出:
10 一个字典,包含key:[numpy矩阵box,string识别结果]
11 eg. {0: [array([ 0, 0, 626, 0, 626, 87, 0, 87]), '陆万壹仟叁佰圆整']}
12 环境:
13 tensorflow,最好是1.14版本,未在其他版本上测试
14
15
16 v1.0 2019.09.19
17 为CRNN基础
18 GPU版本用了CuDNNLSTM, 相比普通LSTM能减少1/2到2/3的时间
19
20 性能:
21 购车发票上平均全对率为93.25
22
23 对box缩放到32的高度,一张32*230的图
24 GPU时间:15ms
25 CPU时间:210ms
26
27 具体性能和box数量和box缩放到32高度时的宽度有关
rec_test.png

65.3 KB

Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!