first commit
Showing
46 changed files
with
2001 additions
and
0 deletions
__init__.py
0 → 100644
File mode changed
__pycache__/argue_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/audio_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/bg_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/class_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_2_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/flow_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/load_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/media_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/meeting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/person_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/pose_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/troops_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_filter.cpython-36.pyc
0 → 100644
No preview for this file type
audio_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import csv | ||
| 3 | import pickle | ||
| 4 | import numpy as np | ||
| 5 | from sklearn.externals import joblib | ||
| 6 | |||
| 7 | |||
| 8 | def start_filter(config): | ||
| 9 | cls_audio_path = config['MODEL']['CLS_AUDIO'] | ||
| 10 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
| 11 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 12 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 13 | feature_name = config['AUDIO']['DATA_NAME'] | ||
| 14 | |||
| 15 | svm_clf = joblib.load(cls_audio_path) | ||
| 16 | |||
| 17 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 18 | result_file = open(result_file_path, 'w') | ||
| 19 | |||
| 20 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 21 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 22 | |||
| 23 | for pair in val_annotation_pairs: | ||
| 24 | |||
| 25 | v = pair[0] | ||
| 26 | n = pair[2] | ||
| 27 | |||
| 28 | feature_np = np.reshape(v, (1, -1)) | ||
| 29 | res = svm_clf.predict_proba(feature_np) | ||
| 30 | proba = np.squeeze(res) | ||
| 31 | |||
| 32 | # class_pre = svm_clf.predict(feature_np) | ||
| 33 | |||
| 34 | result_file.write(str(pair[2])[:-4] + ' ') | ||
| 35 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 | |||
| 43 | def start_filter_xgboost(config): | ||
| 44 | cls_class_path = config['MODEL']['CLS_AUDIO'] | ||
| 45 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
| 46 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 47 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 48 | feature_name = config['AUDIO']['DATA_NAME'] | ||
| 49 | |||
| 50 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 51 | |||
| 52 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 53 | result_file = open(result_file_path, 'w') | ||
| 54 | |||
| 55 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 56 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 57 | |||
| 58 | X_val = [] | ||
| 59 | Y_names = [] | ||
| 60 | for pair in val_annotation_pairs: | ||
| 61 | n, v = pair.items() | ||
| 62 | X_val.append(v) | ||
| 63 | Y_names.append(n) | ||
| 64 | |||
| 65 | X_val = np.array(X_val) | ||
| 66 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 67 | |||
| 68 | for i, Y_name in enumerate(Y_names): | ||
| 69 | result_file.write(Y_name + ' ') | ||
| 70 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 71 | |||
| 72 | result_file.close() | ||
| 73 |
bg_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | import pickle | ||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | cls_class_path = config['MODEL']['CLS_BG'] | ||
| 8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 10 | result_file_name = config['BG']['RESULT_FILE'] | ||
| 11 | feature_name = config['BG']['DATA_NAME'] | ||
| 12 | |||
| 13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 14 | |||
| 15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 16 | result_file = open(result_file_path, 'w') | ||
| 17 | |||
| 18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
| 20 | |||
| 21 | X_val = [] | ||
| 22 | Y_val = [] | ||
| 23 | Y_names = [] | ||
| 24 | for j in range(len(val_annotation_pairs)): | ||
| 25 | pair = val_annotation_pairs[j] | ||
| 26 | X_val.append(np.squeeze(pair[0])) | ||
| 27 | Y_val.append(pair[1]) | ||
| 28 | Y_names.append(pair[2]) | ||
| 29 | |||
| 30 | X_val = np.array(X_val) | ||
| 31 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 32 | |||
| 33 | for i, Y_name in enumerate(Y_names): | ||
| 34 | result_file.write(Y_name + ' ') | ||
| 35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 |
class_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import pickle | ||
| 3 | import numpy as np | ||
| 4 | |||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | |||
| 8 | cls_class_path = config['MODEL']['CLS_CLASS'] | ||
| 9 | feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR'] | ||
| 10 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 11 | result_file_name = config['CLASS']['RESULT_FILE'] | ||
| 12 | feature_name = config['CLASS']['DATA_NAME'] | ||
| 13 | |||
| 14 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 15 | |||
| 16 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 17 | result_file = open(result_file_path, 'w') | ||
| 18 | |||
| 19 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 20 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
| 21 | |||
| 22 | X_val = [] | ||
| 23 | Y_val = [] | ||
| 24 | Y_names = [] | ||
| 25 | for j in range(len(val_annotation_pairs)): | ||
| 26 | pair = val_annotation_pairs[j] | ||
| 27 | X_val.append(pair[0]) | ||
| 28 | Y_val.append(pair[1]) | ||
| 29 | Y_names.append(pair[2]) | ||
| 30 | |||
| 31 | X_val = np.array(X_val) | ||
| 32 | y_pred = xgboost_model.predict(X_val) | ||
| 33 | |||
| 34 | for i, Y_name in enumerate(Y_names): | ||
| 35 | result_file.write(Y_name + ' ') | ||
| 36 | result_file.write(str(y_pred[i]) + '\n') | ||
| 37 | |||
| 38 | result_file.close() |
config.yaml
0 → 100644
| 1 | MODEL: | ||
| 2 | CLS_FIGHTING_2: '/home/jwq/models/cls_fighting_2/cls_fighting_2_v0.0.1.pth' | ||
| 3 | CLS_EMOTION: '/home/jwq/models/cls_emotion/v0.1.0.m' | ||
| 4 | FEATURE_EMOTION: '/home/jwq/models/feature_emotion/FerPlus3.h5' | ||
| 5 | CLS_AUDIO: '/home/jwq/models/cls_audio/v0.0.1.m' | ||
| 6 | CLS_CLASS: '/home/jwq/models/cls_class/v_0.0.1_xgb.pkl' | ||
| 7 | CLS_VIDEO: '/home/jwq/models/cls_video/v0.4.1.pth' | ||
| 8 | CLS_POSE: '/home/jwq/models/cls_pose/v0.0.1.pth' | ||
| 9 | CLS_FLOW: '/home/jwq/models/cls_flow/v0.1.1.pth' | ||
| 10 | CLS_BG: '/home/jwq/models/cls_bg/v0.1.1.pkl' | ||
| 11 | CLS_PERSON: '/home/jwq/models/cls_person/v0.1.1.pkl' | ||
| 12 | |||
| 13 | THRESHOLD: | ||
| 14 | FACES_THRESHOLD: 0.6 | ||
| 15 | |||
| 16 | FILTER: | ||
| 17 | |||
| 18 | |||
| 19 | VIDEO: | ||
| 20 | VIDEO_DIR: '/home/jwq/Desktop/VGAF_EmotiW/Val' | ||
| 21 | LABEL_PATH: '/home/jwq/Desktop/VGAF_EmotiW/Val_labels.txt' | ||
| 22 | VIDEO_SAVE_DIR: '/home/jwq/Desktop/tmp/video' | ||
| 23 | AUDIO_SAVE_DIR: '/home/jwq/npys/' | ||
| 24 | FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/frame' | ||
| 25 | # FRAME_SAVE_DIR: '/home/jwq/Desktop/VGAF_EmotiW_class/train_frame' | ||
| 26 | FLOW_SAVE_DIR: '/home/jwq/Desktop/tmp/flow' | ||
| 27 | POSE_FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/pose_frame' | ||
| 28 | FRAME_LIST_DIR: '/home/jwq/Desktop/tmp/file_list' | ||
| 29 | IS10_FEATURE_NP_DIR: '/home/jwq/npys' | ||
| 30 | IS10_FEATURE_CSV_DIR: '/home/jwq/Desktop/tmp/is10' | ||
| 31 | # FACE_FEATURE_DIR: '/home/jwq/Desktop/tmp/face_feature_retina' | ||
| 32 | # FACE_FEATURE_DIR: '/data2/retinaface/random_face_frame_features/' | ||
| 33 | FACE_FEATURE_DIR: '/data1/segment/' | ||
| 34 | # FACE_FEATURE_DIR: '/home/jwq/npys/' | ||
| 35 | FACE_IMAGE_DIR: '/data2/retinaface/train/' | ||
| 36 | CLASS_FEATURE_DIR: '/home/jwq/Desktop/tmp/class' | ||
| 37 | PREFIX: 'img_{:05d}.jpg' | ||
| 38 | FLOW_PREFIX: 'flow_{}_{:05d}.jpg' | ||
| 39 | THREAD_NUM: 10 | ||
| 40 | FPS: 5 | ||
| 41 | |||
| 42 | VIDEO_FILTER: | ||
| 43 | TEST_SEGMENT: 8 | ||
| 44 | TEST_CROP: 1 | ||
| 45 | BATCH_SIZE: 1 | ||
| 46 | INPUT_SIZE: 224 | ||
| 47 | MODALITY: 'RGB' | ||
| 48 | ARCH: 'resnet50' | ||
| 49 | RESULT_FILE: 'video_filter.txt' | ||
| 50 | |||
| 51 | VIDEO_1_FILTER: | ||
| 52 | TEST_SEGMENT: 8 | ||
| 53 | TEST_CROP: 1 | ||
| 54 | BATCH_SIZE: 1 | ||
| 55 | INPUT_SIZE: 224 | ||
| 56 | MODALITY: 'RGB' | ||
| 57 | ARCH: 'resnet34' | ||
| 58 | RESULT_FILE: 'video_1_filter.txt' | ||
| 59 | |||
| 60 | EMOTION: | ||
| 61 | INTERVAL: 1 | ||
| 62 | INPUT_SIZE: 224 | ||
| 63 | RESULT_FILE: 'emotion_filter.txt' | ||
| 64 | |||
| 65 | EMOTION_1: | ||
| 66 | RESULT_FILE: 'emotion_1_filter.txt' | ||
| 67 | DATA_NAME: 'val.npy' | ||
| 68 | |||
| 69 | ARGUE: | ||
| 70 | DIMENSION: 1582 | ||
| 71 | RESULT_FILE: 'argue_filter.txt' | ||
| 72 | |||
| 73 | FIGHTING: | ||
| 74 | TEST_SEGMENT: 8 | ||
| 75 | TEST_CROP: 1 | ||
| 76 | BATCH_SIZE: 1 | ||
| 77 | INPUT_SIZE: 224 | ||
| 78 | MODALITY: 'RGB' | ||
| 79 | ARCH: 'resnet50' | ||
| 80 | RESULT_FILE: 'fighting_filter.txt' | ||
| 81 | |||
| 82 | FIGHTING_2: | ||
| 83 | TEST_SEGMENT: 8 | ||
| 84 | TEST_CROP: 1 | ||
| 85 | BATCH_SIZE: 1 | ||
| 86 | INPUT_SIZE: 224 | ||
| 87 | MODALITY: 'RGB' | ||
| 88 | ARCH: 'resnet50' | ||
| 89 | RESULT_FILE: 'fighting_2_filter.txt' | ||
| 90 | |||
| 91 | MEETING: | ||
| 92 | TEST_SEGMENT: 8 | ||
| 93 | TEST_CROP: 1 | ||
| 94 | BATCH_SIZE: 1 | ||
| 95 | INPUT_SIZE: 224 | ||
| 96 | MODALITY: 'RGB' | ||
| 97 | ARCH: 'resnet50' | ||
| 98 | RESULT_FILE: 'meeting_filter.txt' | ||
| 99 | |||
| 100 | TROOPS: | ||
| 101 | TEST_SEGMENT: 8 | ||
| 102 | TEST_CROP: 1 | ||
| 103 | BATCH_SIZE: 1 | ||
| 104 | INPUT_SIZE: 224 | ||
| 105 | MODALITY: 'RGB' | ||
| 106 | ARCH: 'resnet50' | ||
| 107 | RESULT_FILE: 'troops_filter.txt' | ||
| 108 | |||
| 109 | FLOW: | ||
| 110 | TEST_SEGMENT: 8 | ||
| 111 | TEST_CROP: 1 | ||
| 112 | BATCH_SIZE: 1 | ||
| 113 | INPUT_SIZE: 224 | ||
| 114 | MODALITY: 'Flow' | ||
| 115 | ARCH: 'resnet50' | ||
| 116 | RESULT_FILE: 'flow_filter.txt' | ||
| 117 | |||
| 118 | |||
| 119 | FINAL: | ||
| 120 | RESULT_FILE: 'final.txt' | ||
| 121 | ERROR_FILE: 'error.txt' | ||
| 122 | SIM_FILE: 'image_sim.txt' | ||
| 123 | |||
| 124 | AUDIO: | ||
| 125 | RESULT_FILE: 'audio.txt' | ||
| 126 | OPENSMILE_DIR: '/home/jwq/Downloads/opensmile-2.3.0' | ||
| 127 | DATA_NAME: 'val.npy' | ||
| 128 | |||
| 129 | CLASS: | ||
| 130 | RESULT_FILE: 'class.txt' | ||
| 131 | DATA_NAME: 'val _reannotation.npy' | ||
| 132 | |||
| 133 | POSE: | ||
| 134 | TEST_SEGMENT: 8 | ||
| 135 | TEST_CROP: 1 | ||
| 136 | BATCH_SIZE: 1 | ||
| 137 | INPUT_SIZE: 224 | ||
| 138 | MODALITY: 'RGB' | ||
| 139 | ARCH: 'resnet50' | ||
| 140 | RESULT_FILE: 'pose_filter.txt' | ||
| 141 | |||
| 142 | BG: | ||
| 143 | RESULT_FILE: 'bg_filter.txt' | ||
| 144 | DATA_NAME: 'bg_val_feature.npy' | ||
| 145 | |||
| 146 | PERSON: | ||
| 147 | RESULT_FILE: 'person_filter.txt' | ||
| 148 | DATA_NAME: 'person_val_feature.npy' | ||
| 149 | |||
| 150 |
emotion_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | from keras.models import Model | ||
| 5 | from keras.models import load_model | ||
| 6 | from sklearn.externals import joblib | ||
| 7 | from tensorflow.keras.preprocessing.image import img_to_array | ||
| 8 | |||
| 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' | ||
| 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | ||
| 11 | |||
| 12 | |||
| 13 | class FeatureExtractor(object): | ||
| 14 | def __init__(self, input_size=224, out_put_layer='avg_pool', model_path='FerPlus3.h5'): | ||
| 15 | self.model = load_model(model_path) | ||
| 16 | self.input_size = input_size | ||
| 17 | self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output) | ||
| 18 | |||
| 19 | def inference(self, image): | ||
| 20 | image = cv2.resize(image, (self.input_size, self.input_size)) | ||
| 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
| 22 | image = image.astype("float") / 255.0 | ||
| 23 | image = img_to_array(image) | ||
| 24 | image = np.expand_dims(image, axis=0) | ||
| 25 | feature = self.model_inter.predict(image)[0] | ||
| 26 | return feature | ||
| 27 | |||
| 28 | |||
| 29 | def features2feature(pics_features): | ||
| 30 | |||
| 31 | pics_features = np.array(pics_features) | ||
| 32 | fea_mean = pics_features.mean(axis=0) | ||
| 33 | fea_max = np.amax(pics_features, axis=0) | ||
| 34 | fea_min = np.amin(pics_features, axis=0) | ||
| 35 | fea_std = pics_features.std(axis=0) | ||
| 36 | |||
| 37 | return np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1).reshape(1, -1) | ||
| 38 | |||
| 39 | |||
| 40 | def start_filter(config): | ||
| 41 | |||
| 42 | cls_emotion_path = config['MODEL']['CLS_EMOTION'] | ||
| 43 | face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 44 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 45 | result_file_name = config['EMOTION']['RESULT_FILE'] | ||
| 46 | |||
| 47 | svm_clf = joblib.load(cls_emotion_path) | ||
| 48 | |||
| 49 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 50 | result_file = open(result_file_path, 'w') | ||
| 51 | |||
| 52 | face_feature_names = os.listdir(face_feature_dir) | ||
| 53 | for face_feature in face_feature_names: | ||
| 54 | face_feature_path = os.path.join(face_feature_dir, face_feature) | ||
| 55 | |||
| 56 | features_np = np.load(face_feature_path, allow_pickle=True) | ||
| 57 | |||
| 58 | feature = features2feature(features_np) | ||
| 59 | res = svm_clf.predict_proba(feature) | ||
| 60 | proba = np.squeeze(res) | ||
| 61 | # class_pre = svm_clf.predict(feature) | ||
| 62 | |||
| 63 | result_file.write(face_feature[:-4] + ' ') | ||
| 64 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 65 | |||
| 66 | result_file.close() | ||
| 67 | |||
| 68 | |||
| 69 | |||
| 70 | |||
| 71 |
fighting_2_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 23 | ucf101_rgb_val_file.write('\n') | ||
| 24 | |||
| 25 | ucf101_rgb_val_file.close() | ||
| 26 | |||
| 27 | return val_path | ||
| 28 | |||
| 29 | |||
| 30 | def start_filter(config): | ||
| 31 | arch = config['FIGHTING_2']['ARCH'] | ||
| 32 | prefix = config['VIDEO']['PREFIX'] | ||
| 33 | modality = config['FIGHTING_2']['MODALITY'] | ||
| 34 | test_crop = config['FIGHTING_2']['TEST_CROP'] | ||
| 35 | batch_size = config['FIGHTING_2']['BATCH_SIZE'] | ||
| 36 | weights_path = config['MODEL']['CLS_FIGHTING_2'] | ||
| 37 | test_segment = config['FIGHTING_2']['TEST_SEGMENT'] | ||
| 38 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 40 | result_file_name = config['FIGHTING_2']['RESULT_FILE'] | ||
| 41 | |||
| 42 | workers = 8 | ||
| 43 | num_class = 2 | ||
| 44 | shift_div = 8 | ||
| 45 | img_feature_dim = 256 | ||
| 46 | |||
| 47 | softmax = False | ||
| 48 | is_shift = True | ||
| 49 | full_res = False | ||
| 50 | non_local = False | ||
| 51 | dense_sample = False | ||
| 52 | twice_sample = False | ||
| 53 | |||
| 54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 56 | |||
| 57 | pretrain = 'imagenet' | ||
| 58 | shift_place = 'blockres' | ||
| 59 | crop_fusion_type = 'avg' | ||
| 60 | |||
| 61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 62 | base_model=arch, | ||
| 63 | consensus_type=crop_fusion_type, | ||
| 64 | img_feature_dim=img_feature_dim, | ||
| 65 | pretrain=pretrain, | ||
| 66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 67 | non_local=non_local, | ||
| 68 | ) | ||
| 69 | |||
| 70 | checkpoint = torch.load(weights_path) | ||
| 71 | checkpoint = checkpoint['state_dict'] | ||
| 72 | |||
| 73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 76 | } | ||
| 77 | for k, v in replace_dict.items(): | ||
| 78 | if k in base_dict: | ||
| 79 | base_dict[v] = base_dict.pop(k) | ||
| 80 | |||
| 81 | net.load_state_dict(base_dict) | ||
| 82 | |||
| 83 | input_size = net.scale_size if full_res else net.input_size | ||
| 84 | |||
| 85 | if test_crop == 1: | ||
| 86 | cropping = torchvision.transforms.Compose([ | ||
| 87 | GroupScale(net.scale_size), | ||
| 88 | GroupCenterCrop(input_size), | ||
| 89 | ]) | ||
| 90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 91 | cropping = torchvision.transforms.Compose([ | ||
| 92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 93 | ]) | ||
| 94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 95 | cropping = torchvision.transforms.Compose([ | ||
| 96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 97 | ]) | ||
| 98 | elif test_crop == 10: | ||
| 99 | cropping = torchvision.transforms.Compose([ | ||
| 100 | GroupOverSample(input_size, net.scale_size) | ||
| 101 | ]) | ||
| 102 | else: | ||
| 103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 104 | |||
| 105 | data_loader = torch.utils.data.DataLoader( | ||
| 106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 107 | new_length=1 if modality == "RGB" else 5, | ||
| 108 | modality=modality, | ||
| 109 | image_tmpl=prefix, | ||
| 110 | test_mode=True, | ||
| 111 | remove_missing=False, | ||
| 112 | transform=torchvision.transforms.Compose([ | ||
| 113 | cropping, | ||
| 114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 116 | GroupNormalize(net.input_mean, net.input_std), | ||
| 117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 118 | batch_size=batch_size, shuffle=False, | ||
| 119 | num_workers=workers, pin_memory=True, | ||
| 120 | ) | ||
| 121 | |||
| 122 | net = torch.nn.DataParallel(net.cuda()) | ||
| 123 | net.eval() | ||
| 124 | data_gen = enumerate(data_loader) | ||
| 125 | max_num = len(data_loader.dataset) | ||
| 126 | |||
| 127 | result_file = open(result_file_path, 'w') | ||
| 128 | |||
| 129 | for i, data_pair in data_gen: | ||
| 130 | directory, data = data_pair | ||
| 131 | with torch.no_grad(): | ||
| 132 | if i >= max_num: | ||
| 133 | break | ||
| 134 | num_crop = test_crop | ||
| 135 | if dense_sample: | ||
| 136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 137 | |||
| 138 | if twice_sample: | ||
| 139 | num_crop *= 2 | ||
| 140 | |||
| 141 | if modality == 'RGB': | ||
| 142 | length = 3 | ||
| 143 | elif modality == 'Flow': | ||
| 144 | length = 10 | ||
| 145 | elif modality == 'RGBDiff': | ||
| 146 | length = 18 | ||
| 147 | else: | ||
| 148 | raise ValueError("Unknown modality " + modality) | ||
| 149 | |||
| 150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 151 | if is_shift: | ||
| 152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | print(proba) | ||
| 169 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 170 | result_file.write(str(directory[0]) + ' ') | ||
| 171 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + '\n') | ||
| 172 | |||
| 173 | result_file.close() | ||
| 174 | print('fighting filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
flow_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'flow_val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ori_list = os.listdir(images_dir) | ||
| 23 | select_list = [element for element in ori_list if 'x' in element] | ||
| 24 | ucf101_rgb_val_file.write(str(len(select_list))) | ||
| 25 | ucf101_rgb_val_file.write('\n') | ||
| 26 | |||
| 27 | ucf101_rgb_val_file.close() | ||
| 28 | |||
| 29 | return val_path | ||
| 30 | |||
| 31 | |||
| 32 | def start_filter(config): | ||
| 33 | arch = config['FLOW']['ARCH'] | ||
| 34 | prefix = config['VIDEO']['FLOW_PREFIX'] | ||
| 35 | modality = config['FLOW']['MODALITY'] | ||
| 36 | test_crop = config['FLOW']['TEST_CROP'] | ||
| 37 | batch_size = config['FLOW']['BATCH_SIZE'] | ||
| 38 | weights_path = config['MODEL']['CLS_FLOW'] | ||
| 39 | test_segment = config['FLOW']['TEST_SEGMENT'] | ||
| 40 | frame_save_dir = config['VIDEO']['FLOW_SAVE_DIR'] | ||
| 41 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 42 | result_file_name = config['FLOW']['RESULT_FILE'] | ||
| 43 | |||
| 44 | workers = 8 | ||
| 45 | num_class = 3 | ||
| 46 | shift_div = 8 | ||
| 47 | img_feature_dim = 256 | ||
| 48 | |||
| 49 | softmax = False | ||
| 50 | is_shift = True | ||
| 51 | full_res = False | ||
| 52 | non_local = False | ||
| 53 | dense_sample = False | ||
| 54 | twice_sample = False | ||
| 55 | |||
| 56 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 57 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 58 | |||
| 59 | pretrain = 'imagenet' | ||
| 60 | shift_place = 'blockres' | ||
| 61 | crop_fusion_type = 'avg' | ||
| 62 | |||
| 63 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 64 | base_model=arch, | ||
| 65 | consensus_type=crop_fusion_type, | ||
| 66 | img_feature_dim=img_feature_dim, | ||
| 67 | pretrain=pretrain, | ||
| 68 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 69 | non_local=non_local, | ||
| 70 | ) | ||
| 71 | |||
| 72 | checkpoint = torch.load(weights_path) | ||
| 73 | checkpoint = checkpoint['state_dict'] | ||
| 74 | |||
| 75 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 76 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 77 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 78 | } | ||
| 79 | for k, v in replace_dict.items(): | ||
| 80 | if k in base_dict: | ||
| 81 | base_dict[v] = base_dict.pop(k) | ||
| 82 | |||
| 83 | net.load_state_dict(base_dict) | ||
| 84 | |||
| 85 | input_size = net.scale_size if full_res else net.input_size | ||
| 86 | |||
| 87 | if test_crop == 1: | ||
| 88 | cropping = torchvision.transforms.Compose([ | ||
| 89 | GroupScale(net.scale_size), | ||
| 90 | GroupCenterCrop(input_size), | ||
| 91 | ]) | ||
| 92 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 93 | cropping = torchvision.transforms.Compose([ | ||
| 94 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 95 | ]) | ||
| 96 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 97 | cropping = torchvision.transforms.Compose([ | ||
| 98 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 99 | ]) | ||
| 100 | elif test_crop == 10: | ||
| 101 | cropping = torchvision.transforms.Compose([ | ||
| 102 | GroupOverSample(input_size, net.scale_size) | ||
| 103 | ]) | ||
| 104 | else: | ||
| 105 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 106 | |||
| 107 | data_loader = torch.utils.data.DataLoader( | ||
| 108 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 109 | new_length=1 if modality == "RGB" else 5, | ||
| 110 | modality=modality, | ||
| 111 | image_tmpl=prefix, | ||
| 112 | test_mode=True, | ||
| 113 | remove_missing=False, | ||
| 114 | transform=torchvision.transforms.Compose([ | ||
| 115 | cropping, | ||
| 116 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 117 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 118 | GroupNormalize(net.input_mean, net.input_std), | ||
| 119 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 120 | batch_size=batch_size, shuffle=False, | ||
| 121 | num_workers=workers, pin_memory=True, | ||
| 122 | ) | ||
| 123 | |||
| 124 | net = torch.nn.DataParallel(net.cuda()) | ||
| 125 | net.eval() | ||
| 126 | data_gen = enumerate(data_loader) | ||
| 127 | max_num = len(data_loader.dataset) | ||
| 128 | |||
| 129 | result_file = open(result_file_path, 'w') | ||
| 130 | |||
| 131 | for i, data_pair in data_gen: | ||
| 132 | directory, data = data_pair | ||
| 133 | with torch.no_grad(): | ||
| 134 | if i >= max_num: | ||
| 135 | break | ||
| 136 | num_crop = test_crop | ||
| 137 | if dense_sample: | ||
| 138 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 139 | |||
| 140 | if twice_sample: | ||
| 141 | num_crop *= 2 | ||
| 142 | |||
| 143 | if modality == 'RGB': | ||
| 144 | length = 3 | ||
| 145 | elif modality == 'Flow': | ||
| 146 | length = 10 | ||
| 147 | elif modality == 'RGBDiff': | ||
| 148 | length = 18 | ||
| 149 | else: | ||
| 150 | raise ValueError("Unknown modality " + modality) | ||
| 151 | |||
| 152 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 153 | if is_shift: | ||
| 154 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 155 | rst, feature = net(data_in) | ||
| 156 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 157 | |||
| 158 | if softmax: | ||
| 159 | # take the softmax to normalize the output to probability | ||
| 160 | rst = F.softmax(rst, dim=1) | ||
| 161 | |||
| 162 | rst = rst.data.cpu().numpy().copy() | ||
| 163 | |||
| 164 | if net.module.is_shift: | ||
| 165 | rst = rst.reshape(batch_size, num_class) | ||
| 166 | else: | ||
| 167 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 168 | |||
| 169 | proba = np.squeeze(rst) | ||
| 170 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 171 | result_file.write(str(directory[0]) + ' ') | ||
| 172 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 173 | |||
| 174 | result_file.close() | ||
| 175 | print('fighting filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
load_util.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import yaml | ||
| 4 | import tensorflow as tf | ||
| 5 | |||
| 6 | |||
| 7 | def load_config(config_path): | ||
| 8 | with open(config_path, 'r') as cf: | ||
| 9 | config_obj = yaml.load(cf, Loader=yaml.FullLoader) | ||
| 10 | print(config_obj) | ||
| 11 | return config_obj | ||
| 12 | |||
| 13 | |||
| 14 | def load_argue_model(config): | ||
| 15 | |||
| 16 | cls_argue_path = config['MODEL']['CLS_ARGUE'] | ||
| 17 | with tf.Graph().as_default(): | ||
| 18 | |||
| 19 | if os.path.isfile(cls_argue_path): | ||
| 20 | print('Model filename: %s' % cls_argue_path) | ||
| 21 | with tf.gfile.GFile(cls_argue_path, 'rb') as f: | ||
| 22 | graph_def = tf.GraphDef() | ||
| 23 | graph_def.ParseFromString(f.read()) | ||
| 24 | tf.import_graph_def(graph_def, name='') | ||
| 25 | |||
| 26 | x = tf.get_default_graph().get_tensor_by_name("x_batch:0") | ||
| 27 | output = tf.get_default_graph().get_tensor_by_name("output/BiasAdd:0") | ||
| 28 | |||
| 29 | config = tf.ConfigProto() | ||
| 30 | config.gpu_options.allow_growth = False | ||
| 31 | sess = tf.Session(config=config) | ||
| 32 | |||
| 33 | return x, output, sess |
media_util.py
0 → 100644
This diff is collapsed.
Click to expand it.
ops/__init__.py
0 → 100755
| 1 | from ops.basic_ops import * | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops/__pycache__/__init__.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/basic_ops.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/dataset.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/models.cpython-36.pyc
0 → 100644
No preview for this file type
No preview for this file type
ops/__pycache__/transforms.cpython-36.pyc
0 → 100644
No preview for this file type
ops/basic_ops.py
0 → 100755
| 1 | import torch | ||
| 2 | |||
| 3 | |||
| 4 | class Identity(torch.nn.Module): | ||
| 5 | def forward(self, input): | ||
| 6 | return input | ||
| 7 | |||
| 8 | |||
| 9 | class SegmentConsensus(torch.nn.Module): | ||
| 10 | |||
| 11 | def __init__(self, consensus_type, dim=1): | ||
| 12 | super(SegmentConsensus, self).__init__() | ||
| 13 | self.consensus_type = consensus_type | ||
| 14 | self.dim = dim | ||
| 15 | self.shape = None | ||
| 16 | |||
| 17 | def forward(self, input_tensor): | ||
| 18 | self.shape = input_tensor.size() | ||
| 19 | if self.consensus_type == 'avg': | ||
| 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) | ||
| 21 | elif self.consensus_type == 'identity': | ||
| 22 | output = input_tensor | ||
| 23 | else: | ||
| 24 | output = None | ||
| 25 | |||
| 26 | return output | ||
| 27 | |||
| 28 | |||
| 29 | class ConsensusModule(torch.nn.Module): | ||
| 30 | |||
| 31 | def __init__(self, consensus_type, dim=1): | ||
| 32 | super(ConsensusModule, self).__init__() | ||
| 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' | ||
| 34 | self.dim = dim | ||
| 35 | |||
| 36 | def forward(self, input): | ||
| 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) |
ops/dataset.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import torch.utils.data as data | ||
| 7 | |||
| 8 | from PIL import Image | ||
| 9 | import os | ||
| 10 | import numpy as np | ||
| 11 | from numpy.random import randint | ||
| 12 | |||
| 13 | |||
| 14 | class VideoRecord(object): | ||
| 15 | def __init__(self, row): | ||
| 16 | self._data = row | ||
| 17 | |||
| 18 | @property | ||
| 19 | def path(self): | ||
| 20 | return self._data[0] | ||
| 21 | |||
| 22 | @property | ||
| 23 | def num_frames(self): | ||
| 24 | return int(self._data[1]) | ||
| 25 | |||
| 26 | |||
| 27 | class TSNDataSet(data.Dataset): | ||
| 28 | def __init__(self, root_path, list_file, | ||
| 29 | num_segments=3, new_length=1, modality='RGB', | ||
| 30 | image_tmpl='img_{:05d}.jpg', transform=None, | ||
| 31 | random_shift=True, test_mode=False, | ||
| 32 | remove_missing=False, dense_sample=False, twice_sample=False): | ||
| 33 | |||
| 34 | self.root_path = root_path | ||
| 35 | self.list_file = list_file | ||
| 36 | self.num_segments = num_segments | ||
| 37 | self.new_length = new_length | ||
| 38 | self.modality = modality | ||
| 39 | self.image_tmpl = image_tmpl | ||
| 40 | self.transform = transform | ||
| 41 | self.random_shift = random_shift | ||
| 42 | self.test_mode = test_mode | ||
| 43 | self.remove_missing = remove_missing | ||
| 44 | self.dense_sample = dense_sample # using dense sample as I3D | ||
| 45 | self.twice_sample = twice_sample # twice sample for more validation | ||
| 46 | if self.dense_sample: | ||
| 47 | print('=> Using dense sample for the dataset...') | ||
| 48 | if self.twice_sample: | ||
| 49 | print('=> Using twice sample for the dataset...') | ||
| 50 | |||
| 51 | if self.modality == 'RGBDiff': | ||
| 52 | self.new_length += 1 # Diff needs one more image to calculate diff | ||
| 53 | |||
| 54 | self._parse_list() | ||
| 55 | |||
| 56 | def _load_image(self, directory, idx): | ||
| 57 | if self.modality == 'RGB' or self.modality == 'RGBDiff': | ||
| 58 | try: | ||
| 59 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] | ||
| 60 | except Exception: | ||
| 61 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
| 62 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] | ||
| 63 | elif self.modality == 'Flow': | ||
| 64 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf | ||
| 65 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( | ||
| 66 | 'L') | ||
| 67 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( | ||
| 68 | 'L') | ||
| 69 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow | ||
| 70 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
| 71 | format(int(directory), 'x', idx))).convert('L') | ||
| 72 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
| 73 | format(int(directory), 'y', idx))).convert('L') | ||
| 74 | else: | ||
| 75 | try: | ||
| 76 | # idx_skip = 1 + (idx-1)*5 | ||
| 77 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert( | ||
| 78 | 'RGB') | ||
| 79 | except Exception: | ||
| 80 | print('error loading flow file:', | ||
| 81 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
| 82 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB') | ||
| 83 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel | ||
| 84 | flow_x, flow_y, _ = flow.split() | ||
| 85 | x_img = flow_x.convert('L') | ||
| 86 | y_img = flow_y.convert('L') | ||
| 87 | |||
| 88 | return [x_img, y_img] | ||
| 89 | |||
| 90 | def _parse_list(self): | ||
| 91 | # check the frame number is large >3: | ||
| 92 | tmp = [x.strip().split(' ') for x in open(self.list_file)] | ||
| 93 | if not self.test_mode or self.remove_missing: | ||
| 94 | tmp = [item for item in tmp if int(item[1]) >= 3] | ||
| 95 | self.video_list = [VideoRecord(item) for item in tmp] | ||
| 96 | |||
| 97 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 98 | for v in self.video_list: | ||
| 99 | v._data[1] = int(v._data[1]) / 2 | ||
| 100 | print('video number:%d' % (len(self.video_list))) | ||
| 101 | |||
| 102 | def _sample_indices(self, record): | ||
| 103 | """ | ||
| 104 | |||
| 105 | :param record: VideoRecord | ||
| 106 | :return: list | ||
| 107 | """ | ||
| 108 | if self.dense_sample: # i3d dense sample | ||
| 109 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 110 | t_stride = 64 // self.num_segments | ||
| 111 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
| 112 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 113 | return np.array(offsets) + 1 | ||
| 114 | else: # normal sample | ||
| 115 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments | ||
| 116 | if average_duration > 0: | ||
| 117 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, | ||
| 118 | size=self.num_segments) | ||
| 119 | elif record.num_frames > self.num_segments: | ||
| 120 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) | ||
| 121 | else: | ||
| 122 | offsets = np.zeros((self.num_segments,)) | ||
| 123 | return offsets + 1 | ||
| 124 | |||
| 125 | def _get_val_indices(self, record): | ||
| 126 | if self.dense_sample: # i3d dense sample | ||
| 127 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 128 | t_stride = 64 // self.num_segments | ||
| 129 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
| 130 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 131 | return np.array(offsets) + 1 | ||
| 132 | else: | ||
| 133 | if record.num_frames > self.num_segments + self.new_length - 1: | ||
| 134 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 135 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
| 136 | else: | ||
| 137 | offsets = np.zeros((self.num_segments,)) | ||
| 138 | return offsets + 1 | ||
| 139 | |||
| 140 | def _get_test_indices(self, record): | ||
| 141 | if self.dense_sample: | ||
| 142 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 143 | t_stride = 64 // self.num_segments | ||
| 144 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) | ||
| 145 | offsets = [] | ||
| 146 | for start_idx in start_list.tolist(): | ||
| 147 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 148 | return np.array(offsets) + 1 | ||
| 149 | elif self.twice_sample: | ||
| 150 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 151 | |||
| 152 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + | ||
| 153 | [int(tick * x) for x in range(self.num_segments)]) | ||
| 154 | |||
| 155 | return offsets + 1 | ||
| 156 | else: | ||
| 157 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 158 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
| 159 | return offsets + 1 | ||
| 160 | |||
| 161 | def __getitem__(self, index): | ||
| 162 | record = self.video_list[index] | ||
| 163 | # check this is a legit video folder | ||
| 164 | |||
| 165 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
| 166 | file_name = self.image_tmpl.format('x', 1) | ||
| 167 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 168 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 169 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
| 170 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
| 171 | else: | ||
| 172 | file_name = self.image_tmpl.format(1) | ||
| 173 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 174 | |||
| 175 | while not os.path.exists(full_path): | ||
| 176 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) | ||
| 177 | index = np.random.randint(len(self.video_list)) | ||
| 178 | record = self.video_list[index] | ||
| 179 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
| 180 | file_name = self.image_tmpl.format('x', 1) | ||
| 181 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 182 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 183 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
| 184 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
| 185 | else: | ||
| 186 | file_name = self.image_tmpl.format(1) | ||
| 187 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 188 | |||
| 189 | if not self.test_mode: | ||
| 190 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) | ||
| 191 | else: | ||
| 192 | segment_indices = self._get_test_indices(record) | ||
| 193 | return self.get(record, segment_indices) | ||
| 194 | |||
| 195 | def get(self, record, indices): | ||
| 196 | |||
| 197 | images = list() | ||
| 198 | for seg_ind in indices: | ||
| 199 | p = int(seg_ind) | ||
| 200 | for i in range(self.new_length): | ||
| 201 | seg_imgs = self._load_image(record.path, p) | ||
| 202 | images.extend(seg_imgs) | ||
| 203 | if p < record.num_frames: | ||
| 204 | p += 1 | ||
| 205 | |||
| 206 | process_data = self.transform(images) | ||
| 207 | return record.path, process_data | ||
| 208 | |||
| 209 | def __len__(self): | ||
| 210 | return len(self.video_list) |
ops/dataset_config.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import os | ||
| 7 | |||
| 8 | ROOT_DATASET = '/data1/action_1_images/' # '/data/jilin/' | ||
| 9 | |||
| 10 | |||
| 11 | def return_ucf101(modality): | ||
| 12 | filename_categories = 'labels/classInd.txt' | ||
| 13 | if modality == 'RGB': | ||
| 14 | root_data = ROOT_DATASET + 'images' | ||
| 15 | filename_imglist_train = 'file_list/ucf101_rgb_train_split_1.txt' | ||
| 16 | filename_imglist_val = 'file_list/ucf101_rgb_val_split_1.txt' | ||
| 17 | prefix = 'img_{:05d}.jpg' | ||
| 18 | elif modality == 'Flow': | ||
| 19 | root_data = ROOT_DATASET + 'UCF101/jpg' | ||
| 20 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' | ||
| 21 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' | ||
| 22 | prefix = 'flow_{}_{:05d}.jpg' | ||
| 23 | else: | ||
| 24 | raise NotImplementedError('no such modality:' + modality) | ||
| 25 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 26 | |||
| 27 | |||
| 28 | def return_hmdb51(modality): | ||
| 29 | filename_categories = 51 | ||
| 30 | if modality == 'RGB': | ||
| 31 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
| 32 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' | ||
| 33 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' | ||
| 34 | prefix = 'img_{:05d}.jpg' | ||
| 35 | elif modality == 'Flow': | ||
| 36 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
| 37 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt' | ||
| 38 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt' | ||
| 39 | prefix = 'flow_{}_{:05d}.jpg' | ||
| 40 | else: | ||
| 41 | raise NotImplementedError('no such modality:' + modality) | ||
| 42 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 43 | |||
| 44 | |||
| 45 | def return_something(modality): | ||
| 46 | filename_categories = 'something/v1/category.txt' | ||
| 47 | if modality == 'RGB': | ||
| 48 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1' | ||
| 49 | filename_imglist_train = 'something/v1/train_videofolder.txt' | ||
| 50 | filename_imglist_val = 'something/v1/val_videofolder.txt' | ||
| 51 | prefix = '{:05d}.jpg' | ||
| 52 | elif modality == 'Flow': | ||
| 53 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow' | ||
| 54 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt' | ||
| 55 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt' | ||
| 56 | prefix = '{:06d}-{}_{:05d}.jpg' | ||
| 57 | else: | ||
| 58 | print('no such modality:'+modality) | ||
| 59 | raise NotImplementedError | ||
| 60 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 61 | |||
| 62 | |||
| 63 | def return_somethingv2(modality): | ||
| 64 | filename_categories = 'something/v2/category.txt' | ||
| 65 | if modality == 'RGB': | ||
| 66 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames' | ||
| 67 | filename_imglist_train = 'something/v2/train_videofolder.txt' | ||
| 68 | filename_imglist_val = 'something/v2/val_videofolder.txt' | ||
| 69 | prefix = '{:06d}.jpg' | ||
| 70 | elif modality == 'Flow': | ||
| 71 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow' | ||
| 72 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt' | ||
| 73 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt' | ||
| 74 | prefix = '{:06d}.jpg' | ||
| 75 | else: | ||
| 76 | raise NotImplementedError('no such modality:'+modality) | ||
| 77 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 78 | |||
| 79 | |||
| 80 | def return_jester(modality): | ||
| 81 | filename_categories = 'jester/category.txt' | ||
| 82 | if modality == 'RGB': | ||
| 83 | prefix = '{:05d}.jpg' | ||
| 84 | root_data = ROOT_DATASET + 'jester/20bn-jester-v1' | ||
| 85 | filename_imglist_train = 'jester/train_videofolder.txt' | ||
| 86 | filename_imglist_val = 'jester/val_videofolder.txt' | ||
| 87 | else: | ||
| 88 | raise NotImplementedError('no such modality:'+modality) | ||
| 89 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 90 | |||
| 91 | |||
| 92 | def return_kinetics(modality): | ||
| 93 | filename_categories = 400 | ||
| 94 | if modality == 'RGB': | ||
| 95 | root_data = ROOT_DATASET + 'kinetics/images' | ||
| 96 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt' | ||
| 97 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt' | ||
| 98 | prefix = 'img_{:05d}.jpg' | ||
| 99 | else: | ||
| 100 | raise NotImplementedError('no such modality:' + modality) | ||
| 101 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 102 | |||
| 103 | |||
| 104 | def return_dataset(dataset, modality): | ||
| 105 | dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2, | ||
| 106 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, | ||
| 107 | 'kinetics': return_kinetics} | ||
| 108 | if dataset in dict_single: | ||
| 109 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) | ||
| 110 | else: | ||
| 111 | raise ValueError('Unknown dataset '+dataset) | ||
| 112 | |||
| 113 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) | ||
| 114 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) | ||
| 115 | if isinstance(file_categories, str): | ||
| 116 | file_categories = os.path.join(ROOT_DATASET, file_categories) | ||
| 117 | with open(file_categories) as f: | ||
| 118 | lines = f.readlines() | ||
| 119 | categories = [item.rstrip() for item in lines] | ||
| 120 | else: # number of categories | ||
| 121 | categories = [None] * file_categories | ||
| 122 | n_class = len(categories) | ||
| 123 | print('{}: {} classes'.format(dataset, n_class)) | ||
| 124 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix |
ops/models.py
0 → 100755
This diff is collapsed.
Click to expand it.
ops/non_local.py
0 → 100644
| 1 | # Non-local block using embedded gaussian | ||
| 2 | # Code from | ||
| 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py | ||
| 4 | import torch | ||
| 5 | from torch import nn | ||
| 6 | from torch.nn import functional as F | ||
| 7 | |||
| 8 | |||
| 9 | class _NonLocalBlockND(nn.Module): | ||
| 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): | ||
| 11 | super(_NonLocalBlockND, self).__init__() | ||
| 12 | |||
| 13 | assert dimension in [1, 2, 3] | ||
| 14 | |||
| 15 | self.dimension = dimension | ||
| 16 | self.sub_sample = sub_sample | ||
| 17 | |||
| 18 | self.in_channels = in_channels | ||
| 19 | self.inter_channels = inter_channels | ||
| 20 | |||
| 21 | if self.inter_channels is None: | ||
| 22 | self.inter_channels = in_channels // 2 | ||
| 23 | if self.inter_channels == 0: | ||
| 24 | self.inter_channels = 1 | ||
| 25 | |||
| 26 | if dimension == 3: | ||
| 27 | conv_nd = nn.Conv3d | ||
| 28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) | ||
| 29 | bn = nn.BatchNorm3d | ||
| 30 | elif dimension == 2: | ||
| 31 | conv_nd = nn.Conv2d | ||
| 32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) | ||
| 33 | bn = nn.BatchNorm2d | ||
| 34 | else: | ||
| 35 | conv_nd = nn.Conv1d | ||
| 36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) | ||
| 37 | bn = nn.BatchNorm1d | ||
| 38 | |||
| 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 40 | kernel_size=1, stride=1, padding=0) | ||
| 41 | |||
| 42 | if bn_layer: | ||
| 43 | self.W = nn.Sequential( | ||
| 44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
| 45 | kernel_size=1, stride=1, padding=0), | ||
| 46 | bn(self.in_channels) | ||
| 47 | ) | ||
| 48 | nn.init.constant_(self.W[1].weight, 0) | ||
| 49 | nn.init.constant_(self.W[1].bias, 0) | ||
| 50 | else: | ||
| 51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
| 52 | kernel_size=1, stride=1, padding=0) | ||
| 53 | nn.init.constant_(self.W.weight, 0) | ||
| 54 | nn.init.constant_(self.W.bias, 0) | ||
| 55 | |||
| 56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 57 | kernel_size=1, stride=1, padding=0) | ||
| 58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 59 | kernel_size=1, stride=1, padding=0) | ||
| 60 | |||
| 61 | if sub_sample: | ||
| 62 | self.g = nn.Sequential(self.g, max_pool_layer) | ||
| 63 | self.phi = nn.Sequential(self.phi, max_pool_layer) | ||
| 64 | |||
| 65 | def forward(self, x): | ||
| 66 | ''' | ||
| 67 | :param x: (b, c, t, h, w) | ||
| 68 | :return: | ||
| 69 | ''' | ||
| 70 | |||
| 71 | batch_size = x.size(0) | ||
| 72 | |||
| 73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) | ||
| 74 | g_x = g_x.permute(0, 2, 1) | ||
| 75 | |||
| 76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | ||
| 77 | theta_x = theta_x.permute(0, 2, 1) | ||
| 78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | ||
| 79 | f = torch.matmul(theta_x, phi_x) | ||
| 80 | f_div_C = F.softmax(f, dim=-1) | ||
| 81 | |||
| 82 | y = torch.matmul(f_div_C, g_x) | ||
| 83 | y = y.permute(0, 2, 1).contiguous() | ||
| 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | ||
| 85 | W_y = self.W(y) | ||
| 86 | z = W_y + x | ||
| 87 | |||
| 88 | return z | ||
| 89 | |||
| 90 | |||
| 91 | class NONLocalBlock1D(_NonLocalBlockND): | ||
| 92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 93 | super(NONLocalBlock1D, self).__init__(in_channels, | ||
| 94 | inter_channels=inter_channels, | ||
| 95 | dimension=1, sub_sample=sub_sample, | ||
| 96 | bn_layer=bn_layer) | ||
| 97 | |||
| 98 | |||
| 99 | class NONLocalBlock2D(_NonLocalBlockND): | ||
| 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 101 | super(NONLocalBlock2D, self).__init__(in_channels, | ||
| 102 | inter_channels=inter_channels, | ||
| 103 | dimension=2, sub_sample=sub_sample, | ||
| 104 | bn_layer=bn_layer) | ||
| 105 | |||
| 106 | |||
| 107 | class NONLocalBlock3D(_NonLocalBlockND): | ||
| 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 109 | super(NONLocalBlock3D, self).__init__(in_channels, | ||
| 110 | inter_channels=inter_channels, | ||
| 111 | dimension=3, sub_sample=sub_sample, | ||
| 112 | bn_layer=bn_layer) | ||
| 113 | |||
| 114 | |||
| 115 | class NL3DWrapper(nn.Module): | ||
| 116 | def __init__(self, block, n_segment): | ||
| 117 | super(NL3DWrapper, self).__init__() | ||
| 118 | self.block = block | ||
| 119 | self.nl = NONLocalBlock3D(block.bn3.num_features) | ||
| 120 | self.n_segment = n_segment | ||
| 121 | |||
| 122 | def forward(self, x): | ||
| 123 | x = self.block(x) | ||
| 124 | |||
| 125 | nt, c, h, w = x.size() | ||
| 126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
| 127 | x = self.nl(x) | ||
| 128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) | ||
| 129 | return x | ||
| 130 | |||
| 131 | |||
| 132 | def make_non_local(net, n_segment): | ||
| 133 | import torchvision | ||
| 134 | import archs | ||
| 135 | if isinstance(net, torchvision.models.ResNet): | ||
| 136 | net.layer2 = nn.Sequential( | ||
| 137 | NL3DWrapper(net.layer2[0], n_segment), | ||
| 138 | net.layer2[1], | ||
| 139 | NL3DWrapper(net.layer2[2], n_segment), | ||
| 140 | net.layer2[3], | ||
| 141 | ) | ||
| 142 | net.layer3 = nn.Sequential( | ||
| 143 | NL3DWrapper(net.layer3[0], n_segment), | ||
| 144 | net.layer3[1], | ||
| 145 | NL3DWrapper(net.layer3[2], n_segment), | ||
| 146 | net.layer3[3], | ||
| 147 | NL3DWrapper(net.layer3[4], n_segment), | ||
| 148 | net.layer3[5], | ||
| 149 | ) | ||
| 150 | else: | ||
| 151 | raise NotImplementedError | ||
| 152 | |||
| 153 | |||
| 154 | if __name__ == '__main__': | ||
| 155 | from torch.autograd import Variable | ||
| 156 | import torch | ||
| 157 | |||
| 158 | sub_sample = True | ||
| 159 | bn_layer = True | ||
| 160 | |||
| 161 | img = Variable(torch.zeros(2, 3, 20)) | ||
| 162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 163 | out = net(img) | ||
| 164 | print(out.size()) | ||
| 165 | |||
| 166 | img = Variable(torch.zeros(2, 3, 20, 20)) | ||
| 167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 168 | out = net(img) | ||
| 169 | print(out.size()) | ||
| 170 | |||
| 171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) | ||
| 172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 173 | out = net(img) | ||
| 174 | print(out.size()) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops/temporal_shift.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import torch | ||
| 7 | import torch.nn as nn | ||
| 8 | import torch.nn.functional as F | ||
| 9 | |||
| 10 | |||
| 11 | class TemporalShift(nn.Module): | ||
| 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): | ||
| 13 | super(TemporalShift, self).__init__() | ||
| 14 | self.net = net | ||
| 15 | self.n_segment = n_segment | ||
| 16 | self.fold_div = n_div | ||
| 17 | self.inplace = inplace | ||
| 18 | if inplace: | ||
| 19 | print('=> Using in-place shift...') | ||
| 20 | print('=> Using fold div: {}'.format(self.fold_div)) | ||
| 21 | |||
| 22 | def forward(self, x): | ||
| 23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) | ||
| 24 | return self.net(x) | ||
| 25 | |||
| 26 | @staticmethod | ||
| 27 | def shift(x, n_segment, fold_div=3, inplace=False): | ||
| 28 | nt, c, h, w = x.size() | ||
| 29 | n_batch = nt // n_segment | ||
| 30 | x = x.view(n_batch, n_segment, c, h, w) | ||
| 31 | |||
| 32 | fold = c // fold_div | ||
| 33 | if inplace: | ||
| 34 | # Due to some out of order error when performing parallel computing. | ||
| 35 | # May need to write a CUDA kernel. | ||
| 36 | raise NotImplementedError | ||
| 37 | # out = InplaceShift.apply(x, fold) | ||
| 38 | else: | ||
| 39 | out = torch.zeros_like(x) | ||
| 40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left | ||
| 41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right | ||
| 42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift | ||
| 43 | |||
| 44 | return out.view(nt, c, h, w) | ||
| 45 | |||
| 46 | |||
| 47 | class InplaceShift(torch.autograd.Function): | ||
| 48 | # Special thanks to @raoyongming for the help to this function | ||
| 49 | @staticmethod | ||
| 50 | def forward(ctx, input, fold): | ||
| 51 | # not support higher order gradient | ||
| 52 | # input = input.detach_() | ||
| 53 | ctx.fold_ = fold | ||
| 54 | n, t, c, h, w = input.size() | ||
| 55 | buffer = input.data.new(n, t, fold, h, w).zero_() | ||
| 56 | buffer[:, :-1] = input.data[:, 1:, :fold] | ||
| 57 | input.data[:, :, :fold] = buffer | ||
| 58 | buffer.zero_() | ||
| 59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] | ||
| 60 | input.data[:, :, fold: 2 * fold] = buffer | ||
| 61 | return input | ||
| 62 | |||
| 63 | @staticmethod | ||
| 64 | def backward(ctx, grad_output): | ||
| 65 | # grad_output = grad_output.detach_() | ||
| 66 | fold = ctx.fold_ | ||
| 67 | n, t, c, h, w = grad_output.size() | ||
| 68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() | ||
| 69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] | ||
| 70 | grad_output.data[:, :, :fold] = buffer | ||
| 71 | buffer.zero_() | ||
| 72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] | ||
| 73 | grad_output.data[:, :, fold: 2 * fold] = buffer | ||
| 74 | return grad_output, None | ||
| 75 | |||
| 76 | |||
| 77 | class TemporalPool(nn.Module): | ||
| 78 | def __init__(self, net, n_segment): | ||
| 79 | super(TemporalPool, self).__init__() | ||
| 80 | self.net = net | ||
| 81 | self.n_segment = n_segment | ||
| 82 | |||
| 83 | def forward(self, x): | ||
| 84 | x = self.temporal_pool(x, n_segment=self.n_segment) | ||
| 85 | return self.net(x) | ||
| 86 | |||
| 87 | @staticmethod | ||
| 88 | def temporal_pool(x, n_segment): | ||
| 89 | nt, c, h, w = x.size() | ||
| 90 | n_batch = nt // n_segment | ||
| 91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
| 92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) | ||
| 93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) | ||
| 94 | return x | ||
| 95 | |||
| 96 | |||
| 97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): | ||
| 98 | if temporal_pool: | ||
| 99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] | ||
| 100 | else: | ||
| 101 | n_segment_list = [n_segment] * 4 | ||
| 102 | assert n_segment_list[-1] > 0 | ||
| 103 | print('=> n_segment per stage: {}'.format(n_segment_list)) | ||
| 104 | |||
| 105 | import torchvision | ||
| 106 | if isinstance(net, torchvision.models.ResNet): | ||
| 107 | if place == 'block': | ||
| 108 | def make_block_temporal(stage, this_segment): | ||
| 109 | blocks = list(stage.children()) | ||
| 110 | print('=> Processing stage with {} blocks'.format(len(blocks))) | ||
| 111 | for i, b in enumerate(blocks): | ||
| 112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) | ||
| 113 | return nn.Sequential(*(blocks)) | ||
| 114 | |||
| 115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
| 116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
| 117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
| 118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
| 119 | |||
| 120 | elif 'blockres' in place: | ||
| 121 | n_round = 1 | ||
| 122 | if len(list(net.layer3.children())) >= 23: | ||
| 123 | n_round = 2 | ||
| 124 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) | ||
| 125 | |||
| 126 | def make_block_temporal(stage, this_segment): | ||
| 127 | blocks = list(stage.children()) | ||
| 128 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) | ||
| 129 | for i, b in enumerate(blocks): | ||
| 130 | if i % n_round == 0: | ||
| 131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) | ||
| 132 | return nn.Sequential(*blocks) | ||
| 133 | |||
| 134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
| 135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
| 136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
| 137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
| 138 | else: | ||
| 139 | raise NotImplementedError(place) | ||
| 140 | |||
| 141 | |||
| 142 | def make_temporal_pool(net, n_segment): | ||
| 143 | import torchvision | ||
| 144 | if isinstance(net, torchvision.models.ResNet): | ||
| 145 | print('=> Injecting nonlocal pooling') | ||
| 146 | net.layer2 = TemporalPool(net.layer2, n_segment) | ||
| 147 | else: | ||
| 148 | raise NotImplementedError | ||
| 149 | |||
| 150 | |||
| 151 | if __name__ == '__main__': | ||
| 152 | # test inplace shift v.s. vanilla shift | ||
| 153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) | ||
| 154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) | ||
| 155 | |||
| 156 | print('=> Testing CPU...') | ||
| 157 | # test forward | ||
| 158 | with torch.no_grad(): | ||
| 159 | for i in range(10): | ||
| 160 | x = torch.rand(2 * 8, 3, 224, 224) | ||
| 161 | y1 = tsm1(x) | ||
| 162 | y2 = tsm2(x) | ||
| 163 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
| 164 | |||
| 165 | # test backward | ||
| 166 | with torch.enable_grad(): | ||
| 167 | for i in range(10): | ||
| 168 | x1 = torch.rand(2 * 8, 3, 224, 224) | ||
| 169 | x1.requires_grad_() | ||
| 170 | x2 = x1.clone() | ||
| 171 | y1 = tsm1(x1) | ||
| 172 | y2 = tsm2(x2) | ||
| 173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
| 174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
| 175 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
| 176 | |||
| 177 | print('=> Testing GPU...') | ||
| 178 | tsm1.cuda() | ||
| 179 | tsm2.cuda() | ||
| 180 | # test forward | ||
| 181 | with torch.no_grad(): | ||
| 182 | for i in range(10): | ||
| 183 | x = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
| 184 | y1 = tsm1(x) | ||
| 185 | y2 = tsm2(x) | ||
| 186 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
| 187 | |||
| 188 | # test backward | ||
| 189 | with torch.enable_grad(): | ||
| 190 | for i in range(10): | ||
| 191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
| 192 | x1.requires_grad_() | ||
| 193 | x2 = x1.clone() | ||
| 194 | y1 = tsm1(x1) | ||
| 195 | y2 = tsm2(x2) | ||
| 196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
| 197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
| 198 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
| 199 | print('Test passed.') | ||
| 200 | |||
| 201 | |||
| 202 | |||
| 203 |
ops/transforms.py
0 → 100755
This diff is collapsed.
Click to expand it.
ops/utils.py
0 → 100755
| 1 | import numpy as np | ||
| 2 | |||
| 3 | |||
| 4 | def softmax(scores): | ||
| 5 | es = np.exp(scores - scores.max(axis=-1)[..., None]) | ||
| 6 | return es / es.sum(axis=-1)[..., None] | ||
| 7 | |||
| 8 | |||
| 9 | class AverageMeter(object): | ||
| 10 | """Computes and stores the average and current value""" | ||
| 11 | |||
| 12 | def __init__(self): | ||
| 13 | self.reset() | ||
| 14 | |||
| 15 | def reset(self): | ||
| 16 | self.val = 0 | ||
| 17 | self.avg = 0 | ||
| 18 | self.sum = 0 | ||
| 19 | self.count = 0 | ||
| 20 | |||
| 21 | def update(self, val, n=1): | ||
| 22 | self.val = val | ||
| 23 | self.sum += val * n | ||
| 24 | self.count += n | ||
| 25 | self.avg = self.sum / self.count | ||
| 26 | |||
| 27 | |||
| 28 | def accuracy(output, target, topk=(1,)): | ||
| 29 | """Computes the precision@k for the specified values of k""" | ||
| 30 | maxk = max(topk) | ||
| 31 | batch_size = target.size(0) | ||
| 32 | |||
| 33 | _, pred = output.topk(maxk, 1, True, True) | ||
| 34 | pred = pred.t() | ||
| 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 36 | |||
| 37 | res = [] | ||
| 38 | for k in topk: | ||
| 39 | correct_k = correct[:k].view(-1).float().sum(0) | ||
| 40 | res.append(correct_k.mul_(100.0 / batch_size)) | ||
| 41 | return res | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
person_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | import pickle | ||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | cls_class_path = config['MODEL']['CLS_PERSON'] | ||
| 8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 10 | result_file_name = config['PERSON']['RESULT_FILE'] | ||
| 11 | feature_name = config['PERSON']['DATA_NAME'] | ||
| 12 | |||
| 13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 14 | |||
| 15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 16 | result_file = open(result_file_path, 'w') | ||
| 17 | |||
| 18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 20 | |||
| 21 | X_val = [] | ||
| 22 | Y_val = [] | ||
| 23 | Y_names = [] | ||
| 24 | for j in range(len(val_annotation_pairs)): | ||
| 25 | pair = val_annotation_pairs[j] | ||
| 26 | X_val.append(np.squeeze(pair[0])) | ||
| 27 | Y_val.append(pair[1]) | ||
| 28 | Y_names.append(pair[2]) | ||
| 29 | |||
| 30 | X_val = np.array(X_val) | ||
| 31 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 32 | |||
| 33 | for i, Y_name in enumerate(Y_names): | ||
| 34 | result_file.write(Y_name + ' ') | ||
| 35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 |
pose_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 23 | ucf101_rgb_val_file.write('\n') | ||
| 24 | |||
| 25 | ucf101_rgb_val_file.close() | ||
| 26 | |||
| 27 | return val_path | ||
| 28 | |||
| 29 | |||
| 30 | def start_filter(config): | ||
| 31 | arch = config['FIGHTING']['ARCH'] | ||
| 32 | prefix = config['VIDEO']['PREFIX'] | ||
| 33 | modality = config['POSE']['MODALITY'] | ||
| 34 | test_crop = config['POSE']['TEST_CROP'] | ||
| 35 | batch_size = config['POSE']['BATCH_SIZE'] | ||
| 36 | weights_path = config['MODEL']['CLS_POSE'] | ||
| 37 | test_segment = config['POSE']['TEST_SEGMENT'] | ||
| 38 | frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR'] | ||
| 39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 40 | result_file_name = config['POSE']['RESULT_FILE'] | ||
| 41 | |||
| 42 | workers = 8 | ||
| 43 | num_class = 3 | ||
| 44 | shift_div = 8 | ||
| 45 | img_feature_dim = 256 | ||
| 46 | |||
| 47 | softmax = False | ||
| 48 | is_shift = True | ||
| 49 | full_res = False | ||
| 50 | non_local = False | ||
| 51 | dense_sample = False | ||
| 52 | twice_sample = False | ||
| 53 | |||
| 54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 56 | |||
| 57 | pretrain = 'imagenet' | ||
| 58 | shift_place = 'blockres' | ||
| 59 | crop_fusion_type = 'avg' | ||
| 60 | |||
| 61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 62 | base_model=arch, | ||
| 63 | consensus_type=crop_fusion_type, | ||
| 64 | img_feature_dim=img_feature_dim, | ||
| 65 | pretrain=pretrain, | ||
| 66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 67 | non_local=non_local, | ||
| 68 | ) | ||
| 69 | |||
| 70 | checkpoint = torch.load(weights_path) | ||
| 71 | checkpoint = checkpoint['state_dict'] | ||
| 72 | |||
| 73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 76 | } | ||
| 77 | for k, v in replace_dict.items(): | ||
| 78 | if k in base_dict: | ||
| 79 | base_dict[v] = base_dict.pop(k) | ||
| 80 | |||
| 81 | net.load_state_dict(base_dict) | ||
| 82 | |||
| 83 | input_size = net.scale_size if full_res else net.input_size | ||
| 84 | |||
| 85 | if test_crop == 1: | ||
| 86 | cropping = torchvision.transforms.Compose([ | ||
| 87 | GroupScale(net.scale_size), | ||
| 88 | GroupCenterCrop(input_size), | ||
| 89 | ]) | ||
| 90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 91 | cropping = torchvision.transforms.Compose([ | ||
| 92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 93 | ]) | ||
| 94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 95 | cropping = torchvision.transforms.Compose([ | ||
| 96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 97 | ]) | ||
| 98 | elif test_crop == 10: | ||
| 99 | cropping = torchvision.transforms.Compose([ | ||
| 100 | GroupOverSample(input_size, net.scale_size) | ||
| 101 | ]) | ||
| 102 | else: | ||
| 103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 104 | |||
| 105 | data_loader = torch.utils.data.DataLoader( | ||
| 106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 107 | new_length=1 if modality == "RGB" else 5, | ||
| 108 | modality=modality, | ||
| 109 | image_tmpl=prefix, | ||
| 110 | test_mode=True, | ||
| 111 | remove_missing=False, | ||
| 112 | transform=torchvision.transforms.Compose([ | ||
| 113 | cropping, | ||
| 114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 116 | GroupNormalize(net.input_mean, net.input_std), | ||
| 117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 118 | batch_size=batch_size, shuffle=False, | ||
| 119 | num_workers=workers, pin_memory=True, | ||
| 120 | ) | ||
| 121 | |||
| 122 | net = torch.nn.DataParallel(net.cuda()) | ||
| 123 | net.eval() | ||
| 124 | data_gen = enumerate(data_loader) | ||
| 125 | max_num = len(data_loader.dataset) | ||
| 126 | |||
| 127 | result_file = open(result_file_path, 'w') | ||
| 128 | |||
| 129 | for i, data_pair in data_gen: | ||
| 130 | directory, data = data_pair | ||
| 131 | with torch.no_grad(): | ||
| 132 | if i >= max_num: | ||
| 133 | break | ||
| 134 | num_crop = test_crop | ||
| 135 | if dense_sample: | ||
| 136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 137 | |||
| 138 | if twice_sample: | ||
| 139 | num_crop *= 2 | ||
| 140 | |||
| 141 | if modality == 'RGB': | ||
| 142 | length = 3 | ||
| 143 | elif modality == 'Flow': | ||
| 144 | length = 10 | ||
| 145 | elif modality == 'RGBDiff': | ||
| 146 | length = 18 | ||
| 147 | else: | ||
| 148 | raise ValueError("Unknown modality " + modality) | ||
| 149 | |||
| 150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 151 | if is_shift: | ||
| 152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 169 | result_file.write(str(directory[0]) + ' ') | ||
| 170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 171 | |||
| 172 | result_file.close() | ||
| 173 | print('video filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
test.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import load_util | ||
| 4 | import media_util | ||
| 5 | import numpy as np | ||
| 6 | from sklearn.metrics import confusion_matrix | ||
| 7 | import fighting_filter, emotion_filter, argue_filter, audio_filter, class_filter | ||
| 8 | import video_filter, pose_filter, flow_filter | ||
| 9 | |||
| 10 | |||
| 11 | |||
| 12 | def accuracy_cal(config): | ||
| 13 | |||
| 14 | label_file_path = config['VIDEO']['LABEL_PATH'] | ||
| 15 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 16 | final_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 17 | |||
| 18 | final_file_path = os.path.join(frame_list_dir, final_file_name) | ||
| 19 | final_file_lines = open(final_file_path).readlines() | ||
| 20 | label_file_lines = open(label_file_path).readlines() | ||
| 21 | |||
| 22 | |||
| 23 | final_pairs = {line.strip().split(' ')[0]: line.strip().split(' ')[1] for line in final_file_lines} | ||
| 24 | |||
| 25 | lines_num = len(label_file_lines) - 1 | ||
| 26 | hit = 0 | ||
| 27 | for i, label_line in enumerate(label_file_lines): | ||
| 28 | if i == 0: | ||
| 29 | continue | ||
| 30 | file, label = label_line.strip().split(' ') | ||
| 31 | final_pre = final_pairs[file] | ||
| 32 | final_pre_class = np.argmax(np.array(final_pre.split(','))) + 1 | ||
| 33 | print(final_pre_class, label) | ||
| 34 | if final_pre_class == int(label): | ||
| 35 | hit += 1 | ||
| 36 | |||
| 37 | return hit/lines_num | ||
| 38 | |||
| 39 | |||
| 40 | def main(): | ||
| 41 | config_path = r'config.yaml' | ||
| 42 | config = load_util.load_config(config_path) | ||
| 43 | |||
| 44 | media_util.extract_wav(config) | ||
| 45 | media_util.extract_frame(config) | ||
| 46 | media_util.extract_frame_pose(config) | ||
| 47 | media_util.extract_is10(config) | ||
| 48 | media_util.extract_random_face_feature(config) | ||
| 49 | media_util.extract_mirror(config) | ||
| 50 | |||
| 51 | fighting_2_filter.start_filter(config) | ||
| 52 | emotion_filter.start_filter(config) | ||
| 53 | |||
| 54 | audio_filter.start_filter(config) | ||
| 55 | |||
| 56 | class_filter.start_filter(config) | ||
| 57 | video_filter.start_filter(config) | ||
| 58 | pose_filter.start_filter(config) | ||
| 59 | |||
| 60 | flow_filter.start_filter(config) | ||
| 61 | |||
| 62 | acc = accuracy_cal(config) | ||
| 63 | print(acc) | ||
| 64 | |||
| 65 | |||
| 66 | if __name__ == '__main__': | ||
| 67 | main() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
video_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.nn.parallel | ||
| 5 | from ops.models import TSN | ||
| 6 | from ops.transforms import * | ||
| 7 | from ops.dataset import TSNDataSet | ||
| 8 | from torch.nn import functional as F | ||
| 9 | |||
| 10 | |||
| 11 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 12 | |||
| 13 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 14 | video_names = os.listdir(frame_save_dir) | ||
| 15 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 16 | |||
| 17 | for video_name in video_names: | ||
| 18 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 19 | ucf101_rgb_val_file.write(video_name) | ||
| 20 | ucf101_rgb_val_file.write(' ') | ||
| 21 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 22 | ucf101_rgb_val_file.write('\n') | ||
| 23 | |||
| 24 | ucf101_rgb_val_file.close() | ||
| 25 | |||
| 26 | return val_path | ||
| 27 | |||
| 28 | |||
| 29 | def start_filter(config): | ||
| 30 | arch = config['FIGHTING']['ARCH'] | ||
| 31 | prefix = config['VIDEO']['PREFIX'] | ||
| 32 | modality = config['VIDEO_FILTER']['MODALITY'] | ||
| 33 | test_crop = config['VIDEO_FILTER']['TEST_CROP'] | ||
| 34 | batch_size = config['VIDEO_FILTER']['BATCH_SIZE'] | ||
| 35 | weights_path = config['MODEL']['CLS_VIDEO'] | ||
| 36 | test_segment = config['VIDEO_FILTER']['TEST_SEGMENT'] | ||
| 37 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 38 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 39 | result_file_name = config['VIDEO_FILTER']['RESULT_FILE'] | ||
| 40 | |||
| 41 | workers = 8 | ||
| 42 | num_class = 3 | ||
| 43 | shift_div = 8 | ||
| 44 | img_feature_dim = 256 | ||
| 45 | |||
| 46 | softmax = False | ||
| 47 | is_shift = True | ||
| 48 | full_res = False | ||
| 49 | non_local = False | ||
| 50 | dense_sample = False | ||
| 51 | twice_sample = False | ||
| 52 | |||
| 53 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 54 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 55 | |||
| 56 | pretrain = 'imagenet' | ||
| 57 | shift_place = 'blockres' | ||
| 58 | crop_fusion_type = 'avg' | ||
| 59 | |||
| 60 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 61 | base_model=arch, | ||
| 62 | consensus_type=crop_fusion_type, | ||
| 63 | img_feature_dim=img_feature_dim, | ||
| 64 | pretrain=pretrain, | ||
| 65 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 66 | non_local=non_local, | ||
| 67 | ) | ||
| 68 | |||
| 69 | checkpoint = torch.load(weights_path) | ||
| 70 | checkpoint = checkpoint['state_dict'] | ||
| 71 | |||
| 72 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 73 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 74 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 75 | } | ||
| 76 | for k, v in replace_dict.items(): | ||
| 77 | if k in base_dict: | ||
| 78 | base_dict[v] = base_dict.pop(k) | ||
| 79 | |||
| 80 | net.load_state_dict(base_dict) | ||
| 81 | |||
| 82 | input_size = net.scale_size if full_res else net.input_size | ||
| 83 | |||
| 84 | if test_crop == 1: | ||
| 85 | cropping = torchvision.transforms.Compose([ | ||
| 86 | GroupScale(net.scale_size), | ||
| 87 | GroupCenterCrop(input_size), | ||
| 88 | ]) | ||
| 89 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 90 | cropping = torchvision.transforms.Compose([ | ||
| 91 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 92 | ]) | ||
| 93 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 94 | cropping = torchvision.transforms.Compose([ | ||
| 95 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 96 | ]) | ||
| 97 | elif test_crop == 10: | ||
| 98 | cropping = torchvision.transforms.Compose([ | ||
| 99 | GroupOverSample(input_size, net.scale_size) | ||
| 100 | ]) | ||
| 101 | else: | ||
| 102 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 103 | |||
| 104 | data_loader = torch.utils.data.DataLoader( | ||
| 105 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 106 | new_length=1 if modality == "RGB" else 5, | ||
| 107 | modality=modality, | ||
| 108 | image_tmpl=prefix, | ||
| 109 | test_mode=True, | ||
| 110 | remove_missing=False, | ||
| 111 | transform=torchvision.transforms.Compose([ | ||
| 112 | cropping, | ||
| 113 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 114 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 115 | GroupNormalize(net.input_mean, net.input_std), | ||
| 116 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 117 | batch_size=batch_size, shuffle=False, | ||
| 118 | num_workers=workers, pin_memory=True, | ||
| 119 | ) | ||
| 120 | |||
| 121 | net = torch.nn.DataParallel(net.cuda()) | ||
| 122 | net.eval() | ||
| 123 | data_gen = enumerate(data_loader) | ||
| 124 | max_num = len(data_loader.dataset) | ||
| 125 | |||
| 126 | result_file = open(result_file_path, 'w') | ||
| 127 | |||
| 128 | for i, data_pair in data_gen: | ||
| 129 | directory, data = data_pair | ||
| 130 | with torch.no_grad(): | ||
| 131 | if i >= max_num: | ||
| 132 | break | ||
| 133 | num_crop = test_crop | ||
| 134 | if dense_sample: | ||
| 135 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 136 | |||
| 137 | if twice_sample: | ||
| 138 | num_crop *= 2 | ||
| 139 | |||
| 140 | if modality == 'RGB': | ||
| 141 | length = 3 | ||
| 142 | elif modality == 'Flow': | ||
| 143 | length = 10 | ||
| 144 | elif modality == 'RGBDiff': | ||
| 145 | length = 18 | ||
| 146 | else: | ||
| 147 | raise ValueError("Unknown modality " + modality) | ||
| 148 | |||
| 149 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 150 | if is_shift: | ||
| 151 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 152 | |||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 169 | result_file.write(str(directory[0]) + ' ') | ||
| 170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 171 | |||
| 172 | result_file.close() | ||
| 173 | print('video filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment