sdtr.py 4.39 KB
# -*- coding: utf-8 -*-
###
import os, sys
import cv2
from PIL import Image
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from SDTR import alphabets

characters = alphabets.alphabet[:]
height = 32
batchsize = 16
nclass = len(characters) + 1


def init():
    global inputs, outs, sess
    if tf.test.gpu_device_name():
        modelPath = './SDTR/crnn_bl2_gpu.pb'
        config = tf.ConfigProto()
        config.gpu_options.per_process_gpu_memory_fraction = 0.3
        config.gpu_options.allow_growth = True
    else:
        modelPath = './SDTR/crnn_bl2.pb'
        config = tf.ConfigProto()
    print(modelPath)
    session = tf.Session(config=config)
    graph = tf.Graph()
    with graph.as_default():
        graph_def = tf.GraphDef()
        with tf.gfile.GFile(modelPath, 'rb') as f:
            graph_def.ParseFromString(f.read())

        inputs = tf.placeholder(tf.float32, [None, 32, None, 1], name='X')
        outs = tf.import_graph_def(
            graph_def,
            input_map={'the_input:0': inputs},
            return_elements=['embedding2/Reshape_1:0'])
    sess = tf.Session(graph=graph, config=config)


init()


def predict(im, boxes):
    # global inputs, outs, sess
    count_boxes = len(boxes)
    boxes_max = sorted(boxes,
                   key=lambda box: int(32.0 * (np.linalg.norm(box[0] - box[1])) / (np.linalg.norm(box[3] - box[0]))),
                   reverse=True)

    if len(boxes) % batchsize != 0:
        add_box = np.expand_dims(boxes[-1], axis=0)
        extend_num = batchsize - len(boxes) % batchsize
        for i in range(extend_num):
            boxes = np.concatenate((boxes, add_box), axis=0)

    results = {}
    labels = []
    rectime = 0.0

    if len(boxes) is not 0:
        for i in range(int(len(boxes) / batchsize)):
            slices = []
            box = boxes_max[i * batchsize]
            w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
            width = int(32.0 * w / h)
            # print(width)
            if width < 24:
                continue
            for index, box in enumerate(boxes[i * batchsize:(i + 1) * batchsize]):
                _box = [n for a in box for n in a]
                if i * batchsize + index < count_boxes:
                    results[i * batchsize + index] = [np.array(_box)]
                w, h = [int(np.linalg.norm(box[0] - box[1])), int(np.linalg.norm(box[3] - box[0]))]
                # print(w)
                pts1 = np.float32(box)
                pts2 = np.float32([[0, 0], [w, 0], [w, h], [0, h]])
                M = cv2.getPerspectiveTransform(pts1, pts2)
                im_crop = cv2.warpPerspective(im, M, (w, h))
                im_crop = resize_img(im_crop, width)
                slices.append(im_crop)
            slices = np.array(slices)
            # print(slices.shape)
            recstart = time.time()
            preds = sess.run(outs, feed_dict={inputs: slices})
            # preds=model.predict(slices)
            recend = time.time()
            preds = preds[0]
            # print(preds)
            rectime += (recend - recstart) * 1000
            # preds=preds[:,2:,:]
            rec_labels = decode(preds)
            labels.extend(rec_labels)
        for index, label in enumerate(labels[:count_boxes]):
            results[index].append(label.replace(' ', '').replace('¥', '¥'))
        return results, rectime


def resize_img(im, width):
    ori_h, ori_w = im.shape
    ratio1 = width * 1.0 / ori_w
    ratio2 = height * 1.0 / ori_h
    if ratio1 < ratio2:
        ratio = ratio1
    else:
        ratio = ratio2
    new_w, new_h = int(ori_w * ratio), int(ori_h * ratio)
    im = cv2.resize(im, (new_w, new_h))
    delta_w = width - new_w
    delta_h = height - new_h
    top = delta_h // 2
    bottom = delta_h - top
    left = delta_w // 2
    right = delta_w - left
    img = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255)
    img = img / 255.0
    img = (img - 0.5) / 0.5
    X = img.reshape((height, width, 1))
    return X


def decode(preds):
    labels = []
    charactersS = characters + u' '
    tops = preds.argmax(axis=2)
    for t in tops:
        length = len(t)
        char_list = []
        for i in range(length):
            if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                char_list.append(charactersS[t[i] - 1])
        labels.append(u''.join(char_list))
    return labels