class_filter.py 1.08 KB
import os
import pickle
import numpy as np


def start_filter(config):

    cls_class_path = config['MODEL']['CLS_CLASS']
    feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR']
    frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
    result_file_name = config['CLASS']['RESULT_FILE']
    feature_name = config['CLASS']['DATA_NAME']

    xgboost_model = pickle.load(open(cls_class_path, "rb"))

    result_file_path = os.path.join(frame_list_dir, result_file_name)
    result_file = open(result_file_path, 'w')

    feature_path = os.path.join(feature_save_dir, feature_name)
    val_annotation_pairs = np.load(feature_path, allow_pickle=True)

    X_val = []
    Y_val = []
    Y_names = []
    for j in range(len(val_annotation_pairs)):
        pair = val_annotation_pairs[j]
        X_val.append(pair[0])
        Y_val.append(pair[1])
        Y_names.append(pair[2])

    X_val = np.array(X_val)
    y_pred = xgboost_model.predict(X_val)

    for i, Y_name in enumerate(Y_names):
        result_file.write(Y_name + ' ')
        result_file.write(str(y_pred[i]) + '\n')

    result_file.close()