5e7dd86a by 乔峰昇

add pipeline inference

1 parent 7c864e59
1 import base64
2 import os
3 import time
4
5 import cv2
6 import numpy as np
7 import requests
8 import tqdm
9
10
11 def image_to_base64(image):
12 image = cv2.imencode('.png', image)[1]
13 return image
14
15
16 def path_to_file(file_path):
17 f = open(file_path, 'rb')
18 return f
19
20
21 def bill_ocr(image):
22 f = image_to_base64(image)
23 resp = requests.post(url=r'http://139.196.149.46:9001/gen_ocr', files={'file': f})
24 results = resp.json()
25 ocr_results = results['ocr_results']
26 return ocr_results
27
28
29 def extract_minsheng_info(ocr_results):
30 name_prefix = '客户姓名:'
31 account_prefix = '客户账号:'
32 results = []
33 for value in ocr_results.values():
34 if name_prefix in value[1]:
35 if name_prefix == value[1]:
36 tmp_value, max_dis = [], 999999
37 top_right_x = value[0][2]
38 top_right_y = value[0][3]
39 for tmp in ocr_results.values():
40 if tmp[1] != name_prefix:
41 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
42 tmp[0][0] - top_right_x) < max_dis:
43 tmp_value = tmp
44 max_dis = abs(tmp[0][0] - top_right_x)
45 else:
46 continue
47 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
48 tmp_value[0][5],
49 value[0][6], value[0][7]]
50 results.append([value[1] + tmp_value[1], new_position])
51 else:
52 results.append([value[1], value[0]])
53 if account_prefix in value[1]:
54 if account_prefix == value[1]:
55 tmp_value, max_dis = [], 999999
56 top_right_x = value[0][2]
57 top_right_y = value[0][3]
58 for tmp in ocr_results.values():
59 if tmp[1] != account_prefix:
60 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
61 tmp[0][0] - top_right_x) < max_dis:
62 tmp_value = tmp
63 max_dis = abs(tmp[0][0] - top_right_x)
64 else:
65 continue
66 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
67 tmp_value[0][5],
68 value[0][6], value[0][7]]
69 results.append([value[1] + tmp_value[1], new_position])
70 else:
71 results.append([value[1], value[0]])
72 return results
73
74
75 def extract_gongshang_info(ocr_results):
76 name_prefix = '户名:'
77 account_prefix = '卡号:'
78 results = []
79 for value in ocr_results.values():
80 if name_prefix in value[1]:
81 if name_prefix == value[1]:
82 tmp_value, max_dis = [], 999999
83 top_right_x = value[0][2]
84 top_right_y = value[0][3]
85 for tmp in ocr_results.values():
86 if tmp[1] != name_prefix:
87 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
88 tmp[0][0] - top_right_x) < max_dis:
89 tmp_value = tmp
90 max_dis = abs(tmp[0][0] - top_right_x)
91 else:
92 continue
93 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
94 tmp_value[0][5],
95 value[0][6], value[0][7]]
96 results.append([value[1] + tmp_value[1], new_position])
97 else:
98 results.append([value[1], value[0]])
99 if account_prefix in value[1]:
100 if account_prefix == value[1]:
101 tmp_value, max_dis = [], 999999
102 top_right_x = value[0][2]
103 top_right_y = value[0][3]
104 for tmp in ocr_results.values():
105 if tmp[1] != account_prefix:
106 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
107 tmp[0][0] - top_right_x) < max_dis:
108 tmp_value = tmp
109 max_dis = abs(tmp[0][0] - top_right_x)
110 else:
111 continue
112 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
113 tmp_value[0][5],
114 value[0][6], value[0][7]]
115 results.append([value[1] + tmp_value[1], new_position])
116 else:
117 results.append([value[1], value[0]])
118 return results
119
120
121 def extract_zhongguo_info(ocr_results):
122 name_prefix = '客户姓名:'
123 account_prefix = '借记卡号:'
124 results = []
125 for value in ocr_results.values():
126 if name_prefix in value[1]:
127 if name_prefix == value[1]:
128 tmp_value, max_dis = [], 999999
129 top_right_x = value[0][2]
130 top_right_y = value[0][3]
131 for tmp in ocr_results.values():
132 if tmp[1] != name_prefix:
133 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
134 tmp[0][0] - top_right_x) < max_dis:
135 tmp_value = tmp
136 max_dis = abs(tmp[0][0] - top_right_x)
137 else:
138 continue
139 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
140 tmp_value[0][5],
141 value[0][6], value[0][7]]
142 results.append([value[1] + tmp_value[1], new_position])
143 else:
144 results.append([value[1], value[0]])
145 if account_prefix in value[1]:
146 if account_prefix == value[1]:
147 tmp_value, max_dis = [], 999999
148 top_right_x = value[0][2]
149 top_right_y = value[0][3]
150 for tmp in ocr_results.values():
151 if tmp[1] != account_prefix:
152 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
153 tmp[0][0] - top_right_x) < max_dis:
154 tmp_value = tmp
155 max_dis = abs(tmp[0][0] - top_right_x)
156 else:
157 continue
158 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
159 tmp_value[0][5],
160 value[0][6], value[0][7]]
161 results.append([value[1] + tmp_value[1], new_position])
162 else:
163 results.append([value[1], value[0]])
164 return results
165
166
167 def extract_jianshe_info(ocr_results):
168 name_prefixes = ['客户名称:', '户名:']
169 account_prefixes = ['卡号/账号:', '卡号:']
170 results = []
171 for value in ocr_results.values():
172 for name_prefix in name_prefixes:
173 if name_prefix in value[1]:
174 if name_prefix == value[1]:
175 tmp_value, max_dis = [], 999999
176 top_right_x = value[0][2]
177 top_right_y = value[0][3]
178 for tmp in ocr_results.values():
179 if tmp[1] != name_prefix:
180 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
181 tmp[0][0] - top_right_x) < max_dis:
182 tmp_value = tmp
183 max_dis = abs(tmp[0][0] - top_right_x)
184 else:
185 continue
186 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
187 tmp_value[0][5],
188 value[0][6], value[0][7]]
189 results.append([value[1] + tmp_value[1], new_position])
190 break
191 else:
192 results.append([value[1], value[0]])
193 break
194 for account_prefix in account_prefixes:
195 if account_prefix in value[1]:
196 if account_prefix == value[1]:
197 tmp_value, max_dis = [], 999999
198 top_right_x = value[0][2]
199 top_right_y = value[0][3]
200 for tmp in ocr_results.values():
201 if tmp[1] != account_prefix:
202 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
203 tmp[0][0] - top_right_x) < max_dis:
204 tmp_value = tmp
205 max_dis = abs(tmp[0][0] - top_right_x)
206 else:
207 continue
208 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
209 tmp_value[0][5],
210 value[0][6], value[0][7]]
211 results.append([value[1] + tmp_value[1], new_position])
212 break
213 else:
214 results.append([value[1], value[0]])
215 break
216 return results
217
218
219 def extract_nongye_info(ocr_results):
220 name_prefixes = ['客户名:', '户名:']
221 account_prefixes = ['账号:']
222 results = []
223 is_account = True
224 for value in ocr_results.values():
225 for name_prefix in name_prefixes:
226 if name_prefix in value[1] and account_prefixes[0][:-1] not in value[1]:
227 if name_prefix == value[1]:
228 tmp_value, max_dis = [], 999999
229 top_right_x = value[0][2]
230 top_right_y = value[0][3]
231 for tmp in ocr_results.values():
232 if tmp[1] != name_prefix:
233 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
234 tmp[0][0] - top_right_x) < max_dis:
235 tmp_value = tmp
236 max_dis = abs(tmp[0][0] - top_right_x)
237 else:
238 continue
239 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
240 tmp_value[0][5],
241 value[0][6], value[0][7]]
242 results.append([value[1] + tmp_value[1], new_position])
243 break
244 else:
245 results.append([value[1], value[0]])
246 break
247 if name_prefix in value[1] and account_prefixes[0][:-1] in value[1] and len(value[1].split(":")[0]) <= 5:
248 is_account = False
249 if len(value[1]) == 5:
250 tmp_value, max_dis = [], 999999
251 top_right_x = value[0][2]
252 top_right_y = value[0][3]
253 tmp_info = {}
254 for tmp in ocr_results.values():
255 if tmp[1] != value[1]:
256 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
257 tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
258 else:
259 continue
260 tmp_info_id = sorted(tmp_info.keys())
261 if not tmp_info[tmp_info_id[0]][1].isdigit() and len(tmp_info[tmp_info_id[0]][1]) > 19:
262 tmp_value = tmp_info[tmp_info_id[0]]
263 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
264 tmp_value[0][5],
265 value[0][6], value[0][7]]
266 results.append([value[1] + tmp_value[1], new_position])
267 if tmp_info[tmp_info_id[0]][1].isdigit():
268 tmp_value = tmp_info[tmp_info_id[1]]
269 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
270 tmp_value[0][5],
271 value[0][6], value[0][7]]
272 results.append([value[1] + tmp_value[1], new_position])
273 break
274 elif len(value[1]) < 25:
275 tmp_info = {}
276 top_right_x = value[0][2]
277 top_right_y = value[0][3]
278 for tmp in ocr_results.values():
279 if tmp[1] != value[1]:
280 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2:
281 tmp_info[abs(tmp[0][0] - top_right_x)] = tmp
282 else:
283 continue
284 tmp_info_id = sorted(tmp_info.keys())
285 tmp_value = tmp_info[tmp_info_id[0]]
286 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
287 tmp_value[0][5],
288 value[0][6], value[0][7]]
289 results.append([value[1] + tmp_value[1], new_position])
290 break
291 else:
292 results.append([value[1], value[0]])
293 break
294 if is_account:
295 for account_prefix in account_prefixes:
296 if account_prefix in value[1]:
297 if account_prefix == value[1]:
298 tmp_value, max_dis = [], 999999
299 top_right_x = value[0][2]
300 top_right_y = value[0][3]
301 for tmp in ocr_results.values():
302 if tmp[1] != account_prefix:
303 if abs(tmp[0][1] - top_right_y) < abs(value[0][3] - value[0][5]) / 2 and abs(
304 tmp[0][0] - top_right_x) < max_dis:
305 tmp_value = tmp
306 max_dis = abs(tmp[0][0] - top_right_x)
307 else:
308 continue
309 new_position = [value[0][0], value[0][1], tmp_value[0][2], tmp_value[0][3], tmp_value[0][4],
310 tmp_value[0][5],
311 value[0][6], value[0][7]]
312 results.append([value[1] + tmp_value[1], new_position])
313 break
314 else:
315 results.append([value[1], value[0]])
316 break
317 else:
318 break
319 return results
320
321
322 def extract_bank_info(ocr_results):
323 results = []
324 for value in ocr_results.values():
325 if value[1].__contains__('建设'):
326 results = extract_jianshe_info(ocr_results)
327 break
328 elif value[1].__contains__('民生'):
329 results = extract_minsheng_info(ocr_results)
330 break
331 elif value[1].__contains__('农业'):
332 results = extract_nongye_info(ocr_results)
333 break
334 elif value[1].__contains__('中国银行'):
335 results = extract_zhongguo_info(ocr_results)
336 break
337 if len(results) == 0:
338 results = extract_gongshang_info(ocr_results)
339
340 return results
341
342
343 if __name__ == '__main__':
344 path = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/minsheng/authentic/images/val'
345 save_path='/data/situ_invoice_bill_data/new_data/results'
346 bank='minsheng'
347 if not os.path.exists(os.path.join(save_path,bank)):
348 os.makedirs(os.path.join(save_path,bank))
349 save_path=os.path.join(save_path,bank)
350 for j in tqdm.tqdm(os.listdir(path)):
351 # if True:
352 img=cv2.imread(os.path.join(path,j))
353 # img = cv2.imread('/data/situ_invoice_bill_data/new_data/results/nongye/6/_1597382769.6449914page_23_img_0.jpg')
354 st = time.time()
355 ocr_result = bill_ocr(img)
356 et1 = time.time()
357 result = extract_bank_info(ocr_result)
358 et2 = time.time()
359 for i in range(len(result)):
360 cv2.rectangle(img, (result[i][1][0], result[i][1][1]), (result[i][1][4], result[i][1][5]), (0, 0, 255), 2)
361 # cv2.imshow('img',img)
362 # cv2.waitKey(0)
363 cv2.imwrite(os.path.join(save_path,j),img)
364 print('spend:{} ocr:{} extract:{}'.format(et2 - st, et1 - st, et2 - et1))
...@@ -576,8 +576,8 @@ def run( ...@@ -576,8 +576,8 @@ def run(
576 576
577 def parse_opt(): 577 def parse_opt():
578 parser = argparse.ArgumentParser() 578 parser = argparse.ArgumentParser()
579 parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path') 579 parser.add_argument('--data', type=str, default=ROOT / 'data/VOC.yaml', help='dataset.yaml path')
580 parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolov5s.pt', help='model.pt path(s)') 580 parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'runs/train/exp/weights/best.pt', help='model.pt path(s)')
581 parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') 581 parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
582 parser.add_argument('--batch-size', type=int, default=1, help='batch size') 582 parser.add_argument('--batch-size', type=int, default=1, help='batch size')
583 parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 583 parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
......
...@@ -95,7 +95,13 @@ class Yolov5: ...@@ -95,7 +95,13 @@ class Yolov5:
95 95
96 if __name__ == "__main__": 96 if __name__ == "__main__":
97 img = cv2.imread( 97 img = cv2.imread(
98 '/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/data/images/crop_img/_1594890230.8032346page_10_img_0_hname.jpg') 98 '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/data/images/img_1.png')
99 detector = Yolov5(config) 99 detector = Yolov5(config)
100 result = detector.detect(img) 100 result = detector.detect(img)
101 for i in result['result']:
102 position=list(i.values())[2:]
103 print(position)
104 cv2.rectangle(img,(position[0],position[1]),(position[0]+position[2],position[1]+position[3]),(0,0,255))
105 cv2.imshow('w',img)
106 cv2.waitKey(0)
101 print(result) 107 print(result)
......
1 from easydict import EasyDict as edict 1 from easydict import EasyDict as edict
2 2
3 config = edict( 3 config = edict(
4 # weights='/home/situ/qfs/invoice_tamper/09_project/project/yolov5_inference/runs/exp2/weights/best.pt', # model path or triton URL
4 weights='runs/train/exp/weights/best.pt', # model path or triton URL 5 weights='runs/train/exp/weights/best.pt', # model path or triton URL
5 data='data/VOC.yaml', # dataset.yaml path 6 data='data/VOC.yaml', # dataset.yaml path
6 imgsz=(640, 640), # inference size (height, width) 7 imgsz=(640, 640), # inference size (height, width)
7 conf_thres=0.5, # confidence threshold 8 conf_thres=0.2, # confidence threshold
8 iou_thres=0.45, # NMS IOU threshold 9 iou_thres=0.45, # NMS IOU threshold
9 max_det=1000, # maximum detections per image 10 max_det=1000, # maximum detections per image
10 device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu 11 device='' # cuda device, i.e. 0 or 0,1,2,3 or cpu
......
1 import time
2
3 import cv2
4
5 from bank_ocr_inference import bill_ocr, extract_bank_info
6 from inference import Yolov5
7 from models.yolov5_config import config
8
9
10 def enlarge_position(box):
11 x1, y1, x2, y2 = box
12 w, h = abs(x2 - x1), abs(y2 - y1)
13 y1, y2 = max(y1 - h // 3, 0), y2 + h // 3
14 x1, x2 = max(x1 - w // 8, 0), x2 + w // 8
15 return [x1, y1, x2, y2]
16
17
18 def tamper_detect(image):
19 st = time.time()
20 ocr_results = bill_ocr(image)
21 et1=time.time()
22 info_results = extract_bank_info(ocr_results)
23 et2=time.time()
24 print(info_results)
25 tamper_results = []
26 if len(info_results) != 0:
27 for info_result in info_results:
28 box = [info_result[1][0], info_result[1][1], info_result[1][4], info_result[1][5]]
29 x1, y1, x2, y2 = enlarge_position(box)
30 # x1, y1, x2, y2 = box
31 info_image = image[y1:y2, x1:x2, :]
32 cv2.imshow('info_image',info_image)
33 results = detector.detect(info_image)
34 print(results)
35 if len(results['result'])!=0:
36 for res in results['result']:
37 left = int(res['left'])
38 top = int(res['top'])
39 width = int(res['width'])
40 height = int(res['height'])
41 absolute_position = [x1 + left, y1 + top, x1 + left + width, y1 + top + height]
42 tamper_results.append(absolute_position)
43 print(tamper_results)
44 et3 = time.time()
45
46 print(f'all:{et3-st} ocr:{et1-st} extract:{et2-et1} yolo:{et3-et2}')
47 for i in tamper_results:
48 cv2.rectangle(image, tuple(i[:2]), tuple(i[2:]), (0, 0, 255), 2)
49 cv2.imshow('info', image)
50 cv2.waitKey(0)
51
52
53 if __name__ == '__main__':
54 detector = Yolov5(config)
55 image = cv2.imread(
56 "/home/situ/下载/_1597378020.731796page_33_img_0.jpg")
57 tamper_detect(image)
......
...@@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions): ...@@ -10,9 +10,9 @@ def get_source_image_det(crop_position, predict_positions):
10 result = [] 10 result = []
11 x1, y1, x2, y2 = crop_position 11 x1, y1, x2, y2 = crop_position
12 for p in predict_positions: 12 for p in predict_positions:
13 px1, py1, px2, py2,score = p 13 px1, py1, px2, py2, score = p
14 w, h = px2 - px1, py2 - py1 14 w, h = px2 - px1, py2 - py1
15 result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h,score]) 15 result.append([x1 + px1, y1 + py1, x1 + px1 + w, y1 + py1 + h, score])
16 return result 16 return result
17 17
18 18
...@@ -22,9 +22,9 @@ def decode_label(image, label_path): ...@@ -22,9 +22,9 @@ def decode_label(image, label_path):
22 result = [] 22 result = []
23 for d in data: 23 for d in data:
24 d = [float(i) for i in d.strip().split(' ')] 24 d = [float(i) for i in d.strip().split(' ')]
25 cls, cx, cy, cw, ch,score = d 25 cls, cx, cy, cw, ch, score = d
26 cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h 26 cx, cy, cw, ch = cx * w, cy * h, cw * w, ch * h
27 result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2),score]) 27 result.append([int(cx - cw // 2), int(cy - ch // 2), int(cx + cw // 2), int(cy + ch // 2), score])
28 return result 28 return result
29 29
30 30
...@@ -38,28 +38,28 @@ if __name__ == '__main__': ...@@ -38,28 +38,28 @@ if __name__ == '__main__':
38 data = pd.read_csv(crop_csv_path) 38 data = pd.read_csv(crop_csv_path)
39 img_name = data.loc[:, 'img_name'].tolist() 39 img_name = data.loc[:, 'img_name'].tolist()
40 crop_position1 = data.loc[:, 'name_crop_coord'].tolist() 40 crop_position1 = data.loc[:, 'name_crop_coord'].tolist()
41 crop_position2 = data.loc[:,'number_crop_coord'].tolist() 41 crop_position2 = data.loc[:, 'number_crop_coord'].tolist()
42 cc='/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3' 42 cc = '/data/situ_invoice_bill_data/new_data/qfs_bank_bill_data/gongshang/tampered/images/val/ps3'
43 for im in os.listdir(cc): 43 for im in os.listdir(cc):
44 print(im) 44 print(im)
45 img = cv2.imread(os.path.join(cc,im)) 45 img = cv2.imread(os.path.join(cc, im))
46 img_=img.copy() 46 img_ = img.copy()
47 id = img_name.index(im) 47 id = img_name.index(im)
48 name_crop_position=[int(i) for i in crop_position1[id].split(',')] 48 name_crop_position = [int(i) for i in crop_position1[id].split(',')]
49 number_crop_position=[int(i) for i in crop_position2[id].split(',')] 49 number_crop_position = [int(i) for i in crop_position2[id].split(',')]
50 nx1,ny1,nx2,ny2=name_crop_position 50 nx1, ny1, nx2, ny2 = name_crop_position
51 nux1,nuy1,nux2,nuy2=number_crop_position 51 nux1, nuy1, nux2, nuy2 = number_crop_position
52 if im[:-4]+'_hname.txt' in predict_labels: 52 if im[:-4] + '_hname.txt' in predict_labels:
53 53
54 h, w, c = img[ny1:ny2, nx1:nx2, :].shape 54 h, w, c = img[ny1:ny2, nx1:nx2, :].shape
55 data = open(os.path.join(predict_label_path,im[:-4]+'_hname.txt')).readlines() 55 data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines()
56 for d in data: 56 for d in data:
57 cls,cx,cy,cw,ch,score = [float(i) for i in d.strip().split(' ')] 57 cls, cx, cy, cw, ch, score = [float(i) for i in d.strip().split(' ')]
58 cx,cy,cw,ch=int(cx*w),int(cy*h),int(cw*w),int(ch*h) 58 cx, cy, cw, ch = int(cx * w), int(cy * h), int(cw * w), int(ch * h)
59 cx1,cy1=cx-cw//2,cy-ch//2 59 cx1, cy1 = cx - cw // 2, cy - ch // 2
60 x1,y1,x2,y2=nx1+cx1,ny1+cy1,nx1+cx1+cw,ny1+cy1+ch 60 x1, y1, x2, y2 = nx1 + cx1, ny1 + cy1, nx1 + cx1 + cw, ny1 + cy1 + ch
61 cv2.rectangle(img,(x1,y1),(x2,y2),(0,0,255),2) 61 cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
62 cv2.putText(img,f'tampered:{score}',(x1,y1-5),cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,255),1) 62 cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
63 if im[:-4] + '_hnumber.txt' in predict_labels: 63 if im[:-4] + '_hnumber.txt' in predict_labels:
64 h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape 64 h, w, c = img[nuy1:nuy2, nux1:nux2, :].shape
65 data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines() 65 data = open(os.path.join(predict_label_path, im[:-4] + '_hname.txt')).readlines()
...@@ -70,5 +70,5 @@ if __name__ == '__main__': ...@@ -70,5 +70,5 @@ if __name__ == '__main__':
70 x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch 70 x1, y1, x2, y2 = nux1 + cx1, nuy1 + cy1, nux1 + cx1 + cw, nuy1 + cy1 + ch
71 cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2) 71 cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255), 2)
72 cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) 72 cv2.putText(img, f'tampered:{score}', (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
73 result = np.vstack((img_,img)) 73 result = np.vstack((img_, img))
74 cv2.imwrite(f'z/{im}',result) 74 cv2.imwrite(f'z/{im}', result)
......
Styling with Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!