init project
0 parents
Showing
9 changed files
with
185 additions
and
0 deletions
SDTR/__pycache__/alphabets.cpython-36.pyc
0 → 100644
No preview for this file type
SDTR/__pycache__/sdtr.cpython-36.pyc
0 → 100644
No preview for this file type
SDTR/alphabets.py
0 → 100644
This diff is collapsed.
Click to expand it.
SDTR/crnn_bl2.pb
0 → 100644
This file is too large to display.
SDTR/crnn_bl2_gpu.pb
0 → 100644
This file is too large to display.
SDTR/sdtr.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | ### | ||
| 3 | import os, sys | ||
| 4 | import cv2 | ||
| 5 | from PIL import Image | ||
| 6 | import time | ||
| 7 | import numpy as np | ||
| 8 | import tensorflow as tf | ||
| 9 | from tensorflow.keras import backend as K | ||
| 10 | from SDTR import alphabets | ||
| 11 | |||
| 12 | characters = alphabets.alphabet[:] | ||
| 13 | height = 32 | ||
| 14 | batchsize = 16 | ||
| 15 | nclass = len(characters) + 1 | ||
| 16 | |||
| 17 | |||
| 18 | def init(): | ||
| 19 | global inputs, outs, sess | ||
| 20 | if tf.test.gpu_device_name(): | ||
| 21 | modelPath = './SDTR/crnn_bl2_gpu.pb' | ||
| 22 | config = tf.ConfigProto() | ||
| 23 | config.gpu_options.per_process_gpu_memory_fraction = 0.3 | ||
| 24 | config.gpu_options.allow_growth = True | ||
| 25 | else: | ||
| 26 | modelPath = './SDTR/crnn_bl2.pb' | ||
| 27 | config = tf.ConfigProto() | ||
| 28 | print(modelPath) | ||
| 29 | session = tf.Session(config=config) | ||
| 30 | graph = tf.Graph() | ||
| 31 | with graph.as_default(): | ||
| 32 | graph_def = tf.GraphDef() | ||
| 33 | with tf.gfile.GFile(modelPath, 'rb') as f: | ||
| 34 | graph_def.ParseFromString(f.read()) | ||
| 35 | |||
| 36 | inputs = tf.placeholder(tf.float32, [None, 32, None, 1], name='X') | ||
| 37 | outs = tf.import_graph_def( | ||
| 38 | graph_def, | ||
| 39 | input_map={'the_input:0': inputs}, | ||
| 40 | return_elements=['embedding2/Reshape_1:0']) | ||
| 41 | sess = tf.Session(graph=graph, config=config) | ||
| 42 | |||
| 43 | |||
| 44 | init() | ||
| 45 | |||
| 46 | |||
| 47 | def predict(im, boxes): | ||
| 48 | # global inputs, outs, sess | ||
| 49 | count_boxes = len(boxes) | ||
| 50 | boxes_max = sorted(boxes, | ||
| 51 | key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))), | ||
| 52 | reverse=True) | ||
| 53 | |||
| 54 | if len(boxes) % batchsize != 0: | ||
| 55 | add_box = np.expand_dims(boxes[-1], axis=0) | ||
| 56 | extend_num = batchsize - len(boxes) % batchsize | ||
| 57 | for i in range(extend_num): | ||
| 58 | boxes = np.concatenate((boxes, add_box), axis=0) | ||
| 59 | |||
| 60 | results = {} | ||
| 61 | labels = [] | ||
| 62 | rectime = 0.0 | ||
| 63 | |||
| 64 | if len(boxes) is not 0: | ||
| 65 | for i in range(int(len(boxes) / batchsize)): | ||
| 66 | slices = [] | ||
| 67 | box = boxes_max[i * batchsize] | ||
| 68 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
| 69 | width = int(32.0 * w / h) | ||
| 70 | # print(width) | ||
| 71 | if width < 24: | ||
| 72 | continue | ||
| 73 | for index, box in enumerate(boxes[i * batchsize:(i + 1) * batchsize]): | ||
| 74 | _box = [n for a in box for n in a] | ||
| 75 | if i * batchsize + index < count_boxes: | ||
| 76 | results[i * batchsize + index] = [np.array(_box)] | ||
| 77 | w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))] | ||
| 78 | # print(w) | ||
| 79 | pts1 = np.float32(box) | ||
| 80 | pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]]) | ||
| 81 | M = cv2.getPerspectiveTransform(pts1, pts2) | ||
| 82 | im_crop = cv2.warpPerspective(im, M, (w, h)) | ||
| 83 | im_crop = resize_img(im_crop, width) | ||
| 84 | slices.append(im_crop) | ||
| 85 | slices = np.array(slices) | ||
| 86 | # print(slices.shape) | ||
| 87 | recstart = time.time() | ||
| 88 | preds = sess.run(outs, feed_dict={inputs: slices}) | ||
| 89 | # preds=model.predict(slices) | ||
| 90 | recend = time.time() | ||
| 91 | preds = preds[0] | ||
| 92 | # print(preds) | ||
| 93 | rectime += (recend - recstart) * 1000 | ||
| 94 | # preds=preds[:,2:,:] | ||
| 95 | rec_labels = decode(preds) | ||
| 96 | labels.extend(rec_labels) | ||
| 97 | for index, label in enumerate(labels[:count_boxes]): | ||
| 98 | results[index].append(label.replace(' ', '').replace('¥', '¥')) | ||
| 99 | return results, rectime | ||
| 100 | |||
| 101 | |||
| 102 | def resize_img(im, width): | ||
| 103 | ori_h, ori_w = im.shape | ||
| 104 | ratio1 = width * 1.0 / ori_w | ||
| 105 | ratio2 = height * 1.0 / ori_h | ||
| 106 | if ratio1 < ratio2: | ||
| 107 | ratio = ratio1 | ||
| 108 | else: | ||
| 109 | ratio = ratio2 | ||
| 110 | new_w, new_h = int(ori_w * ratio), int(ori_h * ratio) | ||
| 111 | im = cv2.resize(im, (new_w, new_h)) | ||
| 112 | delta_w = width - new_w | ||
| 113 | delta_h = height - new_h | ||
| 114 | top = delta_h // 2 | ||
| 115 | bottom = delta_h - top | ||
| 116 | left = delta_w // 2 | ||
| 117 | right = delta_w - left | ||
| 118 | img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255) | ||
| 119 | img = img / 255.0 | ||
| 120 | img = (img - 0.5) / 0.5 | ||
| 121 | X = img.reshape((height, width, 1)) | ||
| 122 | return X | ||
| 123 | |||
| 124 | |||
| 125 | def decode(preds): | ||
| 126 | labels = [] | ||
| 127 | charactersS = characters + u' ' | ||
| 128 | tops = preds.argmax(axis=2) | ||
| 129 | for t in tops: | ||
| 130 | length = len(t) | ||
| 131 | char_list = [] | ||
| 132 | for i in range(length): | ||
| 133 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): | ||
| 134 | char_list.append(charactersS[t[i] - 1]) | ||
| 135 | labels.append(u''.join(char_list)) | ||
| 136 | return labels |
demo.py
0 → 100644
| 1 | # -*- coding: utf-8 -*- | ||
| 2 | ### | ||
| 3 | import cv2 | ||
| 4 | import numpy as np | ||
| 5 | from SDTR import sdtr | ||
| 6 | |||
| 7 | |||
| 8 | if __name__ == '__main__': | ||
| 9 | test_img_path = './rec_test.png' | ||
| 10 | test_img = cv2.imread(test_img_path) | ||
| 11 | test_gray = cv2.cvtColor(test_img, cv2.COLOR_BGR2GRAY) | ||
| 12 | h, w = test_gray.shape | ||
| 13 | box = [np.array([[0, 0], [w, 0], [w, h], [0, h]])] | ||
| 14 | all_time = 0 | ||
| 15 | rangetimes = 1001 | ||
| 16 | for i in range(rangetimes): | ||
| 17 | results, rectime = sdtr.predict(test_gray, box) | ||
| 18 | print('{:.5f}ms'.format(rectime)) | ||
| 19 | print(results) | ||
| 20 | if i != 0: | ||
| 21 | all_time += rectime | ||
| 22 | print('avgtime:{:.5f}ms'.format(all_time / (rangetimes - 1))) |
readme.txt
0 → 100644
| 1 | SDTR_v1.0: | ||
| 2 | 功能:对文字切片进行识别, 返回识别结果 | ||
| 3 | 使用方法: | ||
| 4 | from SDTR import sdtr | ||
| 5 | sdtr.predict(im, boxes) | ||
| 6 | 输入参数: | ||
| 7 | im:opencv下的灰度图 | ||
| 8 | boxes:numpy矩阵,大小为N*4*2, N为box个数,每个box包含四个坐标(x,y),且坐标顺序必须为:左上,右上,右下,左下 | ||
| 9 | 输出: | ||
| 10 | 一个字典,包含key:[numpy矩阵box,string识别结果] | ||
| 11 | eg. {0: [array([ 0, 0, 626, 0, 626, 87, 0, 87]), '陆万壹仟叁佰圆整']} | ||
| 12 | 环境: | ||
| 13 | tensorflow,最好是1.14版本,未在其他版本上测试 | ||
| 14 | |||
| 15 | |||
| 16 | v1.0 2019.09.19 | ||
| 17 | 为CRNN基础 | ||
| 18 | GPU版本用了CuDNNLSTM, 相比普通LSTM能减少1/2到2/3的时间 | ||
| 19 | |||
| 20 | 性能: | ||
| 21 | 购车发票上平均全对率为93.25 | ||
| 22 | |||
| 23 | 对box缩放到32的高度,一张32*230的图 | ||
| 24 | GPU时间:15ms | ||
| 25 | CPU时间:210ms | ||
| 26 | |||
| 27 | 具体性能和box数量和box缩放到32高度时的宽度有关 |
rec_test.png
0 → 100644
65.3 KB
-
Please register or sign in to post a comment