bank_ocr_inference.py 16.8 KB
import base64
import os
import time

import cv2
import numpy as np
import requests
import tqdm


def image_to_base64(image):
    image = cv2.imencode('.png', image)[1]
    return image


def path_to_file(file_path):
    f = open(file_path, 'rb')
    return f


# 流水OCR接口
def bill_ocr(image):
    f = image_to_base64(image)
    resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f})
    results = resp.json()
    ocr_results = results['ocr_results']
    return ocr_results


# 提取民生银行信息
def extract_minsheng_info(ocr_results):

    name_prefix = '客户姓名:'
    account_prefix = '客户账号:'
    results = []
    for value in ocr_results.values():
        if name_prefix in value[1]:
            if name_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != name_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
        if account_prefix in value[1]:
            if account_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != account_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
    return results

# 提取工商银行信息
def extract_gongshang_info(ocr_results):
    name_prefix = '户名:'
    account_prefix = '卡号:'
    results = []
    for value in ocr_results.values():
        if name_prefix in value[1]:
            if name_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != name_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
        if account_prefix in value[1]:
            if account_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != account_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
    return results

# 提取中国银行信息
def extract_zhongguo_info(ocr_results):
    name_prefix = '客户姓名:'
    account_prefix = '借记卡号:'
    results = []
    for value in ocr_results.values():
        if name_prefix in value[1]:
            if name_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != name_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
        if account_prefix in value[1]:
            if account_prefix == value[1]:
                tmp_value, max_dis = [], 999999
                top_right_x = value[0][2]
                top_right_y = value[0][3]
                for tmp in ocr_results.values():
                    if tmp[1] != account_prefix:
                        if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                tmp[0][0] - top_right_x) < max_dis:
                            tmp_value = tmp
                            max_dis = abs(tmp[0][0] - top_right_x)
                    else:
                        continue
                new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                tmp_value[0][5],
                                value[0][6], value[0][7]]
                results.append([value[1] + tmp_value[1], new_position])
            else:
                results.append([value[1], value[0]])
    return results

# 提取建设银行信息
def extract_jianshe_info(ocr_results):
    name_prefixes = ['客户名称:', '户名:']
    account_prefixes = ['卡号/账号:', '卡号:']
    results = []
    for value in ocr_results.values():
        for name_prefix in name_prefixes:
            if name_prefix in value[1]:
                if name_prefix == value[1]:
                    tmp_value, max_dis = [], 999999
                    top_right_x = value[0][2]
                    top_right_y = value[0][3]
                    for tmp in ocr_results.values():
                        if tmp[1] != name_prefix:
                            if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                    tmp[0][0] - top_right_x) < max_dis:
                                tmp_value = tmp
                                max_dis = abs(tmp[0][0] - top_right_x)
                        else:
                            continue
                    new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                    tmp_value[0][5],
                                    value[0][6], value[0][7]]
                    results.append([value[1] + tmp_value[1], new_position])
                    break
                else:
                    results.append([value[1], value[0]])
                    break
        for account_prefix in account_prefixes:
            if account_prefix in value[1]:
                if account_prefix == value[1]:
                    tmp_value, max_dis = [], 999999
                    top_right_x = value[0][2]
                    top_right_y = value[0][3]
                    for tmp in ocr_results.values():
                        if tmp[1] != account_prefix:
                            if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                    tmp[0][0] - top_right_x) < max_dis:
                                tmp_value = tmp
                                max_dis = abs(tmp[0][0] - top_right_x)
                        else:
                            continue
                    new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                    tmp_value[0][5],
                                    value[0][6], value[0][7]]
                    results.append([value[1] + tmp_value[1], new_position])
                    break
                else:
                    results.append([value[1], value[0]])
                    break
    return results

# 提取农业银行信息(比较复杂,目前训练的版式都支持)
def extract_nongye_info(ocr_results):
    name_prefixes = ['客户名:', '户名:']
    account_prefixes = ['账号:']
    results = []
    is_account = True
    for value in ocr_results.values():
        for name_prefix in name_prefixes:
            if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]:
                if name_prefix == value[1]:
                    tmp_value, max_dis = [], 999999
                    top_right_x = value[0][2]
                    top_right_y = value[0][3]
                    for tmp in ocr_results.values():
                        if tmp[1] != name_prefix:
                            if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                    tmp[0][0] - top_right_x) < max_dis:
                                tmp_value = tmp
                                max_dis = abs(tmp[0][0] - top_right_x)
                        else:
                            continue
                    new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                    tmp_value[0][5],
                                    value[0][6], value[0][7]]
                    results.append([value[1] + tmp_value[1], new_position])
                    break
                else:
                    results.append([value[1], value[0]])
                    break
            if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5:
                is_account = False
                if len(value[1]) == 5:
                    tmp_value, max_dis = [], 999999
                    top_right_x = value[0][2]
                    top_right_y = value[0][3]
                    tmp_info = {}
                    for tmp in ocr_results.values():
                        if tmp[1] != value[1]:
                            if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
                                tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
                        else:
                            continue
                    tmp_info_id = sorted(tmp_info.keys())
                    if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19:
                        tmp_value = tmp_info[tmp_info_id[0]]
                        new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                        tmp_value[0][5],
                                        value[0][6], value[0][7]]
                        results.append([value[1] + tmp_value[1], new_position])
                    if tmp_info[tmp_info_id[0]][1].isdigit():
                        tmp_value = tmp_info[tmp_info_id[1]]
                        new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                        tmp_value[0][5],
                                        value[0][6], value[0][7]]
                        results.append([value[1] + tmp_value[1], new_position])
                        break
                elif len(value[1]) < 25:
                    tmp_info = {}
                    top_right_x = value[0][2]
                    top_right_y = value[0][3]
                    for tmp in ocr_results.values():
                        if tmp[1] != value[1]:
                            if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
                                tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
                        else:
                            continue
                    tmp_info_id = sorted(tmp_info.keys())
                    tmp_value = tmp_info[tmp_info_id[0]]
                    new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                    tmp_value[0][5],
                                    value[0][6], value[0][7]]
                    results.append([value[1] + tmp_value[1], new_position])
                    break
                else:
                    results.append([value[1], value[0]])
                    break
        if is_account:
            for account_prefix in account_prefixes:
                if account_prefix in value[1]:
                    if account_prefix == value[1]:
                        tmp_value, max_dis = [], 999999
                        top_right_x = value[0][2]
                        top_right_y = value[0][3]
                        for tmp in ocr_results.values():
                            if tmp[1] != account_prefix:
                                if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
                                        tmp[0][0] - top_right_x) < max_dis:
                                    tmp_value = tmp
                                    max_dis = abs(tmp[0][0] - top_right_x)
                            else:
                                continue
                        new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
                                        tmp_value[0][5],
                                        value[0][6], value[0][7]]
                        results.append([value[1] + tmp_value[1], new_position])
                        break
                    else:
                        results.append([value[1], value[0]])
                        break
        else:
            break
    return results

# 提取银行流水信息总接口
def extract_bank_info(ocr_results):
    results = []
    for value in ocr_results.values():
        if value[1].__contains__('建设'):
            results = extract_jianshe_info(ocr_results)
            break
        elif value[1].__contains__('民生'):
            results = extract_minsheng_info(ocr_results)
            break
        elif value[1].__contains__('农业'):
            results = extract_nongye_info(ocr_results)
            break
        elif value[1].__contains__('中国银行'):
            results = extract_zhongguo_info(ocr_results)
            break
    if len(results) == 0:
        results = extract_gongshang_info(ocr_results)

    return results


if __name__ == '__main__':

    path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
    save_path='/data/situ_invoice_bill_data/new_data/results'
    bank='minsheng'
    if not os.path.exists(os.path.join(save_path,bank)):
        os.makedirs(os.path.join(save_path,bank))
    save_path=os.path.join(save_path,bank)
    for j in tqdm.tqdm(os.listdir(path)):
    # if True:
        img=cv2.imread(os.path.join(path,j))
        # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg')
        st = time.time()
        ocr_result = bill_ocr(img)
        et1 = time.time()
        result = extract_bank_info(ocr_result)
        et2 = time.time()
        for i in range(len(result)):
            cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2)
        # cv2.imshow('img',img)
        # cv2.waitKey(0)
        cv2.imwrite(os.path.join(save_path,j),img)
        print('spend:{}  ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
#