5e7dd86a by 乔峰昇

add pipeline inference

1 parent 7c864e59
import base64
import os
import time
import cv2
import numpy as np
import requests
import tqdm
def image_to_base64(image):
image = cv2.imencode('.png', image)[1]
return image
def path_to_file(file_path):
f = open(file_path, 'rb')
return f
def bill_ocr(image):
f = image_to_base64(image)
resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f})
results = resp.json()
ocr_results = results['ocr_results']
return ocr_results
def extract_minsheng_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '客户账号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
def extract_gongshang_info(ocr_results):
name_prefix = '户名:'
account_prefix = '卡号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
def extract_zhongguo_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '借记卡号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
def extract_jianshe_info(ocr_results):
name_prefixes = ['客户名称:', '户名:']
account_prefixes = ['卡号/账号:', '卡号:']
results = []
for value in ocr_results.values():
for name_prefix in name_prefixes:
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
for account_prefix in account_prefixes:
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
return results
def extract_nongye_info(ocr_results):
name_prefixes = ['客户名:', '户名:']
account_prefixes = ['账号:']
results = []
is_account = True
for value in ocr_results.values():
for name_prefix in name_prefixes:
if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5:
is_account = False
if len(value[1]) == 5:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
tmp_info = {}
for tmp in ocr_results.values():
if tmp[1] != value[1]:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
else:
continue
tmp_info_id = sorted(tmp_info.keys())
if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19:
tmp_value = tmp_info[tmp_info_id[0]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
if tmp_info[tmp_info_id[0]][1].isdigit():
tmp_value = tmp_info[tmp_info_id[1]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
elif len(value[1]) < 25:
tmp_info = {}
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != value[1]:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
else:
continue
tmp_info_id = sorted(tmp_info.keys())
tmp_value = tmp_info[tmp_info_id[0]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
if is_account:
for account_prefix in account_prefixes:
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
else:
break
return results
def extract_bank_info(ocr_results):
results = []
for value in ocr_results.values():
if value[1].__contains__('建设'):
results = extract_jianshe_info(ocr_results)
break
elif value[1].__contains__('民生'):
results = extract_minsheng_info(ocr_results)
break
elif value[1].__contains__('农业'):
results = extract_nongye_info(ocr_results)
break
elif value[1].__contains__('中国银行'):
results = extract_zhongguo_info(ocr_results)
break
if len(results) == 0:
results = extract_gongshang_info(ocr_results)
return results
if __name__ == '__main__':
path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
save_path='/data/situ_invoice_bill_data/new_data/results'
bank='minsheng'
if not os.path.exists(os.path.join(save_path,bank)):
os.makedirs(os.path.join(save_path,bank))
save_path=os.path.join(save_path,bank)
for j in tqdm.tqdm(os.listdir(path)):
# if True:
img=cv2.imread(os.path.join(path,j))
# img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg')
st = time.time()
ocr_result = bill_ocr(img)
et1 = time.time()
result = extract_bank_info(ocr_result)
et2 = time.time()
for i in range(len(result)):
cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2)
# cv2.imshow('img',img)
# cv2.waitKey(0)
cv2.imwrite(os.path.join(save_path,j),img)
print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
......@@ -576,8 +576,8 @@ def run(
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)')
parser.add_argument('--data', type=str, default=ROOT / 'data/VOC.yaml', help='dataset.yaml path')
parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/exp/weights/best.pt', help='model.pt path(s)')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
......
......@@ -95,7 +95,13 @@ class Yolov5:
if __name__ == "__main__":
img = cv2.imread(
'/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/data/images/crop_img/_1594890230.8032346page_10_img_0_hname.jpg')
'/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png')
detector = Yolov5(config)
result = detector.detect(img)
for i in result['result']:
position=list(i.values())[2:]
print(position)
cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255))
cv2.imshow('w',img)
cv2.waitKey(0)
print(result)
......
from easydict import EasyDict as edict
config = edict(
# weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL
weights='runs/train/exp/weights/best.pt', # model path or triton URL
data='data/VOC.yaml', # dataset.yaml path
imgsz=(640, 640), # inference size (height, width)
conf_thres=0.5, # confidence threshold
conf_thres=0.2, # confidence threshold
iou_thres=0.45, # NMS IOU threshold
max_det=1000, # maximum detections per image
device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu
......
import time
import cv2
from bank_ocr_inference import bill_ocr, extract_bank_info
from inference import Yolov5
from models.yolov5_config import config
def enlarge_position(box):
x1, y1, x2, y2 = box
w, h = abs(x2 - x1), abs(y2 - y1)
y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
return [x1, y1, x2, y2]
def tamper_detect(image):
st = time.time()
ocr_results = bill_ocr(image)
et1=time.time()
info_results = extract_bank_info(ocr_results)
et2=time.time()
print(info_results)
tamper_results = []
if len(info_results) != 0:
for info_result in info_results:
box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
x1, y1, x2, y2 = enlarge_position(box)
# x1, y1, x2, y2 = box
info_image = image[y1:y2, x1:x2, :]
cv2.imshow('info_image',info_image)
results = detector.detect(info_image)
print(results)
if len(results['result'])!=0:
for res in results['result']:
left = int(res['left'])
top = int(res['top'])
width = int(res['width'])
height = int(res['height'])
absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
tamper_results.append(absolute_position)
print(tamper_results)
et3 = time.time()
print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}')
for i in tamper_results:
cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
cv2.imshow('info', image)
cv2.waitKey(0)
if __name__ == '__main__':
detector = Yolov5(config)
image = cv2.imread(
"/home/situ/下载/_1597378020.731796page_33_img_0.jpg")
tamper_detect(image)
......
......@@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions):
result = []
x1, y1, x2, y2 = crop_position
for p in predict_positions:
px1, py1, px2, py2,score = p
px1, py1, px2, py2, score = p
w, h = px2 - px1, py2 - py1
result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h,score])
result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h, score])
return result
......@@ -22,9 +22,9 @@ def decode_label(image, label_path):
result = []
for d in data:
d = [float(i) for i in d.strip().split(' ')]
cls, cx, cy, cw, ch,score = d
cls, cx, cy, cw, ch, score = d
cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h
result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2),score])
result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2), score])
return result
......@@ -38,28 +38,28 @@ if __name__ == '__main__':
data = pd.read_csv(crop_csv_path)
img_name = data.loc[:, 'img_name'].tolist()
crop_position1 = data.loc[:, 'name_crop_coord'].tolist()
crop_position2 = data.loc[:,'number_crop_coord'].tolist()
cc='/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3'
crop_position2 = data.loc[:, 'number_crop_coord'].tolist()
cc = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3'
for im in os.listdir(cc):
print(im)
img = cv2.imread(os.path.join(cc,im))
img_=img.copy()
img = cv2.imread(os.path.join(cc, im))
img_ = img.copy()
id = img_name.index(im)
name_crop_position=[int(i) for i in crop_position1[id].split(',')]
number_crop_position=[int(i) for i in crop_position2[id].split(',')]
nx1,ny1,nx2,ny2=name_crop_position
nux1,nuy1,nux2,nuy2=number_crop_position
if im[:-4]+'_hname.txt' in predict_labels:
name_crop_position = [int(i) for i in crop_position1[id].split(',')]
number_crop_position = [int(i) for i in crop_position2[id].split(',')]
nx1, ny1, nx2, ny2 = name_crop_position
nux1, nuy1, nux2, nuy2 = number_crop_position
if im[:-4] + '_hname.txt' in predict_labels:
h, w, c = img[ny1:ny2, nx1:nx2, :].shape
data = open(os.path.join(predict_label_path,im[:-4]+'_hname.txt')).readlines()
data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines()
for d in data:
cls,cx,cy,cw,ch,score = [float(i) for i in d.strip().split(' ')]
cx,cy,cw,ch=int(cx*w),int(cy*h),int(cw*w),int(ch*h)
cx1,cy1=cx-cw//2,cy-ch//2
x1,y1,x2,y2=nx1+cx1,ny1+cy1,nx1+cx1+cw,ny1+cy1+ch
cv2.rectangle(img,(x1,y1),(x2,y2),(0,0,255),2)
cv2.putText(img,f'tampered:{score}',(x1,y1-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),1)
cls, cx, cy, cw, ch, score = [float(i) for i in d.strip().split(' ')]
cx, cy, cw, ch = int(cx * w), int(cy * h), int(cw * w), int(ch * h)
cx1, cy1 = cx - cw // 2, cy - ch // 2
x1, y1, x2, y2 = nx1 + cx1, ny1 + cy1, nx1 + cx1 + cw, ny1 + cy1 + ch
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
if im[:-4] + '_hnumber.txt' in predict_labels:
h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape
data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines()
......@@ -70,5 +70,5 @@ if __name__ == '__main__':
x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
result = np.vstack((img_,img))
cv2.imwrite(f'z/{im}',result)
result = np.vstack((img_, img))
cv2.imwrite(f'z/{im}', result)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!