first commit
Showing
46 changed files
with
2001 additions
and
0 deletions
__init__.py
0 → 100644
File mode changed
__pycache__/argue_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/audio_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/bg_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/class_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/emotion_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_2_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/fighting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/flow_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/load_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/media_util.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/meeting_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/person_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/pose_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/troops_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_1_filter.cpython-36.pyc
0 → 100644
No preview for this file type
__pycache__/video_filter.cpython-36.pyc
0 → 100644
No preview for this file type
audio_filter.py
0 → 100644
1 | import os | ||
2 | import csv | ||
3 | import pickle | ||
4 | import numpy as np | ||
5 | from sklearn.externals import joblib | ||
6 | |||
7 | |||
8 | def start_filter(config): | ||
9 | cls_audio_path = config['MODEL']['CLS_AUDIO'] | ||
10 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
11 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
12 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
13 | feature_name = config['AUDIO']['DATA_NAME'] | ||
14 | |||
15 | svm_clf = joblib.load(cls_audio_path) | ||
16 | |||
17 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
18 | result_file = open(result_file_path, 'w') | ||
19 | |||
20 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
21 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
22 | |||
23 | for pair in val_annotation_pairs: | ||
24 | |||
25 | v = pair[0] | ||
26 | n = pair[2] | ||
27 | |||
28 | feature_np = np.reshape(v, (1, -1)) | ||
29 | res = svm_clf.predict_proba(feature_np) | ||
30 | proba = np.squeeze(res) | ||
31 | |||
32 | # class_pre = svm_clf.predict(feature_np) | ||
33 | |||
34 | result_file.write(str(pair[2])[:-4] + ' ') | ||
35 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
36 | |||
37 | result_file.close() | ||
38 | |||
39 | |||
40 | |||
41 | |||
42 | |||
43 | def start_filter_xgboost(config): | ||
44 | cls_class_path = config['MODEL']['CLS_AUDIO'] | ||
45 | feature_save_dir = config['VIDEO']['IS10_FEATURE_NP_DIR'] | ||
46 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
47 | result_file_name = config['AUDIO']['RESULT_FILE'] | ||
48 | feature_name = config['AUDIO']['DATA_NAME'] | ||
49 | |||
50 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
51 | |||
52 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
53 | result_file = open(result_file_path, 'w') | ||
54 | |||
55 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
56 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
57 | |||
58 | X_val = [] | ||
59 | Y_names = [] | ||
60 | for pair in val_annotation_pairs: | ||
61 | n, v = pair.items() | ||
62 | X_val.append(v) | ||
63 | Y_names.append(n) | ||
64 | |||
65 | X_val = np.array(X_val) | ||
66 | y_pred = xgboost_model.predict_proba(X_val) | ||
67 | |||
68 | for i, Y_name in enumerate(Y_names): | ||
69 | result_file.write(Y_name + ' ') | ||
70 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
71 | |||
72 | result_file.close() | ||
73 |
bg_filter.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import numpy as np | ||
4 | import pickle | ||
5 | |||
6 | def start_filter(config): | ||
7 | cls_class_path = config['MODEL']['CLS_BG'] | ||
8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
10 | result_file_name = config['BG']['RESULT_FILE'] | ||
11 | feature_name = config['BG']['DATA_NAME'] | ||
12 | |||
13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
14 | |||
15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
16 | result_file = open(result_file_path, 'w') | ||
17 | |||
18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
20 | |||
21 | X_val = [] | ||
22 | Y_val = [] | ||
23 | Y_names = [] | ||
24 | for j in range(len(val_annotation_pairs)): | ||
25 | pair = val_annotation_pairs[j] | ||
26 | X_val.append(np.squeeze(pair[0])) | ||
27 | Y_val.append(pair[1]) | ||
28 | Y_names.append(pair[2]) | ||
29 | |||
30 | X_val = np.array(X_val) | ||
31 | y_pred = xgboost_model.predict_proba(X_val) | ||
32 | |||
33 | for i, Y_name in enumerate(Y_names): | ||
34 | result_file.write(Y_name + ' ') | ||
35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
36 | |||
37 | result_file.close() | ||
38 | |||
39 | |||
40 | |||
41 | |||
42 |
class_filter.py
0 → 100644
1 | import os | ||
2 | import pickle | ||
3 | import numpy as np | ||
4 | |||
5 | |||
6 | def start_filter(config): | ||
7 | |||
8 | cls_class_path = config['MODEL']['CLS_CLASS'] | ||
9 | feature_save_dir = config['VIDEO']['CLASS_FEATURE_DIR'] | ||
10 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
11 | result_file_name = config['CLASS']['RESULT_FILE'] | ||
12 | feature_name = config['CLASS']['DATA_NAME'] | ||
13 | |||
14 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
15 | |||
16 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
17 | result_file = open(result_file_path, 'w') | ||
18 | |||
19 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
20 | val_annotation_pairs = np.load(feature_path, allow_pickle=True) | ||
21 | |||
22 | X_val = [] | ||
23 | Y_val = [] | ||
24 | Y_names = [] | ||
25 | for j in range(len(val_annotation_pairs)): | ||
26 | pair = val_annotation_pairs[j] | ||
27 | X_val.append(pair[0]) | ||
28 | Y_val.append(pair[1]) | ||
29 | Y_names.append(pair[2]) | ||
30 | |||
31 | X_val = np.array(X_val) | ||
32 | y_pred = xgboost_model.predict(X_val) | ||
33 | |||
34 | for i, Y_name in enumerate(Y_names): | ||
35 | result_file.write(Y_name + ' ') | ||
36 | result_file.write(str(y_pred[i]) + '\n') | ||
37 | |||
38 | result_file.close() |
config.yaml
0 → 100644
1 | MODEL: | ||
2 | CLS_FIGHTING_2: '/home/jwq/models/cls_fighting_2/cls_fighting_2_v0.0.1.pth' | ||
3 | CLS_EMOTION: '/home/jwq/models/cls_emotion/v0.1.0.m' | ||
4 | FEATURE_EMOTION: '/home/jwq/models/feature_emotion/FerPlus3.h5' | ||
5 | CLS_AUDIO: '/home/jwq/models/cls_audio/v0.0.1.m' | ||
6 | CLS_CLASS: '/home/jwq/models/cls_class/v_0.0.1_xgb.pkl' | ||
7 | CLS_VIDEO: '/home/jwq/models/cls_video/v0.4.1.pth' | ||
8 | CLS_POSE: '/home/jwq/models/cls_pose/v0.0.1.pth' | ||
9 | CLS_FLOW: '/home/jwq/models/cls_flow/v0.1.1.pth' | ||
10 | CLS_BG: '/home/jwq/models/cls_bg/v0.1.1.pkl' | ||
11 | CLS_PERSON: '/home/jwq/models/cls_person/v0.1.1.pkl' | ||
12 | |||
13 | THRESHOLD: | ||
14 | FACES_THRESHOLD: 0.6 | ||
15 | |||
16 | FILTER: | ||
17 | |||
18 | |||
19 | VIDEO: | ||
20 | VIDEO_DIR: '/home/jwq/Desktop/VGAF_EmotiW/Val' | ||
21 | LABEL_PATH: '/home/jwq/Desktop/VGAF_EmotiW/Val_labels.txt' | ||
22 | VIDEO_SAVE_DIR: '/home/jwq/Desktop/tmp/video' | ||
23 | AUDIO_SAVE_DIR: '/home/jwq/npys/' | ||
24 | FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/frame' | ||
25 | # FRAME_SAVE_DIR: '/home/jwq/Desktop/VGAF_EmotiW_class/train_frame' | ||
26 | FLOW_SAVE_DIR: '/home/jwq/Desktop/tmp/flow' | ||
27 | POSE_FRAME_SAVE_DIR: '/home/jwq/Desktop/tmp/pose_frame' | ||
28 | FRAME_LIST_DIR: '/home/jwq/Desktop/tmp/file_list' | ||
29 | IS10_FEATURE_NP_DIR: '/home/jwq/npys' | ||
30 | IS10_FEATURE_CSV_DIR: '/home/jwq/Desktop/tmp/is10' | ||
31 | # FACE_FEATURE_DIR: '/home/jwq/Desktop/tmp/face_feature_retina' | ||
32 | # FACE_FEATURE_DIR: '/data2/retinaface/random_face_frame_features/' | ||
33 | FACE_FEATURE_DIR: '/data1/segment/' | ||
34 | # FACE_FEATURE_DIR: '/home/jwq/npys/' | ||
35 | FACE_IMAGE_DIR: '/data2/retinaface/train/' | ||
36 | CLASS_FEATURE_DIR: '/home/jwq/Desktop/tmp/class' | ||
37 | PREFIX: 'img_{:05d}.jpg' | ||
38 | FLOW_PREFIX: 'flow_{}_{:05d}.jpg' | ||
39 | THREAD_NUM: 10 | ||
40 | FPS: 5 | ||
41 | |||
42 | VIDEO_FILTER: | ||
43 | TEST_SEGMENT: 8 | ||
44 | TEST_CROP: 1 | ||
45 | BATCH_SIZE: 1 | ||
46 | INPUT_SIZE: 224 | ||
47 | MODALITY: 'RGB' | ||
48 | ARCH: 'resnet50' | ||
49 | RESULT_FILE: 'video_filter.txt' | ||
50 | |||
51 | VIDEO_1_FILTER: | ||
52 | TEST_SEGMENT: 8 | ||
53 | TEST_CROP: 1 | ||
54 | BATCH_SIZE: 1 | ||
55 | INPUT_SIZE: 224 | ||
56 | MODALITY: 'RGB' | ||
57 | ARCH: 'resnet34' | ||
58 | RESULT_FILE: 'video_1_filter.txt' | ||
59 | |||
60 | EMOTION: | ||
61 | INTERVAL: 1 | ||
62 | INPUT_SIZE: 224 | ||
63 | RESULT_FILE: 'emotion_filter.txt' | ||
64 | |||
65 | EMOTION_1: | ||
66 | RESULT_FILE: 'emotion_1_filter.txt' | ||
67 | DATA_NAME: 'val.npy' | ||
68 | |||
69 | ARGUE: | ||
70 | DIMENSION: 1582 | ||
71 | RESULT_FILE: 'argue_filter.txt' | ||
72 | |||
73 | FIGHTING: | ||
74 | TEST_SEGMENT: 8 | ||
75 | TEST_CROP: 1 | ||
76 | BATCH_SIZE: 1 | ||
77 | INPUT_SIZE: 224 | ||
78 | MODALITY: 'RGB' | ||
79 | ARCH: 'resnet50' | ||
80 | RESULT_FILE: 'fighting_filter.txt' | ||
81 | |||
82 | FIGHTING_2: | ||
83 | TEST_SEGMENT: 8 | ||
84 | TEST_CROP: 1 | ||
85 | BATCH_SIZE: 1 | ||
86 | INPUT_SIZE: 224 | ||
87 | MODALITY: 'RGB' | ||
88 | ARCH: 'resnet50' | ||
89 | RESULT_FILE: 'fighting_2_filter.txt' | ||
90 | |||
91 | MEETING: | ||
92 | TEST_SEGMENT: 8 | ||
93 | TEST_CROP: 1 | ||
94 | BATCH_SIZE: 1 | ||
95 | INPUT_SIZE: 224 | ||
96 | MODALITY: 'RGB' | ||
97 | ARCH: 'resnet50' | ||
98 | RESULT_FILE: 'meeting_filter.txt' | ||
99 | |||
100 | TROOPS: | ||
101 | TEST_SEGMENT: 8 | ||
102 | TEST_CROP: 1 | ||
103 | BATCH_SIZE: 1 | ||
104 | INPUT_SIZE: 224 | ||
105 | MODALITY: 'RGB' | ||
106 | ARCH: 'resnet50' | ||
107 | RESULT_FILE: 'troops_filter.txt' | ||
108 | |||
109 | FLOW: | ||
110 | TEST_SEGMENT: 8 | ||
111 | TEST_CROP: 1 | ||
112 | BATCH_SIZE: 1 | ||
113 | INPUT_SIZE: 224 | ||
114 | MODALITY: 'Flow' | ||
115 | ARCH: 'resnet50' | ||
116 | RESULT_FILE: 'flow_filter.txt' | ||
117 | |||
118 | |||
119 | FINAL: | ||
120 | RESULT_FILE: 'final.txt' | ||
121 | ERROR_FILE: 'error.txt' | ||
122 | SIM_FILE: 'image_sim.txt' | ||
123 | |||
124 | AUDIO: | ||
125 | RESULT_FILE: 'audio.txt' | ||
126 | OPENSMILE_DIR: '/home/jwq/Downloads/opensmile-2.3.0' | ||
127 | DATA_NAME: 'val.npy' | ||
128 | |||
129 | CLASS: | ||
130 | RESULT_FILE: 'class.txt' | ||
131 | DATA_NAME: 'val _reannotation.npy' | ||
132 | |||
133 | POSE: | ||
134 | TEST_SEGMENT: 8 | ||
135 | TEST_CROP: 1 | ||
136 | BATCH_SIZE: 1 | ||
137 | INPUT_SIZE: 224 | ||
138 | MODALITY: 'RGB' | ||
139 | ARCH: 'resnet50' | ||
140 | RESULT_FILE: 'pose_filter.txt' | ||
141 | |||
142 | BG: | ||
143 | RESULT_FILE: 'bg_filter.txt' | ||
144 | DATA_NAME: 'bg_val_feature.npy' | ||
145 | |||
146 | PERSON: | ||
147 | RESULT_FILE: 'person_filter.txt' | ||
148 | DATA_NAME: 'person_val_feature.npy' | ||
149 | |||
150 |
emotion_filter.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import numpy as np | ||
4 | from keras.models import Model | ||
5 | from keras.models import load_model | ||
6 | from sklearn.externals import joblib | ||
7 | from tensorflow.keras.preprocessing.image import img_to_array | ||
8 | |||
9 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' | ||
10 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | ||
11 | |||
12 | |||
13 | class FeatureExtractor(object): | ||
14 | def __init__(self, input_size=224, out_put_layer='avg_pool', model_path='FerPlus3.h5'): | ||
15 | self.model = load_model(model_path) | ||
16 | self.input_size = input_size | ||
17 | self.model_inter = Model(inputs=self.model.input, outputs=self.model.get_layer(out_put_layer).output) | ||
18 | |||
19 | def inference(self, image): | ||
20 | image = cv2.resize(image, (self.input_size, self.input_size)) | ||
21 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
22 | image = image.astype("float") / 255.0 | ||
23 | image = img_to_array(image) | ||
24 | image = np.expand_dims(image, axis=0) | ||
25 | feature = self.model_inter.predict(image)[0] | ||
26 | return feature | ||
27 | |||
28 | |||
29 | def features2feature(pics_features): | ||
30 | |||
31 | pics_features = np.array(pics_features) | ||
32 | fea_mean = pics_features.mean(axis=0) | ||
33 | fea_max = np.amax(pics_features, axis=0) | ||
34 | fea_min = np.amin(pics_features, axis=0) | ||
35 | fea_std = pics_features.std(axis=0) | ||
36 | |||
37 | return np.concatenate((fea_mean, fea_max, fea_min, fea_std), axis=1).reshape(1, -1) | ||
38 | |||
39 | |||
40 | def start_filter(config): | ||
41 | |||
42 | cls_emotion_path = config['MODEL']['CLS_EMOTION'] | ||
43 | face_feature_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
44 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
45 | result_file_name = config['EMOTION']['RESULT_FILE'] | ||
46 | |||
47 | svm_clf = joblib.load(cls_emotion_path) | ||
48 | |||
49 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
50 | result_file = open(result_file_path, 'w') | ||
51 | |||
52 | face_feature_names = os.listdir(face_feature_dir) | ||
53 | for face_feature in face_feature_names: | ||
54 | face_feature_path = os.path.join(face_feature_dir, face_feature) | ||
55 | |||
56 | features_np = np.load(face_feature_path, allow_pickle=True) | ||
57 | |||
58 | feature = features2feature(features_np) | ||
59 | res = svm_clf.predict_proba(feature) | ||
60 | proba = np.squeeze(res) | ||
61 | # class_pre = svm_clf.predict(feature) | ||
62 | |||
63 | result_file.write(face_feature[:-4] + ' ') | ||
64 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
65 | |||
66 | result_file.close() | ||
67 | |||
68 | |||
69 | |||
70 | |||
71 |
fighting_2_filter.py
0 → 100644
1 | import os | ||
2 | import torch.optim | ||
3 | import numpy as np | ||
4 | import torch.optim | ||
5 | import torch.nn.parallel | ||
6 | from ops.models import TSN | ||
7 | from ops.transforms import * | ||
8 | from ops.dataset import TSNDataSet | ||
9 | from torch.nn import functional as F | ||
10 | |||
11 | |||
12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
13 | |||
14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
15 | video_names = os.listdir(frame_save_dir) | ||
16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
17 | |||
18 | for video_name in video_names: | ||
19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
20 | ucf101_rgb_val_file.write(video_name) | ||
21 | ucf101_rgb_val_file.write(' ') | ||
22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
23 | ucf101_rgb_val_file.write('\n') | ||
24 | |||
25 | ucf101_rgb_val_file.close() | ||
26 | |||
27 | return val_path | ||
28 | |||
29 | |||
30 | def start_filter(config): | ||
31 | arch = config['FIGHTING_2']['ARCH'] | ||
32 | prefix = config['VIDEO']['PREFIX'] | ||
33 | modality = config['FIGHTING_2']['MODALITY'] | ||
34 | test_crop = config['FIGHTING_2']['TEST_CROP'] | ||
35 | batch_size = config['FIGHTING_2']['BATCH_SIZE'] | ||
36 | weights_path = config['MODEL']['CLS_FIGHTING_2'] | ||
37 | test_segment = config['FIGHTING_2']['TEST_SEGMENT'] | ||
38 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
40 | result_file_name = config['FIGHTING_2']['RESULT_FILE'] | ||
41 | |||
42 | workers = 8 | ||
43 | num_class = 2 | ||
44 | shift_div = 8 | ||
45 | img_feature_dim = 256 | ||
46 | |||
47 | softmax = False | ||
48 | is_shift = True | ||
49 | full_res = False | ||
50 | non_local = False | ||
51 | dense_sample = False | ||
52 | twice_sample = False | ||
53 | |||
54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
56 | |||
57 | pretrain = 'imagenet' | ||
58 | shift_place = 'blockres' | ||
59 | crop_fusion_type = 'avg' | ||
60 | |||
61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
62 | base_model=arch, | ||
63 | consensus_type=crop_fusion_type, | ||
64 | img_feature_dim=img_feature_dim, | ||
65 | pretrain=pretrain, | ||
66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
67 | non_local=non_local, | ||
68 | ) | ||
69 | |||
70 | checkpoint = torch.load(weights_path) | ||
71 | checkpoint = checkpoint['state_dict'] | ||
72 | |||
73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
76 | } | ||
77 | for k, v in replace_dict.items(): | ||
78 | if k in base_dict: | ||
79 | base_dict[v] = base_dict.pop(k) | ||
80 | |||
81 | net.load_state_dict(base_dict) | ||
82 | |||
83 | input_size = net.scale_size if full_res else net.input_size | ||
84 | |||
85 | if test_crop == 1: | ||
86 | cropping = torchvision.transforms.Compose([ | ||
87 | GroupScale(net.scale_size), | ||
88 | GroupCenterCrop(input_size), | ||
89 | ]) | ||
90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
91 | cropping = torchvision.transforms.Compose([ | ||
92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
93 | ]) | ||
94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
95 | cropping = torchvision.transforms.Compose([ | ||
96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
97 | ]) | ||
98 | elif test_crop == 10: | ||
99 | cropping = torchvision.transforms.Compose([ | ||
100 | GroupOverSample(input_size, net.scale_size) | ||
101 | ]) | ||
102 | else: | ||
103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
104 | |||
105 | data_loader = torch.utils.data.DataLoader( | ||
106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
107 | new_length=1 if modality == "RGB" else 5, | ||
108 | modality=modality, | ||
109 | image_tmpl=prefix, | ||
110 | test_mode=True, | ||
111 | remove_missing=False, | ||
112 | transform=torchvision.transforms.Compose([ | ||
113 | cropping, | ||
114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
116 | GroupNormalize(net.input_mean, net.input_std), | ||
117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
118 | batch_size=batch_size, shuffle=False, | ||
119 | num_workers=workers, pin_memory=True, | ||
120 | ) | ||
121 | |||
122 | net = torch.nn.DataParallel(net.cuda()) | ||
123 | net.eval() | ||
124 | data_gen = enumerate(data_loader) | ||
125 | max_num = len(data_loader.dataset) | ||
126 | |||
127 | result_file = open(result_file_path, 'w') | ||
128 | |||
129 | for i, data_pair in data_gen: | ||
130 | directory, data = data_pair | ||
131 | with torch.no_grad(): | ||
132 | if i >= max_num: | ||
133 | break | ||
134 | num_crop = test_crop | ||
135 | if dense_sample: | ||
136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
137 | |||
138 | if twice_sample: | ||
139 | num_crop *= 2 | ||
140 | |||
141 | if modality == 'RGB': | ||
142 | length = 3 | ||
143 | elif modality == 'Flow': | ||
144 | length = 10 | ||
145 | elif modality == 'RGBDiff': | ||
146 | length = 18 | ||
147 | else: | ||
148 | raise ValueError("Unknown modality " + modality) | ||
149 | |||
150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
151 | if is_shift: | ||
152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
153 | rst, feature = net(data_in) | ||
154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
155 | |||
156 | if softmax: | ||
157 | # take the softmax to normalize the output to probability | ||
158 | rst = F.softmax(rst, dim=1) | ||
159 | |||
160 | rst = rst.data.cpu().numpy().copy() | ||
161 | |||
162 | if net.module.is_shift: | ||
163 | rst = rst.reshape(batch_size, num_class) | ||
164 | else: | ||
165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
166 | |||
167 | proba = np.squeeze(rst) | ||
168 | print(proba) | ||
169 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
170 | result_file.write(str(directory[0]) + ' ') | ||
171 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + '\n') | ||
172 | |||
173 | result_file.close() | ||
174 | print('fighting filter end') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
flow_filter.py
0 → 100644
1 | import os | ||
2 | import torch.optim | ||
3 | import numpy as np | ||
4 | import torch.optim | ||
5 | import torch.nn.parallel | ||
6 | from ops.models import TSN | ||
7 | from ops.transforms import * | ||
8 | from ops.dataset import TSNDataSet | ||
9 | from torch.nn import functional as F | ||
10 | |||
11 | |||
12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
13 | |||
14 | val_path = os.path.join(frame_list_dir, 'flow_val.txt') | ||
15 | video_names = os.listdir(frame_save_dir) | ||
16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
17 | |||
18 | for video_name in video_names: | ||
19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
20 | ucf101_rgb_val_file.write(video_name) | ||
21 | ucf101_rgb_val_file.write(' ') | ||
22 | ori_list = os.listdir(images_dir) | ||
23 | select_list = [element for element in ori_list if 'x' in element] | ||
24 | ucf101_rgb_val_file.write(str(len(select_list))) | ||
25 | ucf101_rgb_val_file.write('\n') | ||
26 | |||
27 | ucf101_rgb_val_file.close() | ||
28 | |||
29 | return val_path | ||
30 | |||
31 | |||
32 | def start_filter(config): | ||
33 | arch = config['FLOW']['ARCH'] | ||
34 | prefix = config['VIDEO']['FLOW_PREFIX'] | ||
35 | modality = config['FLOW']['MODALITY'] | ||
36 | test_crop = config['FLOW']['TEST_CROP'] | ||
37 | batch_size = config['FLOW']['BATCH_SIZE'] | ||
38 | weights_path = config['MODEL']['CLS_FLOW'] | ||
39 | test_segment = config['FLOW']['TEST_SEGMENT'] | ||
40 | frame_save_dir = config['VIDEO']['FLOW_SAVE_DIR'] | ||
41 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
42 | result_file_name = config['FLOW']['RESULT_FILE'] | ||
43 | |||
44 | workers = 8 | ||
45 | num_class = 3 | ||
46 | shift_div = 8 | ||
47 | img_feature_dim = 256 | ||
48 | |||
49 | softmax = False | ||
50 | is_shift = True | ||
51 | full_res = False | ||
52 | non_local = False | ||
53 | dense_sample = False | ||
54 | twice_sample = False | ||
55 | |||
56 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
57 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
58 | |||
59 | pretrain = 'imagenet' | ||
60 | shift_place = 'blockres' | ||
61 | crop_fusion_type = 'avg' | ||
62 | |||
63 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
64 | base_model=arch, | ||
65 | consensus_type=crop_fusion_type, | ||
66 | img_feature_dim=img_feature_dim, | ||
67 | pretrain=pretrain, | ||
68 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
69 | non_local=non_local, | ||
70 | ) | ||
71 | |||
72 | checkpoint = torch.load(weights_path) | ||
73 | checkpoint = checkpoint['state_dict'] | ||
74 | |||
75 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
76 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
77 | 'base_model.classifier.bias': 'new_fc.bias', | ||
78 | } | ||
79 | for k, v in replace_dict.items(): | ||
80 | if k in base_dict: | ||
81 | base_dict[v] = base_dict.pop(k) | ||
82 | |||
83 | net.load_state_dict(base_dict) | ||
84 | |||
85 | input_size = net.scale_size if full_res else net.input_size | ||
86 | |||
87 | if test_crop == 1: | ||
88 | cropping = torchvision.transforms.Compose([ | ||
89 | GroupScale(net.scale_size), | ||
90 | GroupCenterCrop(input_size), | ||
91 | ]) | ||
92 | elif test_crop == 3: # do not flip, so only 5 crops | ||
93 | cropping = torchvision.transforms.Compose([ | ||
94 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
95 | ]) | ||
96 | elif test_crop == 5: # do not flip, so only 5 crops | ||
97 | cropping = torchvision.transforms.Compose([ | ||
98 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
99 | ]) | ||
100 | elif test_crop == 10: | ||
101 | cropping = torchvision.transforms.Compose([ | ||
102 | GroupOverSample(input_size, net.scale_size) | ||
103 | ]) | ||
104 | else: | ||
105 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
106 | |||
107 | data_loader = torch.utils.data.DataLoader( | ||
108 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
109 | new_length=1 if modality == "RGB" else 5, | ||
110 | modality=modality, | ||
111 | image_tmpl=prefix, | ||
112 | test_mode=True, | ||
113 | remove_missing=False, | ||
114 | transform=torchvision.transforms.Compose([ | ||
115 | cropping, | ||
116 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
117 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
118 | GroupNormalize(net.input_mean, net.input_std), | ||
119 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
120 | batch_size=batch_size, shuffle=False, | ||
121 | num_workers=workers, pin_memory=True, | ||
122 | ) | ||
123 | |||
124 | net = torch.nn.DataParallel(net.cuda()) | ||
125 | net.eval() | ||
126 | data_gen = enumerate(data_loader) | ||
127 | max_num = len(data_loader.dataset) | ||
128 | |||
129 | result_file = open(result_file_path, 'w') | ||
130 | |||
131 | for i, data_pair in data_gen: | ||
132 | directory, data = data_pair | ||
133 | with torch.no_grad(): | ||
134 | if i >= max_num: | ||
135 | break | ||
136 | num_crop = test_crop | ||
137 | if dense_sample: | ||
138 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
139 | |||
140 | if twice_sample: | ||
141 | num_crop *= 2 | ||
142 | |||
143 | if modality == 'RGB': | ||
144 | length = 3 | ||
145 | elif modality == 'Flow': | ||
146 | length = 10 | ||
147 | elif modality == 'RGBDiff': | ||
148 | length = 18 | ||
149 | else: | ||
150 | raise ValueError("Unknown modality " + modality) | ||
151 | |||
152 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
153 | if is_shift: | ||
154 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
155 | rst, feature = net(data_in) | ||
156 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
157 | |||
158 | if softmax: | ||
159 | # take the softmax to normalize the output to probability | ||
160 | rst = F.softmax(rst, dim=1) | ||
161 | |||
162 | rst = rst.data.cpu().numpy().copy() | ||
163 | |||
164 | if net.module.is_shift: | ||
165 | rst = rst.reshape(batch_size, num_class) | ||
166 | else: | ||
167 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
168 | |||
169 | proba = np.squeeze(rst) | ||
170 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
171 | result_file.write(str(directory[0]) + ' ') | ||
172 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
173 | |||
174 | result_file.close() | ||
175 | print('fighting filter end') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
load_util.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import yaml | ||
4 | import tensorflow as tf | ||
5 | |||
6 | |||
7 | def load_config(config_path): | ||
8 | with open(config_path, 'r') as cf: | ||
9 | config_obj = yaml.load(cf, Loader=yaml.FullLoader) | ||
10 | print(config_obj) | ||
11 | return config_obj | ||
12 | |||
13 | |||
14 | def load_argue_model(config): | ||
15 | |||
16 | cls_argue_path = config['MODEL']['CLS_ARGUE'] | ||
17 | with tf.Graph().as_default(): | ||
18 | |||
19 | if os.path.isfile(cls_argue_path): | ||
20 | print('Model filename: %s' % cls_argue_path) | ||
21 | with tf.gfile.GFile(cls_argue_path, 'rb') as f: | ||
22 | graph_def = tf.GraphDef() | ||
23 | graph_def.ParseFromString(f.read()) | ||
24 | tf.import_graph_def(graph_def, name='') | ||
25 | |||
26 | x = tf.get_default_graph().get_tensor_by_name("x_batch:0") | ||
27 | output = tf.get_default_graph().get_tensor_by_name("output/BiasAdd:0") | ||
28 | |||
29 | config = tf.ConfigProto() | ||
30 | config.gpu_options.allow_growth = False | ||
31 | sess = tf.Session(config=config) | ||
32 | |||
33 | return x, output, sess |
media_util.py
0 → 100644
This diff is collapsed.
Click to expand it.
ops/__init__.py
0 → 100755
1 | from ops.basic_ops import * | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ops/__pycache__/__init__.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/basic_ops.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/dataset.cpython-36.pyc
0 → 100644
No preview for this file type
ops/__pycache__/models.cpython-36.pyc
0 → 100644
No preview for this file type
No preview for this file type
ops/__pycache__/transforms.cpython-36.pyc
0 → 100644
No preview for this file type
ops/basic_ops.py
0 → 100755
1 | import torch | ||
2 | |||
3 | |||
4 | class Identity(torch.nn.Module): | ||
5 | def forward(self, input): | ||
6 | return input | ||
7 | |||
8 | |||
9 | class SegmentConsensus(torch.nn.Module): | ||
10 | |||
11 | def __init__(self, consensus_type, dim=1): | ||
12 | super(SegmentConsensus, self).__init__() | ||
13 | self.consensus_type = consensus_type | ||
14 | self.dim = dim | ||
15 | self.shape = None | ||
16 | |||
17 | def forward(self, input_tensor): | ||
18 | self.shape = input_tensor.size() | ||
19 | if self.consensus_type == 'avg': | ||
20 | output = input_tensor.mean(dim=self.dim, keepdim=True) | ||
21 | elif self.consensus_type == 'identity': | ||
22 | output = input_tensor | ||
23 | else: | ||
24 | output = None | ||
25 | |||
26 | return output | ||
27 | |||
28 | |||
29 | class ConsensusModule(torch.nn.Module): | ||
30 | |||
31 | def __init__(self, consensus_type, dim=1): | ||
32 | super(ConsensusModule, self).__init__() | ||
33 | self.consensus_type = consensus_type if consensus_type != 'rnn' else 'identity' | ||
34 | self.dim = dim | ||
35 | |||
36 | def forward(self, input): | ||
37 | return SegmentConsensus(self.consensus_type, self.dim)(input) |
ops/dataset.py
0 → 100755
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
2 | # arXiv:1811.08383 | ||
3 | # Ji Lin*, Chuang Gan, Song Han | ||
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
5 | |||
6 | import torch.utils.data as data | ||
7 | |||
8 | from PIL import Image | ||
9 | import os | ||
10 | import numpy as np | ||
11 | from numpy.random import randint | ||
12 | |||
13 | |||
14 | class VideoRecord(object): | ||
15 | def __init__(self, row): | ||
16 | self._data = row | ||
17 | |||
18 | @property | ||
19 | def path(self): | ||
20 | return self._data[0] | ||
21 | |||
22 | @property | ||
23 | def num_frames(self): | ||
24 | return int(self._data[1]) | ||
25 | |||
26 | |||
27 | class TSNDataSet(data.Dataset): | ||
28 | def __init__(self, root_path, list_file, | ||
29 | num_segments=3, new_length=1, modality='RGB', | ||
30 | image_tmpl='img_{:05d}.jpg', transform=None, | ||
31 | random_shift=True, test_mode=False, | ||
32 | remove_missing=False, dense_sample=False, twice_sample=False): | ||
33 | |||
34 | self.root_path = root_path | ||
35 | self.list_file = list_file | ||
36 | self.num_segments = num_segments | ||
37 | self.new_length = new_length | ||
38 | self.modality = modality | ||
39 | self.image_tmpl = image_tmpl | ||
40 | self.transform = transform | ||
41 | self.random_shift = random_shift | ||
42 | self.test_mode = test_mode | ||
43 | self.remove_missing = remove_missing | ||
44 | self.dense_sample = dense_sample # using dense sample as I3D | ||
45 | self.twice_sample = twice_sample # twice sample for more validation | ||
46 | if self.dense_sample: | ||
47 | print('=> Using dense sample for the dataset...') | ||
48 | if self.twice_sample: | ||
49 | print('=> Using twice sample for the dataset...') | ||
50 | |||
51 | if self.modality == 'RGBDiff': | ||
52 | self.new_length += 1 # Diff needs one more image to calculate diff | ||
53 | |||
54 | self._parse_list() | ||
55 | |||
56 | def _load_image(self, directory, idx): | ||
57 | if self.modality == 'RGB' or self.modality == 'RGBDiff': | ||
58 | try: | ||
59 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] | ||
60 | except Exception: | ||
61 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
62 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] | ||
63 | elif self.modality == 'Flow': | ||
64 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': # ucf | ||
65 | x_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('x', idx))).convert( | ||
66 | 'L') | ||
67 | y_img = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format('y', idx))).convert( | ||
68 | 'L') | ||
69 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': # something v1 flow | ||
70 | x_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
71 | format(int(directory), 'x', idx))).convert('L') | ||
72 | y_img = Image.open(os.path.join(self.root_path, '{:06d}'.format(int(directory)), self.image_tmpl. | ||
73 | format(int(directory), 'y', idx))).convert('L') | ||
74 | else: | ||
75 | try: | ||
76 | # idx_skip = 1 + (idx-1)*5 | ||
77 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert( | ||
78 | 'RGB') | ||
79 | except Exception: | ||
80 | print('error loading flow file:', | ||
81 | os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) | ||
82 | flow = Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB') | ||
83 | # the input flow file is RGB image with (flow_x, flow_y, blank) for each channel | ||
84 | flow_x, flow_y, _ = flow.split() | ||
85 | x_img = flow_x.convert('L') | ||
86 | y_img = flow_y.convert('L') | ||
87 | |||
88 | return [x_img, y_img] | ||
89 | |||
90 | def _parse_list(self): | ||
91 | # check the frame number is large >3: | ||
92 | tmp = [x.strip().split(' ') for x in open(self.list_file)] | ||
93 | if not self.test_mode or self.remove_missing: | ||
94 | tmp = [item for item in tmp if int(item[1]) >= 3] | ||
95 | self.video_list = [VideoRecord(item) for item in tmp] | ||
96 | |||
97 | if self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
98 | for v in self.video_list: | ||
99 | v._data[1] = int(v._data[1]) / 2 | ||
100 | print('video number:%d' % (len(self.video_list))) | ||
101 | |||
102 | def _sample_indices(self, record): | ||
103 | """ | ||
104 | |||
105 | :param record: VideoRecord | ||
106 | :return: list | ||
107 | """ | ||
108 | if self.dense_sample: # i3d dense sample | ||
109 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
110 | t_stride = 64 // self.num_segments | ||
111 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
112 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
113 | return np.array(offsets) + 1 | ||
114 | else: # normal sample | ||
115 | average_duration = (record.num_frames - self.new_length + 1) // self.num_segments | ||
116 | if average_duration > 0: | ||
117 | offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration, | ||
118 | size=self.num_segments) | ||
119 | elif record.num_frames > self.num_segments: | ||
120 | offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments)) | ||
121 | else: | ||
122 | offsets = np.zeros((self.num_segments,)) | ||
123 | return offsets + 1 | ||
124 | |||
125 | def _get_val_indices(self, record): | ||
126 | if self.dense_sample: # i3d dense sample | ||
127 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
128 | t_stride = 64 // self.num_segments | ||
129 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) | ||
130 | offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
131 | return np.array(offsets) + 1 | ||
132 | else: | ||
133 | if record.num_frames > self.num_segments + self.new_length - 1: | ||
134 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
135 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
136 | else: | ||
137 | offsets = np.zeros((self.num_segments,)) | ||
138 | return offsets + 1 | ||
139 | |||
140 | def _get_test_indices(self, record): | ||
141 | if self.dense_sample: | ||
142 | sample_pos = max(1, 1 + record.num_frames - 64) | ||
143 | t_stride = 64 // self.num_segments | ||
144 | start_list = np.linspace(0, sample_pos - 1, num=10, dtype=int) | ||
145 | offsets = [] | ||
146 | for start_idx in start_list.tolist(): | ||
147 | offsets += [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)] | ||
148 | return np.array(offsets) + 1 | ||
149 | elif self.twice_sample: | ||
150 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
151 | |||
152 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)] + | ||
153 | [int(tick * x) for x in range(self.num_segments)]) | ||
154 | |||
155 | return offsets + 1 | ||
156 | else: | ||
157 | tick = (record.num_frames - self.new_length + 1) / float(self.num_segments) | ||
158 | offsets = np.array([int(tick / 2.0 + tick * x) for x in range(self.num_segments)]) | ||
159 | return offsets + 1 | ||
160 | |||
161 | def __getitem__(self, index): | ||
162 | record = self.video_list[index] | ||
163 | # check this is a legit video folder | ||
164 | |||
165 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
166 | file_name = self.image_tmpl.format('x', 1) | ||
167 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
168 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
169 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
170 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
171 | else: | ||
172 | file_name = self.image_tmpl.format(1) | ||
173 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
174 | |||
175 | while not os.path.exists(full_path): | ||
176 | print('################## Not Found:', os.path.join(self.root_path, record.path, file_name)) | ||
177 | index = np.random.randint(len(self.video_list)) | ||
178 | record = self.video_list[index] | ||
179 | if self.image_tmpl == 'flow_{}_{:05d}.jpg': | ||
180 | file_name = self.image_tmpl.format('x', 1) | ||
181 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
182 | elif self.image_tmpl == '{:06d}-{}_{:05d}.jpg': | ||
183 | file_name = self.image_tmpl.format(int(record.path), 'x', 1) | ||
184 | full_path = os.path.join(self.root_path, '{:06d}'.format(int(record.path)), file_name) | ||
185 | else: | ||
186 | file_name = self.image_tmpl.format(1) | ||
187 | full_path = os.path.join(self.root_path, record.path, file_name) | ||
188 | |||
189 | if not self.test_mode: | ||
190 | segment_indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) | ||
191 | else: | ||
192 | segment_indices = self._get_test_indices(record) | ||
193 | return self.get(record, segment_indices) | ||
194 | |||
195 | def get(self, record, indices): | ||
196 | |||
197 | images = list() | ||
198 | for seg_ind in indices: | ||
199 | p = int(seg_ind) | ||
200 | for i in range(self.new_length): | ||
201 | seg_imgs = self._load_image(record.path, p) | ||
202 | images.extend(seg_imgs) | ||
203 | if p < record.num_frames: | ||
204 | p += 1 | ||
205 | |||
206 | process_data = self.transform(images) | ||
207 | return record.path, process_data | ||
208 | |||
209 | def __len__(self): | ||
210 | return len(self.video_list) |
ops/dataset_config.py
0 → 100755
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
2 | # arXiv:1811.08383 | ||
3 | # Ji Lin*, Chuang Gan, Song Han | ||
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
5 | |||
6 | import os | ||
7 | |||
8 | ROOT_DATASET = '/data1/action_1_images/' # '/data/jilin/' | ||
9 | |||
10 | |||
11 | def return_ucf101(modality): | ||
12 | filename_categories = 'labels/classInd.txt' | ||
13 | if modality == 'RGB': | ||
14 | root_data = ROOT_DATASET + 'images' | ||
15 | filename_imglist_train = 'file_list/ucf101_rgb_train_split_1.txt' | ||
16 | filename_imglist_val = 'file_list/ucf101_rgb_val_split_1.txt' | ||
17 | prefix = 'img_{:05d}.jpg' | ||
18 | elif modality == 'Flow': | ||
19 | root_data = ROOT_DATASET + 'UCF101/jpg' | ||
20 | filename_imglist_train = 'UCF101/file_list/ucf101_flow_train_split_1.txt' | ||
21 | filename_imglist_val = 'UCF101/file_list/ucf101_flow_val_split_1.txt' | ||
22 | prefix = 'flow_{}_{:05d}.jpg' | ||
23 | else: | ||
24 | raise NotImplementedError('no such modality:' + modality) | ||
25 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
26 | |||
27 | |||
28 | def return_hmdb51(modality): | ||
29 | filename_categories = 51 | ||
30 | if modality == 'RGB': | ||
31 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
32 | filename_imglist_train = 'HMDB51/splits/hmdb51_rgb_train_split_1.txt' | ||
33 | filename_imglist_val = 'HMDB51/splits/hmdb51_rgb_val_split_1.txt' | ||
34 | prefix = 'img_{:05d}.jpg' | ||
35 | elif modality == 'Flow': | ||
36 | root_data = ROOT_DATASET + 'HMDB51/images' | ||
37 | filename_imglist_train = 'HMDB51/splits/hmdb51_flow_train_split_1.txt' | ||
38 | filename_imglist_val = 'HMDB51/splits/hmdb51_flow_val_split_1.txt' | ||
39 | prefix = 'flow_{}_{:05d}.jpg' | ||
40 | else: | ||
41 | raise NotImplementedError('no such modality:' + modality) | ||
42 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
43 | |||
44 | |||
45 | def return_something(modality): | ||
46 | filename_categories = 'something/v1/category.txt' | ||
47 | if modality == 'RGB': | ||
48 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1' | ||
49 | filename_imglist_train = 'something/v1/train_videofolder.txt' | ||
50 | filename_imglist_val = 'something/v1/val_videofolder.txt' | ||
51 | prefix = '{:05d}.jpg' | ||
52 | elif modality == 'Flow': | ||
53 | root_data = ROOT_DATASET + 'something/v1/20bn-something-something-v1-flow' | ||
54 | filename_imglist_train = 'something/v1/train_videofolder_flow.txt' | ||
55 | filename_imglist_val = 'something/v1/val_videofolder_flow.txt' | ||
56 | prefix = '{:06d}-{}_{:05d}.jpg' | ||
57 | else: | ||
58 | print('no such modality:'+modality) | ||
59 | raise NotImplementedError | ||
60 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
61 | |||
62 | |||
63 | def return_somethingv2(modality): | ||
64 | filename_categories = 'something/v2/category.txt' | ||
65 | if modality == 'RGB': | ||
66 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-frames' | ||
67 | filename_imglist_train = 'something/v2/train_videofolder.txt' | ||
68 | filename_imglist_val = 'something/v2/val_videofolder.txt' | ||
69 | prefix = '{:06d}.jpg' | ||
70 | elif modality == 'Flow': | ||
71 | root_data = ROOT_DATASET + 'something/v2/20bn-something-something-v2-flow' | ||
72 | filename_imglist_train = 'something/v2/train_videofolder_flow.txt' | ||
73 | filename_imglist_val = 'something/v2/val_videofolder_flow.txt' | ||
74 | prefix = '{:06d}.jpg' | ||
75 | else: | ||
76 | raise NotImplementedError('no such modality:'+modality) | ||
77 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
78 | |||
79 | |||
80 | def return_jester(modality): | ||
81 | filename_categories = 'jester/category.txt' | ||
82 | if modality == 'RGB': | ||
83 | prefix = '{:05d}.jpg' | ||
84 | root_data = ROOT_DATASET + 'jester/20bn-jester-v1' | ||
85 | filename_imglist_train = 'jester/train_videofolder.txt' | ||
86 | filename_imglist_val = 'jester/val_videofolder.txt' | ||
87 | else: | ||
88 | raise NotImplementedError('no such modality:'+modality) | ||
89 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
90 | |||
91 | |||
92 | def return_kinetics(modality): | ||
93 | filename_categories = 400 | ||
94 | if modality == 'RGB': | ||
95 | root_data = ROOT_DATASET + 'kinetics/images' | ||
96 | filename_imglist_train = 'kinetics/labels/train_videofolder.txt' | ||
97 | filename_imglist_val = 'kinetics/labels/val_videofolder.txt' | ||
98 | prefix = 'img_{:05d}.jpg' | ||
99 | else: | ||
100 | raise NotImplementedError('no such modality:' + modality) | ||
101 | return filename_categories, filename_imglist_train, filename_imglist_val, root_data, prefix | ||
102 | |||
103 | |||
104 | def return_dataset(dataset, modality): | ||
105 | dict_single = {'jester': return_jester, 'something': return_something, 'somethingv2': return_somethingv2, | ||
106 | 'ucf101': return_ucf101, 'hmdb51': return_hmdb51, | ||
107 | 'kinetics': return_kinetics} | ||
108 | if dataset in dict_single: | ||
109 | file_categories, file_imglist_train, file_imglist_val, root_data, prefix = dict_single[dataset](modality) | ||
110 | else: | ||
111 | raise ValueError('Unknown dataset '+dataset) | ||
112 | |||
113 | file_imglist_train = os.path.join(ROOT_DATASET, file_imglist_train) | ||
114 | file_imglist_val = os.path.join(ROOT_DATASET, file_imglist_val) | ||
115 | if isinstance(file_categories, str): | ||
116 | file_categories = os.path.join(ROOT_DATASET, file_categories) | ||
117 | with open(file_categories) as f: | ||
118 | lines = f.readlines() | ||
119 | categories = [item.rstrip() for item in lines] | ||
120 | else: # number of categories | ||
121 | categories = [None] * file_categories | ||
122 | n_class = len(categories) | ||
123 | print('{}: {} classes'.format(dataset, n_class)) | ||
124 | return n_class, file_imglist_train, file_imglist_val, root_data, prefix |
ops/models.py
0 → 100755
This diff is collapsed.
Click to expand it.
ops/non_local.py
0 → 100644
1 | # Non-local block using embedded gaussian | ||
2 | # Code from | ||
3 | # https://github.com/AlexHex7/Non-local_pytorch/blob/master/Non-Local_pytorch_0.3.1/lib/non_local_embedded_gaussian.py | ||
4 | import torch | ||
5 | from torch import nn | ||
6 | from torch.nn import functional as F | ||
7 | |||
8 | |||
9 | class _NonLocalBlockND(nn.Module): | ||
10 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): | ||
11 | super(_NonLocalBlockND, self).__init__() | ||
12 | |||
13 | assert dimension in [1, 2, 3] | ||
14 | |||
15 | self.dimension = dimension | ||
16 | self.sub_sample = sub_sample | ||
17 | |||
18 | self.in_channels = in_channels | ||
19 | self.inter_channels = inter_channels | ||
20 | |||
21 | if self.inter_channels is None: | ||
22 | self.inter_channels = in_channels // 2 | ||
23 | if self.inter_channels == 0: | ||
24 | self.inter_channels = 1 | ||
25 | |||
26 | if dimension == 3: | ||
27 | conv_nd = nn.Conv3d | ||
28 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) | ||
29 | bn = nn.BatchNorm3d | ||
30 | elif dimension == 2: | ||
31 | conv_nd = nn.Conv2d | ||
32 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) | ||
33 | bn = nn.BatchNorm2d | ||
34 | else: | ||
35 | conv_nd = nn.Conv1d | ||
36 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) | ||
37 | bn = nn.BatchNorm1d | ||
38 | |||
39 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
40 | kernel_size=1, stride=1, padding=0) | ||
41 | |||
42 | if bn_layer: | ||
43 | self.W = nn.Sequential( | ||
44 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
45 | kernel_size=1, stride=1, padding=0), | ||
46 | bn(self.in_channels) | ||
47 | ) | ||
48 | nn.init.constant_(self.W[1].weight, 0) | ||
49 | nn.init.constant_(self.W[1].bias, 0) | ||
50 | else: | ||
51 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, | ||
52 | kernel_size=1, stride=1, padding=0) | ||
53 | nn.init.constant_(self.W.weight, 0) | ||
54 | nn.init.constant_(self.W.bias, 0) | ||
55 | |||
56 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
57 | kernel_size=1, stride=1, padding=0) | ||
58 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, | ||
59 | kernel_size=1, stride=1, padding=0) | ||
60 | |||
61 | if sub_sample: | ||
62 | self.g = nn.Sequential(self.g, max_pool_layer) | ||
63 | self.phi = nn.Sequential(self.phi, max_pool_layer) | ||
64 | |||
65 | def forward(self, x): | ||
66 | ''' | ||
67 | :param x: (b, c, t, h, w) | ||
68 | :return: | ||
69 | ''' | ||
70 | |||
71 | batch_size = x.size(0) | ||
72 | |||
73 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) | ||
74 | g_x = g_x.permute(0, 2, 1) | ||
75 | |||
76 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) | ||
77 | theta_x = theta_x.permute(0, 2, 1) | ||
78 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) | ||
79 | f = torch.matmul(theta_x, phi_x) | ||
80 | f_div_C = F.softmax(f, dim=-1) | ||
81 | |||
82 | y = torch.matmul(f_div_C, g_x) | ||
83 | y = y.permute(0, 2, 1).contiguous() | ||
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) | ||
85 | W_y = self.W(y) | ||
86 | z = W_y + x | ||
87 | |||
88 | return z | ||
89 | |||
90 | |||
91 | class NONLocalBlock1D(_NonLocalBlockND): | ||
92 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
93 | super(NONLocalBlock1D, self).__init__(in_channels, | ||
94 | inter_channels=inter_channels, | ||
95 | dimension=1, sub_sample=sub_sample, | ||
96 | bn_layer=bn_layer) | ||
97 | |||
98 | |||
99 | class NONLocalBlock2D(_NonLocalBlockND): | ||
100 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
101 | super(NONLocalBlock2D, self).__init__(in_channels, | ||
102 | inter_channels=inter_channels, | ||
103 | dimension=2, sub_sample=sub_sample, | ||
104 | bn_layer=bn_layer) | ||
105 | |||
106 | |||
107 | class NONLocalBlock3D(_NonLocalBlockND): | ||
108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): | ||
109 | super(NONLocalBlock3D, self).__init__(in_channels, | ||
110 | inter_channels=inter_channels, | ||
111 | dimension=3, sub_sample=sub_sample, | ||
112 | bn_layer=bn_layer) | ||
113 | |||
114 | |||
115 | class NL3DWrapper(nn.Module): | ||
116 | def __init__(self, block, n_segment): | ||
117 | super(NL3DWrapper, self).__init__() | ||
118 | self.block = block | ||
119 | self.nl = NONLocalBlock3D(block.bn3.num_features) | ||
120 | self.n_segment = n_segment | ||
121 | |||
122 | def forward(self, x): | ||
123 | x = self.block(x) | ||
124 | |||
125 | nt, c, h, w = x.size() | ||
126 | x = x.view(nt // self.n_segment, self.n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
127 | x = self.nl(x) | ||
128 | x = x.transpose(1, 2).contiguous().view(nt, c, h, w) | ||
129 | return x | ||
130 | |||
131 | |||
132 | def make_non_local(net, n_segment): | ||
133 | import torchvision | ||
134 | import archs | ||
135 | if isinstance(net, torchvision.models.ResNet): | ||
136 | net.layer2 = nn.Sequential( | ||
137 | NL3DWrapper(net.layer2[0], n_segment), | ||
138 | net.layer2[1], | ||
139 | NL3DWrapper(net.layer2[2], n_segment), | ||
140 | net.layer2[3], | ||
141 | ) | ||
142 | net.layer3 = nn.Sequential( | ||
143 | NL3DWrapper(net.layer3[0], n_segment), | ||
144 | net.layer3[1], | ||
145 | NL3DWrapper(net.layer3[2], n_segment), | ||
146 | net.layer3[3], | ||
147 | NL3DWrapper(net.layer3[4], n_segment), | ||
148 | net.layer3[5], | ||
149 | ) | ||
150 | else: | ||
151 | raise NotImplementedError | ||
152 | |||
153 | |||
154 | if __name__ == '__main__': | ||
155 | from torch.autograd import Variable | ||
156 | import torch | ||
157 | |||
158 | sub_sample = True | ||
159 | bn_layer = True | ||
160 | |||
161 | img = Variable(torch.zeros(2, 3, 20)) | ||
162 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
163 | out = net(img) | ||
164 | print(out.size()) | ||
165 | |||
166 | img = Variable(torch.zeros(2, 3, 20, 20)) | ||
167 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
168 | out = net(img) | ||
169 | print(out.size()) | ||
170 | |||
171 | img = Variable(torch.randn(2, 3, 10, 20, 20)) | ||
172 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) | ||
173 | out = net(img) | ||
174 | print(out.size()) | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
ops/temporal_shift.py
0 → 100755
1 | # Code for "TSM: Temporal Shift Module for Efficient Video Understanding" | ||
2 | # arXiv:1811.08383 | ||
3 | # Ji Lin*, Chuang Gan, Song Han | ||
4 | # {jilin, songhan}@mit.edu, ganchuang@csail.mit.edu | ||
5 | |||
6 | import torch | ||
7 | import torch.nn as nn | ||
8 | import torch.nn.functional as F | ||
9 | |||
10 | |||
11 | class TemporalShift(nn.Module): | ||
12 | def __init__(self, net, n_segment=3, n_div=8, inplace=False): | ||
13 | super(TemporalShift, self).__init__() | ||
14 | self.net = net | ||
15 | self.n_segment = n_segment | ||
16 | self.fold_div = n_div | ||
17 | self.inplace = inplace | ||
18 | if inplace: | ||
19 | print('=> Using in-place shift...') | ||
20 | print('=> Using fold div: {}'.format(self.fold_div)) | ||
21 | |||
22 | def forward(self, x): | ||
23 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace) | ||
24 | return self.net(x) | ||
25 | |||
26 | @staticmethod | ||
27 | def shift(x, n_segment, fold_div=3, inplace=False): | ||
28 | nt, c, h, w = x.size() | ||
29 | n_batch = nt // n_segment | ||
30 | x = x.view(n_batch, n_segment, c, h, w) | ||
31 | |||
32 | fold = c // fold_div | ||
33 | if inplace: | ||
34 | # Due to some out of order error when performing parallel computing. | ||
35 | # May need to write a CUDA kernel. | ||
36 | raise NotImplementedError | ||
37 | # out = InplaceShift.apply(x, fold) | ||
38 | else: | ||
39 | out = torch.zeros_like(x) | ||
40 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left | ||
41 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right | ||
42 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift | ||
43 | |||
44 | return out.view(nt, c, h, w) | ||
45 | |||
46 | |||
47 | class InplaceShift(torch.autograd.Function): | ||
48 | # Special thanks to @raoyongming for the help to this function | ||
49 | @staticmethod | ||
50 | def forward(ctx, input, fold): | ||
51 | # not support higher order gradient | ||
52 | # input = input.detach_() | ||
53 | ctx.fold_ = fold | ||
54 | n, t, c, h, w = input.size() | ||
55 | buffer = input.data.new(n, t, fold, h, w).zero_() | ||
56 | buffer[:, :-1] = input.data[:, 1:, :fold] | ||
57 | input.data[:, :, :fold] = buffer | ||
58 | buffer.zero_() | ||
59 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold] | ||
60 | input.data[:, :, fold: 2 * fold] = buffer | ||
61 | return input | ||
62 | |||
63 | @staticmethod | ||
64 | def backward(ctx, grad_output): | ||
65 | # grad_output = grad_output.detach_() | ||
66 | fold = ctx.fold_ | ||
67 | n, t, c, h, w = grad_output.size() | ||
68 | buffer = grad_output.data.new(n, t, fold, h, w).zero_() | ||
69 | buffer[:, 1:] = grad_output.data[:, :-1, :fold] | ||
70 | grad_output.data[:, :, :fold] = buffer | ||
71 | buffer.zero_() | ||
72 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold] | ||
73 | grad_output.data[:, :, fold: 2 * fold] = buffer | ||
74 | return grad_output, None | ||
75 | |||
76 | |||
77 | class TemporalPool(nn.Module): | ||
78 | def __init__(self, net, n_segment): | ||
79 | super(TemporalPool, self).__init__() | ||
80 | self.net = net | ||
81 | self.n_segment = n_segment | ||
82 | |||
83 | def forward(self, x): | ||
84 | x = self.temporal_pool(x, n_segment=self.n_segment) | ||
85 | return self.net(x) | ||
86 | |||
87 | @staticmethod | ||
88 | def temporal_pool(x, n_segment): | ||
89 | nt, c, h, w = x.size() | ||
90 | n_batch = nt // n_segment | ||
91 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w | ||
92 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0)) | ||
93 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w) | ||
94 | return x | ||
95 | |||
96 | |||
97 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False): | ||
98 | if temporal_pool: | ||
99 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2] | ||
100 | else: | ||
101 | n_segment_list = [n_segment] * 4 | ||
102 | assert n_segment_list[-1] > 0 | ||
103 | print('=> n_segment per stage: {}'.format(n_segment_list)) | ||
104 | |||
105 | import torchvision | ||
106 | if isinstance(net, torchvision.models.ResNet): | ||
107 | if place == 'block': | ||
108 | def make_block_temporal(stage, this_segment): | ||
109 | blocks = list(stage.children()) | ||
110 | print('=> Processing stage with {} blocks'.format(len(blocks))) | ||
111 | for i, b in enumerate(blocks): | ||
112 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div) | ||
113 | return nn.Sequential(*(blocks)) | ||
114 | |||
115 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
116 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
117 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
118 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
119 | |||
120 | elif 'blockres' in place: | ||
121 | n_round = 1 | ||
122 | if len(list(net.layer3.children())) >= 23: | ||
123 | n_round = 2 | ||
124 | print('=> Using n_round {} to insert temporal shift'.format(n_round)) | ||
125 | |||
126 | def make_block_temporal(stage, this_segment): | ||
127 | blocks = list(stage.children()) | ||
128 | print('=> Processing stage with {} blocks residual'.format(len(blocks))) | ||
129 | for i, b in enumerate(blocks): | ||
130 | if i % n_round == 0: | ||
131 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div) | ||
132 | return nn.Sequential(*blocks) | ||
133 | |||
134 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0]) | ||
135 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1]) | ||
136 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2]) | ||
137 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3]) | ||
138 | else: | ||
139 | raise NotImplementedError(place) | ||
140 | |||
141 | |||
142 | def make_temporal_pool(net, n_segment): | ||
143 | import torchvision | ||
144 | if isinstance(net, torchvision.models.ResNet): | ||
145 | print('=> Injecting nonlocal pooling') | ||
146 | net.layer2 = TemporalPool(net.layer2, n_segment) | ||
147 | else: | ||
148 | raise NotImplementedError | ||
149 | |||
150 | |||
151 | if __name__ == '__main__': | ||
152 | # test inplace shift v.s. vanilla shift | ||
153 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False) | ||
154 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True) | ||
155 | |||
156 | print('=> Testing CPU...') | ||
157 | # test forward | ||
158 | with torch.no_grad(): | ||
159 | for i in range(10): | ||
160 | x = torch.rand(2 * 8, 3, 224, 224) | ||
161 | y1 = tsm1(x) | ||
162 | y2 = tsm2(x) | ||
163 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
164 | |||
165 | # test backward | ||
166 | with torch.enable_grad(): | ||
167 | for i in range(10): | ||
168 | x1 = torch.rand(2 * 8, 3, 224, 224) | ||
169 | x1.requires_grad_() | ||
170 | x2 = x1.clone() | ||
171 | y1 = tsm1(x1) | ||
172 | y2 = tsm2(x2) | ||
173 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
174 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
175 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
176 | |||
177 | print('=> Testing GPU...') | ||
178 | tsm1.cuda() | ||
179 | tsm2.cuda() | ||
180 | # test forward | ||
181 | with torch.no_grad(): | ||
182 | for i in range(10): | ||
183 | x = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
184 | y1 = tsm1(x) | ||
185 | y2 = tsm2(x) | ||
186 | assert torch.norm(y1 - y2).item() < 1e-5 | ||
187 | |||
188 | # test backward | ||
189 | with torch.enable_grad(): | ||
190 | for i in range(10): | ||
191 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda() | ||
192 | x1.requires_grad_() | ||
193 | x2 = x1.clone() | ||
194 | y1 = tsm1(x1) | ||
195 | y2 = tsm2(x2) | ||
196 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0] | ||
197 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0] | ||
198 | assert torch.norm(grad1 - grad2).item() < 1e-5 | ||
199 | print('Test passed.') | ||
200 | |||
201 | |||
202 | |||
203 |
ops/transforms.py
0 → 100755
This diff is collapsed.
Click to expand it.
ops/utils.py
0 → 100755
1 | import numpy as np | ||
2 | |||
3 | |||
4 | def softmax(scores): | ||
5 | es = np.exp(scores - scores.max(axis=-1)[..., None]) | ||
6 | return es / es.sum(axis=-1)[..., None] | ||
7 | |||
8 | |||
9 | class AverageMeter(object): | ||
10 | """Computes and stores the average and current value""" | ||
11 | |||
12 | def __init__(self): | ||
13 | self.reset() | ||
14 | |||
15 | def reset(self): | ||
16 | self.val = 0 | ||
17 | self.avg = 0 | ||
18 | self.sum = 0 | ||
19 | self.count = 0 | ||
20 | |||
21 | def update(self, val, n=1): | ||
22 | self.val = val | ||
23 | self.sum += val * n | ||
24 | self.count += n | ||
25 | self.avg = self.sum / self.count | ||
26 | |||
27 | |||
28 | def accuracy(output, target, topk=(1,)): | ||
29 | """Computes the precision@k for the specified values of k""" | ||
30 | maxk = max(topk) | ||
31 | batch_size = target.size(0) | ||
32 | |||
33 | _, pred = output.topk(maxk, 1, True, True) | ||
34 | pred = pred.t() | ||
35 | correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
36 | |||
37 | res = [] | ||
38 | for k in topk: | ||
39 | correct_k = correct[:k].view(-1).float().sum(0) | ||
40 | res.append(correct_k.mul_(100.0 / batch_size)) | ||
41 | return res | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
person_filter.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import numpy as np | ||
4 | import pickle | ||
5 | |||
6 | def start_filter(config): | ||
7 | cls_class_path = config['MODEL']['CLS_PERSON'] | ||
8 | feature_save_dir = config['VIDEO']['FACE_FEATURE_DIR'] | ||
9 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
10 | result_file_name = config['PERSON']['RESULT_FILE'] | ||
11 | feature_name = config['PERSON']['DATA_NAME'] | ||
12 | |||
13 | xgboost_model = pickle.load(open(cls_class_path, "rb")) | ||
14 | |||
15 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
16 | result_file = open(result_file_path, 'w') | ||
17 | |||
18 | feature_path = os.path.join(feature_save_dir, feature_name) | ||
19 | val_annotation_pairs = np.load(feature_path, allow_pickle=True, encoding='latin1') | ||
20 | |||
21 | X_val = [] | ||
22 | Y_val = [] | ||
23 | Y_names = [] | ||
24 | for j in range(len(val_annotation_pairs)): | ||
25 | pair = val_annotation_pairs[j] | ||
26 | X_val.append(np.squeeze(pair[0])) | ||
27 | Y_val.append(pair[1]) | ||
28 | Y_names.append(pair[2]) | ||
29 | |||
30 | X_val = np.array(X_val) | ||
31 | y_pred = xgboost_model.predict_proba(X_val) | ||
32 | |||
33 | for i, Y_name in enumerate(Y_names): | ||
34 | result_file.write(Y_name + ' ') | ||
35 | result_file.write(str(y_pred[i][0]) + ',' + str(y_pred[i][1]) + ',' + str(y_pred[i][2]) + '\n') | ||
36 | |||
37 | result_file.close() | ||
38 | |||
39 | |||
40 | |||
41 | |||
42 |
pose_filter.py
0 → 100644
1 | import os | ||
2 | import torch.optim | ||
3 | import numpy as np | ||
4 | import torch.optim | ||
5 | import torch.nn.parallel | ||
6 | from ops.models import TSN | ||
7 | from ops.transforms import * | ||
8 | from ops.dataset import TSNDataSet | ||
9 | from torch.nn import functional as F | ||
10 | |||
11 | |||
12 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
13 | |||
14 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
15 | video_names = os.listdir(frame_save_dir) | ||
16 | ucf101_rgb_val_file = open(val_path, 'w') | ||
17 | |||
18 | for video_name in video_names: | ||
19 | images_dir = os.path.join(frame_save_dir, video_name) | ||
20 | ucf101_rgb_val_file.write(video_name) | ||
21 | ucf101_rgb_val_file.write(' ') | ||
22 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
23 | ucf101_rgb_val_file.write('\n') | ||
24 | |||
25 | ucf101_rgb_val_file.close() | ||
26 | |||
27 | return val_path | ||
28 | |||
29 | |||
30 | def start_filter(config): | ||
31 | arch = config['FIGHTING']['ARCH'] | ||
32 | prefix = config['VIDEO']['PREFIX'] | ||
33 | modality = config['POSE']['MODALITY'] | ||
34 | test_crop = config['POSE']['TEST_CROP'] | ||
35 | batch_size = config['POSE']['BATCH_SIZE'] | ||
36 | weights_path = config['MODEL']['CLS_POSE'] | ||
37 | test_segment = config['POSE']['TEST_SEGMENT'] | ||
38 | frame_save_dir = config['VIDEO']['POSE_FRAME_SAVE_DIR'] | ||
39 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
40 | result_file_name = config['POSE']['RESULT_FILE'] | ||
41 | |||
42 | workers = 8 | ||
43 | num_class = 3 | ||
44 | shift_div = 8 | ||
45 | img_feature_dim = 256 | ||
46 | |||
47 | softmax = False | ||
48 | is_shift = True | ||
49 | full_res = False | ||
50 | non_local = False | ||
51 | dense_sample = False | ||
52 | twice_sample = False | ||
53 | |||
54 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
55 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
56 | |||
57 | pretrain = 'imagenet' | ||
58 | shift_place = 'blockres' | ||
59 | crop_fusion_type = 'avg' | ||
60 | |||
61 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
62 | base_model=arch, | ||
63 | consensus_type=crop_fusion_type, | ||
64 | img_feature_dim=img_feature_dim, | ||
65 | pretrain=pretrain, | ||
66 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
67 | non_local=non_local, | ||
68 | ) | ||
69 | |||
70 | checkpoint = torch.load(weights_path) | ||
71 | checkpoint = checkpoint['state_dict'] | ||
72 | |||
73 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
74 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
75 | 'base_model.classifier.bias': 'new_fc.bias', | ||
76 | } | ||
77 | for k, v in replace_dict.items(): | ||
78 | if k in base_dict: | ||
79 | base_dict[v] = base_dict.pop(k) | ||
80 | |||
81 | net.load_state_dict(base_dict) | ||
82 | |||
83 | input_size = net.scale_size if full_res else net.input_size | ||
84 | |||
85 | if test_crop == 1: | ||
86 | cropping = torchvision.transforms.Compose([ | ||
87 | GroupScale(net.scale_size), | ||
88 | GroupCenterCrop(input_size), | ||
89 | ]) | ||
90 | elif test_crop == 3: # do not flip, so only 5 crops | ||
91 | cropping = torchvision.transforms.Compose([ | ||
92 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
93 | ]) | ||
94 | elif test_crop == 5: # do not flip, so only 5 crops | ||
95 | cropping = torchvision.transforms.Compose([ | ||
96 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
97 | ]) | ||
98 | elif test_crop == 10: | ||
99 | cropping = torchvision.transforms.Compose([ | ||
100 | GroupOverSample(input_size, net.scale_size) | ||
101 | ]) | ||
102 | else: | ||
103 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
104 | |||
105 | data_loader = torch.utils.data.DataLoader( | ||
106 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
107 | new_length=1 if modality == "RGB" else 5, | ||
108 | modality=modality, | ||
109 | image_tmpl=prefix, | ||
110 | test_mode=True, | ||
111 | remove_missing=False, | ||
112 | transform=torchvision.transforms.Compose([ | ||
113 | cropping, | ||
114 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
115 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
116 | GroupNormalize(net.input_mean, net.input_std), | ||
117 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
118 | batch_size=batch_size, shuffle=False, | ||
119 | num_workers=workers, pin_memory=True, | ||
120 | ) | ||
121 | |||
122 | net = torch.nn.DataParallel(net.cuda()) | ||
123 | net.eval() | ||
124 | data_gen = enumerate(data_loader) | ||
125 | max_num = len(data_loader.dataset) | ||
126 | |||
127 | result_file = open(result_file_path, 'w') | ||
128 | |||
129 | for i, data_pair in data_gen: | ||
130 | directory, data = data_pair | ||
131 | with torch.no_grad(): | ||
132 | if i >= max_num: | ||
133 | break | ||
134 | num_crop = test_crop | ||
135 | if dense_sample: | ||
136 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
137 | |||
138 | if twice_sample: | ||
139 | num_crop *= 2 | ||
140 | |||
141 | if modality == 'RGB': | ||
142 | length = 3 | ||
143 | elif modality == 'Flow': | ||
144 | length = 10 | ||
145 | elif modality == 'RGBDiff': | ||
146 | length = 18 | ||
147 | else: | ||
148 | raise ValueError("Unknown modality " + modality) | ||
149 | |||
150 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
151 | if is_shift: | ||
152 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
153 | rst, feature = net(data_in) | ||
154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
155 | |||
156 | if softmax: | ||
157 | # take the softmax to normalize the output to probability | ||
158 | rst = F.softmax(rst, dim=1) | ||
159 | |||
160 | rst = rst.data.cpu().numpy().copy() | ||
161 | |||
162 | if net.module.is_shift: | ||
163 | rst = rst.reshape(batch_size, num_class) | ||
164 | else: | ||
165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
166 | |||
167 | proba = np.squeeze(rst) | ||
168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
169 | result_file.write(str(directory[0]) + ' ') | ||
170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
171 | |||
172 | result_file.close() | ||
173 | print('video filter end') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
test.py
0 → 100644
1 | import os | ||
2 | import cv2 | ||
3 | import load_util | ||
4 | import media_util | ||
5 | import numpy as np | ||
6 | from sklearn.metrics import confusion_matrix | ||
7 | import fighting_filter, emotion_filter, argue_filter, audio_filter, class_filter | ||
8 | import video_filter, pose_filter, flow_filter | ||
9 | |||
10 | |||
11 | |||
12 | def accuracy_cal(config): | ||
13 | |||
14 | label_file_path = config['VIDEO']['LABEL_PATH'] | ||
15 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
16 | final_file_name = config['AUDIO']['RESULT_FILE'] | ||
17 | |||
18 | final_file_path = os.path.join(frame_list_dir, final_file_name) | ||
19 | final_file_lines = open(final_file_path).readlines() | ||
20 | label_file_lines = open(label_file_path).readlines() | ||
21 | |||
22 | |||
23 | final_pairs = {line.strip().split(' ')[0]: line.strip().split(' ')[1] for line in final_file_lines} | ||
24 | |||
25 | lines_num = len(label_file_lines) - 1 | ||
26 | hit = 0 | ||
27 | for i, label_line in enumerate(label_file_lines): | ||
28 | if i == 0: | ||
29 | continue | ||
30 | file, label = label_line.strip().split(' ') | ||
31 | final_pre = final_pairs[file] | ||
32 | final_pre_class = np.argmax(np.array(final_pre.split(','))) + 1 | ||
33 | print(final_pre_class, label) | ||
34 | if final_pre_class == int(label): | ||
35 | hit += 1 | ||
36 | |||
37 | return hit/lines_num | ||
38 | |||
39 | |||
40 | def main(): | ||
41 | config_path = r'config.yaml' | ||
42 | config = load_util.load_config(config_path) | ||
43 | |||
44 | media_util.extract_wav(config) | ||
45 | media_util.extract_frame(config) | ||
46 | media_util.extract_frame_pose(config) | ||
47 | media_util.extract_is10(config) | ||
48 | media_util.extract_random_face_feature(config) | ||
49 | media_util.extract_mirror(config) | ||
50 | |||
51 | fighting_2_filter.start_filter(config) | ||
52 | emotion_filter.start_filter(config) | ||
53 | |||
54 | audio_filter.start_filter(config) | ||
55 | |||
56 | class_filter.start_filter(config) | ||
57 | video_filter.start_filter(config) | ||
58 | pose_filter.start_filter(config) | ||
59 | |||
60 | flow_filter.start_filter(config) | ||
61 | |||
62 | acc = accuracy_cal(config) | ||
63 | print(acc) | ||
64 | |||
65 | |||
66 | if __name__ == '__main__': | ||
67 | main() | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
video_filter.py
0 → 100644
1 | import os | ||
2 | import torch.optim | ||
3 | import numpy as np | ||
4 | import torch.nn.parallel | ||
5 | from ops.models import TSN | ||
6 | from ops.transforms import * | ||
7 | from ops.dataset import TSNDataSet | ||
8 | from torch.nn import functional as F | ||
9 | |||
10 | |||
11 | def gen_file_list(frame_save_dir, frame_list_dir): | ||
12 | |||
13 | val_path = os.path.join(frame_list_dir, 'val.txt') | ||
14 | video_names = os.listdir(frame_save_dir) | ||
15 | ucf101_rgb_val_file = open(val_path, 'w') | ||
16 | |||
17 | for video_name in video_names: | ||
18 | images_dir = os.path.join(frame_save_dir, video_name) | ||
19 | ucf101_rgb_val_file.write(video_name) | ||
20 | ucf101_rgb_val_file.write(' ') | ||
21 | ucf101_rgb_val_file.write(str(len(os.listdir(images_dir)))) | ||
22 | ucf101_rgb_val_file.write('\n') | ||
23 | |||
24 | ucf101_rgb_val_file.close() | ||
25 | |||
26 | return val_path | ||
27 | |||
28 | |||
29 | def start_filter(config): | ||
30 | arch = config['FIGHTING']['ARCH'] | ||
31 | prefix = config['VIDEO']['PREFIX'] | ||
32 | modality = config['VIDEO_FILTER']['MODALITY'] | ||
33 | test_crop = config['VIDEO_FILTER']['TEST_CROP'] | ||
34 | batch_size = config['VIDEO_FILTER']['BATCH_SIZE'] | ||
35 | weights_path = config['MODEL']['CLS_VIDEO'] | ||
36 | test_segment = config['VIDEO_FILTER']['TEST_SEGMENT'] | ||
37 | frame_save_dir = config['VIDEO']['FRAME_SAVE_DIR'] | ||
38 | frame_list_dir = config['VIDEO']['FRAME_LIST_DIR'] | ||
39 | result_file_name = config['VIDEO_FILTER']['RESULT_FILE'] | ||
40 | |||
41 | workers = 8 | ||
42 | num_class = 3 | ||
43 | shift_div = 8 | ||
44 | img_feature_dim = 256 | ||
45 | |||
46 | softmax = False | ||
47 | is_shift = True | ||
48 | full_res = False | ||
49 | non_local = False | ||
50 | dense_sample = False | ||
51 | twice_sample = False | ||
52 | |||
53 | val_list = gen_file_list(frame_save_dir, frame_list_dir) | ||
54 | result_file_path = os.path.join(frame_list_dir, result_file_name) | ||
55 | |||
56 | pretrain = 'imagenet' | ||
57 | shift_place = 'blockres' | ||
58 | crop_fusion_type = 'avg' | ||
59 | |||
60 | net = TSN(num_class, test_segment if is_shift else 1, modality, | ||
61 | base_model=arch, | ||
62 | consensus_type=crop_fusion_type, | ||
63 | img_feature_dim=img_feature_dim, | ||
64 | pretrain=pretrain, | ||
65 | is_shift=is_shift, shift_div=shift_div, shift_place=shift_place, | ||
66 | non_local=non_local, | ||
67 | ) | ||
68 | |||
69 | checkpoint = torch.load(weights_path) | ||
70 | checkpoint = checkpoint['state_dict'] | ||
71 | |||
72 | base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint.items())} | ||
73 | replace_dict = {'base_model.classifier.weight': 'new_fc.weight', | ||
74 | 'base_model.classifier.bias': 'new_fc.bias', | ||
75 | } | ||
76 | for k, v in replace_dict.items(): | ||
77 | if k in base_dict: | ||
78 | base_dict[v] = base_dict.pop(k) | ||
79 | |||
80 | net.load_state_dict(base_dict) | ||
81 | |||
82 | input_size = net.scale_size if full_res else net.input_size | ||
83 | |||
84 | if test_crop == 1: | ||
85 | cropping = torchvision.transforms.Compose([ | ||
86 | GroupScale(net.scale_size), | ||
87 | GroupCenterCrop(input_size), | ||
88 | ]) | ||
89 | elif test_crop == 3: # do not flip, so only 5 crops | ||
90 | cropping = torchvision.transforms.Compose([ | ||
91 | GroupFullResSample(input_size, net.scale_size, flip=False) | ||
92 | ]) | ||
93 | elif test_crop == 5: # do not flip, so only 5 crops | ||
94 | cropping = torchvision.transforms.Compose([ | ||
95 | GroupOverSample(input_size, net.scale_size, flip=False) | ||
96 | ]) | ||
97 | elif test_crop == 10: | ||
98 | cropping = torchvision.transforms.Compose([ | ||
99 | GroupOverSample(input_size, net.scale_size) | ||
100 | ]) | ||
101 | else: | ||
102 | raise ValueError("Only 1, 5, 10 crops are supported while we got {}".format(test_crop)) | ||
103 | |||
104 | data_loader = torch.utils.data.DataLoader( | ||
105 | TSNDataSet(frame_save_dir, val_list, num_segments=test_segment, | ||
106 | new_length=1 if modality == "RGB" else 5, | ||
107 | modality=modality, | ||
108 | image_tmpl=prefix, | ||
109 | test_mode=True, | ||
110 | remove_missing=False, | ||
111 | transform=torchvision.transforms.Compose([ | ||
112 | cropping, | ||
113 | Stack(roll=(arch in ['BNInception', 'InceptionV3'])), | ||
114 | ToTorchFormatTensor(div=(arch not in ['BNInception', 'InceptionV3'])), | ||
115 | GroupNormalize(net.input_mean, net.input_std), | ||
116 | ]), dense_sample=dense_sample, twice_sample=twice_sample), | ||
117 | batch_size=batch_size, shuffle=False, | ||
118 | num_workers=workers, pin_memory=True, | ||
119 | ) | ||
120 | |||
121 | net = torch.nn.DataParallel(net.cuda()) | ||
122 | net.eval() | ||
123 | data_gen = enumerate(data_loader) | ||
124 | max_num = len(data_loader.dataset) | ||
125 | |||
126 | result_file = open(result_file_path, 'w') | ||
127 | |||
128 | for i, data_pair in data_gen: | ||
129 | directory, data = data_pair | ||
130 | with torch.no_grad(): | ||
131 | if i >= max_num: | ||
132 | break | ||
133 | num_crop = test_crop | ||
134 | if dense_sample: | ||
135 | num_crop *= 10 # 10 clips for testing when using dense sample | ||
136 | |||
137 | if twice_sample: | ||
138 | num_crop *= 2 | ||
139 | |||
140 | if modality == 'RGB': | ||
141 | length = 3 | ||
142 | elif modality == 'Flow': | ||
143 | length = 10 | ||
144 | elif modality == 'RGBDiff': | ||
145 | length = 18 | ||
146 | else: | ||
147 | raise ValueError("Unknown modality " + modality) | ||
148 | |||
149 | data_in = data.view(-1, length, data.size(2), data.size(3)) | ||
150 | if is_shift: | ||
151 | data_in = data_in.view(batch_size * num_crop, test_segment, length, data_in.size(2), data_in.size(3)) | ||
152 | |||
153 | rst, feature = net(data_in) | ||
154 | rst = rst.reshape(batch_size, num_crop, -1).mean(1) | ||
155 | |||
156 | if softmax: | ||
157 | # take the softmax to normalize the output to probability | ||
158 | rst = F.softmax(rst, dim=1) | ||
159 | |||
160 | rst = rst.data.cpu().numpy().copy() | ||
161 | |||
162 | if net.module.is_shift: | ||
163 | rst = rst.reshape(batch_size, num_class) | ||
164 | else: | ||
165 | rst = rst.reshape((batch_size, -1, num_class)).mean(axis=1).reshape((batch_size, num_class)) | ||
166 | |||
167 | proba = np.squeeze(rst) | ||
168 | proba = np.exp(proba)/sum(np.exp(proba)) | ||
169 | result_file.write(str(directory[0]) + ' ') | ||
170 | result_file.write(str(proba[0]) + ',' + str(proba[1]) + ',' + str(proba[2]) + '\n') | ||
171 | |||
172 | result_file.close() | ||
173 | print('video filter end') | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or sign in to post a comment