detect.py
5.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import glob
import shutil
import os
from config.config import load_config
import mmcv
from ensemble_boxes import *
import cv2
# 初始化文件夹
def init_dir(file_dir):
if not os.path.exists(file_dir):
os.mkdir(file_dir)
# 标准化输入文件格式
def init_input(cfg):
image_files = cfg['image_files']
if os.path.isdir(image_files):
image_files = [os.path.join(image_files, image_file) for image_file in os.listdir(image_files)]
elif os.path.exists(image_files):
if os.path.splitext(image_files)[-1] in ['.jpg', '.png']:
image_files = [image_files]
elif os.path.splitext(image_files)[-1] in ['.txt']:
image_files = [os.path.abspath(line.strip('\n').split(' ')[0]).replace('\\', '/') for line in open(image_files).readlines()]
elif isinstance(image_files, list) and os.path.splitext(image_files[0])[-1] in ['jpg', 'png']:
pass
elif '*.jpg' in image_files or '*.png' in image_files:
image_files = glob.glob(image_files)
else:
print('error input: ', image_files)
return
print(image_files)
return image_files
# 对mmdet输出结果处理
def mmdet_out(out, iou=0.5):
out_list = []
for i, label_list in enumerate(out):
for label in label_list:
if float(label[4]) < iou:
continue
out_list.append([i, label[4], label[0], label[1], label[2], label[3]])
return out_list
# box归一化
def box_normalize(box, size):
box[0] = box[0] / size[0]
box[1] = box[1] / size[1]
box[2] = box[2] / size[0]
box[3] = box[3] / size[1]
for i, s in enumerate(box):
if s > 1:
box[i] = 1
elif s < 0:
box[i] = 0
return box
# box反归一化
def box_re_std(box, size):
box[0] = box[0] * size[0]
box[1] = box[1] * size[1]
box[2] = box[2] * size[0]
box[3] = box[3] * size[1]
return box.tolist()
# box融合
def boxes_fusion(cfg, boxes_list):
rs = []
if 'class_fusion' in cfg['type']:
assert len(cfg['class_list']) == len(boxes_list)
for i in range(len(boxes_list)):
for box in boxes_list[i]:
if box[0] in cfg['class_list']:
rs.append(box)
if 'weighted_boxes_fusion' in cfg['type']:
if rs:
boxes_list.append(rs)
cfg['weight_list'].append(1)
scores_list = [[box[1] for box in boxes] for boxes in boxes_list]
labels_list = [[int(box[0]) for box in boxes] for boxes in boxes_list]
boxes_list = [[box_normalize(box[2:], cfg['size']) for box in boxes] for boxes in boxes_list]
boxes, scores, labels = weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=cfg['weight_list'],
iou_thr=cfg['iou'], skip_box_thr=cfg['skip_box_thr'])
for i, box in enumerate(boxes):
if scores[i] > cfg['score']:
rs.append([labels[i], scores[i]] + box_re_std(box, cfg['size']))
return rs
# mm_detect
def mmdetect(models, fusion, img):
if os.path.exists(models['class_txt']):
shutil.copy(models['class_txt'], 'data/mmdet_classes.txt')
from mmdet.apis import init_detector, inference_detector
boxes_list = []
if isinstance(img, str):
img = mmcv.imread(img)
fusion['size'] = img.shape[:-1][::-1]
for i, config in enumerate(models['config_files']):
model = init_detector(config, os.path.abspath( models['checkpoint_files'][i]), device=models['cuda'])
out = inference_detector(model, img)
out = mmdet_out(out)
boxes_list.append(out)
if fusion:
boxes_list = boxes_fusion(fusion, boxes_list)
print(boxes_list)
return boxes_list
def run():
cfg = load_config() # 加载config文件
# 初始化
init_dir(cfg['out_dir'])
if cfg['save_image']:
image_dir = os.path.join(cfg['out_dir'], 'image')
init_dir(image_dir)
if cfg['save_txt']:
txt_dir = os.path.join(cfg['out_dir'], 'label')
init_dir(txt_dir)
image_files = init_input(cfg['data'])
# detect
for image_file in image_files:
image = cv2.imread(image_file)
boxes = mmdetect(cfg['model'], cfg['fusion'], image_file)
if cfg['save_txt']:
out_txt = os.path.join(txt_dir, image_file.split('/')[-1].replace('.jpg', '.txt'))
f = open(out_txt, 'w')
for box in boxes:
f.write(' {} {} {} {} {} {}'.format(box[0], box[1], box[2], box[3], box[4], box[5]))
f.write('\n')
for box in boxes:
cv2.rectangle(image, (int(box[2]), int(box[3])), (int(box[4]), int(box[5])), (255, 0, 255))
cv2.putText(image, '{} {:.2f}'.format(int(box[0]), box[1]), (int(box[2]), int(box[3]) + 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 255), thickness=2)
if cfg['show']:
cv2.imshow('image', image)
if cfg['save_image']:
image_path = os.path.join(image_dir, image_file.split('/')[-1])
cv2.imwrite(image_path, image)
if __name__ == '__main__':
run()