first commit
Showing
46 changed files
with
3340 additions
and
0 deletions
__init__.py
0 → 100644
File mode changed
__pycache__/argue_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/audio_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/bg_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/class_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_2_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/flow_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/load_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/media_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/meeting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/person_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/pose_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/troops_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_filter.cpython-36.pyc
0 → 100644
No preview for this file type
audio_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import csv | ||
| 3 | import pickle | ||
| 4 | import numpy as np | ||
| 5 | from sklearn.externals import joblib | ||
| 6 | |||
| 7 | |||
| 8 | def start_filter(config): | ||
| 9 | cls_audio_path = config['MODEL']['CLS_AUDIO'] | ||
| 10 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
| 11 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 12 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 13 | feature_name = config['AUDIO']['DATA_NAME'] | ||
| 14 | |||
| 15 | svm_clf = joblib.load(cls_audio_path) | ||
| 16 | |||
| 17 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 18 | result_file = open(result_file_path, 'w') | ||
| 19 | |||
| 20 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 21 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 22 | |||
| 23 | for pair in val_annotation_pairs: | ||
| 24 | |||
| 25 | v = pair[0] | ||
| 26 | n = pair[2] | ||
| 27 | |||
| 28 | feature_np = np.reshape(v, (1, -1)) | ||
| 29 | res = svm_clf.predict_proba(feature_np) | ||
| 30 | proba = np.squeeze(res) | ||
| 31 | |||
| 32 | # class_pre = svm_clf.predict(feature_np) | ||
| 33 | |||
| 34 | result_file.write(str(pair[2])[:-4] + ' ') | ||
| 35 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 | |||
| 43 | def start_filter_xgboost(config): | ||
| 44 | cls_class_path = config['MODEL']['CLS_AUDIO'] | ||
| 45 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
| 46 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 47 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 48 | feature_name = config['AUDIO']['DATA_NAME'] | ||
| 49 | |||
| 50 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 51 | |||
| 52 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 53 | result_file = open(result_file_path, 'w') | ||
| 54 | |||
| 55 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 56 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 57 | |||
| 58 | X_val = [] | ||
| 59 | Y_names = [] | ||
| 60 | for pair in val_annotation_pairs: | ||
| 61 | n, v = pair.items() | ||
| 62 | X_val.append(v) | ||
| 63 | Y_names.append(n) | ||
| 64 | |||
| 65 | X_val = np.array(X_val) | ||
| 66 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 67 | |||
| 68 | for i, Y_name in enumerate(Y_names): | ||
| 69 | result_file.write(Y_name + ' ') | ||
| 70 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 71 | |||
| 72 | result_file.close() | ||
| 73 |
bg_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | import pickle | ||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | cls_class_path = config['MODEL']['CLS_BG'] | ||
| 8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 10 | result_file_name = config['BG']['RESULT_FILE'] | ||
| 11 | feature_name = config['BG']['DATA_NAME'] | ||
| 12 | |||
| 13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 14 | |||
| 15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 16 | result_file = open(result_file_path, 'w') | ||
| 17 | |||
| 18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
| 20 | |||
| 21 | X_val = [] | ||
| 22 | Y_val = [] | ||
| 23 | Y_names = [] | ||
| 24 | for j in range(len(val_annotation_pairs)): | ||
| 25 | pair = val_annotation_pairs[j] | ||
| 26 | X_val.append(np.squeeze(pair[0])) | ||
| 27 | Y_val.append(pair[1]) | ||
| 28 | Y_names.append(pair[2]) | ||
| 29 | |||
| 30 | X_val = np.array(X_val) | ||
| 31 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 32 | |||
| 33 | for i, Y_name in enumerate(Y_names): | ||
| 34 | result_file.write(Y_name + ' ') | ||
| 35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 |
class_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import pickle | ||
| 3 | import numpy as np | ||
| 4 | |||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | |||
| 8 | cls_class_path = config['MODEL']['CLS_CLASS'] | ||
| 9 | feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR'] | ||
| 10 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 11 | result_file_name = config['CLASS']['RESULT_FILE'] | ||
| 12 | feature_name = config['CLASS']['DATA_NAME'] | ||
| 13 | |||
| 14 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 15 | |||
| 16 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 17 | result_file = open(result_file_path, 'w') | ||
| 18 | |||
| 19 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 20 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
| 21 | |||
| 22 | X_val = [] | ||
| 23 | Y_val = [] | ||
| 24 | Y_names = [] | ||
| 25 | for j in range(len(val_annotation_pairs)): | ||
| 26 | pair = val_annotation_pairs[j] | ||
| 27 | X_val.append(pair[0]) | ||
| 28 | Y_val.append(pair[1]) | ||
| 29 | Y_names.append(pair[2]) | ||
| 30 | |||
| 31 | X_val = np.array(X_val) | ||
| 32 | y_pred = xgboost_model.predict(X_val) | ||
| 33 | |||
| 34 | for i, Y_name in enumerate(Y_names): | ||
| 35 | result_file.write(Y_name + ' ') | ||
| 36 | result_file.write(str(y_pred[i]) + '\n') | ||
| 37 | |||
| 38 | result_file.close() |
config.yaml
0 → 100644
| 1 | MODEL: | ||
| 2 | CLS_FIGHTING_2: '/home/jwq/models/cls_fighting_2/cls_fighting_2_v0.0.1.pth' | ||
| 3 | CLS_EMOTION: '/home/jwq/models/cls_emotion/v0.1.0.m' | ||
| 4 | FEATURE_EMOTION: '/home/jwq/models/feature_emotion/FerPlus3.h5' | ||
| 5 | CLS_AUDIO: '/home/jwq/models/cls_audio/v0.0.1.m' | ||
| 6 | CLS_CLASS: '/home/jwq/models/cls_class/v_0.0.1_xgb.pkl' | ||
| 7 | CLS_VIDEO: '/home/jwq/models/cls_video/v0.4.1.pth' | ||
| 8 | CLS_POSE: '/home/jwq/models/cls_pose/v0.0.1.pth' | ||
| 9 | CLS_FLOW: '/home/jwq/models/cls_flow/v0.1.1.pth' | ||
| 10 | CLS_BG: '/home/jwq/models/cls_bg/v0.1.1.pkl' | ||
| 11 | CLS_PERSON: '/home/jwq/models/cls_person/v0.1.1.pkl' | ||
| 12 | |||
| 13 | THRESHOLD: | ||
| 14 | FACES_THRESHOLD: 0.6 | ||
| 15 | |||
| 16 | FILTER: | ||
| 17 | |||
| 18 | |||
| 19 | VIDEO: | ||
| 20 | VIDEO_DIR: '/home/jwq/Desktop/VGAF_EmotiW/Val' | ||
| 21 | LABEL_PATH: '/home/jwq/Desktop/VGAF_EmotiW/Val_labels.txt' | ||
| 22 | VIDEO_SAVE_DIR: '/home/jwq/Desktop/tmp/video' | ||
| 23 | AUDIO_SAVE_DIR: '/home/jwq/npys/' | ||
| 24 | FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/frame' | ||
| 25 | # FRAME_SAVE_DIR: '/home/jwq/Desktop/VGAF_EmotiW_class/train_frame' | ||
| 26 | FLOW_SAVE_DIR: '/home/jwq/Desktop/tmp/flow' | ||
| 27 | POSE_FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/pose_frame' | ||
| 28 | FRAME_LIST_DIR: '/home/jwq/Desktop/tmp/file_list' | ||
| 29 | IS10_FEATURE_NP_DIR: '/home/jwq/npys' | ||
| 30 | IS10_FEATURE_CSV_DIR: '/home/jwq/Desktop/tmp/is10' | ||
| 31 | # FACE_FEATURE_DIR: '/home/jwq/Desktop/tmp/face_feature_retina' | ||
| 32 | # FACE_FEATURE_DIR: '/data2/retinaface/random_face_frame_features/' | ||
| 33 | FACE_FEATURE_DIR: '/data1/segment/' | ||
| 34 | # FACE_FEATURE_DIR: '/home/jwq/npys/' | ||
| 35 | FACE_IMAGE_DIR: '/data2/retinaface/train/' | ||
| 36 | CLASS_FEATURE_DIR: '/home/jwq/Desktop/tmp/class' | ||
| 37 | PREFIX: 'img_{:05d}.jpg' | ||
| 38 | FLOW_PREFIX: 'flow_{}_{:05d}.jpg' | ||
| 39 | THREAD_NUM: 10 | ||
| 40 | FPS: 5 | ||
| 41 | |||
| 42 | VIDEO_FILTER: | ||
| 43 | TEST_SEGMENT: 8 | ||
| 44 | TEST_CROP: 1 | ||
| 45 | BATCH_SIZE: 1 | ||
| 46 | INPUT_SIZE: 224 | ||
| 47 | MODALITY: 'RGB' | ||
| 48 | ARCH: 'resnet50' | ||
| 49 | RESULT_FILE: 'video_filter.txt' | ||
| 50 | |||
| 51 | VIDEO_1_FILTER: | ||
| 52 | TEST_SEGMENT: 8 | ||
| 53 | TEST_CROP: 1 | ||
| 54 | BATCH_SIZE: 1 | ||
| 55 | INPUT_SIZE: 224 | ||
| 56 | MODALITY: 'RGB' | ||
| 57 | ARCH: 'resnet34' | ||
| 58 | RESULT_FILE: 'video_1_filter.txt' | ||
| 59 | |||
| 60 | EMOTION: | ||
| 61 | INTERVAL: 1 | ||
| 62 | INPUT_SIZE: 224 | ||
| 63 | RESULT_FILE: 'emotion_filter.txt' | ||
| 64 | |||
| 65 | EMOTION_1: | ||
| 66 | RESULT_FILE: 'emotion_1_filter.txt' | ||
| 67 | DATA_NAME: 'val.npy' | ||
| 68 | |||
| 69 | ARGUE: | ||
| 70 | DIMENSION: 1582 | ||
| 71 | RESULT_FILE: 'argue_filter.txt' | ||
| 72 | |||
| 73 | FIGHTING: | ||
| 74 | TEST_SEGMENT: 8 | ||
| 75 | TEST_CROP: 1 | ||
| 76 | BATCH_SIZE: 1 | ||
| 77 | INPUT_SIZE: 224 | ||
| 78 | MODALITY: 'RGB' | ||
| 79 | ARCH: 'resnet50' | ||
| 80 | RESULT_FILE: 'fighting_filter.txt' | ||
| 81 | |||
| 82 | FIGHTING_2: | ||
| 83 | TEST_SEGMENT: 8 | ||
| 84 | TEST_CROP: 1 | ||
| 85 | BATCH_SIZE: 1 | ||
| 86 | INPUT_SIZE: 224 | ||
| 87 | MODALITY: 'RGB' | ||
| 88 | ARCH: 'resnet50' | ||
| 89 | RESULT_FILE: 'fighting_2_filter.txt' | ||
| 90 | |||
| 91 | MEETING: | ||
| 92 | TEST_SEGMENT: 8 | ||
| 93 | TEST_CROP: 1 | ||
| 94 | BATCH_SIZE: 1 | ||
| 95 | INPUT_SIZE: 224 | ||
| 96 | MODALITY: 'RGB' | ||
| 97 | ARCH: 'resnet50' | ||
| 98 | RESULT_FILE: 'meeting_filter.txt' | ||
| 99 | |||
| 100 | TROOPS: | ||
| 101 | TEST_SEGMENT: 8 | ||
| 102 | TEST_CROP: 1 | ||
| 103 | BATCH_SIZE: 1 | ||
| 104 | INPUT_SIZE: 224 | ||
| 105 | MODALITY: 'RGB' | ||
| 106 | ARCH: 'resnet50' | ||
| 107 | RESULT_FILE: 'troops_filter.txt' | ||
| 108 | |||
| 109 | FLOW: | ||
| 110 | TEST_SEGMENT: 8 | ||
| 111 | TEST_CROP: 1 | ||
| 112 | BATCH_SIZE: 1 | ||
| 113 | INPUT_SIZE: 224 | ||
| 114 | MODALITY: 'Flow' | ||
| 115 | ARCH: 'resnet50' | ||
| 116 | RESULT_FILE: 'flow_filter.txt' | ||
| 117 | |||
| 118 | |||
| 119 | FINAL: | ||
| 120 | RESULT_FILE: 'final.txt' | ||
| 121 | ERROR_FILE: 'error.txt' | ||
| 122 | SIM_FILE: 'image_sim.txt' | ||
| 123 | |||
| 124 | AUDIO: | ||
| 125 | RESULT_FILE: 'audio.txt' | ||
| 126 | OPENSMILE_DIR: '/home/jwq/Downloads/opensmile-2.3.0' | ||
| 127 | DATA_NAME: 'val.npy' | ||
| 128 | |||
| 129 | CLASS: | ||
| 130 | RESULT_FILE: 'class.txt' | ||
| 131 | DATA_NAME: 'val _reannotation.npy' | ||
| 132 | |||
| 133 | POSE: | ||
| 134 | TEST_SEGMENT: 8 | ||
| 135 | TEST_CROP: 1 | ||
| 136 | BATCH_SIZE: 1 | ||
| 137 | INPUT_SIZE: 224 | ||
| 138 | MODALITY: 'RGB' | ||
| 139 | ARCH: 'resnet50' | ||
| 140 | RESULT_FILE: 'pose_filter.txt' | ||
| 141 | |||
| 142 | BG: | ||
| 143 | RESULT_FILE: 'bg_filter.txt' | ||
| 144 | DATA_NAME: 'bg_val_feature.npy' | ||
| 145 | |||
| 146 | PERSON: | ||
| 147 | RESULT_FILE: 'person_filter.txt' | ||
| 148 | DATA_NAME: 'person_val_feature.npy' | ||
| 149 | |||
| 150 |
emotion_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | from keras.models import Model | ||
| 5 | from keras.models import load_model | ||
| 6 | from sklearn.externals import joblib | ||
| 7 | from tensorflow.keras.preprocessing.image import img_to_array | ||
| 8 | |||
| 9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' | ||
| 10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | ||
| 11 | |||
| 12 | |||
| 13 | class FeatureExtractor(object): | ||
| 14 | def __init__(self, input_size=224, out_put_layer='avg_pool', model_path='FerPlus3.h5'): | ||
| 15 | self.model = load_model(model_path) | ||
| 16 | self.input_size = input_size | ||
| 17 | self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output) | ||
| 18 | |||
| 19 | def inference(self, image): | ||
| 20 | image = cv2.resize(image, (self.input_size, self.input_size)) | ||
| 21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
| 22 | image = image.astype("float") / 255.0 | ||
| 23 | image = img_to_array(image) | ||
| 24 | image = np.expand_dims(image, axis=0) | ||
| 25 | feature = self.model_inter.predict(image)[0] | ||
| 26 | return feature | ||
| 27 | |||
| 28 | |||
| 29 | def features2feature(pics_features): | ||
| 30 | |||
| 31 | pics_features = np.array(pics_features) | ||
| 32 | fea_mean = pics_features.mean(axis=0) | ||
| 33 | fea_max = np.amax(pics_features, axis=0) | ||
| 34 | fea_min = np.amin(pics_features, axis=0) | ||
| 35 | fea_std = pics_features.std(axis=0) | ||
| 36 | |||
| 37 | return np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1).reshape(1, -1) | ||
| 38 | |||
| 39 | |||
| 40 | def start_filter(config): | ||
| 41 | |||
| 42 | cls_emotion_path = config['MODEL']['CLS_EMOTION'] | ||
| 43 | face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 44 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 45 | result_file_name = config['EMOTION']['RESULT_FILE'] | ||
| 46 | |||
| 47 | svm_clf = joblib.load(cls_emotion_path) | ||
| 48 | |||
| 49 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 50 | result_file = open(result_file_path, 'w') | ||
| 51 | |||
| 52 | face_feature_names = os.listdir(face_feature_dir) | ||
| 53 | for face_feature in face_feature_names: | ||
| 54 | face_feature_path = os.path.join(face_feature_dir, face_feature) | ||
| 55 | |||
| 56 | features_np = np.load(face_feature_path, allow_pickle=True) | ||
| 57 | |||
| 58 | feature = features2feature(features_np) | ||
| 59 | res = svm_clf.predict_proba(feature) | ||
| 60 | proba = np.squeeze(res) | ||
| 61 | # class_pre = svm_clf.predict(feature) | ||
| 62 | |||
| 63 | result_file.write(face_feature[:-4] + ' ') | ||
| 64 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 65 | |||
| 66 | result_file.close() | ||
| 67 | |||
| 68 | |||
| 69 | |||
| 70 | |||
| 71 |
fighting_2_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 23 | ucf101_rgb_val_file.write('\n') | ||
| 24 | |||
| 25 | ucf101_rgb_val_file.close() | ||
| 26 | |||
| 27 | return val_path | ||
| 28 | |||
| 29 | |||
| 30 | def start_filter(config): | ||
| 31 | arch = config['FIGHTING_2']['ARCH'] | ||
| 32 | prefix = config['VIDEO']['PREFIX'] | ||
| 33 | modality = config['FIGHTING_2']['MODALITY'] | ||
| 34 | test_crop = config['FIGHTING_2']['TEST_CROP'] | ||
| 35 | batch_size = config['FIGHTING_2']['BATCH_SIZE'] | ||
| 36 | weights_path = config['MODEL']['CLS_FIGHTING_2'] | ||
| 37 | test_segment = config['FIGHTING_2']['TEST_SEGMENT'] | ||
| 38 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 40 | result_file_name = config['FIGHTING_2']['RESULT_FILE'] | ||
| 41 | |||
| 42 | workers = 8 | ||
| 43 | num_class = 2 | ||
| 44 | shift_div = 8 | ||
| 45 | img_feature_dim = 256 | ||
| 46 | |||
| 47 | softmax = False | ||
| 48 | is_shift = True | ||
| 49 | full_res = False | ||
| 50 | non_local = False | ||
| 51 | dense_sample = False | ||
| 52 | twice_sample = False | ||
| 53 | |||
| 54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 56 | |||
| 57 | pretrain = 'imagenet' | ||
| 58 | shift_place = 'blockres' | ||
| 59 | crop_fusion_type = 'avg' | ||
| 60 | |||
| 61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 62 | base_model=arch, | ||
| 63 | consensus_type=crop_fusion_type, | ||
| 64 | img_feature_dim=img_feature_dim, | ||
| 65 | pretrain=pretrain, | ||
| 66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 67 | non_local=non_local, | ||
| 68 | ) | ||
| 69 | |||
| 70 | checkpoint = torch.load(weights_path) | ||
| 71 | checkpoint = checkpoint['state_dict'] | ||
| 72 | |||
| 73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 76 | } | ||
| 77 | for k, v in replace_dict.items(): | ||
| 78 | if k in base_dict: | ||
| 79 | base_dict[v] = base_dict.pop(k) | ||
| 80 | |||
| 81 | net.load_state_dict(base_dict) | ||
| 82 | |||
| 83 | input_size = net.scale_size if full_res else net.input_size | ||
| 84 | |||
| 85 | if test_crop == 1: | ||
| 86 | cropping = torchvision.transforms.Compose([ | ||
| 87 | GroupScale(net.scale_size), | ||
| 88 | GroupCenterCrop(input_size), | ||
| 89 | ]) | ||
| 90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 91 | cropping = torchvision.transforms.Compose([ | ||
| 92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 93 | ]) | ||
| 94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 95 | cropping = torchvision.transforms.Compose([ | ||
| 96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 97 | ]) | ||
| 98 | elif test_crop == 10: | ||
| 99 | cropping = torchvision.transforms.Compose([ | ||
| 100 | GroupOverSample(input_size, net.scale_size) | ||
| 101 | ]) | ||
| 102 | else: | ||
| 103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 104 | |||
| 105 | data_loader = torch.utils.data.DataLoader( | ||
| 106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 107 | new_length=1 if modality == "RGB" else 5, | ||
| 108 | modality=modality, | ||
| 109 | image_tmpl=prefix, | ||
| 110 | test_mode=True, | ||
| 111 | remove_missing=False, | ||
| 112 | transform=torchvision.transforms.Compose([ | ||
| 113 | cropping, | ||
| 114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 116 | GroupNormalize(net.input_mean, net.input_std), | ||
| 117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 118 | batch_size=batch_size, shuffle=False, | ||
| 119 | num_workers=workers, pin_memory=True, | ||
| 120 | ) | ||
| 121 | |||
| 122 | net = torch.nn.DataParallel(net.cuda()) | ||
| 123 | net.eval() | ||
| 124 | data_gen = enumerate(data_loader) | ||
| 125 | max_num = len(data_loader.dataset) | ||
| 126 | |||
| 127 | result_file = open(result_file_path, 'w') | ||
| 128 | |||
| 129 | for i, data_pair in data_gen: | ||
| 130 | directory, data = data_pair | ||
| 131 | with torch.no_grad(): | ||
| 132 | if i >= max_num: | ||
| 133 | break | ||
| 134 | num_crop = test_crop | ||
| 135 | if dense_sample: | ||
| 136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 137 | |||
| 138 | if twice_sample: | ||
| 139 | num_crop *= 2 | ||
| 140 | |||
| 141 | if modality == 'RGB': | ||
| 142 | length = 3 | ||
| 143 | elif modality == 'Flow': | ||
| 144 | length = 10 | ||
| 145 | elif modality == 'RGBDiff': | ||
| 146 | length = 18 | ||
| 147 | else: | ||
| 148 | raise ValueError("Unknown modality " + modality) | ||
| 149 | |||
| 150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 151 | if is_shift: | ||
| 152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | print(proba) | ||
| 169 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 170 | result_file.write(str(directory[0]) + ' ') | ||
| 171 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + '\n') | ||
| 172 | |||
| 173 | result_file.close() | ||
| 174 | print('fighting filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
flow_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'flow_val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ori_list = os.listdir(images_dir) | ||
| 23 | select_list = [element for element in ori_list if 'x' in element] | ||
| 24 | ucf101_rgb_val_file.write(str(len(select_list))) | ||
| 25 | ucf101_rgb_val_file.write('\n') | ||
| 26 | |||
| 27 | ucf101_rgb_val_file.close() | ||
| 28 | |||
| 29 | return val_path | ||
| 30 | |||
| 31 | |||
| 32 | def start_filter(config): | ||
| 33 | arch = config['FLOW']['ARCH'] | ||
| 34 | prefix = config['VIDEO']['FLOW_PREFIX'] | ||
| 35 | modality = config['FLOW']['MODALITY'] | ||
| 36 | test_crop = config['FLOW']['TEST_CROP'] | ||
| 37 | batch_size = config['FLOW']['BATCH_SIZE'] | ||
| 38 | weights_path = config['MODEL']['CLS_FLOW'] | ||
| 39 | test_segment = config['FLOW']['TEST_SEGMENT'] | ||
| 40 | frame_save_dir = config['VIDEO']['FLOW_SAVE_DIR'] | ||
| 41 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 42 | result_file_name = config['FLOW']['RESULT_FILE'] | ||
| 43 | |||
| 44 | workers = 8 | ||
| 45 | num_class = 3 | ||
| 46 | shift_div = 8 | ||
| 47 | img_feature_dim = 256 | ||
| 48 | |||
| 49 | softmax = False | ||
| 50 | is_shift = True | ||
| 51 | full_res = False | ||
| 52 | non_local = False | ||
| 53 | dense_sample = False | ||
| 54 | twice_sample = False | ||
| 55 | |||
| 56 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 57 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 58 | |||
| 59 | pretrain = 'imagenet' | ||
| 60 | shift_place = 'blockres' | ||
| 61 | crop_fusion_type = 'avg' | ||
| 62 | |||
| 63 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 64 | base_model=arch, | ||
| 65 | consensus_type=crop_fusion_type, | ||
| 66 | img_feature_dim=img_feature_dim, | ||
| 67 | pretrain=pretrain, | ||
| 68 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 69 | non_local=non_local, | ||
| 70 | ) | ||
| 71 | |||
| 72 | checkpoint = torch.load(weights_path) | ||
| 73 | checkpoint = checkpoint['state_dict'] | ||
| 74 | |||
| 75 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 76 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 77 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 78 | } | ||
| 79 | for k, v in replace_dict.items(): | ||
| 80 | if k in base_dict: | ||
| 81 | base_dict[v] = base_dict.pop(k) | ||
| 82 | |||
| 83 | net.load_state_dict(base_dict) | ||
| 84 | |||
| 85 | input_size = net.scale_size if full_res else net.input_size | ||
| 86 | |||
| 87 | if test_crop == 1: | ||
| 88 | cropping = torchvision.transforms.Compose([ | ||
| 89 | GroupScale(net.scale_size), | ||
| 90 | GroupCenterCrop(input_size), | ||
| 91 | ]) | ||
| 92 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 93 | cropping = torchvision.transforms.Compose([ | ||
| 94 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 95 | ]) | ||
| 96 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 97 | cropping = torchvision.transforms.Compose([ | ||
| 98 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 99 | ]) | ||
| 100 | elif test_crop == 10: | ||
| 101 | cropping = torchvision.transforms.Compose([ | ||
| 102 | GroupOverSample(input_size, net.scale_size) | ||
| 103 | ]) | ||
| 104 | else: | ||
| 105 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 106 | |||
| 107 | data_loader = torch.utils.data.DataLoader( | ||
| 108 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 109 | new_length=1 if modality == "RGB" else 5, | ||
| 110 | modality=modality, | ||
| 111 | image_tmpl=prefix, | ||
| 112 | test_mode=True, | ||
| 113 | remove_missing=False, | ||
| 114 | transform=torchvision.transforms.Compose([ | ||
| 115 | cropping, | ||
| 116 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 117 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 118 | GroupNormalize(net.input_mean, net.input_std), | ||
| 119 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 120 | batch_size=batch_size, shuffle=False, | ||
| 121 | num_workers=workers, pin_memory=True, | ||
| 122 | ) | ||
| 123 | |||
| 124 | net = torch.nn.DataParallel(net.cuda()) | ||
| 125 | net.eval() | ||
| 126 | data_gen = enumerate(data_loader) | ||
| 127 | max_num = len(data_loader.dataset) | ||
| 128 | |||
| 129 | result_file = open(result_file_path, 'w') | ||
| 130 | |||
| 131 | for i, data_pair in data_gen: | ||
| 132 | directory, data = data_pair | ||
| 133 | with torch.no_grad(): | ||
| 134 | if i >= max_num: | ||
| 135 | break | ||
| 136 | num_crop = test_crop | ||
| 137 | if dense_sample: | ||
| 138 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 139 | |||
| 140 | if twice_sample: | ||
| 141 | num_crop *= 2 | ||
| 142 | |||
| 143 | if modality == 'RGB': | ||
| 144 | length = 3 | ||
| 145 | elif modality == 'Flow': | ||
| 146 | length = 10 | ||
| 147 | elif modality == 'RGBDiff': | ||
| 148 | length = 18 | ||
| 149 | else: | ||
| 150 | raise ValueError("Unknown modality " + modality) | ||
| 151 | |||
| 152 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 153 | if is_shift: | ||
| 154 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 155 | rst, feature = net(data_in) | ||
| 156 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 157 | |||
| 158 | if softmax: | ||
| 159 | # take the softmax to normalize the output to probability | ||
| 160 | rst = F.softmax(rst, dim=1) | ||
| 161 | |||
| 162 | rst = rst.data.cpu().numpy().copy() | ||
| 163 | |||
| 164 | if net.module.is_shift: | ||
| 165 | rst = rst.reshape(batch_size, num_class) | ||
| 166 | else: | ||
| 167 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 168 | |||
| 169 | proba = np.squeeze(rst) | ||
| 170 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 171 | result_file.write(str(directory[0]) + ' ') | ||
| 172 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 173 | |||
| 174 | result_file.close() | ||
| 175 | print('fighting filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
load_util.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import yaml | ||
| 4 | import tensorflow as tf | ||
| 5 | |||
| 6 | |||
| 7 | def load_config(config_path): | ||
| 8 | with open(config_path, 'r') as cf: | ||
| 9 | config_obj = yaml.load(cf, Loader=yaml.FullLoader) | ||
| 10 | print(config_obj) | ||
| 11 | return config_obj | ||
| 12 | |||
| 13 | |||
| 14 | def load_argue_model(config): | ||
| 15 | |||
| 16 | cls_argue_path = config['MODEL']['CLS_ARGUE'] | ||
| 17 | with tf.Graph().as_default(): | ||
| 18 | |||
| 19 | if os.path.isfile(cls_argue_path): | ||
| 20 | print('Model filename: %s' % cls_argue_path) | ||
| 21 | with tf.gfile.GFile(cls_argue_path, 'rb') as f: | ||
| 22 | graph_def = tf.GraphDef() | ||
| 23 | graph_def.ParseFromString(f.read()) | ||
| 24 | tf.import_graph_def(graph_def, name='') | ||
| 25 | |||
| 26 | x = tf.get_default_graph().get_tensor_by_name("x_batch:0") | ||
| 27 | output = tf.get_default_graph().get_tensor_by_name("output/BiasAdd:0") | ||
| 28 | |||
| 29 | config = tf.ConfigProto() | ||
| 30 | config.gpu_options.allow_growth = False | ||
| 31 | sess = tf.Session(config=config) | ||
| 32 | |||
| 33 | return x, output, sess |
media_util.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import random | ||
| 4 | import shutil | ||
| 5 | import subprocess | ||
| 6 | import numpy as np | ||
| 7 | import torch.optim | ||
| 8 | from tqdm import tqdm | ||
| 9 | import torch.nn.parallel | ||
| 10 | from ops.models import TSN | ||
| 11 | from ops.transforms import * | ||
| 12 | from functools import partial | ||
| 13 | from mtcnn.mtcnn import MTCNN | ||
| 14 | from keras.models import Model | ||
| 15 | from multiprocessing import Pool | ||
| 16 | from keras.models import load_model | ||
| 17 | from sklearn.externals import joblib | ||
| 18 | from tensorflow.keras.preprocessing.image import img_to_array | ||
| 19 | |||
| 20 | |||
| 21 | |||
| 22 | |||
| 23 | from ops.dataset import TSNDataSet | ||
| 24 | from torch.nn import functional as F | ||
| 25 | |||
| 26 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' | ||
| 27 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | ||
| 28 | |||
| 29 | class FeatureExtractor(object): | ||
| 30 | def __init__(self, input_size=224, out_put_layer='global_average_pooling2d_1', model_path='nceptionResNetV2-final.h5'): | ||
| 31 | self.model = load_model(model_path) | ||
| 32 | self.input_size = input_size | ||
| 33 | self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output) | ||
| 34 | |||
| 35 | def inference(self, image): | ||
| 36 | image = cv2.resize(image, (self.input_size, self.input_size)) | ||
| 37 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
| 38 | image = image.astype("float") / 255.0 | ||
| 39 | image = img_to_array(image) | ||
| 40 | image = np.expand_dims(image, axis=0) | ||
| 41 | feature = self.model_inter.predict(image)[0] | ||
| 42 | return feature | ||
| 43 | |||
| 44 | |||
| 45 | def extract_wav(config): | ||
| 46 | video_dir = config['VIDEO']['VIDEO_DIR'] | ||
| 47 | video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR'] | ||
| 48 | audio_save_dir = config['VIDEO']['AUDIO_SAVE_DIR'] | ||
| 49 | |||
| 50 | assert os.path.exists(video_dir) | ||
| 51 | video_names = os.listdir(video_dir) | ||
| 52 | for video_index, video_name in enumerate(video_names): | ||
| 53 | file_name = video_name.split('.')[0] | ||
| 54 | video_path = os.path.join(video_dir, video_name) | ||
| 55 | |||
| 56 | assert os.path.exists(audio_save_dir) | ||
| 57 | assert os.path.exists(video_save_dir) | ||
| 58 | |||
| 59 | audio_name = file_name + '.wav' | ||
| 60 | audio_save_path = os.path.join(audio_save_dir, audio_name) | ||
| 61 | video_save_path = os.path.join(video_save_dir, video_name) | ||
| 62 | |||
| 63 | command = 'ffmpeg -i {} -f wav -ar 16000 {}'.format(video_path, audio_save_path) | ||
| 64 | os.popen(command) | ||
| 65 | shutil.copyfile(video_path, video_save_path) | ||
| 66 | |||
| 67 | |||
| 68 | def video2frame(file_name, class_path, dst_class_path): | ||
| 69 | if '.mp4' not in file_name: | ||
| 70 | return | ||
| 71 | name, ext = os.path.splitext(file_name) | ||
| 72 | dst_directory_path = os.path.join(dst_class_path, name) | ||
| 73 | |||
| 74 | video_file_path = os.path.join(class_path, file_name) | ||
| 75 | try: | ||
| 76 | if os.path.exists(dst_directory_path): | ||
| 77 | if not os.path.exists(os.path.join(dst_directory_path, 'img_00001.jpg')): | ||
| 78 | subprocess.call('rm -r \"{}\"'.format(dst_directory_path), shell=True) | ||
| 79 | print('remove {}'.format(dst_directory_path)) | ||
| 80 | os.mkdir(dst_directory_path) | ||
| 81 | else: | ||
| 82 | print('*** convert has been done: {}'.format(dst_directory_path)) | ||
| 83 | return | ||
| 84 | else: | ||
| 85 | os.mkdir(dst_directory_path) | ||
| 86 | except: | ||
| 87 | print(dst_directory_path) | ||
| 88 | return | ||
| 89 | cmd = 'ffmpeg -i \"{}\" -threads 1 -vf scale=-1:331 -q:v 0 \"{}/img_%05d.jpg\"'.format(video_file_path, | ||
| 90 | dst_directory_path) | ||
| 91 | # print(cmd) | ||
| 92 | subprocess.call(cmd, shell=True, | ||
| 93 | stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | ||
| 94 | |||
| 95 | |||
| 96 | def extract_frame(config): | ||
| 97 | video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR'] | ||
| 98 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 99 | n_thread = config['VIDEO']['THREAD_NUM'] | ||
| 100 | |||
| 101 | assert os.path.exists(video_save_dir) | ||
| 102 | video_names = os.listdir(video_save_dir) | ||
| 103 | |||
| 104 | if not os.path.exists(frame_save_dir): | ||
| 105 | os.mkdir(frame_save_dir) | ||
| 106 | |||
| 107 | p = Pool(n_thread) | ||
| 108 | worker = partial(video2frame, class_path=video_save_dir, dst_class_path=frame_save_dir) | ||
| 109 | for _ in tqdm(p.imap_unordered(worker, video_names), total=len(video_names)): | ||
| 110 | pass | ||
| 111 | |||
| 112 | p.close() | ||
| 113 | p.join() | ||
| 114 | |||
| 115 | |||
| 116 | def extract_frame_pose(config): | ||
| 117 | video_save_dir = config['VIDEO']['VIDEO_SAVE_DIR'] | ||
| 118 | frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR'] | ||
| 119 | n_thread = config['VIDEO']['THREAD_NUM'] | ||
| 120 | |||
| 121 | assert os.path.exists(video_save_dir) | ||
| 122 | video_names = os.listdir(video_save_dir) | ||
| 123 | |||
| 124 | if not os.path.exists(frame_save_dir): | ||
| 125 | os.mkdir(frame_save_dir) | ||
| 126 | |||
| 127 | p = Pool(n_thread) | ||
| 128 | worker = partial(video2frame, class_path=video_save_dir, dst_class_path=frame_save_dir) | ||
| 129 | for _ in tqdm(p.imap_unordered(worker, video_names), total=len(video_names)): | ||
| 130 | pass | ||
| 131 | |||
| 132 | p.close() | ||
| 133 | p.join() | ||
| 134 | |||
| 135 | |||
| 136 | def extract_is10(config): | ||
| 137 | open_smile_dir = config['AUDIO']['OPENSMILE_DIR'] | ||
| 138 | audio_save_dir = config['VIDEO']['AUDIO_SAVE_DIR'] | ||
| 139 | is10_save_dir = config['VIDEO']['IS10_FEATURE_CSV_DIR'] | ||
| 140 | |||
| 141 | assert os.path.exists(audio_save_dir) | ||
| 142 | audio_names = os.listdir(audio_save_dir) | ||
| 143 | |||
| 144 | if not os.path.exists(is10_save_dir): | ||
| 145 | os.mkdir(is10_save_dir) | ||
| 146 | |||
| 147 | for audio_name in audio_names: | ||
| 148 | audio_save_path = os.path.join(audio_save_dir, audio_name) | ||
| 149 | csv_name = audio_name[:-4] + '.csv' | ||
| 150 | csv_path = os.path.join(is10_save_dir, csv_name) | ||
| 151 | |||
| 152 | config = '{}/config/IS10_paraling.conf'.format(open_smile_dir) | ||
| 153 | command = '{}/SMILExtract -C {} -I {} -O {}'.format(open_smile_dir, config, audio_save_path, csv_path) | ||
| 154 | os.popen(command) | ||
| 155 | |||
| 156 | |||
| 157 | def extract_face_feature(config): | ||
| 158 | feature_emotion_path = config['MODEL']['FEATURE_EMOTION'] | ||
| 159 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 160 | face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 161 | interval = config['EMOTION']['INTERVAL'] | ||
| 162 | input_size = config['EMOTION']['INPUT_SIZE'] | ||
| 163 | prefix = config['VIDEO']['PREFIX'] | ||
| 164 | |||
| 165 | feature_extractor = FeatureExtractor( | ||
| 166 | input_size=input_size, out_put_layer='global_average_pooling2d_1', model_path=feature_emotion_path) | ||
| 167 | mtcnn_detector = MTCNN() | ||
| 168 | |||
| 169 | video_names = os.listdir(frame_save_dir) | ||
| 170 | for video_index, video_name in enumerate(video_names): | ||
| 171 | print('{}/{}'.format(video_index, len(video_names))) | ||
| 172 | video_dir = os.path.join(frame_save_dir, video_name) | ||
| 173 | frame_names = os.listdir(video_dir) | ||
| 174 | end = 0 | ||
| 175 | features = [] | ||
| 176 | |||
| 177 | while end < len(frame_names): | ||
| 178 | |||
| 179 | if end % interval == 0: | ||
| 180 | frame_name = prefix.format(end + 1) | ||
| 181 | frame_path = os.path.join(video_dir, frame_name) | ||
| 182 | |||
| 183 | frame = cv2.imread(frame_path) | ||
| 184 | img_h, img_w, img_c = frame.shape | ||
| 185 | detect_faces = mtcnn_detector.detect_faces(frame) | ||
| 186 | for i, e in enumerate(detect_faces): | ||
| 187 | x1, y1, w, h = e['box'] | ||
| 188 | x1 = x1 if x1 > 0 else 0 | ||
| 189 | y1 = y1 if y1 > 0 else 0 | ||
| 190 | x1 = x1 if x1 < img_w else img_w | ||
| 191 | y1 = y1 if y1 < img_h else img_h | ||
| 192 | |||
| 193 | face = frame[y1:y1 + h, x1:x1 + w, :] | ||
| 194 | if face is []: | ||
| 195 | continue | ||
| 196 | features.append(feature_extractor.inference(face)[0]) | ||
| 197 | # top_5 = {} | ||
| 198 | # for i, e in enumerate(detect_faces): | ||
| 199 | # x1, y1, w, h = e['box'] | ||
| 200 | # x1 = x1 if x1 > 0 else 0 | ||
| 201 | # y1 = y1 if y1 > 0 else 0 | ||
| 202 | # x1 = x1 if x1 < img_w else img_w | ||
| 203 | # y1 = y1 if y1 < img_h else img_h | ||
| 204 | # | ||
| 205 | # top_5[w*h] = [x1, y1, w, h] | ||
| 206 | # | ||
| 207 | # top_5 = sorted(top_5.items(), key=lambda d:d[0], reverse=True) | ||
| 208 | # j = 0 | ||
| 209 | # for v in top_5: | ||
| 210 | # if j > 5: | ||
| 211 | # break | ||
| 212 | # x1, y1, w, h = v[1] | ||
| 213 | # face = frame[y1:y1+h, x1:x1+w, :] | ||
| 214 | # if face is []: | ||
| 215 | # continue | ||
| 216 | # features.append(feature_extractor.inference(face)[0]) | ||
| 217 | |||
| 218 | end += 1 | ||
| 219 | if len(features) is 0: | ||
| 220 | continue | ||
| 221 | features_np = np.array(features) | ||
| 222 | face_feature_path = os.path.join(face_feature_dir, video_name + '.npy') | ||
| 223 | np.save(face_feature_path, features_np) | ||
| 224 | |||
| 225 | |||
| 226 | def extract_random_face_feature(config): | ||
| 227 | feature_emotion_path = config['MODEL']['FEATURE_EMOTION'] | ||
| 228 | face_save_dir = config['VIDEO']['FACE_IMAGE_DIR'] | ||
| 229 | face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 230 | input_size = config['EMOTION']['INPUT_SIZE'] | ||
| 231 | |||
| 232 | feature_extractor = FeatureExtractor( | ||
| 233 | input_size=input_size, out_put_layer='avg_pool', model_path=feature_emotion_path) | ||
| 234 | |||
| 235 | video_dirs = [] | ||
| 236 | class_names = os.listdir(face_save_dir) | ||
| 237 | for class_name in class_names: | ||
| 238 | class_dir = os.path.join(face_save_dir, class_name) | ||
| 239 | video_names = os.listdir(class_dir) | ||
| 240 | for video_name in video_names: | ||
| 241 | video_dir = os.path.join(class_dir, video_name) | ||
| 242 | video_dirs.append(video_dir) | ||
| 243 | |||
| 244 | for video_dir_index, video_dir in enumerate(video_dirs): | ||
| 245 | print('{}/{}'.format(video_dir_index, len(video_dirs))) | ||
| 246 | class_name, video_name = video_dir.split('/')[-2], video_dir.split('/')[-1] | ||
| 247 | |||
| 248 | video_file_name = video_name.split('.')[0] | ||
| 249 | save_class_dir = os.path.join(face_feature_dir, class_name) | ||
| 250 | face_feature_path = os.path.join(save_class_dir, video_file_name + '.npy') | ||
| 251 | if os.path.exists(face_feature_path): | ||
| 252 | print('file is exists') | ||
| 253 | continue | ||
| 254 | |||
| 255 | image_names = os.listdir(video_dir) | ||
| 256 | image_dirs = [] | ||
| 257 | for image_name in image_names: | ||
| 258 | image_dir = os.path.join(video_dir, image_name) | ||
| 259 | image_dirs.append(image_dir) | ||
| 260 | |||
| 261 | features = [] | ||
| 262 | for image_dir_index, image_dir in enumerate(image_dirs): | ||
| 263 | sub_face_names = os.listdir(image_dir) | ||
| 264 | sub_face_num = len(sub_face_names) | ||
| 265 | for face_index in range(sub_face_num): | ||
| 266 | face_path = os.path.join(image_dir, sub_face_names[face_index]) | ||
| 267 | face_image = cv2.imread(face_path) | ||
| 268 | features.append(feature_extractor.inference(face_image)[0]) | ||
| 269 | |||
| 270 | face_num = len(features) | ||
| 271 | random_1 = random.sample(range(face_num), int(0.8 * face_num)) | ||
| 272 | features_random_1 = [features[c] for c in random_1] | ||
| 273 | |||
| 274 | random_2 = random.sample(range(face_num), int(0.6 * face_num)) | ||
| 275 | features_random_2 = [features[d] for d in random_2] | ||
| 276 | |||
| 277 | random_3 = random.sample(range(face_num), int(0.4 * face_num)) | ||
| 278 | features_random_3 = [features[e] for e in random_3] | ||
| 279 | |||
| 280 | if len(features) is 0: | ||
| 281 | continue | ||
| 282 | |||
| 283 | if os.path.exists(save_class_dir) is False: | ||
| 284 | os.mkdir(save_class_dir) | ||
| 285 | |||
| 286 | features_np = np.array(features) | ||
| 287 | face_feature_path = os.path.join(save_class_dir, video_file_name + '.npy') | ||
| 288 | np.save(face_feature_path, features_np) | ||
| 289 | |||
| 290 | features_np_random_1 = np.array(features_random_1) | ||
| 291 | face_feature_1_path = os.path.join(save_class_dir, video_file_name + '_1.npy') | ||
| 292 | np.save(face_feature_1_path, features_np_random_1) | ||
| 293 | |||
| 294 | features_np_random_2 = np.array(features_random_2) | ||
| 295 | face_feature_2_path = os.path.join(save_class_dir, video_file_name + '_2.npy') | ||
| 296 | np.save(face_feature_2_path, features_np_random_2) | ||
| 297 | |||
| 298 | features_np_random_3 = np.array(features_random_3) | ||
| 299 | face_feature_3_path = os.path.join(save_class_dir, video_file_name + '_3.npy') | ||
| 300 | np.save(face_feature_3_path, features_np_random_3) | ||
| 301 | |||
| 302 | |||
| 303 | def get_vid_fea(pics_features): | ||
| 304 | pics_features = np.array(pics_features) | ||
| 305 | fea_mean = pics_features.mean(axis=0) | ||
| 306 | fea_max = np.amax(pics_features, axis=0) | ||
| 307 | fea_min = np.amin(pics_features, axis=0) | ||
| 308 | fea_std = pics_features.std(axis=0) | ||
| 309 | |||
| 310 | feature_concate = np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1) | ||
| 311 | return np.squeeze(feature_concate) | ||
| 312 | |||
| 313 | |||
| 314 | def extract_random_face_and_frame_feature_(): | ||
| 315 | face_feature_dir = r'/data2/3_log-ResNet50/train_mirror/' | ||
| 316 | new_face_feature_dir = r'/data2/retinaface/random_face_frame_features_train_mirror/' | ||
| 317 | |||
| 318 | video_dirs = [] | ||
| 319 | class_names = os.listdir(face_feature_dir) | ||
| 320 | for class_name in class_names: | ||
| 321 | class_dir = os.path.join(face_feature_dir, class_name) | ||
| 322 | video_names = os.listdir(class_dir) | ||
| 323 | for video_name in video_names: | ||
| 324 | video_dir = os.path.join(class_dir, video_name) | ||
| 325 | video_dirs.append(video_dir) | ||
| 326 | |||
| 327 | for video_dir in video_dirs: | ||
| 328 | video_name = video_dir.split('/')[-1] | ||
| 329 | frame_names = os.listdir(video_dir) | ||
| 330 | feature = [] | ||
| 331 | for frame_name in frame_names: | ||
| 332 | feature_dir = os.path.join(video_dir, frame_name) | ||
| 333 | face_features_names = os.listdir(feature_dir) | ||
| 334 | for face_features_name in face_features_names: | ||
| 335 | face_features_path = os.path.join(feature_dir, face_features_name) | ||
| 336 | feature_np = np.load(face_features_path, allow_pickle=True) | ||
| 337 | feature.append(feature_np) | ||
| 338 | |||
| 339 | feature_num = len(feature) | ||
| 340 | if feature_num < 4: | ||
| 341 | continue | ||
| 342 | |||
| 343 | random_1 = random.sample(range(feature_num), int(0.9 * feature_num)) | ||
| 344 | features_random_1 = [feature[c] for c in random_1] | ||
| 345 | |||
| 346 | random_2 = random.sample(range(feature_num), int(0.7 * feature_num)) | ||
| 347 | features_random_2 = [feature[d] for d in random_2] | ||
| 348 | |||
| 349 | random_3 = random.sample(range(feature_num), int(0.5 * feature_num)) | ||
| 350 | features_random_3 = [feature[e] for e in random_3] | ||
| 351 | |||
| 352 | video_file_name = video_name.split('.')[0] | ||
| 353 | |||
| 354 | features_np = get_vid_fea(feature) | ||
| 355 | face_feature_path = os.path.join(new_face_feature_dir, video_file_name + '.npy') | ||
| 356 | np.save(face_feature_path, features_np) | ||
| 357 | |||
| 358 | features_np_random_1 = get_vid_fea(features_random_1) | ||
| 359 | face_feature_1_path = os.path.join(new_face_feature_dir, video_file_name + '_1.npy') | ||
| 360 | np.save(face_feature_1_path, features_np_random_1) | ||
| 361 | |||
| 362 | features_np_random_2 = get_vid_fea(features_random_2) | ||
| 363 | face_feature_2_path = os.path.join(new_face_feature_dir, video_file_name + '_2.npy') | ||
| 364 | np.save(face_feature_2_path, features_np_random_2) | ||
| 365 | |||
| 366 | features_np_random_3 = get_vid_fea(features_random_3) | ||
| 367 | face_feature_3_path = os.path.join(new_face_feature_dir, video_file_name + '_3.npy') | ||
| 368 | np.save(face_feature_3_path, features_np_random_3) | ||
| 369 | |||
| 370 | |||
| 371 | def extract_random_face_and_frame_feature(config): | ||
| 372 | feature_emotion_path = config['MODEL']['FEATURE_EMOTION'] | ||
| 373 | input_size = config['EMOTION']['INPUT_SIZE'] | ||
| 374 | face_dir = r'/data2/retinaface/train/' | ||
| 375 | new_face_feature_dir = r'/data2/3_log-ResNet50/train_mirror/' | ||
| 376 | |||
| 377 | feature_extractor = FeatureExtractor( | ||
| 378 | input_size=input_size, out_put_layer='avg_pool', model_path=feature_emotion_path) | ||
| 379 | |||
| 380 | sub_face_paths = [] | ||
| 381 | class_names = os.listdir(face_dir) | ||
| 382 | for class_name in class_names: | ||
| 383 | class_dir = os.path.join(face_dir, class_name) | ||
| 384 | video_names = os.listdir(class_dir) | ||
| 385 | for video_name in video_names: | ||
| 386 | video_dir = os.path.join(class_dir, video_name) | ||
| 387 | frame_names = os.listdir(video_dir) | ||
| 388 | for frame_name in frame_names: | ||
| 389 | frame_dir = os.path.join(video_dir, frame_name) | ||
| 390 | sub_face_names = os.listdir(frame_dir) | ||
| 391 | for sub_face_name in sub_face_names: | ||
| 392 | sub_face_path = os.path.join(frame_dir, sub_face_name) | ||
| 393 | sub_face_paths.append(sub_face_path) | ||
| 394 | |||
| 395 | for face_index, sub_face_path in enumerate(sub_face_paths): | ||
| 396 | print('{}/{}'.format(face_index+1, len(sub_face_paths))) | ||
| 397 | class_name, video_name, frame_name, sub_face_name = sub_face_path.split('/')[-4]\ | ||
| 398 | , sub_face_path.split('/')[-3], sub_face_path.split('/')[-2], sub_face_path.split('/')[-1] | ||
| 399 | |||
| 400 | class_dir = os.path.join(new_face_feature_dir, class_name) | ||
| 401 | video_dir = os.path.join(class_dir, video_name) | ||
| 402 | frame_dir = os.path.join(video_dir, frame_name) | ||
| 403 | sub_face_name = sub_face_name.split('.')[0] + '.npy' | ||
| 404 | face_feature_save_path = os.path.join(frame_dir, sub_face_name) | ||
| 405 | if os.path.exists(face_feature_save_path): | ||
| 406 | print('file exists') | ||
| 407 | continue | ||
| 408 | |||
| 409 | face_image = cv2.imread(sub_face_path) | ||
| 410 | mirror_face_image = cv2.flip(face_image, 0) | ||
| 411 | feature = feature_extractor.inference(mirror_face_image)[0] | ||
| 412 | |||
| 413 | |||
| 414 | if os.path.exists(class_dir) is False: | ||
| 415 | os.mkdir(class_dir) | ||
| 416 | |||
| 417 | if os.path.exists(video_dir) is False: | ||
| 418 | os.mkdir(video_dir) | ||
| 419 | |||
| 420 | if os.path.exists(frame_dir) is False: | ||
| 421 | os.mkdir(frame_dir) | ||
| 422 | |||
| 423 | np.save(face_feature_save_path, feature) | ||
| 424 | |||
| 425 | |||
| 426 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 427 | |||
| 428 | val_path = os.path.join(frame_list_dir, 'train.txt') | ||
| 429 | video_names = os.listdir(frame_save_dir) | ||
| 430 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 431 | |||
| 432 | for video_name in video_names: | ||
| 433 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 434 | ucf101_rgb_val_file.write(video_name) | ||
| 435 | ucf101_rgb_val_file.write(' ') | ||
| 436 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 437 | ucf101_rgb_val_file.write('\n') | ||
| 438 | |||
| 439 | ucf101_rgb_val_file.close() | ||
| 440 | |||
| 441 | return val_path | ||
| 442 | |||
| 443 | |||
| 444 | |||
| 445 | def extract_video_features(config): | ||
| 446 | arch = config['FIGHTING']['ARCH'] | ||
| 447 | prefix = config['VIDEO']['PREFIX'] | ||
| 448 | modality = config['VIDEO_FILTER']['MODALITY'] | ||
| 449 | test_crop = config['VIDEO_FILTER']['TEST_CROP'] | ||
| 450 | batch_size = config['VIDEO_FILTER']['BATCH_SIZE'] | ||
| 451 | weights_path = config['MODEL']['CLS_VIDEO'] | ||
| 452 | test_segment = config['VIDEO_FILTER']['TEST_SEGMENT'] | ||
| 453 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 454 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 455 | feature_save_dir = r'/home/jwq/Desktop/tmp/video2np/train/' | ||
| 456 | |||
| 457 | workers = 8 | ||
| 458 | num_class = 3 | ||
| 459 | shift_div = 8 | ||
| 460 | img_feature_dim = 256 | ||
| 461 | |||
| 462 | softmax = False | ||
| 463 | is_shift = True | ||
| 464 | full_res = False | ||
| 465 | non_local = False | ||
| 466 | dense_sample = False | ||
| 467 | twice_sample = False | ||
| 468 | |||
| 469 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 470 | |||
| 471 | pretrain = 'imagenet' | ||
| 472 | shift_place = 'blockres' | ||
| 473 | crop_fusion_type = 'avg' | ||
| 474 | |||
| 475 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 476 | base_model=arch, | ||
| 477 | consensus_type=crop_fusion_type, | ||
| 478 | img_feature_dim=img_feature_dim, | ||
| 479 | pretrain=pretrain, | ||
| 480 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 481 | non_local=non_local, | ||
| 482 | ) | ||
| 483 | |||
| 484 | checkpoint = torch.load(weights_path) | ||
| 485 | checkpoint = checkpoint['state_dict'] | ||
| 486 | |||
| 487 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 488 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 489 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 490 | } | ||
| 491 | for k, v in replace_dict.items(): | ||
| 492 | if k in base_dict: | ||
| 493 | base_dict[v] = base_dict.pop(k) | ||
| 494 | |||
| 495 | net.load_state_dict(base_dict) | ||
| 496 | |||
| 497 | input_size = net.scale_size if full_res else net.input_size | ||
| 498 | |||
| 499 | if test_crop == 1: | ||
| 500 | cropping = torchvision.transforms.Compose([ | ||
| 501 | GroupScale(net.scale_size), | ||
| 502 | GroupCenterCrop(input_size), | ||
| 503 | ]) | ||
| 504 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 505 | cropping = torchvision.transforms.Compose([ | ||
| 506 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 507 | ]) | ||
| 508 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 509 | cropping = torchvision.transforms.Compose([ | ||
| 510 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 511 | ]) | ||
| 512 | elif test_crop == 10: | ||
| 513 | cropping = torchvision.transforms.Compose([ | ||
| 514 | GroupOverSample(input_size, net.scale_size) | ||
| 515 | ]) | ||
| 516 | else: | ||
| 517 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 518 | |||
| 519 | data_loader = torch.utils.data.DataLoader( | ||
| 520 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 521 | new_length=1 if modality == "RGB" else 5, | ||
| 522 | modality=modality, | ||
| 523 | image_tmpl=prefix, | ||
| 524 | test_mode=True, | ||
| 525 | remove_missing=False, | ||
| 526 | transform=torchvision.transforms.Compose([ | ||
| 527 | cropping, | ||
| 528 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 529 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 530 | GroupNormalize(net.input_mean, net.input_std), | ||
| 531 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 532 | batch_size=batch_size, shuffle=False, | ||
| 533 | num_workers=workers, pin_memory=True, | ||
| 534 | ) | ||
| 535 | |||
| 536 | net = torch.nn.DataParallel(net.cuda()) | ||
| 537 | net.eval() | ||
| 538 | data_gen = enumerate(data_loader) | ||
| 539 | max_num = len(data_loader.dataset) | ||
| 540 | |||
| 541 | for i, data_pair in data_gen: | ||
| 542 | directory, data = data_pair | ||
| 543 | with torch.no_grad(): | ||
| 544 | if i >= max_num: | ||
| 545 | break | ||
| 546 | num_crop = test_crop | ||
| 547 | if dense_sample: | ||
| 548 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 549 | |||
| 550 | if twice_sample: | ||
| 551 | num_crop *= 2 | ||
| 552 | |||
| 553 | if modality == 'RGB': | ||
| 554 | length = 3 | ||
| 555 | elif modality == 'Flow': | ||
| 556 | length = 10 | ||
| 557 | elif modality == 'RGBDiff': | ||
| 558 | length = 18 | ||
| 559 | else: | ||
| 560 | raise ValueError("Unknown modality " + modality) | ||
| 561 | |||
| 562 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 563 | if is_shift: | ||
| 564 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 565 | rst, feature = net(data_in) | ||
| 566 | |||
| 567 | feature = np.squeeze(feature.cpu()) | ||
| 568 | print(feature.shape) | ||
| 569 | feature_name = str(directory[0]) + '.npy' | ||
| 570 | feature_save_path = os.path.join(feature_save_dir, feature_name) | ||
| 571 | np.save(feature_save_path, feature) | ||
| 572 | |||
| 573 | |||
| 574 | if __name__ == '__main__': | ||
| 575 | extract_random_face_and_frame_feature_() |
ops/__init__.py
0 → 100755
| 1 | from ops.basic_ops import * | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops/__pycache__/__init__.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/basic_ops.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/dataset.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/models.cpython-36.pyc
0 → 100644
No preview for this file type
No preview for this file type
ops/__pycache__/transforms.cpython-36.pyc
0 → 100644
No preview for this file type
ops/basic_ops.py
0 → 100755
| 1 | import torch | ||
| 2 | |||
| 3 | |||
| 4 | class Identity(torch.nn.Module): | ||
| 5 | def forward(self, input): | ||
| 6 | return input | ||
| 7 | |||
| 8 | |||
| 9 | class SegmentConsensus(torch.nn.Module): | ||
| 10 | |||
| 11 | def __init__(self, consensus_type, dim=1): | ||
| 12 | super(SegmentConsensus, self).__init__() | ||
| 13 | self.consensus_type = consensus_type | ||
| 14 | self.dim = dim | ||
| 15 | self.shape = None | ||
| 16 | |||
| 17 | def forward(self, input_tensor): | ||
| 18 | self.shape = input_tensor.size() | ||
| 19 | if self.consensus_type == 'avg': | ||
| 20 | output = input_tensor.mean(dim=self.dim, keepdim=True) | ||
| 21 | elif self.consensus_type == 'identity': | ||
| 22 | output = input_tensor | ||
| 23 | else: | ||
| 24 | output = None | ||
| 25 | |||
| 26 | return output | ||
| 27 | |||
| 28 | |||
| 29 | class ConsensusModule(torch.nn.Module): | ||
| 30 | |||
| 31 | def __init__(self, consensus_type, dim=1): | ||
| 32 | super(ConsensusModule, self).__init__() | ||
| 33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' | ||
| 34 | self.dim = dim | ||
| 35 | |||
| 36 | def forward(self, input): | ||
| 37 | return SegmentConsensus(self.consensus_type, self.dim)(input) |
ops/dataset.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import torch.utils.data as data | ||
| 7 | |||
| 8 | from PIL import Image | ||
| 9 | import os | ||
| 10 | import numpy as np | ||
| 11 | from numpy.random import randint | ||
| 12 | |||
| 13 | |||
| 14 | class VideoRecord(object): | ||
| 15 | def __init__(self, row): | ||
| 16 | self._data = row | ||
| 17 | |||
| 18 | @property | ||
| 19 | def path(self): | ||
| 20 | return self._data[0] | ||
| 21 | |||
| 22 | @property | ||
| 23 | def num_frames(self): | ||
| 24 | return int(self._data[1]) | ||
| 25 | |||
| 26 | |||
| 27 | class TSNDataSet(data.Dataset): | ||
| 28 | def __init__(self, root_path, list_file, | ||
| 29 | num_segments=3, new_length=1, modality='RGB', | ||
| 30 | image_tmpl='img_{:05d}.jpg', transform=None, | ||
| 31 | random_shift=True, test_mode=False, | ||
| 32 | remove_missing=False, dense_sample=False, twice_sample=False): | ||
| 33 | |||
| 34 | self.root_path = root_path | ||
| 35 | self.list_file = list_file | ||
| 36 | self.num_segments = num_segments | ||
| 37 | self.new_length = new_length | ||
| 38 | self.modality = modality | ||
| 39 | self.image_tmpl = image_tmpl | ||
| 40 | self.transform = transform | ||
| 41 | self.random_shift = random_shift | ||
| 42 | self.test_mode = test_mode | ||
| 43 | self.remove_missing = remove_missing | ||
| 44 | self.dense_sample = dense_sample # using dense sample as I3D | ||
| 45 | self.twice_sample = twice_sample # twice sample for more validation | ||
| 46 | if self.dense_sample: | ||
| 47 | print('=> Using dense sample for the dataset...') | ||
| 48 | if self.twice_sample: | ||
| 49 | print('=> Using twice sample for the dataset...') | ||
| 50 | |||
| 51 | if self.modality == 'RGBDiff': | ||
| 52 | self.new_length += 1 # Diff needs one more image to calculate diff | ||
| 53 | |||
| 54 | self._parse_list() | ||
| 55 | |||
| 56 | def _load_image(self, directory, idx): | ||
| 57 | if self.modality == 'RGB' or self.modality == 'RGBDiff': | ||
| 58 | try: | ||
| 59 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] | ||
| 60 | except Exception: | ||
| 61 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
| 62 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] | ||
| 63 | elif self.modality == 'Flow': | ||
| 64 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf | ||
| 65 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( | ||
| 66 | 'L') | ||
| 67 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( | ||
| 68 | 'L') | ||
| 69 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow | ||
| 70 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
| 71 | format(int(directory), 'x', idx))).convert('L') | ||
| 72 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
| 73 | format(int(directory), 'y', idx))).convert('L') | ||
| 74 | else: | ||
| 75 | try: | ||
| 76 | # idx_skip = 1 + (idx-1)*5 | ||
| 77 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert( | ||
| 78 | 'RGB') | ||
| 79 | except Exception: | ||
| 80 | print('error loading flow file:', | ||
| 81 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
| 82 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB') | ||
| 83 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel | ||
| 84 | flow_x, flow_y, _ = flow.split() | ||
| 85 | x_img = flow_x.convert('L') | ||
| 86 | y_img = flow_y.convert('L') | ||
| 87 | |||
| 88 | return [x_img, y_img] | ||
| 89 | |||
| 90 | def _parse_list(self): | ||
| 91 | # check the frame number is large >3: | ||
| 92 | tmp = [x.strip().split(' ') for x in open(self.list_file)] | ||
| 93 | if not self.test_mode or self.remove_missing: | ||
| 94 | tmp = [item for item in tmp if int(item[1]) >= 3] | ||
| 95 | self.video_list = [VideoRecord(item) for item in tmp] | ||
| 96 | |||
| 97 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 98 | for v in self.video_list: | ||
| 99 | v._data[1] = int(v._data[1]) / 2 | ||
| 100 | print('video number:%d' % (len(self.video_list))) | ||
| 101 | |||
| 102 | def _sample_indices(self, record): | ||
| 103 | """ | ||
| 104 | |||
| 105 | :param record: VideoRecord | ||
| 106 | :return: list | ||
| 107 | """ | ||
| 108 | if self.dense_sample: # i3d dense sample | ||
| 109 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 110 | t_stride = 64 // self.num_segments | ||
| 111 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
| 112 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 113 | return np.array(offsets) + 1 | ||
| 114 | else: # normal sample | ||
| 115 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments | ||
| 116 | if average_duration > 0: | ||
| 117 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, | ||
| 118 | size=self.num_segments) | ||
| 119 | elif record.num_frames > self.num_segments: | ||
| 120 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) | ||
| 121 | else: | ||
| 122 | offsets = np.zeros((self.num_segments,)) | ||
| 123 | return offsets + 1 | ||
| 124 | |||
| 125 | def _get_val_indices(self, record): | ||
| 126 | if self.dense_sample: # i3d dense sample | ||
| 127 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 128 | t_stride = 64 // self.num_segments | ||
| 129 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
| 130 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 131 | return np.array(offsets) + 1 | ||
| 132 | else: | ||
| 133 | if record.num_frames > self.num_segments + self.new_length - 1: | ||
| 134 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 135 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
| 136 | else: | ||
| 137 | offsets = np.zeros((self.num_segments,)) | ||
| 138 | return offsets + 1 | ||
| 139 | |||
| 140 | def _get_test_indices(self, record): | ||
| 141 | if self.dense_sample: | ||
| 142 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
| 143 | t_stride = 64 // self.num_segments | ||
| 144 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) | ||
| 145 | offsets = [] | ||
| 146 | for start_idx in start_list.tolist(): | ||
| 147 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
| 148 | return np.array(offsets) + 1 | ||
| 149 | elif self.twice_sample: | ||
| 150 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 151 | |||
| 152 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + | ||
| 153 | [int(tick * x) for x in range(self.num_segments)]) | ||
| 154 | |||
| 155 | return offsets + 1 | ||
| 156 | else: | ||
| 157 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
| 158 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
| 159 | return offsets + 1 | ||
| 160 | |||
| 161 | def __getitem__(self, index): | ||
| 162 | record = self.video_list[index] | ||
| 163 | # check this is a legit video folder | ||
| 164 | |||
| 165 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
| 166 | file_name = self.image_tmpl.format('x', 1) | ||
| 167 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 168 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 169 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
| 170 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
| 171 | else: | ||
| 172 | file_name = self.image_tmpl.format(1) | ||
| 173 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 174 | |||
| 175 | while not os.path.exists(full_path): | ||
| 176 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) | ||
| 177 | index = np.random.randint(len(self.video_list)) | ||
| 178 | record = self.video_list[index] | ||
| 179 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
| 180 | file_name = self.image_tmpl.format('x', 1) | ||
| 181 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 182 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
| 183 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
| 184 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
| 185 | else: | ||
| 186 | file_name = self.image_tmpl.format(1) | ||
| 187 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
| 188 | |||
| 189 | if not self.test_mode: | ||
| 190 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) | ||
| 191 | else: | ||
| 192 | segment_indices = self._get_test_indices(record) | ||
| 193 | return self.get(record, segment_indices) | ||
| 194 | |||
| 195 | def get(self, record, indices): | ||
| 196 | |||
| 197 | images = list() | ||
| 198 | for seg_ind in indices: | ||
| 199 | p = int(seg_ind) | ||
| 200 | for i in range(self.new_length): | ||
| 201 | seg_imgs = self._load_image(record.path, p) | ||
| 202 | images.extend(seg_imgs) | ||
| 203 | if p < record.num_frames: | ||
| 204 | p += 1 | ||
| 205 | |||
| 206 | process_data = self.transform(images) | ||
| 207 | return record.path, process_data | ||
| 208 | |||
| 209 | def __len__(self): | ||
| 210 | return len(self.video_list) |
ops/dataset_config.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import os | ||
| 7 | |||
| 8 | ROOT_DATASET = '/data1/action_1_images/' # '/data/jilin/' | ||
| 9 | |||
| 10 | |||
| 11 | def return_ucf101(modality): | ||
| 12 | filename_categories = 'labels/classInd.txt' | ||
| 13 | if modality == 'RGB': | ||
| 14 | root_data = ROOT_DATASET + 'images' | ||
| 15 | filename_imglist_train = 'file_list/ucf101_rgb_train_split_1.txt' | ||
| 16 | filename_imglist_val = 'file_list/ucf101_rgb_val_split_1.txt' | ||
| 17 | prefix = 'img_{:05d}.jpg' | ||
| 18 | elif modality == 'Flow': | ||
| 19 | root_data = ROOT_DATASET + 'UCF101/jpg' | ||
| 20 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' | ||
| 21 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' | ||
| 22 | prefix = 'flow_{}_{:05d}.jpg' | ||
| 23 | else: | ||
| 24 | raise NotImplementedError('no such modality:' + modality) | ||
| 25 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 26 | |||
| 27 | |||
| 28 | def return_hmdb51(modality): | ||
| 29 | filename_categories = 51 | ||
| 30 | if modality == 'RGB': | ||
| 31 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
| 32 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' | ||
| 33 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' | ||
| 34 | prefix = 'img_{:05d}.jpg' | ||
| 35 | elif modality == 'Flow': | ||
| 36 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
| 37 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt' | ||
| 38 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt' | ||
| 39 | prefix = 'flow_{}_{:05d}.jpg' | ||
| 40 | else: | ||
| 41 | raise NotImplementedError('no such modality:' + modality) | ||
| 42 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 43 | |||
| 44 | |||
| 45 | def return_something(modality): | ||
| 46 | filename_categories = 'something/v1/category.txt' | ||
| 47 | if modality == 'RGB': | ||
| 48 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1' | ||
| 49 | filename_imglist_train = 'something/v1/train_videofolder.txt' | ||
| 50 | filename_imglist_val = 'something/v1/val_videofolder.txt' | ||
| 51 | prefix = '{:05d}.jpg' | ||
| 52 | elif modality == 'Flow': | ||
| 53 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow' | ||
| 54 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt' | ||
| 55 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt' | ||
| 56 | prefix = '{:06d}-{}_{:05d}.jpg' | ||
| 57 | else: | ||
| 58 | print('no such modality:'+modality) | ||
| 59 | raise NotImplementedError | ||
| 60 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 61 | |||
| 62 | |||
| 63 | def return_somethingv2(modality): | ||
| 64 | filename_categories = 'something/v2/category.txt' | ||
| 65 | if modality == 'RGB': | ||
| 66 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames' | ||
| 67 | filename_imglist_train = 'something/v2/train_videofolder.txt' | ||
| 68 | filename_imglist_val = 'something/v2/val_videofolder.txt' | ||
| 69 | prefix = '{:06d}.jpg' | ||
| 70 | elif modality == 'Flow': | ||
| 71 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow' | ||
| 72 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt' | ||
| 73 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt' | ||
| 74 | prefix = '{:06d}.jpg' | ||
| 75 | else: | ||
| 76 | raise NotImplementedError('no such modality:'+modality) | ||
| 77 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 78 | |||
| 79 | |||
| 80 | def return_jester(modality): | ||
| 81 | filename_categories = 'jester/category.txt' | ||
| 82 | if modality == 'RGB': | ||
| 83 | prefix = '{:05d}.jpg' | ||
| 84 | root_data = ROOT_DATASET + 'jester/20bn-jester-v1' | ||
| 85 | filename_imglist_train = 'jester/train_videofolder.txt' | ||
| 86 | filename_imglist_val = 'jester/val_videofolder.txt' | ||
| 87 | else: | ||
| 88 | raise NotImplementedError('no such modality:'+modality) | ||
| 89 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 90 | |||
| 91 | |||
| 92 | def return_kinetics(modality): | ||
| 93 | filename_categories = 400 | ||
| 94 | if modality == 'RGB': | ||
| 95 | root_data = ROOT_DATASET + 'kinetics/images' | ||
| 96 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt' | ||
| 97 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt' | ||
| 98 | prefix = 'img_{:05d}.jpg' | ||
| 99 | else: | ||
| 100 | raise NotImplementedError('no such modality:' + modality) | ||
| 101 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
| 102 | |||
| 103 | |||
| 104 | def return_dataset(dataset, modality): | ||
| 105 | dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2, | ||
| 106 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, | ||
| 107 | 'kinetics': return_kinetics} | ||
| 108 | if dataset in dict_single: | ||
| 109 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) | ||
| 110 | else: | ||
| 111 | raise ValueError('Unknown dataset '+dataset) | ||
| 112 | |||
| 113 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) | ||
| 114 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) | ||
| 115 | if isinstance(file_categories, str): | ||
| 116 | file_categories = os.path.join(ROOT_DATASET, file_categories) | ||
| 117 | with open(file_categories) as f: | ||
| 118 | lines = f.readlines() | ||
| 119 | categories = [item.rstrip() for item in lines] | ||
| 120 | else: # number of categories | ||
| 121 | categories = [None] * file_categories | ||
| 122 | n_class = len(categories) | ||
| 123 | print('{}: {} classes'.format(dataset, n_class)) | ||
| 124 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix |
ops/models.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | from torch import nn | ||
| 7 | |||
| 8 | from ops.basic_ops import ConsensusModule | ||
| 9 | from ops.transforms import * | ||
| 10 | from torch.nn.init import normal_, constant_ | ||
| 11 | |||
| 12 | |||
| 13 | class TSN(nn.Module): | ||
| 14 | def __init__(self, num_class, num_segments, modality, | ||
| 15 | base_model='resnet101', new_length=None, | ||
| 16 | consensus_type='avg', before_softmax=True, | ||
| 17 | dropout=0.8, img_feature_dim=256, | ||
| 18 | crop_num=1, partial_bn=True, print_spec=True, pretrain='imagenet', | ||
| 19 | is_shift=True, shift_div=8, shift_place='blockres', fc_lr5=False, | ||
| 20 | temporal_pool=False, non_local=False): | ||
| 21 | super(TSN, self).__init__() | ||
| 22 | self.modality = modality | ||
| 23 | self.num_segments = num_segments | ||
| 24 | self.reshape = True | ||
| 25 | self.before_softmax = before_softmax | ||
| 26 | self.dropout = dropout | ||
| 27 | self.crop_num = crop_num | ||
| 28 | self.consensus_type = consensus_type | ||
| 29 | self.img_feature_dim = img_feature_dim # the dimension of the CNN feature to represent each frame | ||
| 30 | self.pretrain = pretrain | ||
| 31 | |||
| 32 | self.is_shift = is_shift | ||
| 33 | self.shift_div = shift_div | ||
| 34 | self.shift_place = shift_place | ||
| 35 | self.base_model_name = base_model | ||
| 36 | self.fc_lr5 = fc_lr5 | ||
| 37 | self.temporal_pool = temporal_pool | ||
| 38 | self.non_local = non_local | ||
| 39 | |||
| 40 | if not before_softmax and consensus_type != 'avg': | ||
| 41 | raise ValueError("Only avg consensus can be used after Softmax") | ||
| 42 | |||
| 43 | if new_length is None: | ||
| 44 | self.new_length = 1 if modality == "RGB" else 5 | ||
| 45 | else: | ||
| 46 | self.new_length = new_length | ||
| 47 | if print_spec: | ||
| 48 | print((""" | ||
| 49 | Initializing TSN with base model: {}. | ||
| 50 | TSN Configurations: | ||
| 51 | input_modality: {} | ||
| 52 | num_segments: {} | ||
| 53 | new_length: {} | ||
| 54 | consensus_module: {} | ||
| 55 | dropout_ratio: {} | ||
| 56 | img_feature_dim: {} | ||
| 57 | """.format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim))) | ||
| 58 | |||
| 59 | self._prepare_base_model(base_model) | ||
| 60 | |||
| 61 | feature_dim = self._prepare_tsn(num_class) | ||
| 62 | |||
| 63 | if self.modality == 'Flow': | ||
| 64 | print("Converting the ImageNet model to a flow init model") | ||
| 65 | self.base_model = self._construct_flow_model(self.base_model) | ||
| 66 | print("Done. Flow model ready...") | ||
| 67 | elif self.modality == 'RGBDiff': | ||
| 68 | print("Converting the ImageNet model to RGB+Diff init model") | ||
| 69 | self.base_model = self._construct_diff_model(self.base_model) | ||
| 70 | print("Done. RGBDiff model ready.") | ||
| 71 | |||
| 72 | self.consensus = ConsensusModule(consensus_type) | ||
| 73 | |||
| 74 | if not self.before_softmax: | ||
| 75 | self.softmax = nn.Softmax() | ||
| 76 | |||
| 77 | self._enable_pbn = partial_bn | ||
| 78 | if partial_bn: | ||
| 79 | self.partialBN(True) | ||
| 80 | |||
| 81 | def _prepare_tsn(self, num_class): | ||
| 82 | feature_dim = getattr(self.base_model, self.base_model.last_layer_name).in_features | ||
| 83 | if self.dropout == 0: | ||
| 84 | setattr(self.base_model, self.base_model.last_layer_name, nn.Linear(feature_dim, num_class)) | ||
| 85 | self.new_fc = None | ||
| 86 | else: | ||
| 87 | setattr(self.base_model, self.base_model.last_layer_name, nn.Dropout(p=self.dropout)) | ||
| 88 | self.new_fc = nn.Linear(feature_dim, num_class) | ||
| 89 | |||
| 90 | std = 0.001 | ||
| 91 | if self.new_fc is None: | ||
| 92 | normal_(getattr(self.base_model, self.base_model.last_layer_name).weight, 0, std) | ||
| 93 | constant_(getattr(self.base_model, self.base_model.last_layer_name).bias, 0) | ||
| 94 | else: | ||
| 95 | if hasattr(self.new_fc, 'weight'): | ||
| 96 | normal_(self.new_fc.weight, 0, std) | ||
| 97 | constant_(self.new_fc.bias, 0) | ||
| 98 | return feature_dim | ||
| 99 | |||
| 100 | def _prepare_base_model(self, base_model): | ||
| 101 | print('=> base model: {}'.format(base_model)) | ||
| 102 | |||
| 103 | if 'resnet' in base_model: | ||
| 104 | self.base_model = getattr(torchvision.models, base_model)(True if self.pretrain == 'imagenet' else False) | ||
| 105 | if self.is_shift: | ||
| 106 | print('Adding temporal shift...') | ||
| 107 | from ops.temporal_shift import make_temporal_shift | ||
| 108 | make_temporal_shift(self.base_model, self.num_segments, | ||
| 109 | n_div=self.shift_div, place=self.shift_place, temporal_pool=self.temporal_pool) | ||
| 110 | |||
| 111 | if self.non_local: | ||
| 112 | print('Adding non-local module...') | ||
| 113 | from ops.non_local import make_non_local | ||
| 114 | make_non_local(self.base_model, self.num_segments) | ||
| 115 | |||
| 116 | self.base_model.last_layer_name = 'fc' | ||
| 117 | self.input_size = 224 | ||
| 118 | self.input_mean = [0.485, 0.456, 0.406] | ||
| 119 | self.input_std = [0.229, 0.224, 0.225] | ||
| 120 | |||
| 121 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) | ||
| 122 | |||
| 123 | if self.modality == 'Flow': | ||
| 124 | self.input_mean = [0.5] | ||
| 125 | self.input_std = [np.mean(self.input_std)] | ||
| 126 | elif self.modality == 'RGBDiff': | ||
| 127 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length | ||
| 128 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length | ||
| 129 | |||
| 130 | elif base_model == 'mobilenetv2': | ||
| 131 | from archs.mobilenet_v2 import mobilenet_v2, InvertedResidual | ||
| 132 | self.base_model = mobilenet_v2(True if self.pretrain == 'imagenet' else False) | ||
| 133 | |||
| 134 | self.base_model.last_layer_name = 'classifier' | ||
| 135 | self.input_size = 224 | ||
| 136 | self.input_mean = [0.485, 0.456, 0.406] | ||
| 137 | self.input_std = [0.229, 0.224, 0.225] | ||
| 138 | |||
| 139 | self.base_model.avgpool = nn.AdaptiveAvgPool2d(1) | ||
| 140 | if self.is_shift: | ||
| 141 | from ops.temporal_shift import TemporalShift | ||
| 142 | for m in self.base_model.modules(): | ||
| 143 | if isinstance(m, InvertedResidual) and len(m.conv) == 8 and m.use_res_connect: | ||
| 144 | if self.print_spec: | ||
| 145 | print('Adding temporal shift... {}'.format(m.use_res_connect)) | ||
| 146 | m.conv[0] = TemporalShift(m.conv[0], n_segment=self.num_segments, n_div=self.shift_div) | ||
| 147 | if self.modality == 'Flow': | ||
| 148 | self.input_mean = [0.5] | ||
| 149 | self.input_std = [np.mean(self.input_std)] | ||
| 150 | elif self.modality == 'RGBDiff': | ||
| 151 | self.input_mean = [0.485, 0.456, 0.406] + [0] * 3 * self.new_length | ||
| 152 | self.input_std = self.input_std + [np.mean(self.input_std) * 2] * 3 * self.new_length | ||
| 153 | |||
| 154 | elif base_model == 'BNInception': | ||
| 155 | from archs.bn_inception import bninception | ||
| 156 | self.base_model = bninception(pretrained=self.pretrain) | ||
| 157 | self.input_size = self.base_model.input_size | ||
| 158 | self.input_mean = self.base_model.mean | ||
| 159 | self.input_std = self.base_model.std | ||
| 160 | self.base_model.last_layer_name = 'fc' | ||
| 161 | if self.modality == 'Flow': | ||
| 162 | self.input_mean = [128] | ||
| 163 | elif self.modality == 'RGBDiff': | ||
| 164 | self.input_mean = self.input_mean * (1 + self.new_length) | ||
| 165 | if self.is_shift: | ||
| 166 | print('Adding temporal shift...') | ||
| 167 | self.base_model.build_temporal_ops( | ||
| 168 | self.num_segments, is_temporal_shift=self.shift_place, shift_div=self.shift_div) | ||
| 169 | else: | ||
| 170 | raise ValueError('Unknown base model: {}'.format(base_model)) | ||
| 171 | |||
| 172 | def train(self, mode=True): | ||
| 173 | """ | ||
| 174 | Override the default train() to freeze the BN parameters | ||
| 175 | :return: | ||
| 176 | """ | ||
| 177 | super(TSN, self).train(mode) | ||
| 178 | count = 0 | ||
| 179 | if self._enable_pbn and mode: | ||
| 180 | print("Freezing BatchNorm2D except the first one.") | ||
| 181 | for m in self.base_model.modules(): | ||
| 182 | if isinstance(m, nn.BatchNorm2d): | ||
| 183 | count += 1 | ||
| 184 | if count >= (2 if self._enable_pbn else 1): | ||
| 185 | m.eval() | ||
| 186 | # shutdown update in frozen mode | ||
| 187 | m.weight.requires_grad = False | ||
| 188 | m.bias.requires_grad = False | ||
| 189 | |||
| 190 | def partialBN(self, enable): | ||
| 191 | self._enable_pbn = enable | ||
| 192 | |||
| 193 | def get_optim_policies(self): | ||
| 194 | first_conv_weight = [] | ||
| 195 | first_conv_bias = [] | ||
| 196 | normal_weight = [] | ||
| 197 | normal_bias = [] | ||
| 198 | lr5_weight = [] | ||
| 199 | lr10_bias = [] | ||
| 200 | bn = [] | ||
| 201 | custom_ops = [] | ||
| 202 | |||
| 203 | conv_cnt = 0 | ||
| 204 | bn_cnt = 0 | ||
| 205 | for m in self.modules(): | ||
| 206 | if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d): | ||
| 207 | ps = list(m.parameters()) | ||
| 208 | conv_cnt += 1 | ||
| 209 | if conv_cnt == 1: | ||
| 210 | first_conv_weight.append(ps[0]) | ||
| 211 | if len(ps) == 2: | ||
| 212 | first_conv_bias.append(ps[1]) | ||
| 213 | else: | ||
| 214 | normal_weight.append(ps[0]) | ||
| 215 | if len(ps) == 2: | ||
| 216 | normal_bias.append(ps[1]) | ||
| 217 | elif isinstance(m, torch.nn.Linear): | ||
| 218 | ps = list(m.parameters()) | ||
| 219 | if self.fc_lr5: | ||
| 220 | lr5_weight.append(ps[0]) | ||
| 221 | else: | ||
| 222 | normal_weight.append(ps[0]) | ||
| 223 | if len(ps) == 2: | ||
| 224 | if self.fc_lr5: | ||
| 225 | lr10_bias.append(ps[1]) | ||
| 226 | else: | ||
| 227 | normal_bias.append(ps[1]) | ||
| 228 | |||
| 229 | elif isinstance(m, torch.nn.BatchNorm2d): | ||
| 230 | bn_cnt += 1 | ||
| 231 | # later BN's are frozen | ||
| 232 | if not self._enable_pbn or bn_cnt == 1: | ||
| 233 | bn.extend(list(m.parameters())) | ||
| 234 | elif isinstance(m, torch.nn.BatchNorm3d): | ||
| 235 | bn_cnt += 1 | ||
| 236 | # later BN's are frozen | ||
| 237 | if not self._enable_pbn or bn_cnt == 1: | ||
| 238 | bn.extend(list(m.parameters())) | ||
| 239 | elif len(m._modules) == 0: | ||
| 240 | if len(list(m.parameters())) > 0: | ||
| 241 | raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m))) | ||
| 242 | |||
| 243 | return [ | ||
| 244 | {'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1, | ||
| 245 | 'name': "first_conv_weight"}, | ||
| 246 | {'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0, | ||
| 247 | 'name': "first_conv_bias"}, | ||
| 248 | {'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1, | ||
| 249 | 'name': "normal_weight"}, | ||
| 250 | {'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0, | ||
| 251 | 'name': "normal_bias"}, | ||
| 252 | {'params': bn, 'lr_mult': 1, 'decay_mult': 0, | ||
| 253 | 'name': "BN scale/shift"}, | ||
| 254 | {'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1, | ||
| 255 | 'name': "custom_ops"}, | ||
| 256 | # for fc | ||
| 257 | {'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1, | ||
| 258 | 'name': "lr5_weight"}, | ||
| 259 | {'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0, | ||
| 260 | 'name': "lr10_bias"}, | ||
| 261 | ] | ||
| 262 | |||
| 263 | def forward(self, input, no_reshape=False): | ||
| 264 | if not no_reshape: | ||
| 265 | sample_len = (3 if self.modality == "RGB" else 2) * self.new_length | ||
| 266 | |||
| 267 | if self.modality == 'RGBDiff': | ||
| 268 | sample_len = 3 * self.new_length | ||
| 269 | input = self._get_diff(input) | ||
| 270 | |||
| 271 | base_out = self.base_model(input.view((-1, sample_len) + input.size()[-2:])) | ||
| 272 | else: | ||
| 273 | base_out = self.base_model(input) | ||
| 274 | |||
| 275 | if self.dropout > 0: | ||
| 276 | feature = base_out.view(base_out.size(0), -1) | ||
| 277 | base_out = self.new_fc(base_out) | ||
| 278 | |||
| 279 | if not self.before_softmax: | ||
| 280 | base_out = self.softmax(base_out) | ||
| 281 | |||
| 282 | if self.reshape: | ||
| 283 | if self.is_shift and self.temporal_pool: | ||
| 284 | base_out = base_out.view((-1, self.num_segments // 2) + base_out.size()[1:]) | ||
| 285 | else: | ||
| 286 | base_out = base_out.view((-1, self.num_segments) + base_out.size()[1:]) | ||
| 287 | output = self.consensus(base_out) | ||
| 288 | return output.squeeze(1), feature | ||
| 289 | |||
| 290 | def _get_diff(self, input, keep_rgb=False): | ||
| 291 | input_c = 3 if self.modality in ["RGB", "RGBDiff"] else 2 | ||
| 292 | input_view = input.view((-1, self.num_segments, self.new_length + 1, input_c,) + input.size()[2:]) | ||
| 293 | if keep_rgb: | ||
| 294 | new_data = input_view.clone() | ||
| 295 | else: | ||
| 296 | new_data = input_view[:, :, 1:, :, :, :].clone() | ||
| 297 | |||
| 298 | for x in reversed(list(range(1, self.new_length + 1))): | ||
| 299 | if keep_rgb: | ||
| 300 | new_data[:, :, x, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] | ||
| 301 | else: | ||
| 302 | new_data[:, :, x - 1, :, :, :] = input_view[:, :, x, :, :, :] - input_view[:, :, x - 1, :, :, :] | ||
| 303 | |||
| 304 | return new_data | ||
| 305 | |||
| 306 | def _construct_flow_model(self, base_model): | ||
| 307 | # modify the convolution layers | ||
| 308 | # Torch models are usually defined in a hierarchical way. | ||
| 309 | # nn.modules.children() return all sub modules in a DFS manner | ||
| 310 | modules = list(self.base_model.modules()) | ||
| 311 | first_conv_idx = list(filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules)))))[0] | ||
| 312 | conv_layer = modules[first_conv_idx] | ||
| 313 | container = modules[first_conv_idx - 1] | ||
| 314 | |||
| 315 | # modify parameters, assume the first blob contains the convolution kernels | ||
| 316 | params = [x.clone() for x in conv_layer.parameters()] | ||
| 317 | kernel_size = params[0].size() | ||
| 318 | new_kernel_size = kernel_size[:1] + (2 * self.new_length, ) + kernel_size[2:] | ||
| 319 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() | ||
| 320 | |||
| 321 | new_conv = nn.Conv2d(2 * self.new_length, conv_layer.out_channels, | ||
| 322 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, | ||
| 323 | bias=True if len(params) == 2 else False) | ||
| 324 | new_conv.weight.data = new_kernels | ||
| 325 | if len(params) == 2: | ||
| 326 | new_conv.bias.data = params[1].data # add bias if neccessary | ||
| 327 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name | ||
| 328 | |||
| 329 | # replace the first convlution layer | ||
| 330 | setattr(container, layer_name, new_conv) | ||
| 331 | |||
| 332 | if self.base_model_name == 'BNInception': | ||
| 333 | import torch.utils.model_zoo as model_zoo | ||
| 334 | sd = model_zoo.load_url('https://www.dropbox.com/s/35ftw2t4mxxgjae/BNInceptionFlow-ef652051.pth.tar?dl=1') | ||
| 335 | base_model.load_state_dict(sd) | ||
| 336 | print('=> Loading pretrained Flow weight done...') | ||
| 337 | else: | ||
| 338 | print('#' * 30, 'Warning! No Flow pretrained model is found') | ||
| 339 | return base_model | ||
| 340 | |||
| 341 | def _construct_diff_model(self, base_model, keep_rgb=False): | ||
| 342 | # modify the convolution layers | ||
| 343 | # Torch models are usually defined in a hierarchical way. | ||
| 344 | # nn.modules.children() return all sub modules in a DFS manner | ||
| 345 | modules = list(self.base_model.modules()) | ||
| 346 | first_conv_idx = filter(lambda x: isinstance(modules[x], nn.Conv2d), list(range(len(modules))))[0] | ||
| 347 | conv_layer = modules[first_conv_idx] | ||
| 348 | container = modules[first_conv_idx - 1] | ||
| 349 | |||
| 350 | # modify parameters, assume the first blob contains the convolution kernels | ||
| 351 | params = [x.clone() for x in conv_layer.parameters()] | ||
| 352 | kernel_size = params[0].size() | ||
| 353 | if not keep_rgb: | ||
| 354 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] | ||
| 355 | new_kernels = params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous() | ||
| 356 | else: | ||
| 357 | new_kernel_size = kernel_size[:1] + (3 * self.new_length,) + kernel_size[2:] | ||
| 358 | new_kernels = torch.cat((params[0].data, params[0].data.mean(dim=1, keepdim=True).expand(new_kernel_size).contiguous()), | ||
| 359 | 1) | ||
| 360 | new_kernel_size = kernel_size[:1] + (3 + 3 * self.new_length,) + kernel_size[2:] | ||
| 361 | |||
| 362 | new_conv = nn.Conv2d(new_kernel_size[1], conv_layer.out_channels, | ||
| 363 | conv_layer.kernel_size, conv_layer.stride, conv_layer.padding, | ||
| 364 | bias=True if len(params) == 2 else False) | ||
| 365 | new_conv.weight.data = new_kernels | ||
| 366 | if len(params) == 2: | ||
| 367 | new_conv.bias.data = params[1].data # add bias if neccessary | ||
| 368 | layer_name = list(container.state_dict().keys())[0][:-7] # remove .weight suffix to get the layer name | ||
| 369 | |||
| 370 | # replace the first convolution layer | ||
| 371 | setattr(container, layer_name, new_conv) | ||
| 372 | return base_model | ||
| 373 | |||
| 374 | @property | ||
| 375 | def crop_size(self): | ||
| 376 | return self.input_size | ||
| 377 | |||
| 378 | @property | ||
| 379 | def scale_size(self): | ||
| 380 | return self.input_size * 256 // 224 | ||
| 381 | |||
| 382 | def get_augmentation(self, flip=True): | ||
| 383 | if self.modality == 'RGB': | ||
| 384 | if flip: | ||
| 385 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66]), | ||
| 386 | GroupRandomHorizontalFlip(is_flow=False)]) | ||
| 387 | else: | ||
| 388 | print('#' * 20, 'NO FLIP!!!') | ||
| 389 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75, .66])]) | ||
| 390 | elif self.modality == 'Flow': | ||
| 391 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), | ||
| 392 | GroupRandomHorizontalFlip(is_flow=True)]) | ||
| 393 | elif self.modality == 'RGBDiff': | ||
| 394 | return torchvision.transforms.Compose([GroupMultiScaleCrop(self.input_size, [1, .875, .75]), | ||
| 395 | GroupRandomHorizontalFlip(is_flow=False)]) |
ops/non_local.py
0 → 100644
| 1 | # Non-local block using embedded gaussian | ||
| 2 | # Code from | ||
| 3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py | ||
| 4 | import torch | ||
| 5 | from torch import nn | ||
| 6 | from torch.nn import functional as F | ||
| 7 | |||
| 8 | |||
| 9 | class _NonLocalBlockND(nn.Module): | ||
| 10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): | ||
| 11 | super(_NonLocalBlockND, self).__init__() | ||
| 12 | |||
| 13 | assert dimension in [1, 2, 3] | ||
| 14 | |||
| 15 | self.dimension = dimension | ||
| 16 | self.sub_sample = sub_sample | ||
| 17 | |||
| 18 | self.in_channels = in_channels | ||
| 19 | self.inter_channels = inter_channels | ||
| 20 | |||
| 21 | if self.inter_channels is None: | ||
| 22 | self.inter_channels = in_channels // 2 | ||
| 23 | if self.inter_channels == 0: | ||
| 24 | self.inter_channels = 1 | ||
| 25 | |||
| 26 | if dimension == 3: | ||
| 27 | conv_nd = nn.Conv3d | ||
| 28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) | ||
| 29 | bn = nn.BatchNorm3d | ||
| 30 | elif dimension == 2: | ||
| 31 | conv_nd = nn.Conv2d | ||
| 32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) | ||
| 33 | bn = nn.BatchNorm2d | ||
| 34 | else: | ||
| 35 | conv_nd = nn.Conv1d | ||
| 36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) | ||
| 37 | bn = nn.BatchNorm1d | ||
| 38 | |||
| 39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 40 | kernel_size=1, stride=1, padding=0) | ||
| 41 | |||
| 42 | if bn_layer: | ||
| 43 | self.W = nn.Sequential( | ||
| 44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
| 45 | kernel_size=1, stride=1, padding=0), | ||
| 46 | bn(self.in_channels) | ||
| 47 | ) | ||
| 48 | nn.init.constant_(self.W[1].weight, 0) | ||
| 49 | nn.init.constant_(self.W[1].bias, 0) | ||
| 50 | else: | ||
| 51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
| 52 | kernel_size=1, stride=1, padding=0) | ||
| 53 | nn.init.constant_(self.W.weight, 0) | ||
| 54 | nn.init.constant_(self.W.bias, 0) | ||
| 55 | |||
| 56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 57 | kernel_size=1, stride=1, padding=0) | ||
| 58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
| 59 | kernel_size=1, stride=1, padding=0) | ||
| 60 | |||
| 61 | if sub_sample: | ||
| 62 | self.g = nn.Sequential(self.g, max_pool_layer) | ||
| 63 | self.phi = nn.Sequential(self.phi, max_pool_layer) | ||
| 64 | |||
| 65 | def forward(self, x): | ||
| 66 | ''' | ||
| 67 | :param x: (b, c, t, h, w) | ||
| 68 | :return: | ||
| 69 | ''' | ||
| 70 | |||
| 71 | batch_size = x.size(0) | ||
| 72 | |||
| 73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) | ||
| 74 | g_x = g_x.permute(0, 2, 1) | ||
| 75 | |||
| 76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | ||
| 77 | theta_x = theta_x.permute(0, 2, 1) | ||
| 78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | ||
| 79 | f = torch.matmul(theta_x, phi_x) | ||
| 80 | f_div_C = F.softmax(f, dim=-1) | ||
| 81 | |||
| 82 | y = torch.matmul(f_div_C, g_x) | ||
| 83 | y = y.permute(0, 2, 1).contiguous() | ||
| 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | ||
| 85 | W_y = self.W(y) | ||
| 86 | z = W_y + x | ||
| 87 | |||
| 88 | return z | ||
| 89 | |||
| 90 | |||
| 91 | class NONLocalBlock1D(_NonLocalBlockND): | ||
| 92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 93 | super(NONLocalBlock1D, self).__init__(in_channels, | ||
| 94 | inter_channels=inter_channels, | ||
| 95 | dimension=1, sub_sample=sub_sample, | ||
| 96 | bn_layer=bn_layer) | ||
| 97 | |||
| 98 | |||
| 99 | class NONLocalBlock2D(_NonLocalBlockND): | ||
| 100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 101 | super(NONLocalBlock2D, self).__init__(in_channels, | ||
| 102 | inter_channels=inter_channels, | ||
| 103 | dimension=2, sub_sample=sub_sample, | ||
| 104 | bn_layer=bn_layer) | ||
| 105 | |||
| 106 | |||
| 107 | class NONLocalBlock3D(_NonLocalBlockND): | ||
| 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
| 109 | super(NONLocalBlock3D, self).__init__(in_channels, | ||
| 110 | inter_channels=inter_channels, | ||
| 111 | dimension=3, sub_sample=sub_sample, | ||
| 112 | bn_layer=bn_layer) | ||
| 113 | |||
| 114 | |||
| 115 | class NL3DWrapper(nn.Module): | ||
| 116 | def __init__(self, block, n_segment): | ||
| 117 | super(NL3DWrapper, self).__init__() | ||
| 118 | self.block = block | ||
| 119 | self.nl = NONLocalBlock3D(block.bn3.num_features) | ||
| 120 | self.n_segment = n_segment | ||
| 121 | |||
| 122 | def forward(self, x): | ||
| 123 | x = self.block(x) | ||
| 124 | |||
| 125 | nt, c, h, w = x.size() | ||
| 126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
| 127 | x = self.nl(x) | ||
| 128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) | ||
| 129 | return x | ||
| 130 | |||
| 131 | |||
| 132 | def make_non_local(net, n_segment): | ||
| 133 | import torchvision | ||
| 134 | import archs | ||
| 135 | if isinstance(net, torchvision.models.ResNet): | ||
| 136 | net.layer2 = nn.Sequential( | ||
| 137 | NL3DWrapper(net.layer2[0], n_segment), | ||
| 138 | net.layer2[1], | ||
| 139 | NL3DWrapper(net.layer2[2], n_segment), | ||
| 140 | net.layer2[3], | ||
| 141 | ) | ||
| 142 | net.layer3 = nn.Sequential( | ||
| 143 | NL3DWrapper(net.layer3[0], n_segment), | ||
| 144 | net.layer3[1], | ||
| 145 | NL3DWrapper(net.layer3[2], n_segment), | ||
| 146 | net.layer3[3], | ||
| 147 | NL3DWrapper(net.layer3[4], n_segment), | ||
| 148 | net.layer3[5], | ||
| 149 | ) | ||
| 150 | else: | ||
| 151 | raise NotImplementedError | ||
| 152 | |||
| 153 | |||
| 154 | if __name__ == '__main__': | ||
| 155 | from torch.autograd import Variable | ||
| 156 | import torch | ||
| 157 | |||
| 158 | sub_sample = True | ||
| 159 | bn_layer = True | ||
| 160 | |||
| 161 | img = Variable(torch.zeros(2, 3, 20)) | ||
| 162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 163 | out = net(img) | ||
| 164 | print(out.size()) | ||
| 165 | |||
| 166 | img = Variable(torch.zeros(2, 3, 20, 20)) | ||
| 167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 168 | out = net(img) | ||
| 169 | print(out.size()) | ||
| 170 | |||
| 171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) | ||
| 172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
| 173 | out = net(img) | ||
| 174 | print(out.size()) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops/temporal_shift.py
0 → 100755
| 1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
| 2 | # arXiv:1811.08383 | ||
| 3 | # Ji Lin*, Chuang Gan, Song Han | ||
| 4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
| 5 | |||
| 6 | import torch | ||
| 7 | import torch.nn as nn | ||
| 8 | import torch.nn.functional as F | ||
| 9 | |||
| 10 | |||
| 11 | class TemporalShift(nn.Module): | ||
| 12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): | ||
| 13 | super(TemporalShift, self).__init__() | ||
| 14 | self.net = net | ||
| 15 | self.n_segment = n_segment | ||
| 16 | self.fold_div = n_div | ||
| 17 | self.inplace = inplace | ||
| 18 | if inplace: | ||
| 19 | print('=> Using in-place shift...') | ||
| 20 | print('=> Using fold div: {}'.format(self.fold_div)) | ||
| 21 | |||
| 22 | def forward(self, x): | ||
| 23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) | ||
| 24 | return self.net(x) | ||
| 25 | |||
| 26 | @staticmethod | ||
| 27 | def shift(x, n_segment, fold_div=3, inplace=False): | ||
| 28 | nt, c, h, w = x.size() | ||
| 29 | n_batch = nt // n_segment | ||
| 30 | x = x.view(n_batch, n_segment, c, h, w) | ||
| 31 | |||
| 32 | fold = c // fold_div | ||
| 33 | if inplace: | ||
| 34 | # Due to some out of order error when performing parallel computing. | ||
| 35 | # May need to write a CUDA kernel. | ||
| 36 | raise NotImplementedError | ||
| 37 | # out = InplaceShift.apply(x, fold) | ||
| 38 | else: | ||
| 39 | out = torch.zeros_like(x) | ||
| 40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left | ||
| 41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right | ||
| 42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift | ||
| 43 | |||
| 44 | return out.view(nt, c, h, w) | ||
| 45 | |||
| 46 | |||
| 47 | class InplaceShift(torch.autograd.Function): | ||
| 48 | # Special thanks to @raoyongming for the help to this function | ||
| 49 | @staticmethod | ||
| 50 | def forward(ctx, input, fold): | ||
| 51 | # not support higher order gradient | ||
| 52 | # input = input.detach_() | ||
| 53 | ctx.fold_ = fold | ||
| 54 | n, t, c, h, w = input.size() | ||
| 55 | buffer = input.data.new(n, t, fold, h, w).zero_() | ||
| 56 | buffer[:, :-1] = input.data[:, 1:, :fold] | ||
| 57 | input.data[:, :, :fold] = buffer | ||
| 58 | buffer.zero_() | ||
| 59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] | ||
| 60 | input.data[:, :, fold: 2 * fold] = buffer | ||
| 61 | return input | ||
| 62 | |||
| 63 | @staticmethod | ||
| 64 | def backward(ctx, grad_output): | ||
| 65 | # grad_output = grad_output.detach_() | ||
| 66 | fold = ctx.fold_ | ||
| 67 | n, t, c, h, w = grad_output.size() | ||
| 68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() | ||
| 69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] | ||
| 70 | grad_output.data[:, :, :fold] = buffer | ||
| 71 | buffer.zero_() | ||
| 72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] | ||
| 73 | grad_output.data[:, :, fold: 2 * fold] = buffer | ||
| 74 | return grad_output, None | ||
| 75 | |||
| 76 | |||
| 77 | class TemporalPool(nn.Module): | ||
| 78 | def __init__(self, net, n_segment): | ||
| 79 | super(TemporalPool, self).__init__() | ||
| 80 | self.net = net | ||
| 81 | self.n_segment = n_segment | ||
| 82 | |||
| 83 | def forward(self, x): | ||
| 84 | x = self.temporal_pool(x, n_segment=self.n_segment) | ||
| 85 | return self.net(x) | ||
| 86 | |||
| 87 | @staticmethod | ||
| 88 | def temporal_pool(x, n_segment): | ||
| 89 | nt, c, h, w = x.size() | ||
| 90 | n_batch = nt // n_segment | ||
| 91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
| 92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) | ||
| 93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) | ||
| 94 | return x | ||
| 95 | |||
| 96 | |||
| 97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): | ||
| 98 | if temporal_pool: | ||
| 99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] | ||
| 100 | else: | ||
| 101 | n_segment_list = [n_segment] * 4 | ||
| 102 | assert n_segment_list[-1] > 0 | ||
| 103 | print('=> n_segment per stage: {}'.format(n_segment_list)) | ||
| 104 | |||
| 105 | import torchvision | ||
| 106 | if isinstance(net, torchvision.models.ResNet): | ||
| 107 | if place == 'block': | ||
| 108 | def make_block_temporal(stage, this_segment): | ||
| 109 | blocks = list(stage.children()) | ||
| 110 | print('=> Processing stage with {} blocks'.format(len(blocks))) | ||
| 111 | for i, b in enumerate(blocks): | ||
| 112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) | ||
| 113 | return nn.Sequential(*(blocks)) | ||
| 114 | |||
| 115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
| 116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
| 117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
| 118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
| 119 | |||
| 120 | elif 'blockres' in place: | ||
| 121 | n_round = 1 | ||
| 122 | if len(list(net.layer3.children())) >= 23: | ||
| 123 | n_round = 2 | ||
| 124 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) | ||
| 125 | |||
| 126 | def make_block_temporal(stage, this_segment): | ||
| 127 | blocks = list(stage.children()) | ||
| 128 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) | ||
| 129 | for i, b in enumerate(blocks): | ||
| 130 | if i % n_round == 0: | ||
| 131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) | ||
| 132 | return nn.Sequential(*blocks) | ||
| 133 | |||
| 134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
| 135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
| 136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
| 137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
| 138 | else: | ||
| 139 | raise NotImplementedError(place) | ||
| 140 | |||
| 141 | |||
| 142 | def make_temporal_pool(net, n_segment): | ||
| 143 | import torchvision | ||
| 144 | if isinstance(net, torchvision.models.ResNet): | ||
| 145 | print('=> Injecting nonlocal pooling') | ||
| 146 | net.layer2 = TemporalPool(net.layer2, n_segment) | ||
| 147 | else: | ||
| 148 | raise NotImplementedError | ||
| 149 | |||
| 150 | |||
| 151 | if __name__ == '__main__': | ||
| 152 | # test inplace shift v.s. vanilla shift | ||
| 153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) | ||
| 154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) | ||
| 155 | |||
| 156 | print('=> Testing CPU...') | ||
| 157 | # test forward | ||
| 158 | with torch.no_grad(): | ||
| 159 | for i in range(10): | ||
| 160 | x = torch.rand(2 * 8, 3, 224, 224) | ||
| 161 | y1 = tsm1(x) | ||
| 162 | y2 = tsm2(x) | ||
| 163 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
| 164 | |||
| 165 | # test backward | ||
| 166 | with torch.enable_grad(): | ||
| 167 | for i in range(10): | ||
| 168 | x1 = torch.rand(2 * 8, 3, 224, 224) | ||
| 169 | x1.requires_grad_() | ||
| 170 | x2 = x1.clone() | ||
| 171 | y1 = tsm1(x1) | ||
| 172 | y2 = tsm2(x2) | ||
| 173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
| 174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
| 175 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
| 176 | |||
| 177 | print('=> Testing GPU...') | ||
| 178 | tsm1.cuda() | ||
| 179 | tsm2.cuda() | ||
| 180 | # test forward | ||
| 181 | with torch.no_grad(): | ||
| 182 | for i in range(10): | ||
| 183 | x = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
| 184 | y1 = tsm1(x) | ||
| 185 | y2 = tsm2(x) | ||
| 186 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
| 187 | |||
| 188 | # test backward | ||
| 189 | with torch.enable_grad(): | ||
| 190 | for i in range(10): | ||
| 191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
| 192 | x1.requires_grad_() | ||
| 193 | x2 = x1.clone() | ||
| 194 | y1 = tsm1(x1) | ||
| 195 | y2 = tsm2(x2) | ||
| 196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
| 197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
| 198 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
| 199 | print('Test passed.') | ||
| 200 | |||
| 201 | |||
| 202 | |||
| 203 |
ops/transforms.py
0 → 100755
| 1 | import torchvision | ||
| 2 | import random | ||
| 3 | from PIL import Image, ImageOps | ||
| 4 | import numpy as np | ||
| 5 | import numbers | ||
| 6 | import math | ||
| 7 | import torch | ||
| 8 | |||
| 9 | |||
| 10 | class GroupRandomCrop(object): | ||
| 11 | def __init__(self, size): | ||
| 12 | if isinstance(size, numbers.Number): | ||
| 13 | self.size = (int(size), int(size)) | ||
| 14 | else: | ||
| 15 | self.size = size | ||
| 16 | |||
| 17 | def __call__(self, img_group): | ||
| 18 | |||
| 19 | w, h = img_group[0].size | ||
| 20 | th, tw = self.size | ||
| 21 | |||
| 22 | out_images = list() | ||
| 23 | |||
| 24 | x1 = random.randint(0, w - tw) | ||
| 25 | y1 = random.randint(0, h - th) | ||
| 26 | |||
| 27 | for img in img_group: | ||
| 28 | assert(img.size[0] == w and img.size[1] == h) | ||
| 29 | if w == tw and h == th: | ||
| 30 | out_images.append(img) | ||
| 31 | else: | ||
| 32 | out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) | ||
| 33 | |||
| 34 | return out_images | ||
| 35 | |||
| 36 | |||
| 37 | class GroupCenterCrop(object): | ||
| 38 | def __init__(self, size): | ||
| 39 | self.worker = torchvision.transforms.CenterCrop(size) | ||
| 40 | |||
| 41 | def __call__(self, img_group): | ||
| 42 | return [self.worker(img) for img in img_group] | ||
| 43 | |||
| 44 | |||
| 45 | class GroupRandomHorizontalFlip(object): | ||
| 46 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 | ||
| 47 | """ | ||
| 48 | def __init__(self, is_flow=False): | ||
| 49 | self.is_flow = is_flow | ||
| 50 | |||
| 51 | def __call__(self, img_group, is_flow=False): | ||
| 52 | v = random.random() | ||
| 53 | if v < 0.5: | ||
| 54 | ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] | ||
| 55 | if self.is_flow: | ||
| 56 | for i in range(0, len(ret), 2): | ||
| 57 | ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping | ||
| 58 | return ret | ||
| 59 | else: | ||
| 60 | return img_group | ||
| 61 | |||
| 62 | |||
| 63 | class GroupNormalize(object): | ||
| 64 | def __init__(self, mean, std): | ||
| 65 | self.mean = mean | ||
| 66 | self.std = std | ||
| 67 | |||
| 68 | def __call__(self, tensor): | ||
| 69 | rep_mean = self.mean * (tensor.size()[0]//len(self.mean)) | ||
| 70 | rep_std = self.std * (tensor.size()[0]//len(self.std)) | ||
| 71 | |||
| 72 | # TODO: make efficient | ||
| 73 | for t, m, s in zip(tensor, rep_mean, rep_std): | ||
| 74 | t.sub_(m).div_(s) | ||
| 75 | return tensor | ||
| 76 | |||
| 77 | |||
| 78 | class GroupScale(object): | ||
| 79 | """ Rescales the input PIL.Image to the given 'size'. | ||
| 80 | 'size' will be the size of the smaller edge. | ||
| 81 | For example, if height > width, then image will be | ||
| 82 | rescaled to (size * height / width, size) | ||
| 83 | size: size of the smaller edge | ||
| 84 | interpolation: Default: PIL.Image.BILINEAR | ||
| 85 | """ | ||
| 86 | |||
| 87 | def __init__(self, size, interpolation=Image.BILINEAR): | ||
| 88 | self.worker = torchvision.transforms.Resize(size, interpolation) | ||
| 89 | |||
| 90 | def __call__(self, img_group): | ||
| 91 | return [self.worker(img) for img in img_group] | ||
| 92 | |||
| 93 | |||
| 94 | class GroupOverSample(object): | ||
| 95 | def __init__(self, crop_size, scale_size=None, flip=True): | ||
| 96 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) | ||
| 97 | |||
| 98 | if scale_size is not None: | ||
| 99 | self.scale_worker = GroupScale(scale_size) | ||
| 100 | else: | ||
| 101 | self.scale_worker = None | ||
| 102 | self.flip = flip | ||
| 103 | |||
| 104 | def __call__(self, img_group): | ||
| 105 | |||
| 106 | if self.scale_worker is not None: | ||
| 107 | img_group = self.scale_worker(img_group) | ||
| 108 | |||
| 109 | image_w, image_h = img_group[0].size | ||
| 110 | crop_w, crop_h = self.crop_size | ||
| 111 | |||
| 112 | offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) | ||
| 113 | oversample_group = list() | ||
| 114 | for o_w, o_h in offsets: | ||
| 115 | normal_group = list() | ||
| 116 | flip_group = list() | ||
| 117 | for i, img in enumerate(img_group): | ||
| 118 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) | ||
| 119 | normal_group.append(crop) | ||
| 120 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) | ||
| 121 | |||
| 122 | if img.mode == 'L' and i % 2 == 0: | ||
| 123 | flip_group.append(ImageOps.invert(flip_crop)) | ||
| 124 | else: | ||
| 125 | flip_group.append(flip_crop) | ||
| 126 | |||
| 127 | oversample_group.extend(normal_group) | ||
| 128 | if self.flip: | ||
| 129 | oversample_group.extend(flip_group) | ||
| 130 | return oversample_group | ||
| 131 | |||
| 132 | |||
| 133 | class GroupFullResSample(object): | ||
| 134 | def __init__(self, crop_size, scale_size=None, flip=True): | ||
| 135 | self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) | ||
| 136 | |||
| 137 | if scale_size is not None: | ||
| 138 | self.scale_worker = GroupScale(scale_size) | ||
| 139 | else: | ||
| 140 | self.scale_worker = None | ||
| 141 | self.flip = flip | ||
| 142 | |||
| 143 | def __call__(self, img_group): | ||
| 144 | |||
| 145 | if self.scale_worker is not None: | ||
| 146 | img_group = self.scale_worker(img_group) | ||
| 147 | |||
| 148 | image_w, image_h = img_group[0].size | ||
| 149 | crop_w, crop_h = self.crop_size | ||
| 150 | |||
| 151 | w_step = (image_w - crop_w) // 4 | ||
| 152 | h_step = (image_h - crop_h) // 4 | ||
| 153 | |||
| 154 | offsets = list() | ||
| 155 | offsets.append((0 * w_step, 2 * h_step)) # left | ||
| 156 | offsets.append((4 * w_step, 2 * h_step)) # right | ||
| 157 | offsets.append((2 * w_step, 2 * h_step)) # center | ||
| 158 | |||
| 159 | oversample_group = list() | ||
| 160 | for o_w, o_h in offsets: | ||
| 161 | normal_group = list() | ||
| 162 | flip_group = list() | ||
| 163 | for i, img in enumerate(img_group): | ||
| 164 | crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) | ||
| 165 | normal_group.append(crop) | ||
| 166 | if self.flip: | ||
| 167 | flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) | ||
| 168 | |||
| 169 | if img.mode == 'L' and i % 2 == 0: | ||
| 170 | flip_group.append(ImageOps.invert(flip_crop)) | ||
| 171 | else: | ||
| 172 | flip_group.append(flip_crop) | ||
| 173 | |||
| 174 | oversample_group.extend(normal_group) | ||
| 175 | oversample_group.extend(flip_group) | ||
| 176 | return oversample_group | ||
| 177 | |||
| 178 | |||
| 179 | class GroupMultiScaleCrop(object): | ||
| 180 | |||
| 181 | def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): | ||
| 182 | self.scales = scales if scales is not None else [1, .875, .75, .66] | ||
| 183 | self.max_distort = max_distort | ||
| 184 | self.fix_crop = fix_crop | ||
| 185 | self.more_fix_crop = more_fix_crop | ||
| 186 | self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] | ||
| 187 | self.interpolation = Image.BILINEAR | ||
| 188 | |||
| 189 | def __call__(self, img_group): | ||
| 190 | |||
| 191 | im_size = img_group[0].size | ||
| 192 | |||
| 193 | crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) | ||
| 194 | crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] | ||
| 195 | ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) | ||
| 196 | for img in crop_img_group] | ||
| 197 | return ret_img_group | ||
| 198 | |||
| 199 | def _sample_crop_size(self, im_size): | ||
| 200 | image_w, image_h = im_size[0], im_size[1] | ||
| 201 | |||
| 202 | # find a crop size | ||
| 203 | base_size = min(image_w, image_h) | ||
| 204 | crop_sizes = [int(base_size * x) for x in self.scales] | ||
| 205 | crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] | ||
| 206 | crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] | ||
| 207 | |||
| 208 | pairs = [] | ||
| 209 | for i, h in enumerate(crop_h): | ||
| 210 | for j, w in enumerate(crop_w): | ||
| 211 | if abs(i - j) <= self.max_distort: | ||
| 212 | pairs.append((w, h)) | ||
| 213 | |||
| 214 | crop_pair = random.choice(pairs) | ||
| 215 | if not self.fix_crop: | ||
| 216 | w_offset = random.randint(0, image_w - crop_pair[0]) | ||
| 217 | h_offset = random.randint(0, image_h - crop_pair[1]) | ||
| 218 | else: | ||
| 219 | w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) | ||
| 220 | |||
| 221 | return crop_pair[0], crop_pair[1], w_offset, h_offset | ||
| 222 | |||
| 223 | def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): | ||
| 224 | offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) | ||
| 225 | return random.choice(offsets) | ||
| 226 | |||
| 227 | @staticmethod | ||
| 228 | def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): | ||
| 229 | w_step = (image_w - crop_w) // 4 | ||
| 230 | h_step = (image_h - crop_h) // 4 | ||
| 231 | |||
| 232 | ret = list() | ||
| 233 | ret.append((0, 0)) # upper left | ||
| 234 | ret.append((4 * w_step, 0)) # upper right | ||
| 235 | ret.append((0, 4 * h_step)) # lower left | ||
| 236 | ret.append((4 * w_step, 4 * h_step)) # lower right | ||
| 237 | ret.append((2 * w_step, 2 * h_step)) # center | ||
| 238 | |||
| 239 | if more_fix_crop: | ||
| 240 | ret.append((0, 2 * h_step)) # center left | ||
| 241 | ret.append((4 * w_step, 2 * h_step)) # center right | ||
| 242 | ret.append((2 * w_step, 4 * h_step)) # lower center | ||
| 243 | ret.append((2 * w_step, 0 * h_step)) # upper center | ||
| 244 | |||
| 245 | ret.append((1 * w_step, 1 * h_step)) # upper left quarter | ||
| 246 | ret.append((3 * w_step, 1 * h_step)) # upper right quarter | ||
| 247 | ret.append((1 * w_step, 3 * h_step)) # lower left quarter | ||
| 248 | ret.append((3 * w_step, 3 * h_step)) # lower righ quarter | ||
| 249 | |||
| 250 | return ret | ||
| 251 | |||
| 252 | |||
| 253 | class GroupRandomSizedCrop(object): | ||
| 254 | """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size | ||
| 255 | and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio | ||
| 256 | This is popularly used to train the Inception networks | ||
| 257 | size: size of the smaller edge | ||
| 258 | interpolation: Default: PIL.Image.BILINEAR | ||
| 259 | """ | ||
| 260 | def __init__(self, size, interpolation=Image.BILINEAR): | ||
| 261 | self.size = size | ||
| 262 | self.interpolation = interpolation | ||
| 263 | |||
| 264 | def __call__(self, img_group): | ||
| 265 | for attempt in range(10): | ||
| 266 | area = img_group[0].size[0] * img_group[0].size[1] | ||
| 267 | target_area = random.uniform(0.08, 1.0) * area | ||
| 268 | aspect_ratio = random.uniform(3. / 4, 4. / 3) | ||
| 269 | |||
| 270 | w = int(round(math.sqrt(target_area * aspect_ratio))) | ||
| 271 | h = int(round(math.sqrt(target_area / aspect_ratio))) | ||
| 272 | |||
| 273 | if random.random() < 0.5: | ||
| 274 | w, h = h, w | ||
| 275 | |||
| 276 | if w <= img_group[0].size[0] and h <= img_group[0].size[1]: | ||
| 277 | x1 = random.randint(0, img_group[0].size[0] - w) | ||
| 278 | y1 = random.randint(0, img_group[0].size[1] - h) | ||
| 279 | found = True | ||
| 280 | break | ||
| 281 | else: | ||
| 282 | found = False | ||
| 283 | x1 = 0 | ||
| 284 | y1 = 0 | ||
| 285 | |||
| 286 | if found: | ||
| 287 | out_group = list() | ||
| 288 | for img in img_group: | ||
| 289 | img = img.crop((x1, y1, x1 + w, y1 + h)) | ||
| 290 | assert(img.size == (w, h)) | ||
| 291 | out_group.append(img.resize((self.size, self.size), self.interpolation)) | ||
| 292 | return out_group | ||
| 293 | else: | ||
| 294 | # Fallback | ||
| 295 | scale = GroupScale(self.size, interpolation=self.interpolation) | ||
| 296 | crop = GroupRandomCrop(self.size) | ||
| 297 | return crop(scale(img_group)) | ||
| 298 | |||
| 299 | |||
| 300 | class Stack(object): | ||
| 301 | |||
| 302 | def __init__(self, roll=False): | ||
| 303 | self.roll = roll | ||
| 304 | |||
| 305 | def __call__(self, img_group): | ||
| 306 | if img_group[0].mode == 'L': | ||
| 307 | return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) | ||
| 308 | elif img_group[0].mode == 'RGB': | ||
| 309 | if self.roll: | ||
| 310 | return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) | ||
| 311 | else: | ||
| 312 | return np.concatenate(img_group, axis=2) | ||
| 313 | |||
| 314 | |||
| 315 | class ToTorchFormatTensor(object): | ||
| 316 | """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] | ||
| 317 | to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ | ||
| 318 | def __init__(self, div=True): | ||
| 319 | self.div = div | ||
| 320 | |||
| 321 | def __call__(self, pic): | ||
| 322 | if isinstance(pic, np.ndarray): | ||
| 323 | # handle numpy array | ||
| 324 | img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() | ||
| 325 | else: | ||
| 326 | # handle PIL Image | ||
| 327 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | ||
| 328 | img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | ||
| 329 | # put it from HWC to CHW format | ||
| 330 | # yikes, this transpose takes 80% of the loading time/CPU | ||
| 331 | img = img.transpose(0, 1).transpose(0, 2).contiguous() | ||
| 332 | return img.float().div(255) if self.div else img.float() | ||
| 333 | |||
| 334 | |||
| 335 | class IdentityTransform(object): | ||
| 336 | |||
| 337 | def __call__(self, data): | ||
| 338 | return data | ||
| 339 | |||
| 340 | |||
| 341 | if __name__ == "__main__": | ||
| 342 | trans = torchvision.transforms.Compose([ | ||
| 343 | GroupScale(256), | ||
| 344 | GroupRandomCrop(224), | ||
| 345 | Stack(), | ||
| 346 | ToTorchFormatTensor(), | ||
| 347 | GroupNormalize( | ||
| 348 | mean=[.485, .456, .406], | ||
| 349 | std=[.229, .224, .225] | ||
| 350 | )] | ||
| 351 | ) | ||
| 352 | |||
| 353 | im = Image.open('../tensorflow-model-zoo.torch/lena_299.png') | ||
| 354 | |||
| 355 | color_group = [im] * 3 | ||
| 356 | rst = trans(color_group) | ||
| 357 | |||
| 358 | gray_group = [im.convert('L')] * 9 | ||
| 359 | gray_rst = trans(gray_group) | ||
| 360 | |||
| 361 | trans2 = torchvision.transforms.Compose([ | ||
| 362 | GroupRandomSizedCrop(256), | ||
| 363 | Stack(), | ||
| 364 | ToTorchFormatTensor(), | ||
| 365 | GroupNormalize( | ||
| 366 | mean=[.485, .456, .406], | ||
| 367 | std=[.229, .224, .225]) | ||
| 368 | ]) | ||
| 369 | print(trans2(color_group)) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
ops/utils.py
0 → 100755
| 1 | import numpy as np | ||
| 2 | |||
| 3 | |||
| 4 | def softmax(scores): | ||
| 5 | es = np.exp(scores - scores.max(axis=-1)[..., None]) | ||
| 6 | return es / es.sum(axis=-1)[..., None] | ||
| 7 | |||
| 8 | |||
| 9 | class AverageMeter(object): | ||
| 10 | """Computes and stores the average and current value""" | ||
| 11 | |||
| 12 | def __init__(self): | ||
| 13 | self.reset() | ||
| 14 | |||
| 15 | def reset(self): | ||
| 16 | self.val = 0 | ||
| 17 | self.avg = 0 | ||
| 18 | self.sum = 0 | ||
| 19 | self.count = 0 | ||
| 20 | |||
| 21 | def update(self, val, n=1): | ||
| 22 | self.val = val | ||
| 23 | self.sum += val * n | ||
| 24 | self.count += n | ||
| 25 | self.avg = self.sum / self.count | ||
| 26 | |||
| 27 | |||
| 28 | def accuracy(output, target, topk=(1,)): | ||
| 29 | """Computes the precision@k for the specified values of k""" | ||
| 30 | maxk = max(topk) | ||
| 31 | batch_size = target.size(0) | ||
| 32 | |||
| 33 | _, pred = output.topk(maxk, 1, True, True) | ||
| 34 | pred = pred.t() | ||
| 35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
| 36 | |||
| 37 | res = [] | ||
| 38 | for k in topk: | ||
| 39 | correct_k = correct[:k].view(-1).float().sum(0) | ||
| 40 | res.append(correct_k.mul_(100.0 / batch_size)) | ||
| 41 | return res | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
person_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import numpy as np | ||
| 4 | import pickle | ||
| 5 | |||
| 6 | def start_filter(config): | ||
| 7 | cls_class_path = config['MODEL']['CLS_PERSON'] | ||
| 8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
| 9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 10 | result_file_name = config['PERSON']['RESULT_FILE'] | ||
| 11 | feature_name = config['PERSON']['DATA_NAME'] | ||
| 12 | |||
| 13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
| 14 | |||
| 15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 16 | result_file = open(result_file_path, 'w') | ||
| 17 | |||
| 18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
| 19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
| 20 | |||
| 21 | X_val = [] | ||
| 22 | Y_val = [] | ||
| 23 | Y_names = [] | ||
| 24 | for j in range(len(val_annotation_pairs)): | ||
| 25 | pair = val_annotation_pairs[j] | ||
| 26 | X_val.append(np.squeeze(pair[0])) | ||
| 27 | Y_val.append(pair[1]) | ||
| 28 | Y_names.append(pair[2]) | ||
| 29 | |||
| 30 | X_val = np.array(X_val) | ||
| 31 | y_pred = xgboost_model.predict_proba(X_val) | ||
| 32 | |||
| 33 | for i, Y_name in enumerate(Y_names): | ||
| 34 | result_file.write(Y_name + ' ') | ||
| 35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
| 36 | |||
| 37 | result_file.close() | ||
| 38 | |||
| 39 | |||
| 40 | |||
| 41 | |||
| 42 |
pose_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.optim | ||
| 5 | import torch.nn.parallel | ||
| 6 | from ops.models import TSN | ||
| 7 | from ops.transforms import * | ||
| 8 | from ops.dataset import TSNDataSet | ||
| 9 | from torch.nn import functional as F | ||
| 10 | |||
| 11 | |||
| 12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 13 | |||
| 14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 15 | video_names = os.listdir(frame_save_dir) | ||
| 16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 17 | |||
| 18 | for video_name in video_names: | ||
| 19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 20 | ucf101_rgb_val_file.write(video_name) | ||
| 21 | ucf101_rgb_val_file.write(' ') | ||
| 22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 23 | ucf101_rgb_val_file.write('\n') | ||
| 24 | |||
| 25 | ucf101_rgb_val_file.close() | ||
| 26 | |||
| 27 | return val_path | ||
| 28 | |||
| 29 | |||
| 30 | def start_filter(config): | ||
| 31 | arch = config['FIGHTING']['ARCH'] | ||
| 32 | prefix = config['VIDEO']['PREFIX'] | ||
| 33 | modality = config['POSE']['MODALITY'] | ||
| 34 | test_crop = config['POSE']['TEST_CROP'] | ||
| 35 | batch_size = config['POSE']['BATCH_SIZE'] | ||
| 36 | weights_path = config['MODEL']['CLS_POSE'] | ||
| 37 | test_segment = config['POSE']['TEST_SEGMENT'] | ||
| 38 | frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR'] | ||
| 39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 40 | result_file_name = config['POSE']['RESULT_FILE'] | ||
| 41 | |||
| 42 | workers = 8 | ||
| 43 | num_class = 3 | ||
| 44 | shift_div = 8 | ||
| 45 | img_feature_dim = 256 | ||
| 46 | |||
| 47 | softmax = False | ||
| 48 | is_shift = True | ||
| 49 | full_res = False | ||
| 50 | non_local = False | ||
| 51 | dense_sample = False | ||
| 52 | twice_sample = False | ||
| 53 | |||
| 54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 56 | |||
| 57 | pretrain = 'imagenet' | ||
| 58 | shift_place = 'blockres' | ||
| 59 | crop_fusion_type = 'avg' | ||
| 60 | |||
| 61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 62 | base_model=arch, | ||
| 63 | consensus_type=crop_fusion_type, | ||
| 64 | img_feature_dim=img_feature_dim, | ||
| 65 | pretrain=pretrain, | ||
| 66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 67 | non_local=non_local, | ||
| 68 | ) | ||
| 69 | |||
| 70 | checkpoint = torch.load(weights_path) | ||
| 71 | checkpoint = checkpoint['state_dict'] | ||
| 72 | |||
| 73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 76 | } | ||
| 77 | for k, v in replace_dict.items(): | ||
| 78 | if k in base_dict: | ||
| 79 | base_dict[v] = base_dict.pop(k) | ||
| 80 | |||
| 81 | net.load_state_dict(base_dict) | ||
| 82 | |||
| 83 | input_size = net.scale_size if full_res else net.input_size | ||
| 84 | |||
| 85 | if test_crop == 1: | ||
| 86 | cropping = torchvision.transforms.Compose([ | ||
| 87 | GroupScale(net.scale_size), | ||
| 88 | GroupCenterCrop(input_size), | ||
| 89 | ]) | ||
| 90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 91 | cropping = torchvision.transforms.Compose([ | ||
| 92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 93 | ]) | ||
| 94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 95 | cropping = torchvision.transforms.Compose([ | ||
| 96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 97 | ]) | ||
| 98 | elif test_crop == 10: | ||
| 99 | cropping = torchvision.transforms.Compose([ | ||
| 100 | GroupOverSample(input_size, net.scale_size) | ||
| 101 | ]) | ||
| 102 | else: | ||
| 103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 104 | |||
| 105 | data_loader = torch.utils.data.DataLoader( | ||
| 106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 107 | new_length=1 if modality == "RGB" else 5, | ||
| 108 | modality=modality, | ||
| 109 | image_tmpl=prefix, | ||
| 110 | test_mode=True, | ||
| 111 | remove_missing=False, | ||
| 112 | transform=torchvision.transforms.Compose([ | ||
| 113 | cropping, | ||
| 114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 116 | GroupNormalize(net.input_mean, net.input_std), | ||
| 117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 118 | batch_size=batch_size, shuffle=False, | ||
| 119 | num_workers=workers, pin_memory=True, | ||
| 120 | ) | ||
| 121 | |||
| 122 | net = torch.nn.DataParallel(net.cuda()) | ||
| 123 | net.eval() | ||
| 124 | data_gen = enumerate(data_loader) | ||
| 125 | max_num = len(data_loader.dataset) | ||
| 126 | |||
| 127 | result_file = open(result_file_path, 'w') | ||
| 128 | |||
| 129 | for i, data_pair in data_gen: | ||
| 130 | directory, data = data_pair | ||
| 131 | with torch.no_grad(): | ||
| 132 | if i >= max_num: | ||
| 133 | break | ||
| 134 | num_crop = test_crop | ||
| 135 | if dense_sample: | ||
| 136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 137 | |||
| 138 | if twice_sample: | ||
| 139 | num_crop *= 2 | ||
| 140 | |||
| 141 | if modality == 'RGB': | ||
| 142 | length = 3 | ||
| 143 | elif modality == 'Flow': | ||
| 144 | length = 10 | ||
| 145 | elif modality == 'RGBDiff': | ||
| 146 | length = 18 | ||
| 147 | else: | ||
| 148 | raise ValueError("Unknown modality " + modality) | ||
| 149 | |||
| 150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 151 | if is_shift: | ||
| 152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 169 | result_file.write(str(directory[0]) + ' ') | ||
| 170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 171 | |||
| 172 | result_file.close() | ||
| 173 | print('video filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
test.py
0 → 100644
| 1 | import os | ||
| 2 | import cv2 | ||
| 3 | import load_util | ||
| 4 | import media_util | ||
| 5 | import numpy as np | ||
| 6 | from sklearn.metrics import confusion_matrix | ||
| 7 | import fighting_filter, emotion_filter, argue_filter, audio_filter, class_filter | ||
| 8 | import video_filter, pose_filter, flow_filter | ||
| 9 | |||
| 10 | |||
| 11 | |||
| 12 | def accuracy_cal(config): | ||
| 13 | |||
| 14 | label_file_path = config['VIDEO']['LABEL_PATH'] | ||
| 15 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 16 | final_file_name = config['AUDIO']['RESULT_FILE'] | ||
| 17 | |||
| 18 | final_file_path = os.path.join(frame_list_dir, final_file_name) | ||
| 19 | final_file_lines = open(final_file_path).readlines() | ||
| 20 | label_file_lines = open(label_file_path).readlines() | ||
| 21 | |||
| 22 | |||
| 23 | final_pairs = {line.strip().split(' ')[0]: line.strip().split(' ')[1] for line in final_file_lines} | ||
| 24 | |||
| 25 | lines_num = len(label_file_lines) - 1 | ||
| 26 | hit = 0 | ||
| 27 | for i, label_line in enumerate(label_file_lines): | ||
| 28 | if i == 0: | ||
| 29 | continue | ||
| 30 | file, label = label_line.strip().split(' ') | ||
| 31 | final_pre = final_pairs[file] | ||
| 32 | final_pre_class = np.argmax(np.array(final_pre.split(','))) + 1 | ||
| 33 | print(final_pre_class, label) | ||
| 34 | if final_pre_class == int(label): | ||
| 35 | hit += 1 | ||
| 36 | |||
| 37 | return hit/lines_num | ||
| 38 | |||
| 39 | |||
| 40 | def main(): | ||
| 41 | config_path = r'config.yaml' | ||
| 42 | config = load_util.load_config(config_path) | ||
| 43 | |||
| 44 | media_util.extract_wav(config) | ||
| 45 | media_util.extract_frame(config) | ||
| 46 | media_util.extract_frame_pose(config) | ||
| 47 | media_util.extract_is10(config) | ||
| 48 | media_util.extract_random_face_feature(config) | ||
| 49 | media_util.extract_mirror(config) | ||
| 50 | |||
| 51 | fighting_2_filter.start_filter(config) | ||
| 52 | emotion_filter.start_filter(config) | ||
| 53 | |||
| 54 | audio_filter.start_filter(config) | ||
| 55 | |||
| 56 | class_filter.start_filter(config) | ||
| 57 | video_filter.start_filter(config) | ||
| 58 | pose_filter.start_filter(config) | ||
| 59 | |||
| 60 | flow_filter.start_filter(config) | ||
| 61 | |||
| 62 | acc = accuracy_cal(config) | ||
| 63 | print(acc) | ||
| 64 | |||
| 65 | |||
| 66 | if __name__ == '__main__': | ||
| 67 | main() | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
video_filter.py
0 → 100644
| 1 | import os | ||
| 2 | import torch.optim | ||
| 3 | import numpy as np | ||
| 4 | import torch.nn.parallel | ||
| 5 | from ops.models import TSN | ||
| 6 | from ops.transforms import * | ||
| 7 | from ops.dataset import TSNDataSet | ||
| 8 | from torch.nn import functional as F | ||
| 9 | |||
| 10 | |||
| 11 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
| 12 | |||
| 13 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
| 14 | video_names = os.listdir(frame_save_dir) | ||
| 15 | ucf101_rgb_val_file = open(val_path, 'w') | ||
| 16 | |||
| 17 | for video_name in video_names: | ||
| 18 | images_dir = os.path.join(frame_save_dir, video_name) | ||
| 19 | ucf101_rgb_val_file.write(video_name) | ||
| 20 | ucf101_rgb_val_file.write(' ') | ||
| 21 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
| 22 | ucf101_rgb_val_file.write('\n') | ||
| 23 | |||
| 24 | ucf101_rgb_val_file.close() | ||
| 25 | |||
| 26 | return val_path | ||
| 27 | |||
| 28 | |||
| 29 | def start_filter(config): | ||
| 30 | arch = config['FIGHTING']['ARCH'] | ||
| 31 | prefix = config['VIDEO']['PREFIX'] | ||
| 32 | modality = config['VIDEO_FILTER']['MODALITY'] | ||
| 33 | test_crop = config['VIDEO_FILTER']['TEST_CROP'] | ||
| 34 | batch_size = config['VIDEO_FILTER']['BATCH_SIZE'] | ||
| 35 | weights_path = config['MODEL']['CLS_VIDEO'] | ||
| 36 | test_segment = config['VIDEO_FILTER']['TEST_SEGMENT'] | ||
| 37 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
| 38 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
| 39 | result_file_name = config['VIDEO_FILTER']['RESULT_FILE'] | ||
| 40 | |||
| 41 | workers = 8 | ||
| 42 | num_class = 3 | ||
| 43 | shift_div = 8 | ||
| 44 | img_feature_dim = 256 | ||
| 45 | |||
| 46 | softmax = False | ||
| 47 | is_shift = True | ||
| 48 | full_res = False | ||
| 49 | non_local = False | ||
| 50 | dense_sample = False | ||
| 51 | twice_sample = False | ||
| 52 | |||
| 53 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
| 54 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
| 55 | |||
| 56 | pretrain = 'imagenet' | ||
| 57 | shift_place = 'blockres' | ||
| 58 | crop_fusion_type = 'avg' | ||
| 59 | |||
| 60 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
| 61 | base_model=arch, | ||
| 62 | consensus_type=crop_fusion_type, | ||
| 63 | img_feature_dim=img_feature_dim, | ||
| 64 | pretrain=pretrain, | ||
| 65 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
| 66 | non_local=non_local, | ||
| 67 | ) | ||
| 68 | |||
| 69 | checkpoint = torch.load(weights_path) | ||
| 70 | checkpoint = checkpoint['state_dict'] | ||
| 71 | |||
| 72 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
| 73 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
| 74 | 'base_model.classifier.bias': 'new_fc.bias', | ||
| 75 | } | ||
| 76 | for k, v in replace_dict.items(): | ||
| 77 | if k in base_dict: | ||
| 78 | base_dict[v] = base_dict.pop(k) | ||
| 79 | |||
| 80 | net.load_state_dict(base_dict) | ||
| 81 | |||
| 82 | input_size = net.scale_size if full_res else net.input_size | ||
| 83 | |||
| 84 | if test_crop == 1: | ||
| 85 | cropping = torchvision.transforms.Compose([ | ||
| 86 | GroupScale(net.scale_size), | ||
| 87 | GroupCenterCrop(input_size), | ||
| 88 | ]) | ||
| 89 | elif test_crop == 3: # do not flip, so only 5 crops | ||
| 90 | cropping = torchvision.transforms.Compose([ | ||
| 91 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
| 92 | ]) | ||
| 93 | elif test_crop == 5: # do not flip, so only 5 crops | ||
| 94 | cropping = torchvision.transforms.Compose([ | ||
| 95 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
| 96 | ]) | ||
| 97 | elif test_crop == 10: | ||
| 98 | cropping = torchvision.transforms.Compose([ | ||
| 99 | GroupOverSample(input_size, net.scale_size) | ||
| 100 | ]) | ||
| 101 | else: | ||
| 102 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
| 103 | |||
| 104 | data_loader = torch.utils.data.DataLoader( | ||
| 105 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
| 106 | new_length=1 if modality == "RGB" else 5, | ||
| 107 | modality=modality, | ||
| 108 | image_tmpl=prefix, | ||
| 109 | test_mode=True, | ||
| 110 | remove_missing=False, | ||
| 111 | transform=torchvision.transforms.Compose([ | ||
| 112 | cropping, | ||
| 113 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
| 114 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
| 115 | GroupNormalize(net.input_mean, net.input_std), | ||
| 116 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
| 117 | batch_size=batch_size, shuffle=False, | ||
| 118 | num_workers=workers, pin_memory=True, | ||
| 119 | ) | ||
| 120 | |||
| 121 | net = torch.nn.DataParallel(net.cuda()) | ||
| 122 | net.eval() | ||
| 123 | data_gen = enumerate(data_loader) | ||
| 124 | max_num = len(data_loader.dataset) | ||
| 125 | |||
| 126 | result_file = open(result_file_path, 'w') | ||
| 127 | |||
| 128 | for i, data_pair in data_gen: | ||
| 129 | directory, data = data_pair | ||
| 130 | with torch.no_grad(): | ||
| 131 | if i >= max_num: | ||
| 132 | break | ||
| 133 | num_crop = test_crop | ||
| 134 | if dense_sample: | ||
| 135 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
| 136 | |||
| 137 | if twice_sample: | ||
| 138 | num_crop *= 2 | ||
| 139 | |||
| 140 | if modality == 'RGB': | ||
| 141 | length = 3 | ||
| 142 | elif modality == 'Flow': | ||
| 143 | length = 10 | ||
| 144 | elif modality == 'RGBDiff': | ||
| 145 | length = 18 | ||
| 146 | else: | ||
| 147 | raise ValueError("Unknown modality " + modality) | ||
| 148 | |||
| 149 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
| 150 | if is_shift: | ||
| 151 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
| 152 | |||
| 153 | rst, feature = net(data_in) | ||
| 154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
| 155 | |||
| 156 | if softmax: | ||
| 157 | # take the softmax to normalize the output to probability | ||
| 158 | rst = F.softmax(rst, dim=1) | ||
| 159 | |||
| 160 | rst = rst.data.cpu().numpy().copy() | ||
| 161 | |||
| 162 | if net.module.is_shift: | ||
| 163 | rst = rst.reshape(batch_size, num_class) | ||
| 164 | else: | ||
| 165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
| 166 | |||
| 167 | proba = np.squeeze(rst) | ||
| 168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
| 169 | result_file.write(str(directory[0]) + ' ') | ||
| 170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
| 171 | |||
| 172 | result_file.close() | ||
| 173 | print('video filter end') | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment