64510cdb by 乔峰昇

update readme

1 parent 5e7dd86a
## 五大银行OCR+关键字段信息提取
python bank_ocr_inference.py
其中函数extract_bank_info()为总的流水信息提取函数,参数为bill_ocr()返回的OCR所有识别结果results
## yolov5推理
python inference.py
## OCR+yolov5整体Pipeline
python pipeline.py
\ No newline at end of file
......@@ -18,15 +18,18 @@ def path_to_file(file_path):
return f
# 流水OCR接口
def bill_ocr(image):
f = image_to_base64(image)
resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f})
resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f})
results = resp.json()
ocr_results = results['ocr_results']
return ocr_results
# 提取民生银行信息
def extract_minsheng_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '客户账号:'
results = []
......@@ -71,7 +74,7 @@ def extract_minsheng_info(ocr_results):
results.append([value[1], value[0]])
return results
# 提取工商银行信息
def extract_gongshang_info(ocr_results):
name_prefix = '户名:'
account_prefix = '卡号:'
......@@ -117,7 +120,7 @@ def extract_gongshang_info(ocr_results):
results.append([value[1], value[0]])
return results
# 提取中国银行信息
def extract_zhongguo_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '借记卡号:'
......@@ -163,7 +166,7 @@ def extract_zhongguo_info(ocr_results):
results.append([value[1], value[0]])
return results
# 提取建设银行信息
def extract_jianshe_info(ocr_results):
name_prefixes = ['客户名称:', '户名:']
account_prefixes = ['卡号/账号:', '卡号:']
......@@ -215,7 +218,7 @@ def extract_jianshe_info(ocr_results):
break
return results
# 提取农业银行信息(比较复杂,目前训练的版式都支持)
def extract_nongye_info(ocr_results):
name_prefixes = ['客户名:', '户名:']
account_prefixes = ['账号:']
......@@ -318,7 +321,7 @@ def extract_nongye_info(ocr_results):
break
return results
# 提取银行流水信息总接口
def extract_bank_info(ocr_results):
results = []
for value in ocr_results.values():
......@@ -341,6 +344,7 @@ def extract_bank_info(ocr_results):
if __name__ == '__main__':
path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
save_path='/data/situ_invoice_bill_data/new_data/results'
bank='minsheng'
......@@ -362,3 +366,4 @@ if __name__ == '__main__':
# cv2.waitKey(0)
cv2.imwrite(os.path.join(save_path,j),img)
print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
#
\ No newline at end of file
......
......@@ -54,6 +54,26 @@ def gen_result_dict(boxes, label_list=[], std=False):
return result
def keep_resize_padding(image):
h, w, c = image.shape
if h >= w:
pad1 = (h - w) // 2
pad2 = h - w - pad1
p1 = np.ones((h, pad1, 3)) * 114.0
p2 = np.ones((h, pad2, 3)) * 114.0
p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
new_image = np.hstack((p1, image, p2))
else:
pad1 = (w - h) // 2
pad2 = w - h - pad1
p1 = np.ones((pad1, w, 3)) * 114.0
p2 = np.ones((pad2, w, 3)) * 114.0
p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
new_image = np.vstack((p1, image, p2))
new_image = cv2.resize(new_image, (640, 640))
return new_image
class Yolov5:
def __init__(self, cfg=None):
self.cfg = cfg
......@@ -66,7 +86,16 @@ class Yolov5:
imgsz = check_img_size(self.cfg.imgsz, s=stride) # check image size
# Dataloader
bs = 1 # batch_size
im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize
# im = letterbox(image, imgsz, stride=stride, auto=True)[0] # padded resize
# hh, ww, cc = im.shape
# tlen1 = (640 - hh) // 2
# tlen2 = 640 - hh - tlen1
# t1 = np.zeros((tlen1, ww, cc))
# t2 = np.zeros((tlen2, ww, cc))
# im = np.vstack((t1, im, t2))
im = keep_resize_padding(image)
# print(im.shape)
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
# Run inference
......@@ -74,13 +103,15 @@ class Yolov5:
im = torch.from_numpy(im).to(self.model.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
im /= 255 # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # expand for batch dim
# Inference
pred = self.model(im, augment=False, visualize=False)
# print(pred[0].shape)
# exit(0)
# NMS
pred = non_max_suppression(pred, self.cfg.conf_thres, self.cfg.iou_thres, None, False, max_det=self.cfg.max_det)
det = pred[0]
# if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], image0.shape).round()
......@@ -95,13 +126,14 @@ class Yolov5:
if __name__ == "__main__":
img = cv2.imread(
'/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png')
'/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
detector = Yolov5(config)
result = detector.detect(img)
for i in result['result']:
position=list(i.values())[2:]
position = list(i.values())[2:]
print(position)
cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255))
cv2.imshow('w',img)
cv2.rectangle(img, (position[0], position[1]), (position[0] + position[2], position[1] + position[3]),
(0, 0, 255))
cv2.imshow('w', img)
cv2.waitKey(0)
print(result)
......
......@@ -2,7 +2,7 @@ from easydict import EasyDict as edict
config = edict(
# weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL
weights='runs/train/exp/weights/best.pt', # model path or triton URL
weights='runs/train/exp/weights/best.onnx', # model path or triton URL
data='data/VOC.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.2, # confidence threshold
......
import time
import cv2
from bank_ocr_inference import bill_ocr, extract_bank_info
from inference import Yolov5
from models.yolov5_config import config
......@@ -18,10 +16,9 @@ def enlarge_position(box):
def tamper_detect(image):
st = time.time()
ocr_results = bill_ocr(image)
et1=time.time()
et1 = time.time()
info_results = extract_bank_info(ocr_results)
et2=time.time()
print(info_results)
et2 = time.time()
tamper_results = []
if len(info_results) != 0:
for info_result in info_results:
......@@ -29,21 +26,21 @@ def tamper_detect(image):
x1, y1, x2, y2 = enlarge_position(box)
# x1, y1, x2, y2 = box
info_image = image[y1:y2, x1:x2, :]
cv2.imshow('info_image',info_image)
cv2.imshow('info_image', info_image)
results = detector.detect(info_image)
print(results)
if len(results['result'])!=0:
if len(results['result']) != 0:
for res in results['result']:
left = int(res['left'])
top = int(res['top'])
width = int(res['width'])
height = int(res['height'])
absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
tamper_results.append(absolute_position)
tamper_results .append(absolute_position)
print(tamper_results)
et3 = time.time()
print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}')
print(f'all:{et3 - st} ocr:{et1 - st} extract:{et2 - et1} yolo:{et3 - et2}')
for i in tamper_results:
cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
cv2.imshow('info', image)
......@@ -53,5 +50,5 @@ def tamper_detect(image):
if __name__ == '__main__':
detector = Yolov5(config)
image = cv2.imread(
"/home/situ/下载/_1597378020.731796page_33_img_0.jpg")
"/data/situ_invoice_bill_data/new_data/qfs_bank_bill_person_ps/gongshang/tampered/images/val/ps3/CH-B006369332_page_67_img_0.jpg")
tamper_detect(image)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!