detect.py 5.05 KB
import glob
import shutil
import os
from config.config import load_config
import mmcv
from ensemble_boxes import *
import cv2


# 初始化文件夹
def init_dir(file_dir):
    if not os.path.exists(file_dir):
        os.mkdir(file_dir)


# 标准化输入文件格式
def init_input(cfg):
    image_files = cfg['image_files']
    if os.path.isdir(image_files):
        image_files = [os.path.join(image_files, image_file) for image_file in os.listdir(image_files)]
    elif os.path.exists(image_files):
        if os.path.splitext(image_files)[-1] in ['.jpg', '.png']:
            image_files = [image_files]
        elif os.path.splitext(image_files)[-1] in ['.txt']:
            image_files = [os.path.abspath(line.strip('\n').split(' ')[0]).replace('\\', '/') for line in open(image_files).readlines()]
    elif isinstance(image_files, list) and os.path.splitext(image_files[0])[-1] in ['jpg', 'png']:
        pass
    elif '*.jpg' in image_files or '*.png' in image_files:
        image_files = glob.glob(image_files)
    else:
        print('error input: ', image_files)
        return
    print(image_files)
    return image_files


# 对mmdet输出结果处理
def mmdet_out(out, iou=0.5):
    out_list = []
    for i, label_list in enumerate(out):
        for label in label_list:
            if float(label[4]) < iou:
                continue
            out_list.append([i, label[4], label[0], label[1], label[2], label[3]])
    return out_list


# box归一化
def box_normalize(box, size):
    box[0] = box[0] / size[0]
    box[1] = box[1] / size[1]
    box[2] = box[2] / size[0]
    box[3] = box[3] / size[1]
    for i, s in enumerate(box):
        if s > 1:
            box[i] = 1
        elif s < 0:
            box[i] = 0
    return box


# box反归一化
def box_re_std(box, size):
    box[0] = box[0] * size[0]
    box[1] = box[1] * size[1]
    box[2] = box[2] * size[0]
    box[3] = box[3] * size[1]
    return box.tolist()


# box融合
def boxes_fusion(cfg, boxes_list):
    rs = []
    if 'class_fusion' in cfg['type']:
        assert len(cfg['class_list']) == len(boxes_list)
        for i in range(len(boxes_list)):
            for box in boxes_list[i]:
                if box[0] in cfg['class_list']:
                    rs.append(box)

    if 'weighted_boxes_fusion' in cfg['type']:
        if rs:
            boxes_list.append(rs)
            cfg['weight_list'].append(1)

        scores_list = [[box[1] for box in boxes] for boxes in boxes_list]
        labels_list = [[int(box[0]) for box in boxes] for boxes in boxes_list]
        boxes_list = [[box_normalize(box[2:], cfg['size']) for box in boxes] for boxes in boxes_list]
        boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=cfg['weight_list'],
                                                      iou_thr=cfg['iou'], skip_box_thr=cfg['skip_box_thr'])
        for i, box in enumerate(boxes):
            if scores[i] > cfg['score']:
                rs.append([labels[i], scores[i]] + box_re_std(box, cfg['size']))
    return rs


# mm_detect
def mmdetect(models, fusion, img):
    if os.path.exists(models['class_txt']):
        shutil.copy(models['class_txt'], 'data/mmdet_classes.txt')

    from mmdet.apis import init_detector, inference_detector
    boxes_list = []
    if isinstance(img, str):
        img = mmcv.imread(img)
    fusion['size'] = img.shape[:-1][::-1]
    for i, config in enumerate(models['config_files']):
        model = init_detector(config, os.path.abspath( models['checkpoint_files'][i]), device=models['cuda'])
        out = inference_detector(model, img)
        out = mmdet_out(out)
        boxes_list.append(out)

    if fusion:
        boxes_list = boxes_fusion(fusion, boxes_list)
    print(boxes_list)
    return boxes_list


def run():
    cfg = load_config()  # 加载config文件

    # 初始化
    init_dir(cfg['out_dir'])
    if cfg['save_image']:
        image_dir = os.path.join(cfg['out_dir'], 'image')
        init_dir(image_dir)
    if cfg['save_txt']:
        txt_dir = os.path.join(cfg['out_dir'], 'label')
        init_dir(txt_dir)

    image_files = init_input(cfg['data'])

    # detect
    for image_file in image_files:
        image = cv2.imread(image_file)
        boxes = mmdetect(cfg['model'], cfg['fusion'], image_file)

        if cfg['save_txt']:
            out_txt = os.path.join(txt_dir, image_file.split('/')[-1].replace('.jpg', '.txt'))
            f = open(out_txt, 'w')
            for box in boxes:
                f.write(' {} {} {} {} {} {}'.format(box[0], box[1], box[2], box[3], box[4], box[5]))
            f.write('\n')

        for box in boxes:
            cv2.rectangle(image, (int(box[2]), int(box[3])), (int(box[4]), int(box[5])), (255, 0, 255))
            cv2.putText(image, '{} {:.2f}'.format(int(box[0]), box[1]), (int(box[2]), int(box[3]) + 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), thickness=2)

        if cfg['show']:
            cv2.imshow('image', image)

        if cfg['save_image']:
            image_path = os.path.join(image_dir, image_file.split('/')[-1])
            cv2.imwrite(image_path, image)


if __name__ == '__main__':
    run()