import base64 import os import time import cv2 import numpy as np import requests import tqdm def image_to_base64(image): image = cv2.imencode('.png', image)[1] return image def path_to_file(file_path): f = open(file_path, 'rb') return f # 流水OCR接口 def bill_ocr(image): f = image_to_base64(image) resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f}) results = resp.json() ocr_results = results['ocr_results'] return ocr_results # 提取民生银行信息 def extract_minsheng_info(ocr_results): name_prefix = '客户姓名:' account_prefix = '客户账号:' results = [] for value in ocr_results.values(): if name_prefix in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) return results # 提取工商银行信息 def extract_gongshang_info(ocr_results): name_prefix = '户名:' account_prefix = '卡号:' results = [] for value in ocr_results.values(): if name_prefix in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) return results # 提取中国银行信息 def extract_zhongguo_info(ocr_results): name_prefix = '客户姓名:' account_prefix = '借记卡号:' results = [] for value in ocr_results.values(): if name_prefix in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) else: results.append([value[1], value[0]]) return results # 提取建设银行信息 def extract_jianshe_info(ocr_results): name_prefixes = ['客户名称:', '户名:'] account_prefixes = ['卡号/账号:', '卡号:'] results = [] for value in ocr_results.values(): for name_prefix in name_prefixes: if name_prefix in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break for account_prefix in account_prefixes: if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break return results # 提取农业银行信息(比较复杂,目前训练的版式都支持) def extract_nongye_info(ocr_results): name_prefixes = ['客户名:', '户名:'] account_prefixes = ['账号:'] results = [] is_account = True for value in ocr_results.values(): for name_prefix in name_prefixes: if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5: is_account = False if len(value[1]) == 5: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] tmp_info = {} for tmp in ocr_results.values(): if tmp[1] != value[1]: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: tmp_info[abs(tmp[0][0] - top_right_x)] = tmp else: continue tmp_info_id = sorted(tmp_info.keys()) if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19: tmp_value = tmp_info[tmp_info_id[0]] new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) if tmp_info[tmp_info_id[0]][1].isdigit(): tmp_value = tmp_info[tmp_info_id[1]] new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break elif len(value[1]) < 25: tmp_info = {} top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != value[1]: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2: tmp_info[abs(tmp[0][0] - top_right_x)] = tmp else: continue tmp_info_id = sorted(tmp_info.keys()) tmp_value = tmp_info[tmp_info_id[0]] new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break if is_account: for account_prefix in account_prefixes: if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break else: break return results # 提取银行流水信息总接口 def extract_bank_info(ocr_results): results = [] for value in ocr_results.values(): if value[1].__contains__('建设'): results = extract_jianshe_info(ocr_results) break elif value[1].__contains__('民生'): results = extract_minsheng_info(ocr_results) break elif value[1].__contains__('农业'): results = extract_nongye_info(ocr_results) break elif value[1].__contains__('中国银行'): results = extract_zhongguo_info(ocr_results) break elif value[1].__contains__('中国邮政储蓄'): results = extract_youchu_info(ocr_results) if len(results) == 0: results = extract_gongshang_info(ocr_results) return results def extract_youchu_info(ocr_results): name_prefixes = ['户名:'] account_prefixes = ['账号:', '卡号:'] results = [] for value in ocr_results.values(): for name_prefix in name_prefixes: if name_prefix in value[1]: if name_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != name_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break for account_prefix in account_prefixes: if account_prefix in value[1]: if account_prefix == value[1]: tmp_value, max_dis = [], 999999 top_right_x = value[0][2] top_right_y = value[0][3] for tmp in ocr_results.values(): if tmp[1] != account_prefix: if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs( tmp[0][0] - top_right_x) < max_dis: tmp_value = tmp max_dis = abs(tmp[0][0] - top_right_x) else: continue new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4], tmp_value[0][5], value[0][6], value[0][7]] results.append([value[1] + tmp_value[1], new_position]) break else: results.append([value[1], value[0]]) break return results if __name__ == '__main__': img = cv2.imread('/home/situ/下载/邮储对账单/飞书20221020-155202.jpg') ocr_results = bill_ocr(img) results = extract_youchu_info(ocr_results) print(results) # path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val' # save_path='/data/situ_invoice_bill_data/new_data/results' # bank='minsheng' # if not os.path.exists(os.path.join(save_path,bank)): # os.makedirs(os.path.join(save_path,bank)) # save_path=os.path.join(save_path,bank) # for j in tqdm.tqdm(os.listdir(path)): # # if True: # img=cv2.imread(os.path.join(path,j)) # # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg') # st = time.time() # ocr_result = bill_ocr(img) # et1 = time.time() # result = extract_bank_info(ocr_result) # et2 = time.time() # for i in range(len(result)): # cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2) # # cv2.imshow('img',img) # # cv2.waitKey(0) # cv2.imwrite(os.path.join(save_path,j),img) # print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))