audio_filter.py 2.16 KB
import os
import csv
import pickle
import numpy as np
from sklearn.externals import joblib


def start_filter(config):
    cls_audio_path = config['MODEL']['CLS_AUDIO']
    feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR']
    frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
    result_file_name = config['AUDIO']['RESULT_FILE']
    feature_name = config['AUDIO']['DATA_NAME']

    svm_clf = joblib.load(cls_audio_path)

    result_file_path = os.path.join(frame_list_dir, result_file_name)
    result_file = open(result_file_path, 'w')

    feature_path = os.path.join(feature_save_dir, feature_name)
    val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1')

    for pair in val_annotation_pairs:

        v = pair[0]
        n = pair[2]
        
        feature_np = np.reshape(v, (1, -1))
        res = svm_clf.predict_proba(feature_np)
        proba = np.squeeze(res)

        # class_pre = svm_clf.predict(feature_np)

        result_file.write(str(pair[2])[:-4] + ' ')
        result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n')

    result_file.close()
    




def start_filter_xgboost(config):
    cls_class_path = config['MODEL']['CLS_AUDIO']
    feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR']
    frame_list_dir = config['VIDEO']['FRAME_LIST_DIR']
    result_file_name = config['AUDIO']['RESULT_FILE']
    feature_name = config['AUDIO']['DATA_NAME']

    xgboost_model = pickle.load(open(cls_class_path, "rb"))

    result_file_path = os.path.join(frame_list_dir, result_file_name)
    result_file = open(result_file_path, 'w')

    feature_path = os.path.join(feature_save_dir, feature_name)
    val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1')

    X_val = []
    Y_names = []
    for pair in val_annotation_pairs:
        n, v = pair.items()
        X_val.append(v)
        Y_names.append(n)

    X_val = np.array(X_val)
    y_pred = xgboost_model.predict_proba(X_val)

    for i, Y_name in enumerate(Y_names):
        result_file.write(Y_name + ' ')
        result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n')

    result_file.close()