sdtr.py
4.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# -*- coding: utf-8 -*-
###
import os, sys
import cv2
from PIL import Image
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from SDTR import alphabets
characters = alphabets.alphabet[:]
height = 32
batchsize = 16
nclass = len(characters) + 1
def init():
global inputs, outs, sess
if tf.test.gpu_device_name():
modelPath = './SDTR/crnn_bl2_gpu.pb'
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
config.gpu_options.allow_growth = True
else:
modelPath = './SDTR/crnn_bl2.pb'
config = tf.ConfigProto()
print(modelPath)
session = tf.Session(config=config)
graph = tf.Graph()
with graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(modelPath, 'rb') as f:
graph_def.ParseFromString(f.read())
inputs = tf.placeholder(tf.float32, [None, 32, None, 1], name='X')
outs = tf.import_graph_def(
graph_def,
input_map={'the_input:0': inputs},
return_elements=['embedding2/Reshape_1:0'])
sess = tf.Session(graph=graph, config=config)
init()
def predict(im, boxes):
# global inputs, outs, sess
count_boxes = len(boxes)
boxes_max = sorted(boxes,
key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))),
reverse=True)
if len(boxes) % batchsize != 0:
add_box = np.expand_dims(boxes[-1], axis=0)
extend_num = batchsize - len(boxes) % batchsize
for i in range(extend_num):
boxes = np.concatenate((boxes, add_box), axis=0)
results = {}
labels = []
rectime = 0.0
if len(boxes) is not 0:
for i in range(int(len(boxes) / batchsize)):
slices = []
box = boxes_max[i * batchsize]
w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
width = int(32.0 * w / h)
# print(width)
if width < 24:
continue
for index, box in enumerate(boxes[i * batchsize:(i + 1) * batchsize]):
_box = [n for a in box for n in a]
if i * batchsize + index < count_boxes:
results[i * batchsize + index] = [np.array(_box)]
w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
# print(w)
pts1 = np.float32(box)
pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
M = cv2.getPerspectiveTransform(pts1, pts2)
im_crop = cv2.warpPerspective(im, M, (w, h))
im_crop = resize_img(im_crop, width)
slices.append(im_crop)
slices = np.array(slices)
# print(slices.shape)
recstart = time.time()
preds = sess.run(outs, feed_dict={inputs: slices})
# preds=model.predict(slices)
recend = time.time()
preds = preds[0]
# print(preds)
rectime += (recend - recstart) * 1000
# preds=preds[:,2:,:]
rec_labels = decode(preds)
labels.extend(rec_labels)
for index, label in enumerate(labels[:count_boxes]):
results[index].append(label.replace(' ', '').replace('¥', '¥'))
return results, rectime
def resize_img(im, width):
ori_h, ori_w = im.shape
ratio1 = width * 1.0 / ori_w
ratio2 = height * 1.0 / ori_h
if ratio1 < ratio2:
ratio = ratio1
else:
ratio = ratio2
new_w, new_h = int(ori_w * ratio), int(ori_h * ratio)
im = cv2.resize(im, (new_w, new_h))
delta_w = width - new_w
delta_h = height - new_h
top = delta_h // 2
bottom = delta_h - top
left = delta_w // 2
right = delta_w - left
img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255)
img = img / 255.0
img = (img - 0.5) / 0.5
X = img.reshape((height, width, 1))
return X
def decode(preds):
labels = []
charactersS = characters + u' '
tops = preds.argmax(axis=2)
for t in tops:
length = len(t)
char_list = []
for i in range(length):
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
char_list.append(charactersS[t[i] - 1])
labels.append(u''.join(char_list))
return labels