d479b4ec by 乔峰昇

ocr_yolo triton-inference-server

0 parents
model_repository/
.idea/
OCR_Engine @ 3dddc11a
Subproject commit 3dddc11a8a1d369ca4fbd0b69e4e21e6af81cc4c
## OCR+yolov5 triton-inference-server服务
1.使用docker启动triton服务
sudo docker run --gpus="device=0" --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v /home/situ/qfs/triton_inference_server/demo/model_repository:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models
2.分别启动OCR和yolov5的web服务
cd OCR_Engine/api
python ocr_engine_server.py
cd yolov5_onnx_demo/api
python yolov5_onnx_server.py
3.pipeline测试
python triton_pipeline.py
import base64
import os
import time
import cv2
import numpy as np
import requests
import tqdm
def image_to_base64(image):
image = cv2.imencode('.png', image)[1]
return image
def path_to_file(file_path):
f = open(file_path, 'rb')
return f
# 流水OCR接口
def bill_ocr(image):
f = image_to_base64(image)
resp = requests.post(url=r'http://192.168.10.11:9001/gen_ocr', files={'file': f})
results = resp.json()
ocr_results = results['ocr_results']
return ocr_results
# 提取民生银行信息
def extract_minsheng_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '客户账号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
# 提取工商银行信息
def extract_gongshang_info(ocr_results):
name_prefix = '户名:'
account_prefix = '卡号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
# 提取中国银行信息
def extract_zhongguo_info(ocr_results):
name_prefix = '客户姓名:'
account_prefix = '借记卡号:'
results = []
for value in ocr_results.values():
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
else:
results.append([value[1], value[0]])
return results
# 提取建设银行信息
def extract_jianshe_info(ocr_results):
name_prefixes = ['客户名称:', '户名:']
account_prefixes = ['卡号/账号:', '卡号:']
results = []
for value in ocr_results.values():
for name_prefix in name_prefixes:
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
for account_prefix in account_prefixes:
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
return results
# 提取农业银行信息(比较复杂,目前训练的版式都支持)
def extract_nongye_info(ocr_results):
name_prefixes = ['客户名:', '户名:']
account_prefixes = ['账号:']
results = []
is_account = True
for value in ocr_results.values():
for name_prefix in name_prefixes:
if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5:
is_account = False
if len(value[1]) == 5:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
tmp_info = {}
for tmp in ocr_results.values():
if tmp[1] != value[1]:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
else:
continue
tmp_info_id = sorted(tmp_info.keys())
if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19:
tmp_value = tmp_info[tmp_info_id[0]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
if tmp_info[tmp_info_id[0]][1].isdigit():
tmp_value = tmp_info[tmp_info_id[1]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
elif len(value[1]) < 25:
tmp_info = {}
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != value[1]:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
else:
continue
tmp_info_id = sorted(tmp_info.keys())
tmp_value = tmp_info[tmp_info_id[0]]
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
if is_account:
for account_prefix in account_prefixes:
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
else:
break
return results
# 提取银行流水信息总接口
def extract_bank_info(ocr_results):
results = []
for value in ocr_results.values():
if value[1].__contains__('建设'):
results = extract_jianshe_info(ocr_results)
break
elif value[1].__contains__('民生'):
results = extract_minsheng_info(ocr_results)
break
elif value[1].__contains__('农业'):
results = extract_nongye_info(ocr_results)
break
elif value[1].__contains__('中国银行'):
results = extract_zhongguo_info(ocr_results)
break
elif value[1].__contains__('邮政'):
results = extract_youchu_info(ocr_results)
if len(results) == 0:
results = extract_gongshang_info(ocr_results)
return results
def extract_youchu_info(ocr_results):
name_prefixes = ['户名:']
account_prefixes = ['账号:', '卡号:']
results = []
for value in ocr_results.values():
for name_prefix in name_prefixes:
if name_prefix in value[1]:
if name_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != name_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
for account_prefix in account_prefixes:
if account_prefix in value[1]:
if account_prefix == value[1]:
tmp_value, max_dis = [], 999999
top_right_x = value[0][2]
top_right_y = value[0][3]
for tmp in ocr_results.values():
if tmp[1] != account_prefix:
if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
tmp[0][0] - top_right_x) < max_dis:
tmp_value = tmp
max_dis = abs(tmp[0][0] - top_right_x)
else:
continue
new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
tmp_value[0][5],
value[0][6], value[0][7]]
results.append([value[1] + tmp_value[1], new_position])
break
else:
results.append([value[1], value[0]])
break
return results
if __name__ == '__main__':
img = cv2.imread('/home/situ/下载/邮储对账单/飞书20221020-155202.jpg')
ocr_results = bill_ocr(img)
results = extract_youchu_info(ocr_results)
print(results)
# path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
# save_path='/data/situ_invoice_bill_data/new_data/results'
# bank='minsheng'
# if not os.path.exists(os.path.join(save_path,bank)):
# os.makedirs(os.path.join(save_path,bank))
# save_path=os.path.join(save_path,bank)
# for j in tqdm.tqdm(os.listdir(path)):
# # if True:
# img=cv2.imread(os.path.join(path,j))
# # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg')
# st = time.time()
# ocr_result = bill_ocr(img)
# et1 = time.time()
# result = extract_bank_info(ocr_result)
# et2 = time.time()
# for i in range(len(result)):
# cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2)
# # cv2.imshow('img',img)
# # cv2.waitKey(0)
# cv2.imwrite(os.path.join(save_path,j),img)
# print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
import base64
import json
from bank_ocr_inference import *
def enlarge_position(box):
x1, y1, x2, y2 = box
w, h = abs(x2 - x1), abs(y2 - y1)
y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
return [x1, y1, x2, y2]
def path_base64(file_path):
f = open(file_path, 'rb')
file64 = base64.b64encode(f.read()) # image 64 bytes 类型
file64 = file64.decode('utf-8')
return file64
def bgr_base64(image):
_, img64 = cv2.imencode('.jpg', image)
img64 = base64.b64encode(img64)
return img64.decode('utf-8')
def base64_bgr(img64):
str_img64 = base64.b64decode(img64)
image = np.frombuffer(str_img64, np.uint8)
image = cv2.imdecode(image, cv2.IMREAD_COLOR)
return image
def tamper_detect_(image):
img64 = bgr_base64(image)
resp = requests.post(url=r'http://192.168.10.11:8009/tamper_det', data=json.dumps({'img': img64}))
results = resp.json()
return results
if __name__ == '__main__':
image = cv2.imread(
'/data/situ_invoice_bill_data/银行流水样本/普通打印-部分格线-竖版-农业银行-8列/_1594626974.367834page_20_img_0.jpg')
st = time.time()
ocr_results = bill_ocr(image)
et1 = time.time()
info_results = extract_bank_info(ocr_results)
et2 = time.time()
tamper_results = []
if len(info_results) != 0:
for info_result in info_results:
box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
x1, y1, x2, y2 = enlarge_position(box)
# x1, y1, x2, y2 = box
info_image = image[y1:y2, x1:x2, :]
results = tamper_detect_(info_image)
print(results)
if len(results['results']) != 0:
for res in results['results']:
cx = int(res[0])
cy = int(res[1])
width = int(res[2])
height = int(res[3])
left = cx - width // 2
top = cy - height // 2
absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
# absolute_position = [x1+left, y1+top, x2, y2]
tamper_results.append(absolute_position)
et3 = time.time()
print(tamper_results)
print(f'all time:{et3 - st} ocr time:{et1 - st} extract info time:{et2 - et1} yolo time:{et3 - et2}')
for i in tamper_results:
cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
cv2.imshow('info', image)
cv2.waitKey(0)
import base64
import cv2
import numpy as np
from sanic import Sanic
from sanic.response import json
from yolov5_onnx_demo.model.yolov5_infer import *
def base64_to_bgr(bs64):
img_data = base64.b64decode(bs64)
img_arr = np.fromstring(img_data, np.uint8)
img_np = cv2.imdecode(img_arr, cv2.IMREAD_COLOR)
return img_np
app = Sanic('tamper_det')
@app.post('/tamper_det')
def hello(request):
d = request.json
print(d['img'])
img = base64_to_bgr(d['img'])
result = grpc_detect(img)
return json({'results': result})
if __name__ == '__main__':
app.run(host='192.168.10.11', port=8009,workers=10)
import base64
import requests
import json
from yolov5_onnx_demo.model.yolov5_infer import *
def path_base64(file_path):
f = open(file_path, 'rb')
file64 = base64.b64encode(f.read()) # image 64 bytes 类型
file64 = file64.decode('utf-8')
return file64
res = requests.post('http://192.168.10.11:8009/tamper_det', data=json.dumps(
{'img': path_base64('/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')}))
results = res.json()
img = cv2.imread(
'/data/situ_invoice_bill_data/qfs_train_val_data/train_data/machine/minsheng/images/train/_1597386625.07514page_20_img_0_machine_name_full_splicing.jpg')
print(res)
plot_label(img,results['keys'])
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
def keep_resize_padding(image):
'''
注意由于输入需要固定640*640的大小,而官方的推理为了加速采用了最小缩放比的方式进行
导致输入的尺寸不固定,重写resize方法,添加padding到640*640
'''
h, w, c = image.shape
if h >= w:
pad1 = (h - w) // 2
pad2 = h - w - pad1
p1 = np.ones((h, pad1, 3)) * 114.0
p2 = np.ones((h, pad2, 3)) * 114.0
p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
new_image = np.hstack((p1, image, p2))
padding_info = [pad1, pad2, 0]
else:
pad1 = (w - h) // 2
pad2 = w - h - pad1
p1 = np.ones((pad1, w, 3)) * 114.0
p2 = np.ones((pad2, w, 3)) * 114.0
p1, p2 = p1.astype(np.uint8), p2.astype(np.uint8)
new_image = np.vstack((p1, image, p2))
padding_info = [pad1, pad2, 1]
new_image = cv2.resize(new_image, (640, 640))
return new_image, padding_info
# remove padding
def extract_authentic_bboxes(image, padding_info, bboxes):
'''
反算坐标信息
'''
pad1, pad2, pad_type = padding_info
h, w, c = image.shape
bboxes = np.array(bboxes)
max_slide = max(h, w)
scale = max_slide / 640
bboxes[:, :4] = bboxes[:, :4] * scale
if pad_type == 0:
bboxes[:, 0] = bboxes[:, 0] - pad1
else:
bboxes[:, 1] = bboxes[:, 1] - pad1
return bboxes.tolist()
# NMS
def py_nms_cpu(
prediction,
conf_thres=0.25,
iou_thres=0.45,
):
"""Non-Maximum Suppression (NMS) on inference results to reject overlapping detections
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
xc = prediction[..., 4] > conf_thres # candidates
prediction = prediction[xc]
# MNS
x1 = prediction[..., 0] - prediction[..., 2] / 2
y1 = prediction[..., 1] - prediction[..., 3] / 2
x2 = prediction[..., 0] + prediction[..., 2] / 2
y2 = prediction[..., 1] + prediction[..., 3] / 2
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
score = prediction[..., 5]
order = np.argsort(score)
keep = []
while order.size > 0:
i = order[0]
keep.append(i)
xx1 = np.maximum(x1[i], x1[order[1:]])
yy1 = np.maximum(y1[i], y1[order[1:]])
xx2 = np.minimum(x2[i], x2[order[1:]])
yy2 = np.minimum(y2[i], y2[order[1:]])
ww, hh = np.maximum(0, xx2 - xx1 + 1), np.maximum(0, yy2 - yy1 + 1)
inter = ww * hh
over = inter / (areas[i] + areas[order[1:]] - inter)
idx = np.where(over < iou_thres)[0]
order = order[idx + 1]
return prediction[keep]
def client_init(url='localhost:8001',
ssl=False,
private_key=None,
root_certificates=None,
certificate_chain=None,
verbose=False):
triton_client = grpcclient.InferenceServerClient(
url=url,
verbose=verbose, # 详细输出 默认是False
ssl=ssl,
root_certificates=root_certificates,
private_key=private_key,
certificate_chain=certificate_chain,
)
return triton_client
triton_client = client_init('localhost:8001')
compression_algorithm = None
input_name = 'images'
output_name = 'output0'
model_name = 'yolov5'
def grpc_detect(img):
image, padding_info = keep_resize_padding(img)
image = image.transpose((2, 0, 1))[::-1]
image = image.astype(np.float32)
image = image / 255.0
if len(image.shape) == 3:
image = image[None]
outputs, inputs = [], []
# 动态输入
input_shape = image.shape
inputs.append(grpcclient.InferInput(input_name, input_shape, 'FP32'))
outputs.append(grpcclient.InferRequestedOutput(output_name))
inputs[0].set_data_from_numpy(image.astype(np.float32))
pred = triton_client.infer(
model_name=model_name,
inputs=inputs, outputs=outputs,
compression_algorithm=compression_algorithm
)
pred = pred.as_numpy(output_name).copy()
result_bboxes = py_nms_cpu(pred)
result_bboxes = extract_authentic_bboxes(img, padding_info, result_bboxes)
return result_bboxes
def plot_label(img, result_bboxes):
print(result_bboxes)
for bbox in result_bboxes:
x, y, w, h, conf, cls = bbox
cv2.rectangle(img, (int(x - w // 2), int(y - h // 2)), (int(x + w // 2), int(y + h // 2)), (0, 0, 255), 2)
cv2.imshow('im', img)
cv2.waitKey(0)
if __name__ == '__main__':
img = cv2.imread(
'/data/situ_invoice_bill_data/qfs_train_val_data/train_data/authentic/gongshang/images/val/_1594890232.0110397page_11_img_0_name_au_gongshang.jpg')
result_bboxes = grpc_detect(img)
plot_label(result_bboxes)
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!