class_filter.py
1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import pickle
import numpy as np
def start_filter(config):
cls_class_path = config['MODEL']['CLS_CLASS']
feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR']
frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
result_file_name = config['CLASS']['RESULT_FILE']
feature_name = config['CLASS']['DATA_NAME']
xgboost_model = pickle.load(open(cls_class_path, "rb"))
result_file_path = os.path.join(frame_list_dir, result_file_name)
result_file = open(result_file_path, 'w')
feature_path = os.path.join(feature_save_dir, feature_name)
val_annotation_pairs = np.load(feature_path, allow_pickle=True)
X_val = []
Y_val = []
Y_names = []
for j in range(len(val_annotation_pairs)):
pair = val_annotation_pairs[j]
X_val.append(pair[0])
Y_val.append(pair[1])
Y_names.append(pair[2])
X_val = np.array(X_val)
y_pred = xgboost_model.predict(X_val)
for i, Y_name in enumerate(Y_names):
result_file.write(Y_name + ' ')
result_file.write(str(y_pred[i]) + '\n')
result_file.close()