get_pr.py
4.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import cv2
import numpy as np
from sklearn.metrics import precision_score, recall_score, confusion_matrix
def iou(box, boxes):
x1, y1, x2, y2 = box
x1s, y1s, x2s, y2s = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
area1 = abs(x2 - x1) * abs(y2 - y1)
areas = (x2s - x1s) * (y2s - y1s)
xx1 = np.maximum(x1, x1s)
yy1 = np.maximum(y1, y1s)
xx2 = np.minimum(x2, x2s)
yy2 = np.minimum(y2, y2s)
inner = np.maximum(0, (xx2 - xx1) * (yy2 - yy1))
return inner / (area1 + areas - inner)
def get_evaluate_score(true_image_path, true_label_path, predict_label_path, threshold):
true_labels = os.listdir(true_label_path)
predict_labels = os.listdir(predict_label_path)
targets, predicts = [], []
for label in true_labels:
true_label = open(os.path.join(true_label_path, label)).readlines()
img = cv2.imread(os.path.join(true_image_path, label.replace('.txt', '.jpg')))
h, w, c = img.shape
if len(true_label) == 0:
targets.append(0)
if label in predict_labels:
predicts.append(1)
else:
predicts.append(0)
else:
targets.append(1)
if label not in predict_labels:
predicts.append(0)
else:
tmp = 0
predict_label = open(os.path.join(predict_label_path, label)).readlines()
boxes = []
for pl in predict_label:
cls, x1, y1, w1, h1 = [float(i) for i in pl.strip().split(' ')]
x1, y1, w1, h1 = int(x1 * w), int(y1 * h), int(w1 * w), int(h1 * h)
xx1, yy1, xx2, yy2 = x1 - w1 // 2, y1 - h1 // 2, x1 + w1 // 2, y1 + h1 // 2
boxes.append([xx1, yy1, xx2, yy2])
for tl in true_label:
cls, x1, y1, w1, h1 = [float(i) for i in tl.strip().split(' ')]
x1, y1, w1, h1 = int(x1 * w), int(y1 * h), int(w1 * w), int(h1 * h)
xx1, yy1, xx2, yy2 = x1 - w1 // 2, y1 - h1 // 2, x1 + w1 // 2, y1 + h1 // 2
box1 = [xx1, yy1, xx2, yy2]
inner_score = iou(np.array(box1), np.array(boxes))
if max(inner_score) > threshold:
tmp = 1
predicts.append(1)
break
if tmp == 0:
predicts.append(0)
p = precision_score(targets, predicts)
r = recall_score(targets, predicts)
conf = confusion_matrix(targets, predicts)
bg_mask = np.ones((500, 500, 3)) * 255
cv2.putText(bg_mask, f' authentic tampered ', (20, 50), cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
cv2.putText(bg_mask, f'authentic {conf[0, 0]} {conf[0, 1]}', (20, 80), cv2.FONT_ITALIC, 0.7,
(0, 0, 255), 1)
cv2.putText(bg_mask, f'tempered {conf[1, 0]} {conf[1, 1]}', (20, 110), cv2.FONT_ITALIC, 0.7,
(0, 0, 255), 1)
cv2.putText(bg_mask, f'authentic precision:{round(conf[0, 0] / (conf[0, 0] + conf[1, 0]), 3)}', (20, 170),
cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
cv2.putText(bg_mask, f' recall:{round(conf[0, 0] / (conf[0, 0] + conf[0, 1]), 3)}', (20, 200),
cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
cv2.putText(bg_mask, f'tampered precision:{round(conf[1, 1] / (conf[0, 1] + conf[1, 1]), 3)}', (20, 230),
cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
cv2.putText(bg_mask, f' recall:{round(conf[1, 1] / (conf[1, 0] + conf[1, 1]), 3)}', (20, 260),
cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
cv2.imwrite(f'pr_result.jpg', bg_mask)
print('precison:', p)
print('recall:', r)
print(conf)
print(f' 预 测 ')
print(f' authentic tampered ')
print(f'真 authentic \t\t{conf[0, 0]} \t\t{conf[0, 1]}')
print(f'实 tempered \t\t{conf[1, 0]} \t\t\t{conf[1, 1]}')
print(
f'authentic precision:{conf[0, 0] / (conf[0, 0] + conf[1, 0])}\trecall:{conf[0, 0] / (conf[0, 0] + conf[0, 1])}')
print(
f'tampered precision:{conf[1, 1] / (conf[0, 1] + conf[1, 1])}\trecall:{conf[1, 1] / (conf[1, 0] + conf[1, 1])}')
if __name__ == '__main__':
true_image_path = '/data/situ_invoice_bill_data/qfs_train_val_data/test_data/only_human_ps/all/images'
true_label_path = '/data/situ_invoice_bill_data/qfs_train_val_data/test_data/only_human_ps/all/labels'
predict_label_path = '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/runs/detect/exp2/labels'
threshold = 0.1
get_evaluate_score(true_image_path, true_label_path, predict_label_path, threshold)