get_pr.py 4.66 KB
import os

import cv2
import numpy as np
from sklearn.metrics import precision_score, recall_score, confusion_matrix


def iou(box, boxes):
    x1, y1, x2, y2 = box
    x1s, y1s, x2s, y2s = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    area1 = abs(x2 - x1) * abs(y2 - y1)
    areas = (x2s - x1s) * (y2s - y1s)
    xx1 = np.maximum(x1, x1s)
    yy1 = np.maximum(y1, y1s)
    xx2 = np.minimum(x2, x2s)
    yy2 = np.minimum(y2, y2s)
    inner = np.maximum(0, (xx2 - xx1) * (yy2 - yy1))
    return inner / (area1 + areas - inner)


def get_evaluate_score(true_image_path, true_label_path, predict_label_path, threshold):
    true_labels = os.listdir(true_label_path)
    predict_labels = os.listdir(predict_label_path)
    targets, predicts = [], []
    for label in true_labels:
        true_label = open(os.path.join(true_label_path, label)).readlines()
        img = cv2.imread(os.path.join(true_image_path, label.replace('.txt', '.jpg')))
        h, w, c = img.shape
        if len(true_label) == 0:
            targets.append(0)
            if label in predict_labels:
                predicts.append(1)
            else:
                predicts.append(0)

        else:
            targets.append(1)
            if label not in predict_labels:
                predicts.append(0)
            else:
                tmp = 0
                predict_label = open(os.path.join(predict_label_path, label)).readlines()
                boxes = []
                for pl in predict_label:
                    cls, x1, y1, w1, h1 = [float(i) for i in pl.strip().split(' ')]
                    x1, y1, w1, h1 = int(x1 * w), int(y1 * h), int(w1 * w), int(h1 * h)
                    xx1, yy1, xx2, yy2 = x1 - w1 // 2, y1 - h1 // 2, x1 + w1 // 2, y1 + h1 // 2
                    boxes.append([xx1, yy1, xx2, yy2])
                for tl in true_label:
                    cls, x1, y1, w1, h1 = [float(i) for i in tl.strip().split(' ')]
                    x1, y1, w1, h1 = int(x1 * w), int(y1 * h), int(w1 * w), int(h1 * h)
                    xx1, yy1, xx2, yy2 = x1 - w1 // 2, y1 - h1 // 2, x1 + w1 // 2, y1 + h1 // 2
                    box1 = [xx1, yy1, xx2, yy2]
                    inner_score = iou(np.array(box1), np.array(boxes))
                    if max(inner_score) > threshold:
                        tmp = 1
                        predicts.append(1)
                        break
                if tmp == 0:
                    predicts.append(0)
    p = precision_score(targets, predicts)
    r = recall_score(targets, predicts)
    conf = confusion_matrix(targets, predicts)
    bg_mask = np.ones((500, 500, 3)) * 255
    cv2.putText(bg_mask, f'              authentic     tampered ', (20, 50), cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
    cv2.putText(bg_mask, f'authentic      {conf[0, 0]}           {conf[0, 1]}', (20, 80), cv2.FONT_ITALIC, 0.7,
                (0, 0, 255), 1)
    cv2.putText(bg_mask, f'tempered       {conf[1, 0]}           {conf[1, 1]}', (20, 110), cv2.FONT_ITALIC, 0.7,
                (0, 0, 255), 1)
    cv2.putText(bg_mask, f'authentic  precision:{round(conf[0, 0] / (conf[0, 0] + conf[1, 0]), 3)}', (20, 170),
                cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
    cv2.putText(bg_mask, f'           recall:{round(conf[0, 0] / (conf[0, 0] + conf[0, 1]), 3)}', (20, 200),
                cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
    cv2.putText(bg_mask, f'tampered  precision:{round(conf[1, 1] / (conf[0, 1] + conf[1, 1]), 3)}', (20, 230),
                cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
    cv2.putText(bg_mask, f'           recall:{round(conf[1, 1] / (conf[1, 0] + conf[1, 1]), 3)}', (20, 260),
                cv2.FONT_ITALIC, 0.7, (0, 0, 255), 1)
    cv2.imwrite(f'pr_result.jpg', bg_mask)
    print('precison:', p)
    print('recall:', r)
    print(conf)
    print(f'                      预    测        ')
    print(f'               authentic     tampered ')
    print(f'真  authentic \t\t{conf[0, 0]}  \t\t{conf[0, 1]}')
    print(f'实  tempered  \t\t{conf[1, 0]}  \t\t\t{conf[1, 1]}')
    print(
        f'authentic precision:{conf[0, 0] / (conf[0, 0] + conf[1, 0])}\trecall:{conf[0, 0] / (conf[0, 0] + conf[0, 1])}')
    print(
        f'tampered  precision:{conf[1, 1] / (conf[0, 1] + conf[1, 1])}\trecall:{conf[1, 1] / (conf[1, 0] + conf[1, 1])}')


if __name__ == '__main__':
    true_image_path = '/data/situ_invoice_bill_data/qfs_train_val_data/test_data/only_human_ps/all/images'
    true_label_path = '/data/situ_invoice_bill_data/qfs_train_val_data/test_data/only_human_ps/all/labels'
    predict_label_path = '/home/situ/qfs/invoice_tamper/09_project/project/tamper_det/runs/detect/exp2/labels'
    threshold = 0.1
    get_evaluate_score(true_image_path, true_label_path, predict_label_path, threshold)